sevensense_embedding/infrastructure/
model_manager.rs1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11use sha2::{Digest, Sha256};
12use thiserror::Error;
13use tracing::{debug, info, instrument, warn};
14
15use super::onnx_inference::OnnxInference;
16use crate::domain::entities::{EmbeddingModel, ModelVersion};
17
18#[derive(Debug, Error)]
20pub enum ModelError {
21 #[error("Model not found: {0}")]
23 NotFound(String),
24
25 #[error("Failed to load model: {0}")]
27 LoadFailed(String),
28
29 #[error("Checksum mismatch for model {model}: expected {expected}, got {actual}")]
31 ChecksumMismatch {
32 model: String,
34 expected: String,
36 actual: String,
38 },
39
40 #[error("Model initialization failed: {0}")]
42 InitializationFailed(String),
43
44 #[error("IO error: {0}")]
46 Io(#[from] std::io::Error),
47
48 #[error("ONNX Runtime error: {0}")]
50 OnnxRuntime(String),
51
52 #[error("Model not ready: {0}")]
54 NotReady(String),
55}
56
57#[derive(Debug, Clone)]
59pub struct ModelConfig {
60 pub model_dir: PathBuf,
62
63 pub intra_op_threads: usize,
65
66 pub inter_op_threads: usize,
68
69 pub verify_checksums: bool,
71
72 pub execution_providers: Vec<ExecutionProvider>,
74
75 pub max_cached_sessions: usize,
77}
78
79impl Default for ModelConfig {
80 fn default() -> Self {
81 Self {
82 model_dir: PathBuf::from("models"),
83 intra_op_threads: num_cpus::get().min(4),
84 inter_op_threads: 1,
85 verify_checksums: true,
86 execution_providers: vec![
87 ExecutionProvider::Cuda { device_id: 0 },
88 ExecutionProvider::CoreML,
89 ExecutionProvider::Cpu,
90 ],
91 max_cached_sessions: 4,
92 }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
98pub enum ExecutionProvider {
99 Cpu,
101
102 Cuda {
104 device_id: i32,
106 },
107
108 CoreML,
110
111 DirectML {
113 device_id: i32,
115 },
116}
117
118pub struct ModelManager {
123 sessions: RwLock<HashMap<String, Arc<OnnxInference>>>,
125
126 models: RwLock<HashMap<String, EmbeddingModel>>,
128
129 active_version: RwLock<ModelVersion>,
131
132 config: ModelConfig,
134}
135
136impl ModelManager {
137 pub fn new(config: ModelConfig) -> Result<Self, ModelError> {
143 if !config.model_dir.exists() {
145 std::fs::create_dir_all(&config.model_dir)?;
146 debug!(path = ?config.model_dir, "Created model directory");
147 }
148
149 Ok(Self {
150 sessions: RwLock::new(HashMap::new()),
151 models: RwLock::new(HashMap::new()),
152 active_version: RwLock::new(ModelVersion::perch_v2_base()),
153 config,
154 })
155 }
156
157 pub fn with_defaults() -> Result<Self, ModelError> {
159 Self::new(ModelConfig::default())
160 }
161
162 #[instrument(skip(self), fields(model = %name))]
172 pub fn load_model(&self, name: &str) -> Result<Arc<OnnxInference>, ModelError> {
173 let version = self.active_version.read().clone();
174 let version_key = version.full_version();
175
176 {
178 let sessions = self.sessions.read();
179 if let Some(session) = sessions.get(&version_key) {
180 debug!("Using cached session for {}", version_key);
181 return Ok(Arc::clone(session));
182 }
183 }
184
185 let model_path = self.resolve_model_path(name, &version)?;
187
188 if self.config.verify_checksums {
190 if let Some(model) = self.models.read().get(&version_key) {
191 if !model.checksum.is_empty() {
192 self.verify_checksum(&model_path, &model.checksum)?;
193 }
194 }
195 }
196
197 info!(path = ?model_path, "Loading model");
199 let session = self.create_session(&model_path)?;
200 let session = Arc::new(session);
201
202 {
204 let mut sessions = self.sessions.write();
205
206 while sessions.len() >= self.config.max_cached_sessions {
208 if let Some(key) = sessions.keys().next().cloned() {
209 sessions.remove(&key);
210 debug!("Evicted cached session: {}", key);
211 }
212 }
213
214 sessions.insert(version_key.clone(), Arc::clone(&session));
215 }
216
217 {
219 let mut models = self.models.write();
220 if let Some(model) = models.get_mut(&version_key) {
221 model.mark_active();
222 }
223 }
224
225 info!(version = %version_key, "Model loaded successfully");
226 Ok(session)
227 }
228
229 pub fn verify_checksum(&self, path: &Path, expected: &str) -> Result<bool, ModelError> {
235 let actual = self.compute_checksum(path)?;
236
237 if actual != expected {
238 return Err(ModelError::ChecksumMismatch {
239 model: path.display().to_string(),
240 expected: expected.to_string(),
241 actual,
242 });
243 }
244
245 debug!(path = ?path, "Checksum verified");
246 Ok(true)
247 }
248
249 fn compute_checksum(&self, path: &Path) -> Result<String, ModelError> {
251 let mut file = std::fs::File::open(path)?;
252 let mut hasher = Sha256::new();
253 std::io::copy(&mut file, &mut hasher)?;
254 let hash = hasher.finalize();
255 Ok(hex::encode(hash))
256 }
257
258 #[instrument(skip(self, new_path), fields(model = %name, path = ?new_path))]
269 pub fn hot_swap(&self, name: &str, new_path: &Path) -> Result<(), ModelError> {
270 info!("Attempting hot-swap to new model");
272 let new_session = self.create_session(new_path)?;
273
274 let checksum = self.compute_checksum(new_path)?;
276
277 let old_version = self.active_version.read().clone();
279 let new_version = ModelVersion::new(
280 name,
281 &old_version.version, "hot-swap",
283 );
284 let version_key = new_version.full_version();
285
286 {
288 let mut sessions = self.sessions.write();
289 sessions.insert(version_key.clone(), Arc::new(new_session));
290 }
291
292 {
294 let mut models = self.models.write();
295 let mut model = EmbeddingModel::new(
296 name.to_string(),
297 new_version.clone(),
298 checksum,
299 );
300 model.model_path = Some(new_path.to_string_lossy().to_string());
301 model.mark_active();
302 models.insert(version_key, model);
303 }
304
305 *self.active_version.write() = new_version.clone();
307
308 info!(
309 old_version = %old_version,
310 new_version = %new_version,
311 "Hot-swap completed successfully"
312 );
313
314 Ok(())
315 }
316
317 pub async fn get_inference(&self) -> Result<Arc<OnnxInference>, ModelError> {
323 let version = self.active_version.read().clone();
324 self.load_model(&version.name)
325 }
326
327 #[must_use]
329 pub fn current_version(&self) -> ModelVersion {
330 self.active_version.read().clone()
331 }
332
333 pub fn set_active_version(&self, version: ModelVersion) {
335 *self.active_version.write() = version;
336 }
337
338 pub async fn is_ready(&self) -> bool {
340 let version_key = self.active_version.read().full_version();
341 self.sessions.read().contains_key(&version_key)
342 }
343
344 #[must_use]
346 pub fn get_model(&self, version_key: &str) -> Option<EmbeddingModel> {
347 self.models.read().get(version_key).cloned()
348 }
349
350 #[must_use]
352 pub fn list_models(&self) -> Vec<EmbeddingModel> {
353 self.models.read().values().cloned().collect()
354 }
355
356 pub fn clear_cache(&self) {
358 self.sessions.write().clear();
359 info!("Cleared model session cache");
360 }
361
362 fn resolve_model_path(&self, name: &str, version: &ModelVersion) -> Result<PathBuf, ModelError> {
364 let candidates = vec![
366 self.config.model_dir.join(format!("{}.onnx", version.full_version())),
367 self.config.model_dir.join(format!("{}_{}.onnx", name, version.version)),
368 self.config.model_dir.join(format!("{}.onnx", name)),
369 self.config.model_dir.join(format!("{}/{}.onnx", name, version.version)),
370 ];
371
372 for path in &candidates {
373 if path.exists() {
374 return Ok(path.clone());
375 }
376 }
377
378 let version_key = version.full_version();
380 if let Some(model) = self.models.read().get(&version_key) {
381 if let Some(ref path_str) = model.model_path {
382 let path = PathBuf::from(path_str);
383 if path.exists() {
384 return Ok(path);
385 }
386 }
387 }
388
389 Err(ModelError::NotFound(format!(
390 "Model {} not found in {:?}. Tried: {:?}",
391 name, self.config.model_dir, candidates
392 )))
393 }
394
395 fn create_session(&self, path: &Path) -> Result<OnnxInference, ModelError> {
397 OnnxInference::new(
398 path,
399 self.config.intra_op_threads,
400 self.config.inter_op_threads,
401 &self.config.execution_providers,
402 )
403 .map_err(|e| ModelError::LoadFailed(e.to_string()))
404 }
405
406 pub fn register_model(&self, model: EmbeddingModel) {
408 let version_key = model.version.full_version();
409 self.models.write().insert(version_key, model);
410 }
411
412 pub fn unload_model(&self, version_key: &str) -> bool {
414 let removed = self.sessions.write().remove(version_key).is_some();
415 if removed {
416 info!(version = %version_key, "Unloaded model from cache");
417 }
418 removed
419 }
420}
421
422impl std::fmt::Debug for ModelManager {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 f.debug_struct("ModelManager")
425 .field("model_dir", &self.config.model_dir)
426 .field("active_version", &*self.active_version.read())
427 .field("cached_sessions", &self.sessions.read().len())
428 .finish()
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use std::io::Write;
436 use tempfile::tempdir;
437
438 #[test]
439 fn test_model_config_default() {
440 let config = ModelConfig::default();
441 assert!(config.intra_op_threads > 0);
442 assert!(config.verify_checksums);
443 }
444
445 #[test]
446 fn test_model_manager_creation() {
447 let dir = tempdir().unwrap();
448 let config = ModelConfig {
449 model_dir: dir.path().to_path_buf(),
450 ..Default::default()
451 };
452 let manager = ModelManager::new(config);
453 assert!(manager.is_ok());
454 }
455
456 #[test]
457 fn test_checksum_computation() {
458 let dir = tempdir().unwrap();
459 let file_path = dir.path().join("test.bin");
460
461 let mut file = std::fs::File::create(&file_path).unwrap();
462 file.write_all(b"test content").unwrap();
463
464 let config = ModelConfig {
465 model_dir: dir.path().to_path_buf(),
466 ..Default::default()
467 };
468 let manager = ModelManager::new(config).unwrap();
469
470 let checksum = manager.compute_checksum(&file_path).unwrap();
471 assert!(!checksum.is_empty());
472 assert_eq!(checksum.len(), 64); }
474
475 #[test]
476 fn test_model_version_key() {
477 let version = ModelVersion::perch_v2_base();
478 assert_eq!(version.full_version(), "perch-v2-2.0.0-base");
479 }
480
481 #[test]
482 fn test_register_model() {
483 let dir = tempdir().unwrap();
484 let config = ModelConfig {
485 model_dir: dir.path().to_path_buf(),
486 ..Default::default()
487 };
488 let manager = ModelManager::new(config).unwrap();
489
490 let model = EmbeddingModel::perch_v2_default();
491 let version_key = model.version.full_version();
492
493 manager.register_model(model);
494
495 let retrieved = manager.get_model(&version_key);
496 assert!(retrieved.is_some());
497 }
498
499 #[test]
500 fn test_clear_cache() {
501 let dir = tempdir().unwrap();
502 let config = ModelConfig {
503 model_dir: dir.path().to_path_buf(),
504 ..Default::default()
505 };
506 let manager = ModelManager::new(config).unwrap();
507
508 manager.clear_cache();
509 }
511}