From ba9527abc07bc9b1300317f74a233edf8c6abe51 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Sun, 31 May 2020 17:01:11 +0200 Subject: [PATCH] Support typos with a levenshtein automata --- Cargo.lock | 10 ++++++ Cargo.toml | 1 + src/bin/search.rs | 77 ++++++++++++++++++++++++++++++----------------- 3 files changed, 60 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e5f05b909..91f07e95f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,6 +361,15 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "levenshtein_automata" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f44db4199cdb049b494a92d105acbfa43c25b3925e33803923ba9580b7bc9e1a" +dependencies = [ + "fst", +] + [[package]] name = "libc" version = "0.2.70" @@ -403,6 +412,7 @@ dependencies = [ "fxhash", "heed", "jemallocator", + "levenshtein_automata", "memmap", "oxidized-mtbl", "rayon", diff --git a/Cargo.toml b/Cargo.toml index 45e71778f..f6eeb778d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ fst = "0.4.3" fxhash = "0.2.1" heed = { version = "0.8.0", default-features = false, features = ["lmdb"] } jemallocator = "0.3.2" +levenshtein_automata = { version = "0.2.0", features = ["fst_automaton"] } memmap = "0.7.0" oxidized-mtbl = { git = "https://github.com/Kerollmops/oxidized-mtbl.git", rev = "8918476" } rayon = "1.3.0" diff --git a/src/bin/search.rs b/src/bin/search.rs index aea12610c..9a25309d3 100644 --- a/src/bin/search.rs +++ b/src/bin/search.rs @@ -3,8 +3,10 @@ use std::path::PathBuf; use std::time::Instant; use cow_utils::CowUtils; +use fst::{Streamer, IntoStreamer}; use heed::types::*; use heed::{EnvOpenOptions, Database}; +use levenshtein_automata::LevenshteinAutomatonBuilder; use roaring::RoaringBitmap; use structopt::StructOpt; @@ -38,42 +40,61 @@ fn main() -> anyhow::Result<()> { let documents: Database, ByteSlice> = env.create_database(Some("documents"))?; let rtxn = env.read_txn()?; - - let before = Instant::now(); - let mut result: Option = None; - for word in alphanumeric_tokens(&opt.query) { - let word = word.cow_to_lowercase(); - match postings_ids.get(&rtxn, &word)? { - Some(ids) => { - let before = Instant::now(); - let right = RoaringBitmap::deserialize_from(ids)?; - eprintln!("deserialized bitmap for {:?} took {:.02?}", word, before.elapsed()); - result = match result.take() { - Some(mut left) => { - let before = Instant::now(); - let left_len = left.len(); - left.intersect_with(&right); - eprintln!("intersect between {:?} and {:?} took {:.02?}", - left_len, right.len(), before.elapsed()); - Some(left) - }, - None => Some(right), - }; - }, - None => result = Some(RoaringBitmap::default()), - } - } - let headers = match main.get::<_, Str, ByteSlice>(&rtxn, "headers")? { Some(headers) => headers, None => return Ok(()), }; + let fst = match main.get::<_, Str, ByteSlice>(&rtxn, "words-fst")? { + Some(bytes) => fst::Set::new(bytes)?, + None => return Ok(()), + }; + + // Building this factory is not free. + let lev_0_builder = LevenshteinAutomatonBuilder::new(0, true); + let lev_1_builder = LevenshteinAutomatonBuilder::new(1, true); + let lev_2_builder = LevenshteinAutomatonBuilder::new(2, true); + + let dfas = alphanumeric_tokens(&opt.query).map(|word| { + let word = word.cow_to_lowercase(); + match word.len() { + 0..=4 => lev_0_builder.build_dfa(&word), + 5..=8 => lev_1_builder.build_dfa(&word), + _ => lev_2_builder.build_dfa(&word), + } + }); + + let before = Instant::now(); + let mut intersect_result: Option = None; + for dfa in dfas { + let mut union_result = RoaringBitmap::default(); + let mut stream = fst.search(dfa).into_stream(); + while let Some(word) = stream.next() { + let word = std::str::from_utf8(word)?; + if let Some(ids) = postings_ids.get(&rtxn, word)? { + let right = RoaringBitmap::deserialize_from(ids)?; + union_result.union_with(&right); + } + } + + intersect_result = match intersect_result.take() { + Some(mut left) => { + let before = Instant::now(); + let left_len = left.len(); + left.intersect_with(&union_result); + eprintln!("intersect between {:?} and {:?} took {:.02?}", + left_len, union_result.len(), before.elapsed()); + Some(left) + }, + None => Some(union_result), + }; + } + let mut stdout = io::stdout(); stdout.write_all(&headers)?; - let total_length = result.as_ref().map_or(0, |x| x.len()); - for id in result.unwrap_or_default().iter().take(20) { + let total_length = intersect_result.as_ref().map_or(0, |x| x.len()); + for id in intersect_result.unwrap_or_default().iter().take(20) { if let Some(content) = documents.get(&rtxn, &BEU32::new(id))? { stdout.write_all(&content)?; }