In e.g. Natural Language Processing in Machine Learning, a beam-search is often used to predict the next objects to add on to a sequence and rank them. A key part of the beam-search is the top-k score metric, which is effective: Given a list of choices of length N of probability scores, return the top k scoring items of N. This is as simple as sorting a list and then taking the top values.
Referring to a visual example https://www.researchgate.net/figure/A-partially-completed-beam-search-procedure-with-a-beam-width-of-5-for-an-example-input_fig2_317377611 in a beam-search (in the above case, k=5, and a ‘top’ score is a minimal value), at each iteration, each node selects the top k items from the list of choices N, resulting in k2 total potential paths. From these paths, the top k overall are filtered, which form the nodes for the next iteration. In the previous example, you can see only the filtered nodes at each time-step. https://d2l.ai/_images/beam-search.svg expands the case of k=2, N=5 comprehensively.
Imagine, instead of optimizing one choice from N for each branch/node, you had to choose multiple values: When exploring from a node, you have a set of choices of dimension (N, q) from which you want to select q values, one from each column q. Then, to find the highest-scoring sets of choices, you need to consider combinations of the values in these columns. For example: For a matrix of choices N=5, q=4:
+---+--------+--------+--------+--------+ | N | q0 | q1 | q2 | q3 | +---+--------+--------+--------+--------+ | 0 | 0.9763 | 0.0791 | 0.1530 | 0.5565 | | 1 | 0.1560 | 0.1014 | 0.6932 | 0.7551 | | 2 | 0.8142 | 0.9494 | 0.4582 | 0.4411 | | 3 | 0.3807 | 0.2403 | 0.6897 | 0.7356 | | 4 | 0.0156 | 0.9419 | 0.9568 | 0.2266 | +---+--------+--------+--------+--------+
If k=5, this top-k function should return the following:
which are the largest possible sums, using one value from each column.
Solving this for arbitrary N and q, the naive approach would be to calculate all Nq sums, sort them, then take the top k results. The first step of optimization would be to sort each column, then only calculate the combinations of sums from the top k values in each column, reducing the complexity to kq.
However, given this function to find top scores must be called k times every time-step of the beam-search, every possible speedup is vital if one wishes to scale to high k or high q. The best solution I’ve come up with (condensed to a minimum example, assuming matrix is a NumPy array of shape (N, q), and taking q to be 4):
import numpy as np from itertools import combinations class Beamsearch(): def __init__(self, klen, q=4): self.klen = klen self.combis = [] for lens in range(klen): self.combis.extend(list(self.partition(lens, q))) self.width = q self.wdth = list(range(q)) def partition(self, N, size): n = N + size - 1 for splits in combinations(range(n), size - 1): yield [s1 - s0 - 1 for s0, s1 in zip((-1,) + splits, splits + (n,))] def getkmaxscores(self, matrix): matrix_argsort = np.argsort(-matrix, axis=0) sums = [] for comb in self.combis: midxs = matrix_argsort[[comb], [self.wdth]] midxslist = midxs.tolist() msum = (sum(matrix[midxs, [self.wdth]]), midxslist) sums.append(msum) sums.sort(reverse=True) return sums[:self.klen]
This method creates partitions of integers p into a given width q for integers 0 ≤ p ≤ k, e.g. for q=4:
0 ≤ p ≤ k
p0: [0, 0, 0, 0] p1: [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0] p2: [0, 0, 0, 2], [0, 0, 1, 1], [0, 0, 2, 0], [0, 1, 0, 1], [0, 1, 1, 0], [0, 2, 0, 0], [1, 0, 0, 1], [1, 0, 1, 0], [1, 1, 0, 0], [2, 0, 0, 0]
etc.
These are then used to index the are sorted input matrix, to select each combination for summation. The length of pi in the case q=4 follows the triangular pyramidal sequence (https://oeis.org/A000292): This reduces the search space to the sum of all p0...k which is the Binomial coefficient (k,4) = k(k-1)(k-2)(k-3)/24 (https://oeis.org/A000332). This is a vast improvement over the k4 solution for small k (for k < 30, this is less than k3), but still grows on the order of k4. Does there exist a solution to the arbitrary case with complexity <O(kq)?
(k,4) = k(k-1)(k-2)(k-3)/24