1use crate::{
11 config::CloningConfig,
12 thread_safety::{CacheStats, ModelCache},
13 Error, Result,
14};
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, VecDeque};
17use std::sync::Arc;
18use std::time::{Duration, Instant, SystemTime};
19use tokio::sync::{Mutex, RwLock, Semaphore};
20use tracing::{debug, error, info, trace, warn};
21
22pub struct ModelLoadingManager {
24 cache: Arc<ModelCache<Box<dyn ModelInterface>>>,
26 preloader: Arc<ModelPreloader>,
28 usage_analyzer: Arc<UsagePatternAnalyzer>,
30 memory_manager: Arc<ModelMemoryManager>,
32 config: Arc<RwLock<ModelLoadingConfig>>,
34 metrics: Arc<RwLock<LoadingMetrics>>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ModelLoadingConfig {
41 pub max_cached_models: usize,
43 pub max_cache_memory_mb: usize,
45 pub preload_count: usize,
47 pub preload_interval: Duration,
49 pub memory_map_threshold_mb: usize,
51 pub enable_predictive_loading: bool,
53 pub warming_timeout: Duration,
55 pub enable_compression: bool,
57}
58
59impl Default for ModelLoadingConfig {
60 fn default() -> Self {
61 Self {
62 max_cached_models: 50,
63 max_cache_memory_mb: 2048, preload_count: 5,
65 preload_interval: Duration::from_secs(300), memory_map_threshold_mb: 100,
67 enable_predictive_loading: true,
68 warming_timeout: Duration::from_secs(30),
69 enable_compression: false,
70 }
71 }
72}
73
74pub trait ModelInterface: Send + Sync {
76 fn size_bytes(&self) -> usize;
78 fn model_type(&self) -> &str;
80 fn warm_up(&mut self) -> Result<()>;
82 fn is_ready(&self) -> bool;
84 fn version(&self) -> String;
86}
87
88#[derive(Debug, Clone)]
90pub struct ModelMetadata {
91 pub model_id: String,
92 pub model_type: String,
93 pub size_bytes: usize,
94 pub version: String,
95 pub last_accessed: Instant,
96 pub access_count: u64,
97 pub load_time: Duration,
98 pub memory_mapped: bool,
99 pub compressed: bool,
100}
101
102pub struct ModelPreloader {
104 preload_queue: Arc<RwLock<VecDeque<PreloadRequest>>>,
106 preload_semaphore: Arc<Semaphore>,
108 active_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
110}
111
112#[derive(Debug, Clone)]
114pub struct PreloadRequest {
115 pub model_id: String,
116 pub priority: PreloadPriority,
117 pub requested_at: Instant,
118 pub estimated_size: usize,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
123pub enum PreloadPriority {
124 Low = 1,
125 Medium = 2,
126 High = 3,
127 Critical = 4,
128}
129
130pub struct UsagePatternAnalyzer {
132 access_history: Arc<RwLock<VecDeque<AccessRecord>>>,
134 co_occurrence: Arc<RwLock<HashMap<String, HashMap<String, f32>>>>,
136 temporal_patterns: Arc<RwLock<HashMap<String, Vec<TimeWindow>>>>,
138 config: AnalysisConfig,
140}
141
142#[derive(Debug, Clone)]
144pub struct AccessRecord {
145 pub model_id: String,
146 pub access_time: SystemTime,
147 pub session_id: String,
148 pub access_duration: Duration,
149}
150
151#[derive(Debug, Clone)]
153pub struct TimeWindow {
154 pub hour: u8,
155 pub day_of_week: u8,
156 pub access_probability: f32,
157}
158
159#[derive(Debug, Clone)]
161pub struct AnalysisConfig {
162 pub max_history_size: usize,
163 pub analysis_window: Duration,
164 pub co_occurrence_threshold: f32,
165 pub prediction_confidence_threshold: f32,
166}
167
168impl Default for AnalysisConfig {
169 fn default() -> Self {
170 Self {
171 max_history_size: 10000,
172 analysis_window: Duration::from_secs(7 * 24 * 60 * 60), co_occurrence_threshold: 0.3,
174 prediction_confidence_threshold: 0.7,
175 }
176 }
177}
178
179pub struct ModelMemoryManager {
181 memory_usage: Arc<RwLock<HashMap<String, usize>>>,
183 memory_mapped_models: Arc<RwLock<HashMap<String, Arc<MmapModel>>>>,
185 compressed_cache: Arc<RwLock<HashMap<String, CompressedModel>>>,
187 pressure_monitor: Arc<Mutex<MemoryPressureMonitor>>,
189}
190
191pub struct MmapModel {
193 pub data: Vec<u8>, pub metadata: ModelMetadata,
195}
196
197#[derive(Debug, Clone)]
199pub struct CompressedModel {
200 pub compressed_data: Vec<u8>,
201 pub compression_ratio: f32,
202 pub original_size: usize,
203 pub metadata: ModelMetadata,
204}
205
206#[derive(Debug)]
208pub struct MemoryPressureMonitor {
209 pub current_usage_mb: usize,
210 pub peak_usage_mb: usize,
211 pub pressure_level: MemoryPressureLevel,
212 pub last_cleanup: Instant,
213}
214
215impl Default for MemoryPressureMonitor {
216 fn default() -> Self {
217 Self {
218 current_usage_mb: 0,
219 peak_usage_mb: 0,
220 pressure_level: MemoryPressureLevel::Low,
221 last_cleanup: Instant::now(),
222 }
223 }
224}
225
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub enum MemoryPressureLevel {
229 Low,
230 Medium,
231 High,
232 Critical,
233}
234
235impl Default for MemoryPressureLevel {
236 fn default() -> Self {
237 MemoryPressureLevel::Low
238 }
239}
240
241#[derive(Debug, Default, Clone)]
243pub struct LoadingMetrics {
244 pub cache_hits: u64,
245 pub cache_misses: u64,
246 pub preload_successes: u64,
247 pub preload_failures: u64,
248 pub prediction_accuracy: f32,
249 pub average_load_time: Duration,
250 pub memory_efficiency: f32,
251 pub total_models_loaded: u64,
252 pub models_evicted: u64,
253}
254
255impl ModelLoadingManager {
256 pub fn new(config: ModelLoadingConfig) -> Self {
258 let cache = Arc::new(ModelCache::new(
259 config.max_cached_models,
260 10, ));
262
263 Self {
264 cache,
265 preloader: Arc::new(ModelPreloader::new(5)), usage_analyzer: Arc::new(UsagePatternAnalyzer::new(AnalysisConfig::default())),
267 memory_manager: Arc::new(ModelMemoryManager::new()),
268 config: Arc::new(RwLock::new(config)),
269 metrics: Arc::new(RwLock::new(LoadingMetrics::default())),
270 }
271 }
272
273 pub async fn load_model<T>(&self, model_id: &str) -> Result<Arc<T>>
275 where
276 T: ModelInterface + 'static,
277 {
278 let start_time = Instant::now();
279
280 self.usage_analyzer.record_access(model_id).await;
282
283 if let Some(model) = self.try_load_from_cache(model_id).await? {
285 let mut metrics = self.metrics.write().await;
286 metrics.cache_hits += 1;
287 return Ok(model);
288 }
289
290 let mut metrics = self.metrics.write().await;
292 metrics.cache_misses += 1;
293 drop(metrics);
294
295 let loading_strategy = self.determine_loading_strategy(model_id).await?;
297
298 let model = match loading_strategy {
300 LoadingStrategy::Direct => self.load_direct(model_id).await?,
301 LoadingStrategy::MemoryMapped => self.load_memory_mapped(model_id).await?,
302 LoadingStrategy::Compressed => self.load_compressed(model_id).await?,
303 LoadingStrategy::Streaming => self.load_streaming(model_id).await?,
304 };
305
306 self.warm_up_model(&model).await?;
308
309 let load_time = start_time.elapsed();
311 let mut metrics = self.metrics.write().await;
312 metrics.total_models_loaded += 1;
313 metrics.average_load_time = Duration::from_nanos(
314 (metrics.average_load_time.as_nanos() as u64 + load_time.as_nanos() as u64) / 2,
315 );
316
317 if self.config.read().await.enable_predictive_loading {
319 self.trigger_predictive_preloading(model_id).await?;
320 }
321
322 Ok(model)
323 }
324
325 async fn try_load_from_cache<T>(&self, model_id: &str) -> Result<Option<Arc<T>>>
327 where
328 T: ModelInterface + 'static,
329 {
330 Ok(None)
333 }
334
335 async fn determine_loading_strategy(&self, model_id: &str) -> Result<LoadingStrategy> {
337 let config = self.config.read().await;
338 let memory_pressure = self.memory_manager.get_memory_pressure().await;
339
340 let estimated_size_mb = self.estimate_model_size(model_id).await?;
342
343 match (estimated_size_mb, memory_pressure) {
344 (size, MemoryPressureLevel::Critical) if size > 50 => Ok(LoadingStrategy::Streaming),
345 (size, _) if size > config.memory_map_threshold_mb => Ok(LoadingStrategy::MemoryMapped),
346 (_, MemoryPressureLevel::High) if config.enable_compression => {
347 Ok(LoadingStrategy::Compressed)
348 }
349 _ => Ok(LoadingStrategy::Direct),
350 }
351 }
352
353 async fn estimate_model_size(&self, model_id: &str) -> Result<usize> {
355 match model_id {
357 id if id.contains("large") => Ok(500), id if id.contains("base") => Ok(100), id if id.contains("small") => Ok(25), _ => Ok(100), }
362 }
363
364 async fn load_direct<T>(&self, model_id: &str) -> Result<Arc<T>>
366 where
367 T: ModelInterface + 'static,
368 {
369 Err(Error::Validation(
371 "Direct loading not implemented".to_string(),
372 ))
373 }
374
375 async fn load_memory_mapped<T>(&self, model_id: &str) -> Result<Arc<T>>
377 where
378 T: ModelInterface + 'static,
379 {
380 Err(Error::Validation(
382 "Memory-mapped loading not implemented".to_string(),
383 ))
384 }
385
386 async fn load_compressed<T>(&self, model_id: &str) -> Result<Arc<T>>
388 where
389 T: ModelInterface + 'static,
390 {
391 Err(Error::Validation(
393 "Compressed loading not implemented".to_string(),
394 ))
395 }
396
397 async fn load_streaming<T>(&self, model_id: &str) -> Result<Arc<T>>
399 where
400 T: ModelInterface + 'static,
401 {
402 Err(Error::Validation(
404 "Streaming loading not implemented".to_string(),
405 ))
406 }
407
408 async fn warm_up_model<T>(&self, model: &Arc<T>) -> Result<()>
410 where
411 T: ModelInterface + 'static,
412 {
413 let config = self.config.read().await;
414 let timeout = config.warming_timeout;
415
416 match tokio::time::timeout(timeout, async {
418 Ok(())
420 })
421 .await
422 {
423 Ok(result) => result,
424 Err(_) => {
425 warn!("Model warm-up timed out after {:?}", timeout);
426 Err(Error::Validation("Model warm-up timeout".to_string()))
427 }
428 }
429 }
430
431 async fn trigger_predictive_preloading(&self, current_model: &str) -> Result<()> {
433 let predictions = self
434 .usage_analyzer
435 .predict_next_models(current_model, 3)
436 .await?;
437
438 for (model_id, confidence) in predictions {
439 if confidence > 0.7 {
440 self.preloader
441 .schedule_preload(model_id, PreloadPriority::High)
442 .await?;
443 }
444 }
445
446 Ok(())
447 }
448
449 pub async fn get_metrics(&self) -> LoadingMetrics {
451 self.metrics.read().await.clone()
452 }
453
454 pub async fn optimize_cache(&self) -> Result<()> {
456 let patterns = self.usage_analyzer.analyze_patterns().await?;
458
459 let mut config = self.config.write().await;
461
462 let metrics = self.metrics.read().await;
464 let hit_rate =
465 metrics.cache_hits as f32 / (metrics.cache_hits + metrics.cache_misses) as f32;
466
467 if hit_rate < 0.8 {
468 config.preload_count = (config.preload_count as f32 * 1.2).min(20.0) as usize;
469 } else if hit_rate > 0.95 {
470 config.preload_count = (config.preload_count as f32 * 0.9).max(3.0) as usize;
471 }
472
473 info!(
474 "Cache optimization completed. Hit rate: {:.2}%, Preload count: {}",
475 hit_rate * 100.0,
476 config.preload_count
477 );
478
479 Ok(())
480 }
481}
482
483#[derive(Debug, Clone, Copy)]
485pub enum LoadingStrategy {
486 Direct,
488 MemoryMapped,
490 Compressed,
492 Streaming,
494}
495
496impl ModelPreloader {
497 pub fn new(max_concurrent: usize) -> Self {
499 Self {
500 preload_queue: Arc::new(RwLock::new(VecDeque::new())),
501 preload_semaphore: Arc::new(Semaphore::new(max_concurrent)),
502 active_tasks: Arc::new(RwLock::new(HashMap::new())),
503 }
504 }
505
506 pub async fn schedule_preload(
508 &self,
509 model_id: String,
510 priority: PreloadPriority,
511 ) -> Result<()> {
512 let request = PreloadRequest {
513 model_id: model_id.clone(),
514 priority,
515 requested_at: Instant::now(),
516 estimated_size: 100 * 1024 * 1024, };
518
519 let mut queue = self.preload_queue.write().await;
520
521 let insert_pos = queue
523 .iter()
524 .position(|req| req.priority < priority)
525 .unwrap_or(queue.len());
526
527 queue.insert(insert_pos, request);
528
529 debug!(
530 "Scheduled preload for model {} with priority {:?}",
531 model_id, priority
532 );
533 Ok(())
534 }
535
536 pub async fn process_queue(&self) -> Result<()> {
538 let request = {
539 let mut queue = self.preload_queue.write().await;
540 queue.pop_front()
541 };
542
543 if let Some(req) = request {
544 let _permit =
545 self.preload_semaphore.acquire().await.map_err(|_| {
546 Error::Validation("Failed to acquire preload permit".to_string())
547 })?;
548
549 let model_id = req.model_id.clone();
551 let handle = tokio::spawn(async move {
552 tokio::time::sleep(Duration::from_millis(100)).await; trace!("Completed preload for model {}", model_id);
555 });
556
557 let mut active_tasks = self.active_tasks.write().await;
559 active_tasks.insert(req.model_id, handle);
560 }
561
562 Ok(())
563 }
564}
565
566impl UsagePatternAnalyzer {
567 pub fn new(config: AnalysisConfig) -> Self {
569 Self {
570 access_history: Arc::new(RwLock::new(VecDeque::new())),
571 co_occurrence: Arc::new(RwLock::new(HashMap::new())),
572 temporal_patterns: Arc::new(RwLock::new(HashMap::new())),
573 config,
574 }
575 }
576
577 pub async fn record_access(&self, model_id: &str) {
579 let record = AccessRecord {
580 model_id: model_id.to_string(),
581 access_time: SystemTime::now(),
582 session_id: "session_placeholder".to_string(), access_duration: Duration::from_millis(100), };
585
586 let mut history = self.access_history.write().await;
587 history.push_back(record);
588
589 while history.len() > self.config.max_history_size {
591 history.pop_front();
592 }
593 }
594
595 pub async fn predict_next_models(
597 &self,
598 current_model: &str,
599 count: usize,
600 ) -> Result<Vec<(String, f32)>> {
601 let co_occurrence = self.co_occurrence.read().await;
602
603 if let Some(related_models) = co_occurrence.get(current_model) {
604 let mut predictions: Vec<_> = related_models
605 .iter()
606 .map(|(model, score)| (model.clone(), *score))
607 .collect();
608
609 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
610 predictions.truncate(count);
611
612 Ok(predictions)
613 } else {
614 Ok(Vec::new())
615 }
616 }
617
618 pub async fn analyze_patterns(&self) -> Result<PatternAnalysisResult> {
620 let history = self.access_history.read().await;
621
622 let mut co_occurrence_map = HashMap::new();
624 let window_size = 5; for window in history.iter().collect::<Vec<_>>().windows(window_size) {
627 for (i, record) in window.iter().enumerate() {
628 for other in window.iter().skip(i + 1) {
629 let entry = co_occurrence_map
630 .entry(record.model_id.clone())
631 .or_insert_with(HashMap::new);
632
633 *entry.entry(other.model_id.clone()).or_insert(0.0) += 1.0;
634 }
635 }
636 }
637
638 for (_, related) in co_occurrence_map.iter_mut() {
640 let max_score = related.values().fold(0.0f32, |a, b| a.max(*b));
641 if max_score > 0.0 {
642 for score in related.values_mut() {
643 *score /= max_score;
644 }
645 }
646 }
647
648 let mut co_occurrence = self.co_occurrence.write().await;
650 *co_occurrence = co_occurrence_map;
651
652 Ok(PatternAnalysisResult {
653 total_accesses: history.len(),
654 unique_models: co_occurrence.len(),
655 average_co_occurrence_strength: 0.5, })
657 }
658}
659
660#[derive(Debug, Clone)]
662pub struct PatternAnalysisResult {
663 pub total_accesses: usize,
664 pub unique_models: usize,
665 pub average_co_occurrence_strength: f32,
666}
667
668impl Default for ModelMemoryManager {
669 fn default() -> Self {
670 Self::new()
671 }
672}
673
674impl ModelMemoryManager {
675 pub fn new() -> Self {
677 Self {
678 memory_usage: Arc::new(RwLock::new(HashMap::new())),
679 memory_mapped_models: Arc::new(RwLock::new(HashMap::new())),
680 compressed_cache: Arc::new(RwLock::new(HashMap::new())),
681 pressure_monitor: Arc::new(Mutex::new(MemoryPressureMonitor::default())),
682 }
683 }
684
685 pub async fn get_memory_pressure(&self) -> MemoryPressureLevel {
687 let monitor = self.pressure_monitor.lock().await;
688 monitor.pressure_level
689 }
690
691 pub async fn update_memory_pressure(&self, current_usage_mb: usize) {
693 let mut monitor = self.pressure_monitor.lock().await;
694 monitor.current_usage_mb = current_usage_mb;
695 monitor.peak_usage_mb = monitor.peak_usage_mb.max(current_usage_mb);
696
697 monitor.pressure_level = match current_usage_mb {
698 usage if usage < 512 => MemoryPressureLevel::Low,
699 usage if usage < 1024 => MemoryPressureLevel::Medium,
700 usage if usage < 1536 => MemoryPressureLevel::High,
701 _ => MemoryPressureLevel::Critical,
702 };
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709 use tokio;
710
711 #[tokio::test]
712 async fn test_model_loading_manager_creation() {
713 let config = ModelLoadingConfig::default();
714 let manager = ModelLoadingManager::new(config);
715
716 let metrics = manager.get_metrics().await;
717 assert_eq!(metrics.cache_hits, 0);
718 assert_eq!(metrics.cache_misses, 0);
719 }
720
721 #[tokio::test]
722 async fn test_preloader_scheduling() {
723 let preloader = ModelPreloader::new(5);
724
725 let result = preloader
726 .schedule_preload("test_model".to_string(), PreloadPriority::High)
727 .await;
728
729 assert!(result.is_ok());
730 }
731
732 #[tokio::test]
733 async fn test_usage_pattern_analyzer() {
734 let analyzer = UsagePatternAnalyzer::new(AnalysisConfig::default());
735
736 analyzer.record_access("model_a").await;
738 analyzer.record_access("model_b").await;
739 analyzer.record_access("model_a").await;
740
741 let result = analyzer.analyze_patterns().await;
742 assert!(result.is_ok());
743
744 let patterns = result.unwrap();
745 assert_eq!(patterns.total_accesses, 3);
746 }
747
748 #[tokio::test]
749 async fn test_memory_pressure_monitoring() {
750 let memory_manager = ModelMemoryManager::new();
751
752 assert_eq!(
754 memory_manager.get_memory_pressure().await,
755 MemoryPressureLevel::Low
756 );
757
758 memory_manager.update_memory_pressure(1200).await;
760 assert_eq!(
761 memory_manager.get_memory_pressure().await,
762 MemoryPressureLevel::High
763 );
764 }
765
766 #[tokio::test]
767 async fn test_loading_config_defaults() {
768 let config = ModelLoadingConfig::default();
769
770 assert_eq!(config.max_cached_models, 50);
771 assert_eq!(config.max_cache_memory_mb, 2048);
772 assert_eq!(config.preload_count, 5);
773 assert!(config.enable_predictive_loading);
774 }
775}