diff --git a/implicit/nearest_neighbours.py b/implicit/nearest_neighbours.py index 29e50c2..c20600a 100644 --- a/implicit/nearest_neighbours.py +++ b/implicit/nearest_neighbours.py @@ -72,6 +72,12 @@ def recommend( if filter_items is not None and items is not None: raise ValueError("Can't specify both filter_items and items") + # Preserve the caller-requested result count: when filter_items is + # set we over-fetch by len(filter_items) so that the post-filter + # slice still has enough rows to satisfy N, but the final result + # must be truncated to the originally-requested N (not the + # inflated value). See https://github.com/benfred/implicit/issues/736 + requested_n = N if filter_items is not None: N += len(filter_items) elif items is not None: @@ -91,7 +97,7 @@ def recommend( if filter_items is not None: mask = np.isin(ids, filter_items, invert=True) - ids, scores = ids[mask][:N], scores[mask][:N] + ids, scores = ids[mask][:requested_n], scores[mask][:requested_n] elif items is not None: mask = np.isin(ids, items)