1use crate::{
2 database::{Database, HashDb},
3 extractors::{CharacterNgrams, FeatureExtractor, WordNgrams},
4 measures::{Cosine, Dice, ExactMatch, Jaccard, Measure, Overlap},
5 search::{SearchError as RustSearchError, Searcher as RustSearcher},
6};
7use pyo3::create_exception;
8use pyo3::prelude::*;
9use std::sync::Arc;
10
11create_exception!(simstring_rust, SearchError, pyo3::exceptions::PyValueError);
12
13#[derive(Clone)]
15enum PyFeatureExtractor {
16 Character(CharacterNgrams),
17 Word(WordNgrams),
18}
19
20impl FeatureExtractor for PyFeatureExtractor {
21 fn features(&self, text: &str, interner: &mut lasso::Rodeo) -> Vec<lasso::Spur> {
22 match self {
23 PyFeatureExtractor::Character(e) => e.features(text, interner),
24 PyFeatureExtractor::Word(e) => e.features(text, interner),
25 }
26 }
27}
28
29#[pyclass(name = "CharacterNgrams")]
30#[derive(Clone)]
31struct PyCharacterNgrams(CharacterNgrams);
32
33#[pymethods]
34impl PyCharacterNgrams {
35 #[new]
36 fn new(n: usize, endmarker: &str) -> Self {
37 Self(CharacterNgrams::new(n, endmarker))
38 }
39}
40
41#[pyclass(name = "WordNgrams")]
42#[derive(Clone)]
43struct PyWordNgrams(WordNgrams);
44
45#[pymethods]
46impl PyWordNgrams {
47 #[new]
48 fn new(n: usize, splitter: &str, padder: &str) -> Self {
49 Self(WordNgrams::new(n, splitter, padder))
50 }
51}
52
53#[derive(Clone, Copy)]
55enum PyMeasure {
56 Cosine,
57 Dice,
58 ExactMatch,
59 Jaccard,
60 Overlap,
61}
62
63impl Measure for PyMeasure {
64 fn min_feature_size(&self, query_size: usize, alpha: f64) -> usize {
65 match self {
66 PyMeasure::Cosine => Cosine.min_feature_size(query_size, alpha),
67 PyMeasure::Dice => Dice.min_feature_size(query_size, alpha),
68 PyMeasure::ExactMatch => ExactMatch.min_feature_size(query_size, alpha),
69 PyMeasure::Jaccard => Jaccard.min_feature_size(query_size, alpha),
70 PyMeasure::Overlap => Overlap.min_feature_size(query_size, alpha),
71 }
72 }
73
74 fn max_feature_size(&self, query_size: usize, alpha: f64, db: &dyn Database) -> usize {
75 match self {
76 PyMeasure::Cosine => Cosine.max_feature_size(query_size, alpha, db),
77 PyMeasure::Dice => Dice.max_feature_size(query_size, alpha, db),
78 PyMeasure::ExactMatch => ExactMatch.max_feature_size(query_size, alpha, db),
79 PyMeasure::Jaccard => Jaccard.max_feature_size(query_size, alpha, db),
80 PyMeasure::Overlap => Overlap.max_feature_size(query_size, alpha, db),
81 }
82 }
83
84 fn minimum_common_feature_count(&self, query_size: usize, y_size: usize, alpha: f64) -> usize {
85 match self {
86 PyMeasure::Cosine => Cosine.minimum_common_feature_count(query_size, y_size, alpha),
87 PyMeasure::Dice => Dice.minimum_common_feature_count(query_size, y_size, alpha),
88 PyMeasure::ExactMatch => {
89 ExactMatch.minimum_common_feature_count(query_size, y_size, alpha)
90 }
91 PyMeasure::Jaccard => Jaccard.minimum_common_feature_count(query_size, y_size, alpha),
92 PyMeasure::Overlap => Overlap.minimum_common_feature_count(query_size, y_size, alpha),
93 }
94 }
95
96 fn similarity(&self, x: &[lasso::Spur], y: &[lasso::Spur]) -> f64 {
97 match self {
98 PyMeasure::Cosine => Cosine.similarity(x, y),
99 PyMeasure::Dice => Dice.similarity(x, y),
100 PyMeasure::ExactMatch => ExactMatch.similarity(x, y),
101 PyMeasure::Jaccard => Jaccard.similarity(x, y),
102 PyMeasure::Overlap => Overlap.similarity(x, y),
103 }
104 }
105}
106
107#[pyclass(name = "Cosine")]
108#[derive(Clone, Copy)]
109struct PyCosine;
110#[pymethods]
111impl PyCosine {
112 #[new]
113 fn new() -> Self {
114 PyCosine
115 }
116}
117
118#[pyclass(name = "Dice")]
119#[derive(Clone, Copy)]
120struct PyDice;
121#[pymethods]
122impl PyDice {
123 #[new]
124 fn new() -> Self {
125 PyDice
126 }
127}
128
129#[pyclass(name = "ExactMatch")]
130#[derive(Clone, Copy)]
131struct PyExactMatch;
132#[pymethods]
133impl PyExactMatch {
134 #[new]
135 fn new() -> Self {
136 PyExactMatch
137 }
138}
139
140#[pyclass(name = "Jaccard")]
141#[derive(Clone, Copy)]
142struct PyJaccard;
143#[pymethods]
144impl PyJaccard {
145 #[new]
146 fn new() -> Self {
147 PyJaccard
148 }
149}
150
151#[pyclass(name = "Overlap")]
152#[derive(Clone, Copy)]
153struct PyOverlap;
154#[pymethods]
155impl PyOverlap {
156 #[new]
157 fn new() -> Self {
158 PyOverlap
159 }
160}
161
162#[pyclass(name = "HashDb")]
163struct PyHashDb {
164 db: HashDb,
165}
166
167#[pymethods]
168impl PyHashDb {
169 #[new]
170 fn new(extractor: &Bound<'_, PyAny>) -> PyResult<Self> {
171 let py_feature_extractor =
172 if let Ok(char_ngram) = extractor.extract::<PyRef<PyCharacterNgrams>>() {
173 PyFeatureExtractor::Character(char_ngram.0.clone())
174 } else if let Ok(word_ngram) = extractor.extract::<PyRef<PyWordNgrams>>() {
175 PyFeatureExtractor::Word(word_ngram.0.clone())
176 } else {
177 return Err(pyo3::exceptions::PyTypeError::new_err(
178 "Extractor must be CharacterNgrams or WordNgrams",
179 ));
180 };
181
182 let db = HashDb::new(Arc::new(py_feature_extractor));
183 Ok(Self { db })
184 }
185
186 fn insert(&mut self, text: String) {
187 self.db.insert(text);
188 }
189
190 fn clear(&mut self) {
191 self.db.clear();
192 }
193
194 fn __len__(&self) -> usize {
195 self.db.strings.len()
196 }
197}
198
199#[pyclass(name = "Searcher")]
200struct PySearcher {
201 db: Py<PyHashDb>,
202 measure: PyMeasure,
203}
204
205#[pymethods]
206impl PySearcher {
207 #[new]
208 fn new(db: Py<PyHashDb>, measure: &Bound<'_, PyAny>) -> PyResult<Self> {
209 let py_measure = if measure.is_instance_of::<PyCosine>() {
210 PyMeasure::Cosine
211 } else if measure.is_instance_of::<PyDice>() {
212 PyMeasure::Dice
213 } else if measure.is_instance_of::<PyExactMatch>() {
214 PyMeasure::ExactMatch
215 } else if measure.is_instance_of::<PyJaccard>() {
216 PyMeasure::Jaccard
217 } else if measure.is_instance_of::<PyOverlap>() {
218 PyMeasure::Overlap
219 } else {
220 return Err(pyo3::exceptions::PyTypeError::new_err(
221 "Measure must be one of Cosine, Dice, Jaccard, Overlap, ExactMatch",
222 ));
223 };
224 Ok(Self {
225 db,
226 measure: py_measure,
227 })
228 }
229
230 fn search<'py>(
231 &self,
232 py: Python<'py>,
233 query_string: &str,
234 alpha: f64,
235 ) -> PyResult<Vec<String>> {
236 let db_borrow = self.db.borrow(py);
237 let searcher = RustSearcher::new(&db_borrow.db, self.measure);
238 searcher.search(query_string, alpha).map_err(|e| match e {
239 RustSearchError::InvalidThreshold(val) => {
240 SearchError::new_err(format!("Invalid threshold: {val}"))
241 }
242 })
243 }
244
245 fn ranked_search<'py>(
246 &self,
247 py: Python<'py>,
248 query_string: &str,
249 alpha: f64,
250 ) -> PyResult<Vec<(String, f64)>> {
251 let db_borrow = self.db.borrow(py);
252 let searcher = RustSearcher::new(&db_borrow.db, self.measure);
253 searcher
254 .ranked_search(query_string, alpha)
255 .map_err(|e| match e {
256 RustSearchError::InvalidThreshold(val) => {
257 SearchError::new_err(format!("Invalid threshold: {val}"))
258 }
259 })
260 }
261}
262
263#[pymodule]
264fn simstring_rust(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
265 let database_module = PyModule::new(py, "database")?;
267 database_module.add_class::<PyHashDb>()?;
268 m.add_submodule(&database_module)?;
269
270 let extractors_module = PyModule::new(py, "extractors")?;
272 extractors_module.add_class::<PyCharacterNgrams>()?;
273 extractors_module.add_class::<PyWordNgrams>()?;
274 m.add_submodule(&extractors_module)?;
275
276 let measures_module = PyModule::new(py, "measures")?;
278 measures_module.add_class::<PyCosine>()?;
279 measures_module.add_class::<PyDice>()?;
280 measures_module.add_class::<PyJaccard>()?;
281 measures_module.add_class::<PyOverlap>()?;
282 measures_module.add_class::<PyExactMatch>()?;
283 m.add_submodule(&measures_module)?;
284
285 let searcher_module = PyModule::new(py, "searcher")?;
287 searcher_module.add_class::<PySearcher>()?;
288 m.add_submodule(&searcher_module)?;
289
290 let errors_module = PyModule::new(py, "errors")?;
292 errors_module.add("SearchError", py.get_type::<SearchError>())?;
293 m.add_submodule(&errors_module)?;
294
295 let sys = PyModule::import(py, "sys")?;
297 let modules = sys
298 .getattr("modules")?
299 .downcast_into::<pyo3::types::PyDict>()?;
300 modules.set_item("simstring_rust.database", database_module)?;
301 modules.set_item("simstring_rust.extractors", extractors_module)?;
302 modules.set_item("simstring_rust.measures", measures_module)?;
303 modules.set_item("simstring_rust.searcher", searcher_module)?;
304 modules.set_item("simstring_rust.errors", errors_module)?;
305
306 Ok(())
307}