simstring_rust/python/
mod.rs

1use crate::{
2    database::{Database, HashDb},
3    extractors::{CharacterNgrams, FeatureExtractor, WordNgrams},
4    measures::{Cosine, Dice, ExactMatch, Jaccard, Measure, Overlap},
5    search::{SearchError as RustSearchError, Searcher as RustSearcher},
6};
7use pyo3::create_exception;
8use pyo3::prelude::*;
9use std::sync::Arc;
10
11create_exception!(simstring_rust, SearchError, pyo3::exceptions::PyValueError);
12
13// Wrapper for FeatureExtractor trait as I can't find any direct translation.
14#[derive(Clone)]
15enum PyFeatureExtractor {
16    Character(CharacterNgrams),
17    Word(WordNgrams),
18}
19
20impl FeatureExtractor for PyFeatureExtractor {
21    fn features(&self, text: &str, interner: &mut lasso::Rodeo) -> Vec<lasso::Spur> {
22        match self {
23            PyFeatureExtractor::Character(e) => e.features(text, interner),
24            PyFeatureExtractor::Word(e) => e.features(text, interner),
25        }
26    }
27}
28
29#[pyclass(name = "CharacterNgrams")]
30#[derive(Clone)]
31struct PyCharacterNgrams(CharacterNgrams);
32
33#[pymethods]
34impl PyCharacterNgrams {
35    #[new]
36    fn new(n: usize, endmarker: &str) -> Self {
37        Self(CharacterNgrams::new(n, endmarker))
38    }
39}
40
41#[pyclass(name = "WordNgrams")]
42#[derive(Clone)]
43struct PyWordNgrams(WordNgrams);
44
45#[pymethods]
46impl PyWordNgrams {
47    #[new]
48    fn new(n: usize, splitter: &str, padder: &str) -> Self {
49        Self(WordNgrams::new(n, splitter, padder))
50    }
51}
52
53// Wrapper for Measure trait
54#[derive(Clone, Copy)]
55enum PyMeasure {
56    Cosine,
57    Dice,
58    ExactMatch,
59    Jaccard,
60    Overlap,
61}
62
63impl Measure for PyMeasure {
64    fn min_feature_size(&self, query_size: usize, alpha: f64) -> usize {
65        match self {
66            PyMeasure::Cosine => Cosine.min_feature_size(query_size, alpha),
67            PyMeasure::Dice => Dice.min_feature_size(query_size, alpha),
68            PyMeasure::ExactMatch => ExactMatch.min_feature_size(query_size, alpha),
69            PyMeasure::Jaccard => Jaccard.min_feature_size(query_size, alpha),
70            PyMeasure::Overlap => Overlap.min_feature_size(query_size, alpha),
71        }
72    }
73
74    fn max_feature_size(&self, query_size: usize, alpha: f64, db: &dyn Database) -> usize {
75        match self {
76            PyMeasure::Cosine => Cosine.max_feature_size(query_size, alpha, db),
77            PyMeasure::Dice => Dice.max_feature_size(query_size, alpha, db),
78            PyMeasure::ExactMatch => ExactMatch.max_feature_size(query_size, alpha, db),
79            PyMeasure::Jaccard => Jaccard.max_feature_size(query_size, alpha, db),
80            PyMeasure::Overlap => Overlap.max_feature_size(query_size, alpha, db),
81        }
82    }
83
84    fn minimum_common_feature_count(&self, query_size: usize, y_size: usize, alpha: f64) -> usize {
85        match self {
86            PyMeasure::Cosine => Cosine.minimum_common_feature_count(query_size, y_size, alpha),
87            PyMeasure::Dice => Dice.minimum_common_feature_count(query_size, y_size, alpha),
88            PyMeasure::ExactMatch => {
89                ExactMatch.minimum_common_feature_count(query_size, y_size, alpha)
90            }
91            PyMeasure::Jaccard => Jaccard.minimum_common_feature_count(query_size, y_size, alpha),
92            PyMeasure::Overlap => Overlap.minimum_common_feature_count(query_size, y_size, alpha),
93        }
94    }
95
96    fn similarity(&self, x: &[lasso::Spur], y: &[lasso::Spur]) -> f64 {
97        match self {
98            PyMeasure::Cosine => Cosine.similarity(x, y),
99            PyMeasure::Dice => Dice.similarity(x, y),
100            PyMeasure::ExactMatch => ExactMatch.similarity(x, y),
101            PyMeasure::Jaccard => Jaccard.similarity(x, y),
102            PyMeasure::Overlap => Overlap.similarity(x, y),
103        }
104    }
105}
106
107#[pyclass(name = "Cosine")]
108#[derive(Clone, Copy)]
109struct PyCosine;
110#[pymethods]
111impl PyCosine {
112    #[new]
113    fn new() -> Self {
114        PyCosine
115    }
116}
117
118#[pyclass(name = "Dice")]
119#[derive(Clone, Copy)]
120struct PyDice;
121#[pymethods]
122impl PyDice {
123    #[new]
124    fn new() -> Self {
125        PyDice
126    }
127}
128
129#[pyclass(name = "ExactMatch")]
130#[derive(Clone, Copy)]
131struct PyExactMatch;
132#[pymethods]
133impl PyExactMatch {
134    #[new]
135    fn new() -> Self {
136        PyExactMatch
137    }
138}
139
140#[pyclass(name = "Jaccard")]
141#[derive(Clone, Copy)]
142struct PyJaccard;
143#[pymethods]
144impl PyJaccard {
145    #[new]
146    fn new() -> Self {
147        PyJaccard
148    }
149}
150
151#[pyclass(name = "Overlap")]
152#[derive(Clone, Copy)]
153struct PyOverlap;
154#[pymethods]
155impl PyOverlap {
156    #[new]
157    fn new() -> Self {
158        PyOverlap
159    }
160}
161
162#[pyclass(name = "HashDb")]
163struct PyHashDb {
164    db: HashDb,
165}
166
167#[pymethods]
168impl PyHashDb {
169    #[new]
170    fn new(extractor: &Bound<'_, PyAny>) -> PyResult<Self> {
171        let py_feature_extractor =
172            if let Ok(char_ngram) = extractor.extract::<PyRef<PyCharacterNgrams>>() {
173                PyFeatureExtractor::Character(char_ngram.0.clone())
174            } else if let Ok(word_ngram) = extractor.extract::<PyRef<PyWordNgrams>>() {
175                PyFeatureExtractor::Word(word_ngram.0.clone())
176            } else {
177                return Err(pyo3::exceptions::PyTypeError::new_err(
178                    "Extractor must be CharacterNgrams or WordNgrams",
179                ));
180            };
181
182        let db = HashDb::new(Arc::new(py_feature_extractor));
183        Ok(Self { db })
184    }
185
186    fn insert(&mut self, text: String) {
187        self.db.insert(text);
188    }
189
190    fn clear(&mut self) {
191        self.db.clear();
192    }
193
194    fn __len__(&self) -> usize {
195        self.db.strings.len()
196    }
197}
198
199#[pyclass(name = "Searcher")]
200struct PySearcher {
201    db: Py<PyHashDb>,
202    measure: PyMeasure,
203}
204
205#[pymethods]
206impl PySearcher {
207    #[new]
208    fn new(db: Py<PyHashDb>, measure: &Bound<'_, PyAny>) -> PyResult<Self> {
209        let py_measure = if measure.is_instance_of::<PyCosine>() {
210            PyMeasure::Cosine
211        } else if measure.is_instance_of::<PyDice>() {
212            PyMeasure::Dice
213        } else if measure.is_instance_of::<PyExactMatch>() {
214            PyMeasure::ExactMatch
215        } else if measure.is_instance_of::<PyJaccard>() {
216            PyMeasure::Jaccard
217        } else if measure.is_instance_of::<PyOverlap>() {
218            PyMeasure::Overlap
219        } else {
220            return Err(pyo3::exceptions::PyTypeError::new_err(
221                "Measure must be one of Cosine, Dice, Jaccard, Overlap, ExactMatch",
222            ));
223        };
224        Ok(Self {
225            db,
226            measure: py_measure,
227        })
228    }
229
230    fn search<'py>(
231        &self,
232        py: Python<'py>,
233        query_string: &str,
234        alpha: f64,
235    ) -> PyResult<Vec<String>> {
236        let db_borrow = self.db.borrow(py);
237        let searcher = RustSearcher::new(&db_borrow.db, self.measure);
238        searcher.search(query_string, alpha).map_err(|e| match e {
239            RustSearchError::InvalidThreshold(val) => {
240                SearchError::new_err(format!("Invalid threshold: {val}"))
241            }
242        })
243    }
244
245    fn ranked_search<'py>(
246        &self,
247        py: Python<'py>,
248        query_string: &str,
249        alpha: f64,
250    ) -> PyResult<Vec<(String, f64)>> {
251        let db_borrow = self.db.borrow(py);
252        let searcher = RustSearcher::new(&db_borrow.db, self.measure);
253        searcher
254            .ranked_search(query_string, alpha)
255            .map_err(|e| match e {
256                RustSearchError::InvalidThreshold(val) => {
257                    SearchError::new_err(format!("Invalid threshold: {val}"))
258                }
259            })
260    }
261}
262
263#[pymodule]
264fn simstring_rust(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
265    // Database submodule
266    let database_module = PyModule::new(py, "database")?;
267    database_module.add_class::<PyHashDb>()?;
268    m.add_submodule(&database_module)?;
269
270    // Extractors submodule
271    let extractors_module = PyModule::new(py, "extractors")?;
272    extractors_module.add_class::<PyCharacterNgrams>()?;
273    extractors_module.add_class::<PyWordNgrams>()?;
274    m.add_submodule(&extractors_module)?;
275
276    // Measures submodule
277    let measures_module = PyModule::new(py, "measures")?;
278    measures_module.add_class::<PyCosine>()?;
279    measures_module.add_class::<PyDice>()?;
280    measures_module.add_class::<PyJaccard>()?;
281    measures_module.add_class::<PyOverlap>()?;
282    measures_module.add_class::<PyExactMatch>()?;
283    m.add_submodule(&measures_module)?;
284
285    // Searcher submodule
286    let searcher_module = PyModule::new(py, "searcher")?;
287    searcher_module.add_class::<PySearcher>()?;
288    m.add_submodule(&searcher_module)?;
289
290    // errors submodule
291    let errors_module = PyModule::new(py, "errors")?;
292    errors_module.add("SearchError", py.get_type::<SearchError>())?;
293    m.add_submodule(&errors_module)?;
294
295    // Add modules to sys.modules to allow direct import
296    let sys = PyModule::import(py, "sys")?;
297    let modules = sys
298        .getattr("modules")?
299        .downcast_into::<pyo3::types::PyDict>()?;
300    modules.set_item("simstring_rust.database", database_module)?;
301    modules.set_item("simstring_rust.extractors", extractors_module)?;
302    modules.set_item("simstring_rust.measures", measures_module)?;
303    modules.set_item("simstring_rust.searcher", searcher_module)?;
304    modules.set_item("simstring_rust.errors", errors_module)?;
305
306    Ok(())
307}