From 11e2a2c1aabbb8897f9d49f48f31071a5c7378bb Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 12:08:09 +0100 Subject: [PATCH] Fix geosort bug --- milli/src/search/new/geo_sort.rs | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/milli/src/search/new/geo_sort.rs b/milli/src/search/new/geo_sort.rs index b2e3a2f3d..5f5ceb379 100644 --- a/milli/src/search/new/geo_sort.rs +++ b/milli/src/search/new/geo_sort.rs @@ -107,12 +107,16 @@ impl GeoSort { /// Refill the internal buffer of cached docids based on the strategy. /// Drop the rtree if we don't need it anymore. - fn fill_buffer(&mut self, ctx: &mut SearchContext) -> Result<()> { + fn fill_buffer( + &mut self, + ctx: &mut SearchContext, + geo_candidates: &RoaringBitmap, + ) -> Result<()> { debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng"); debug_assert!(self.cached_sorted_docids.is_empty()); // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` - let rtree = if self.strategy.use_rtree(self.geo_candidates.len() as usize) { + let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) { if let Some(rtree) = self.rtree.as_ref() { // get rtree from cache Some(rtree) @@ -131,7 +135,7 @@ impl GeoSort { if self.ascending { let point = lat_lng_to_xyz(&self.point); for point in rtree.nearest_neighbor_iter(&point) { - if self.geo_candidates.contains(point.data.0) { + if geo_candidates.contains(point.data.0) { self.cached_sorted_docids.push_back(point.data); if self.cached_sorted_docids.len() >= cache_size { break; @@ -143,7 +147,7 @@ impl GeoSort { // and we insert the points in reverse order they get reversed when emptying the cache later on let point = lat_lng_to_xyz(&opposite_of(self.point)); for point in rtree.nearest_neighbor_iter(&point) { - if self.geo_candidates.contains(point.data.0) { + if geo_candidates.contains(point.data.0) { self.cached_sorted_docids.push_front(point.data); if self.cached_sorted_docids.len() >= cache_size { break; @@ -155,8 +159,7 @@ impl GeoSort { // the iterative version let [lat, lng] = self.field_ids.unwrap(); - let mut documents = self - .geo_candidates + let mut documents = geo_candidates .iter() .map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) }) .collect::>>()?; @@ -216,9 +219,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { assert!(self.query.is_none()); self.query = Some(query.clone()); - self.geo_candidates &= universe; - if self.geo_candidates.is_empty() { + let geo_candidates = &self.geo_candidates & universe; + + if geo_candidates.is_empty() { return Ok(()); } @@ -226,7 +230,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat"); let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng"); self.field_ids = Some([lat, lng]); - self.fill_buffer(ctx)?; + self.fill_buffer(ctx, &geo_candidates)?; Ok(()) } @@ -238,9 +242,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { universe: &RoaringBitmap, ) -> Result>> { let query = self.query.as_ref().unwrap().clone(); - self.geo_candidates &= universe; - if self.geo_candidates.is_empty() { + let geo_candidates = &self.geo_candidates & universe; + + if geo_candidates.is_empty() { return Ok(Some(RankingRuleOutput { query, candidates: universe.clone(), @@ -261,7 +266,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { } }; while let Some((id, point)) = next(&mut self.cached_sorted_docids) { - if self.geo_candidates.contains(id) { + if geo_candidates.contains(id) { return Ok(Some(RankingRuleOutput { query, candidates: RoaringBitmap::from_iter([id]), @@ -276,7 +281,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { // if we got out of this loop it means we've exhausted our cache. // we need to refill it and run the function again. - self.fill_buffer(ctx)?; + self.fill_buffer(ctx, &geo_candidates)?; self.next_bucket(ctx, logger, universe) }