sevensense_embedding/infrastructure/
model_manager.rs

1//! Model management for ONNX embedding models.
2//!
3//! Provides thread-safe loading, caching, and hot-swapping of
4//! Perch 2.0 ONNX models for embedding generation.
5
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11use sha2::{Digest, Sha256};
12use thiserror::Error;
13use tracing::{debug, info, instrument, warn};
14
15use super::onnx_inference::OnnxInference;
16use crate::domain::entities::{EmbeddingModel, ModelVersion};
17
18/// Errors that can occur during model management
19#[derive(Debug, Error)]
20pub enum ModelError {
21    /// Model file not found
22    #[error("Model not found: {0}")]
23    NotFound(String),
24
25    /// Failed to load model
26    #[error("Failed to load model: {0}")]
27    LoadFailed(String),
28
29    /// Checksum verification failed
30    #[error("Checksum mismatch for model {model}: expected {expected}, got {actual}")]
31    ChecksumMismatch {
32        /// Model name
33        model: String,
34        /// Expected checksum
35        expected: String,
36        /// Actual checksum
37        actual: String,
38    },
39
40    /// Model initialization failed
41    #[error("Model initialization failed: {0}")]
42    InitializationFailed(String),
43
44    /// IO error
45    #[error("IO error: {0}")]
46    Io(#[from] std::io::Error),
47
48    /// ONNX Runtime error
49    #[error("ONNX Runtime error: {0}")]
50    OnnxRuntime(String),
51
52    /// Model not ready
53    #[error("Model not ready: {0}")]
54    NotReady(String),
55}
56
57/// Configuration for the model manager
58#[derive(Debug, Clone)]
59pub struct ModelConfig {
60    /// Directory containing model files
61    pub model_dir: PathBuf,
62
63    /// Number of threads for intra-op parallelism
64    pub intra_op_threads: usize,
65
66    /// Number of threads for inter-op parallelism
67    pub inter_op_threads: usize,
68
69    /// Whether to verify model checksums on load
70    pub verify_checksums: bool,
71
72    /// Execution providers in priority order
73    pub execution_providers: Vec<ExecutionProvider>,
74
75    /// Maximum number of cached sessions
76    pub max_cached_sessions: usize,
77}
78
79impl Default for ModelConfig {
80    fn default() -> Self {
81        Self {
82            model_dir: PathBuf::from("models"),
83            intra_op_threads: num_cpus::get().min(4),
84            inter_op_threads: 1,
85            verify_checksums: true,
86            execution_providers: vec![
87                ExecutionProvider::Cuda { device_id: 0 },
88                ExecutionProvider::CoreML,
89                ExecutionProvider::Cpu,
90            ],
91            max_cached_sessions: 4,
92        }
93    }
94}
95
96/// Execution provider for ONNX Runtime
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub enum ExecutionProvider {
99    /// CPU execution
100    Cpu,
101
102    /// NVIDIA CUDA execution
103    Cuda {
104        /// GPU device ID
105        device_id: i32,
106    },
107
108    /// Apple CoreML execution
109    CoreML,
110
111    /// DirectML execution (Windows)
112    DirectML {
113        /// Device ID
114        device_id: i32,
115    },
116}
117
118/// Thread-safe model session manager with caching and hot-swap support.
119///
120/// Manages the lifecycle of ONNX models used for embedding generation,
121/// including loading, caching, and version management.
122pub struct ModelManager {
123    /// Cached model sessions by version
124    sessions: RwLock<HashMap<String, Arc<OnnxInference>>>,
125
126    /// Model metadata by version
127    models: RwLock<HashMap<String, EmbeddingModel>>,
128
129    /// Currently active model version
130    active_version: RwLock<ModelVersion>,
131
132    /// Configuration
133    config: ModelConfig,
134}
135
136impl ModelManager {
137    /// Create a new model manager with the given configuration.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if the model directory doesn't exist and can't be created.
142    pub fn new(config: ModelConfig) -> Result<Self, ModelError> {
143        // Ensure model directory exists
144        if !config.model_dir.exists() {
145            std::fs::create_dir_all(&config.model_dir)?;
146            debug!(path = ?config.model_dir, "Created model directory");
147        }
148
149        Ok(Self {
150            sessions: RwLock::new(HashMap::new()),
151            models: RwLock::new(HashMap::new()),
152            active_version: RwLock::new(ModelVersion::perch_v2_base()),
153            config,
154        })
155    }
156
157    /// Create with default configuration
158    pub fn with_defaults() -> Result<Self, ModelError> {
159        Self::new(ModelConfig::default())
160    }
161
162    /// Load a model from a file.
163    ///
164    /// # Arguments
165    ///
166    /// * `name` - Model name (e.g., "perch-v2")
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if the model file doesn't exist or fails to load.
171    #[instrument(skip(self), fields(model = %name))]
172    pub fn load_model(&self, name: &str) -> Result<Arc<OnnxInference>, ModelError> {
173        let version = self.active_version.read().clone();
174        let version_key = version.full_version();
175
176        // Check cache first
177        {
178            let sessions = self.sessions.read();
179            if let Some(session) = sessions.get(&version_key) {
180                debug!("Using cached session for {}", version_key);
181                return Ok(Arc::clone(session));
182            }
183        }
184
185        // Resolve model path
186        let model_path = self.resolve_model_path(name, &version)?;
187
188        // Verify checksum if configured
189        if self.config.verify_checksums {
190            if let Some(model) = self.models.read().get(&version_key) {
191                if !model.checksum.is_empty() {
192                    self.verify_checksum(&model_path, &model.checksum)?;
193                }
194            }
195        }
196
197        // Create new session
198        info!(path = ?model_path, "Loading model");
199        let session = self.create_session(&model_path)?;
200        let session = Arc::new(session);
201
202        // Cache the session
203        {
204            let mut sessions = self.sessions.write();
205
206            // Evict old sessions if at capacity
207            while sessions.len() >= self.config.max_cached_sessions {
208                if let Some(key) = sessions.keys().next().cloned() {
209                    sessions.remove(&key);
210                    debug!("Evicted cached session: {}", key);
211                }
212            }
213
214            sessions.insert(version_key.clone(), Arc::clone(&session));
215        }
216
217        // Update model metadata
218        {
219            let mut models = self.models.write();
220            if let Some(model) = models.get_mut(&version_key) {
221                model.mark_active();
222            }
223        }
224
225        info!(version = %version_key, "Model loaded successfully");
226        Ok(session)
227    }
228
229    /// Verify the checksum of a model file.
230    ///
231    /// # Errors
232    ///
233    /// Returns an error if the checksum doesn't match.
234    pub fn verify_checksum(&self, path: &Path, expected: &str) -> Result<bool, ModelError> {
235        let actual = self.compute_checksum(path)?;
236
237        if actual != expected {
238            return Err(ModelError::ChecksumMismatch {
239                model: path.display().to_string(),
240                expected: expected.to_string(),
241                actual,
242            });
243        }
244
245        debug!(path = ?path, "Checksum verified");
246        Ok(true)
247    }
248
249    /// Compute the SHA-256 checksum of a file.
250    fn compute_checksum(&self, path: &Path) -> Result<String, ModelError> {
251        let mut file = std::fs::File::open(path)?;
252        let mut hasher = Sha256::new();
253        std::io::copy(&mut file, &mut hasher)?;
254        let hash = hasher.finalize();
255        Ok(hex::encode(hash))
256    }
257
258    /// Hot-swap to a new model version without restart.
259    ///
260    /// # Arguments
261    ///
262    /// * `name` - Model name
263    /// * `new_path` - Path to the new model file
264    ///
265    /// # Errors
266    ///
267    /// Returns an error if the new model fails to load.
268    #[instrument(skip(self, new_path), fields(model = %name, path = ?new_path))]
269    pub fn hot_swap(&self, name: &str, new_path: &Path) -> Result<(), ModelError> {
270        // Validate the new model can be loaded
271        info!("Attempting hot-swap to new model");
272        let new_session = self.create_session(new_path)?;
273
274        // Compute checksum for the new model
275        let checksum = self.compute_checksum(new_path)?;
276
277        // Create new version
278        let old_version = self.active_version.read().clone();
279        let new_version = ModelVersion::new(
280            name,
281            &old_version.version, // Keep same semantic version
282            "hot-swap",
283        );
284        let version_key = new_version.full_version();
285
286        // Update sessions cache
287        {
288            let mut sessions = self.sessions.write();
289            sessions.insert(version_key.clone(), Arc::new(new_session));
290        }
291
292        // Update model metadata
293        {
294            let mut models = self.models.write();
295            let mut model = EmbeddingModel::new(
296                name.to_string(),
297                new_version.clone(),
298                checksum,
299            );
300            model.model_path = Some(new_path.to_string_lossy().to_string());
301            model.mark_active();
302            models.insert(version_key, model);
303        }
304
305        // Update active version
306        *self.active_version.write() = new_version.clone();
307
308        info!(
309            old_version = %old_version,
310            new_version = %new_version,
311            "Hot-swap completed successfully"
312        );
313
314        Ok(())
315    }
316
317    /// Get the ONNX inference engine for the current model.
318    ///
319    /// # Errors
320    ///
321    /// Returns an error if no model is loaded.
322    pub async fn get_inference(&self) -> Result<Arc<OnnxInference>, ModelError> {
323        let version = self.active_version.read().clone();
324        self.load_model(&version.name)
325    }
326
327    /// Get the currently active model version.
328    #[must_use]
329    pub fn current_version(&self) -> ModelVersion {
330        self.active_version.read().clone()
331    }
332
333    /// Set the active model version.
334    pub fn set_active_version(&self, version: ModelVersion) {
335        *self.active_version.write() = version;
336    }
337
338    /// Check if a model is loaded and ready.
339    pub async fn is_ready(&self) -> bool {
340        let version_key = self.active_version.read().full_version();
341        self.sessions.read().contains_key(&version_key)
342    }
343
344    /// Get model metadata for a version.
345    #[must_use]
346    pub fn get_model(&self, version_key: &str) -> Option<EmbeddingModel> {
347        self.models.read().get(version_key).cloned()
348    }
349
350    /// List all loaded models.
351    #[must_use]
352    pub fn list_models(&self) -> Vec<EmbeddingModel> {
353        self.models.read().values().cloned().collect()
354    }
355
356    /// Clear all cached sessions.
357    pub fn clear_cache(&self) {
358        self.sessions.write().clear();
359        info!("Cleared model session cache");
360    }
361
362    /// Resolve the path to a model file.
363    fn resolve_model_path(&self, name: &str, version: &ModelVersion) -> Result<PathBuf, ModelError> {
364        // Try various naming conventions
365        let candidates = vec![
366            self.config.model_dir.join(format!("{}.onnx", version.full_version())),
367            self.config.model_dir.join(format!("{}_{}.onnx", name, version.version)),
368            self.config.model_dir.join(format!("{}.onnx", name)),
369            self.config.model_dir.join(format!("{}/{}.onnx", name, version.version)),
370        ];
371
372        for path in &candidates {
373            if path.exists() {
374                return Ok(path.clone());
375            }
376        }
377
378        // Also check if there's a model metadata entry with a path
379        let version_key = version.full_version();
380        if let Some(model) = self.models.read().get(&version_key) {
381            if let Some(ref path_str) = model.model_path {
382                let path = PathBuf::from(path_str);
383                if path.exists() {
384                    return Ok(path);
385                }
386            }
387        }
388
389        Err(ModelError::NotFound(format!(
390            "Model {} not found in {:?}. Tried: {:?}",
391            name, self.config.model_dir, candidates
392        )))
393    }
394
395    /// Create an ONNX inference session from a model file.
396    fn create_session(&self, path: &Path) -> Result<OnnxInference, ModelError> {
397        OnnxInference::new(
398            path,
399            self.config.intra_op_threads,
400            self.config.inter_op_threads,
401            &self.config.execution_providers,
402        )
403        .map_err(|e| ModelError::LoadFailed(e.to_string()))
404    }
405
406    /// Register a model without loading it.
407    pub fn register_model(&self, model: EmbeddingModel) {
408        let version_key = model.version.full_version();
409        self.models.write().insert(version_key, model);
410    }
411
412    /// Unload a specific model version from cache.
413    pub fn unload_model(&self, version_key: &str) -> bool {
414        let removed = self.sessions.write().remove(version_key).is_some();
415        if removed {
416            info!(version = %version_key, "Unloaded model from cache");
417        }
418        removed
419    }
420}
421
422impl std::fmt::Debug for ModelManager {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        f.debug_struct("ModelManager")
425            .field("model_dir", &self.config.model_dir)
426            .field("active_version", &*self.active_version.read())
427            .field("cached_sessions", &self.sessions.read().len())
428            .finish()
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use std::io::Write;
436    use tempfile::tempdir;
437
438    #[test]
439    fn test_model_config_default() {
440        let config = ModelConfig::default();
441        assert!(config.intra_op_threads > 0);
442        assert!(config.verify_checksums);
443    }
444
445    #[test]
446    fn test_model_manager_creation() {
447        let dir = tempdir().unwrap();
448        let config = ModelConfig {
449            model_dir: dir.path().to_path_buf(),
450            ..Default::default()
451        };
452        let manager = ModelManager::new(config);
453        assert!(manager.is_ok());
454    }
455
456    #[test]
457    fn test_checksum_computation() {
458        let dir = tempdir().unwrap();
459        let file_path = dir.path().join("test.bin");
460
461        let mut file = std::fs::File::create(&file_path).unwrap();
462        file.write_all(b"test content").unwrap();
463
464        let config = ModelConfig {
465            model_dir: dir.path().to_path_buf(),
466            ..Default::default()
467        };
468        let manager = ModelManager::new(config).unwrap();
469
470        let checksum = manager.compute_checksum(&file_path).unwrap();
471        assert!(!checksum.is_empty());
472        assert_eq!(checksum.len(), 64); // SHA-256 hex length
473    }
474
475    #[test]
476    fn test_model_version_key() {
477        let version = ModelVersion::perch_v2_base();
478        assert_eq!(version.full_version(), "perch-v2-2.0.0-base");
479    }
480
481    #[test]
482    fn test_register_model() {
483        let dir = tempdir().unwrap();
484        let config = ModelConfig {
485            model_dir: dir.path().to_path_buf(),
486            ..Default::default()
487        };
488        let manager = ModelManager::new(config).unwrap();
489
490        let model = EmbeddingModel::perch_v2_default();
491        let version_key = model.version.full_version();
492
493        manager.register_model(model);
494
495        let retrieved = manager.get_model(&version_key);
496        assert!(retrieved.is_some());
497    }
498
499    #[test]
500    fn test_clear_cache() {
501        let dir = tempdir().unwrap();
502        let config = ModelConfig {
503            model_dir: dir.path().to_path_buf(),
504            ..Default::default()
505        };
506        let manager = ModelManager::new(config).unwrap();
507
508        manager.clear_cache();
509        // Should not panic
510    }
511}