Skip to main content

yscv_recognize/
recognizer.rs

1use std::fs;
2use std::path::Path;
3
4use yscv_tensor::Tensor;
5
6use super::RecognizeError;
7use super::similarity::cosine_similarity_prevalidated;
8use super::snapshot::{IdentitySnapshot, RecognizerSnapshot};
9use super::types::{IdentityEmbedding, Recognition};
10use super::validate::{validate_embedding, validate_embedding_slice, validate_threshold};
11use super::vp_tree::VpTree;
12
13#[derive(Debug, Clone)]
14pub struct Recognizer {
15    threshold: f32,
16    entries: Vec<IdentityEmbedding>,
17    embedding_dim: Option<usize>,
18    index: Option<VpTree>,
19}
20
21impl Recognizer {
22    pub fn new(threshold: f32) -> Result<Self, RecognizeError> {
23        validate_threshold(threshold)?;
24        Ok(Self {
25            threshold,
26            entries: Vec::new(),
27            embedding_dim: None,
28            index: None,
29        })
30    }
31
32    pub fn threshold(&self) -> f32 {
33        self.threshold
34    }
35
36    pub fn set_threshold(&mut self, threshold: f32) -> Result<(), RecognizeError> {
37        validate_threshold(threshold)?;
38        self.threshold = threshold;
39        Ok(())
40    }
41
42    pub fn enroll(
43        &mut self,
44        id: impl Into<String>,
45        embedding: Tensor,
46    ) -> Result<(), RecognizeError> {
47        validate_embedding(&embedding)?;
48        let id = id.into();
49        if self.entries.iter().any(|entry| entry.id == id) {
50            return Err(RecognizeError::DuplicateIdentity { id });
51        }
52        self.enforce_dim(embedding.len())?;
53        self.entries.push(IdentityEmbedding { id, embedding });
54        Ok(())
55    }
56
57    pub fn enroll_or_replace(
58        &mut self,
59        id: impl Into<String>,
60        embedding: Tensor,
61    ) -> Result<(), RecognizeError> {
62        validate_embedding(&embedding)?;
63        self.enforce_dim(embedding.len())?;
64        let id = id.into();
65        if let Some(existing) = self.entries.iter_mut().find(|entry| entry.id == id) {
66            existing.embedding = embedding;
67            return Ok(());
68        }
69        self.entries.push(IdentityEmbedding { id, embedding });
70        Ok(())
71    }
72
73    pub fn remove(&mut self, id: &str) -> bool {
74        if let Some(position) = self.entries.iter().position(|entry| entry.id == id) {
75            self.entries.remove(position);
76            if self.entries.is_empty() {
77                self.embedding_dim = None;
78            }
79            true
80        } else {
81            false
82        }
83    }
84
85    pub fn identities(&self) -> &[IdentityEmbedding] {
86        &self.entries
87    }
88
89    pub fn clear(&mut self) {
90        self.entries.clear();
91        self.embedding_dim = None;
92    }
93
94    pub fn recognize(&self, embedding: &Tensor) -> Result<Recognition, RecognizeError> {
95        validate_embedding(embedding)?;
96        self.recognize_prevalidated(embedding.data())
97    }
98
99    pub fn recognize_slice(&self, embedding: &[f32]) -> Result<Recognition, RecognizeError> {
100        validate_embedding_slice(embedding)?;
101        self.recognize_prevalidated(embedding)
102    }
103
104    fn recognize_prevalidated(&self, embedding: &[f32]) -> Result<Recognition, RecognizeError> {
105        if let Some(expected_dim) = self.embedding_dim {
106            if expected_dim != embedding.len() {
107                return Err(RecognizeError::EmbeddingDimMismatch {
108                    expected: expected_dim,
109                    got: embedding.len(),
110                });
111            }
112        } else {
113            return Ok(Recognition {
114                identity: None,
115                score: 0.0,
116            });
117        }
118
119        let mut best_index = None::<usize>;
120        let mut best_score = -1.0f32;
121        for (index, entry) in self.entries.iter().enumerate() {
122            let score = cosine_similarity_prevalidated(embedding, entry.embedding.data())?;
123            if score > best_score {
124                best_score = score;
125                best_index = Some(index);
126            }
127        }
128
129        if best_score >= self.threshold {
130            Ok(Recognition {
131                identity: best_index.map(|index| self.entries[index].id.clone()),
132                score: best_score,
133            })
134        } else {
135            Ok(Recognition {
136                identity: None,
137                score: best_score,
138            })
139        }
140    }
141
142    pub fn to_snapshot(&self) -> RecognizerSnapshot {
143        let mut identities = Vec::with_capacity(self.entries.len());
144        for entry in &self.entries {
145            identities.push(IdentitySnapshot {
146                id: entry.id.clone(),
147                embedding: entry.embedding.data().to_vec(),
148            });
149        }
150
151        RecognizerSnapshot {
152            threshold: self.threshold,
153            identities,
154        }
155    }
156
157    pub fn from_snapshot(snapshot: RecognizerSnapshot) -> Result<Self, RecognizeError> {
158        let mut recognizer = Self::new(snapshot.threshold)?;
159        for entry in snapshot.identities {
160            let embedding = Tensor::from_vec(vec![entry.embedding.len()], entry.embedding)
161                .map_err(|err| RecognizeError::Serialization {
162                    message: err.to_string(),
163                })?;
164            recognizer.enroll(entry.id, embedding)?;
165        }
166        Ok(recognizer)
167    }
168
169    pub fn to_json_pretty(&self) -> Result<String, RecognizeError> {
170        serde_json::to_string_pretty(&self.to_snapshot()).map_err(|err| {
171            RecognizeError::Serialization {
172                message: err.to_string(),
173            }
174        })
175    }
176
177    pub fn from_json(json: &str) -> Result<Self, RecognizeError> {
178        let snapshot: RecognizerSnapshot =
179            serde_json::from_str(json).map_err(|err| RecognizeError::Serialization {
180                message: err.to_string(),
181            })?;
182        Self::from_snapshot(snapshot)
183    }
184
185    pub fn save_json_file(&self, path: impl AsRef<Path>) -> Result<(), RecognizeError> {
186        let json = self.to_json_pretty()?;
187        fs::write(path, json).map_err(|err| RecognizeError::Io {
188            message: err.to_string(),
189        })
190    }
191
192    pub fn load_json_file(path: impl AsRef<Path>) -> Result<Self, RecognizeError> {
193        let json = fs::read_to_string(path).map_err(|err| RecognizeError::Io {
194            message: err.to_string(),
195        })?;
196        Self::from_json(&json)
197    }
198
199    /// Build a VP-tree index from the current gallery for fast nearest-neighbor search.
200    pub fn build_index(&mut self) {
201        let entries: Vec<(String, Vec<f32>)> = self
202            .entries
203            .iter()
204            .map(|e| (e.id.clone(), e.embedding.data().to_vec()))
205            .collect();
206        self.index = Some(VpTree::build(entries));
207    }
208
209    /// Search using the VP-tree index if available, otherwise fall back to linear scan.
210    ///
211    /// Returns the `k` nearest identities that meet the recognition threshold.
212    pub fn search_indexed(
213        &self,
214        embedding: &Tensor,
215        k: usize,
216    ) -> Result<Vec<Recognition>, RecognizeError> {
217        validate_embedding(embedding)?;
218
219        if let Some(expected_dim) = self.embedding_dim {
220            if expected_dim != embedding.len() {
221                return Err(RecognizeError::EmbeddingDimMismatch {
222                    expected: expected_dim,
223                    got: embedding.len(),
224                });
225            }
226        } else {
227            return Ok(Vec::new());
228        }
229
230        if let Some(ref index) = self.index {
231            let results = index.query(embedding.data(), k);
232            Ok(results
233                .into_iter()
234                .filter_map(|r| {
235                    let score = 1.0 - r.distance;
236                    if score >= self.threshold {
237                        Some(Recognition {
238                            identity: Some(r.id),
239                            score,
240                        })
241                    } else {
242                        None
243                    }
244                })
245                .collect())
246        } else {
247            // Fall back to linear scan: collect all scores, sort, take top k.
248            let mut scored: Vec<(usize, f32)> = Vec::with_capacity(self.entries.len());
249            for (i, entry) in self.entries.iter().enumerate() {
250                let score =
251                    cosine_similarity_prevalidated(embedding.data(), entry.embedding.data())?;
252                scored.push((i, score));
253            }
254            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
255            scored.truncate(k);
256            Ok(scored
257                .into_iter()
258                .filter_map(|(i, score)| {
259                    if score >= self.threshold {
260                        Some(Recognition {
261                            identity: Some(self.entries[i].id.clone()),
262                            score,
263                        })
264                    } else {
265                        None
266                    }
267                })
268                .collect())
269        }
270    }
271
272    fn enforce_dim(&mut self, dim: usize) -> Result<(), RecognizeError> {
273        if let Some(expected_dim) = self.embedding_dim {
274            if expected_dim != dim {
275                return Err(RecognizeError::EmbeddingDimMismatch {
276                    expected: expected_dim,
277                    got: dim,
278                });
279            }
280        } else {
281            self.embedding_dim = Some(dim);
282        }
283        Ok(())
284    }
285}