1use crate::config::{Model, ModelCapability, SandboxProfile, Slm};
33use std::collections::HashMap;
34use thiserror::Error;
35
36#[derive(Debug, Error)]
38pub enum ModelCatalogError {
39 #[error("Model not found: {id}")]
40 ModelNotFound { id: String },
41
42 #[error("Invalid model configuration: {reason}")]
43 InvalidConfig { reason: String },
44
45 #[error("Sandbox profile not found: {profile}")]
46 SandboxProfileNotFound { profile: String },
47
48 #[error("Agent has no associated models: {agent_id}")]
49 NoModelsForAgent { agent_id: String },
50}
51
52#[derive(Debug, Clone)]
58pub struct ModelCatalog {
59 models: HashMap<String, Model>,
61 agent_model_maps: HashMap<String, Vec<String>>,
63 sandbox_profiles: HashMap<String, SandboxProfile>,
65 default_sandbox_profile: String,
67 allow_runtime_overrides: bool,
69}
70
71impl ModelCatalog {
72 pub fn new(slm_config: Slm) -> Result<Self, ModelCatalogError> {
79 if !slm_config
81 .sandbox_profiles
82 .contains_key(&slm_config.default_sandbox_profile)
83 {
84 return Err(ModelCatalogError::SandboxProfileNotFound {
85 profile: slm_config.default_sandbox_profile,
86 });
87 }
88
89 let mut models = HashMap::new();
91 for model in slm_config.model_allow_lists.global_models {
92 if models.insert(model.id.clone(), model.clone()).is_some() {
93 return Err(ModelCatalogError::InvalidConfig {
94 reason: format!("Duplicate model ID: {}", model.id),
95 });
96 }
97 }
98
99 for (agent_id, model_ids) in &slm_config.model_allow_lists.agent_model_maps {
101 for model_id in model_ids {
102 if !models.contains_key(model_id) {
103 return Err(ModelCatalogError::InvalidConfig {
104 reason: format!(
105 "Agent '{}' references non-existent model: {}",
106 agent_id, model_id
107 ),
108 });
109 }
110 }
111 }
112
113 Ok(Self {
114 models,
115 agent_model_maps: slm_config.model_allow_lists.agent_model_maps,
116 sandbox_profiles: slm_config.sandbox_profiles,
117 default_sandbox_profile: slm_config.default_sandbox_profile,
118 allow_runtime_overrides: slm_config.model_allow_lists.allow_runtime_overrides,
119 })
120 }
121
122 pub fn get_model(&self, model_id: &str) -> Option<&Model> {
126 self.models.get(model_id)
127 }
128
129 pub fn list_models(&self) -> Vec<&Model> {
131 self.models.values().collect()
132 }
133
134 pub fn get_models_for_agent(&self, agent_id: &str) -> Vec<&Model> {
139 if let Some(model_ids) = self.agent_model_maps.get(agent_id) {
140 model_ids
141 .iter()
142 .filter_map(|id| self.models.get(id))
143 .collect()
144 } else {
145 self.list_models()
147 }
148 }
149
150 pub fn get_models_with_capability(&self, capability: &ModelCapability) -> Vec<&Model> {
152 self.models
153 .values()
154 .filter(|model| model.capabilities.contains(capability))
155 .collect()
156 }
157
158 pub fn get_default_sandbox_profile(&self) -> Option<&SandboxProfile> {
160 self.sandbox_profiles.get(&self.default_sandbox_profile)
161 }
162
163 pub fn get_sandbox_profile(&self, profile_name: &str) -> Option<&SandboxProfile> {
165 self.sandbox_profiles.get(profile_name)
166 }
167
168 pub fn list_sandbox_profiles(&self) -> Vec<(&String, &SandboxProfile)> {
170 self.sandbox_profiles.iter().collect()
171 }
172
173 pub fn allows_runtime_overrides(&self) -> bool {
175 self.allow_runtime_overrides
176 }
177
178 pub fn get_model_requirements(
180 &self,
181 model_id: &str,
182 ) -> Option<&crate::config::ModelResourceRequirements> {
183 self.get_model(model_id)
184 .map(|model| &model.resource_requirements)
185 }
186
187 pub fn find_best_model_for_requirements(
192 &self,
193 required_capabilities: &[ModelCapability],
194 max_memory_mb: Option<u64>,
195 agent_id: Option<&str>,
196 ) -> Option<&Model> {
197 let candidate_models = if let Some(agent_id) = agent_id {
198 self.get_models_for_agent(agent_id)
199 } else {
200 self.list_models()
201 };
202
203 candidate_models
204 .into_iter()
205 .filter(|model| {
206 required_capabilities
208 .iter()
209 .all(|cap| model.capabilities.contains(cap))
210 })
211 .filter(|model| {
212 if let Some(max_memory) = max_memory_mb {
214 model.resource_requirements.min_memory_mb <= max_memory
215 } else {
216 true
217 }
218 })
219 .min_by_key(|model| model.resource_requirements.min_memory_mb)
220 }
221
222 pub fn validate_model_access(
224 &self,
225 model_id: &str,
226 agent_id: &str,
227 ) -> Result<(), ModelCatalogError> {
228 if !self.models.contains_key(model_id) {
230 return Err(ModelCatalogError::ModelNotFound {
231 id: model_id.to_string(),
232 });
233 }
234
235 let agent_models = self.get_models_for_agent(agent_id);
237 if !agent_models.iter().any(|model| model.id == model_id) {
238 return Err(ModelCatalogError::InvalidConfig {
239 reason: format!(
240 "Agent '{}' does not have access to model '{}'",
241 agent_id, model_id
242 ),
243 });
244 }
245
246 Ok(())
247 }
248
249 pub fn get_statistics(&self) -> CatalogStatistics {
251 let total_models = self.models.len();
252 let models_with_gpu = self
253 .models
254 .values()
255 .filter(|model| model.resource_requirements.gpu_requirements.is_some())
256 .count();
257
258 let mut capability_counts = HashMap::new();
259 for model in self.models.values() {
260 for capability in &model.capabilities {
261 *capability_counts.entry(capability.clone()).or_insert(0) += 1;
262 }
263 }
264
265 CatalogStatistics {
266 total_models,
267 models_with_gpu,
268 total_agents_with_mappings: self.agent_model_maps.len(),
269 total_sandbox_profiles: self.sandbox_profiles.len(),
270 capability_counts,
271 }
272 }
273}
274
275#[derive(Debug, Clone)]
277pub struct CatalogStatistics {
278 pub total_models: usize,
279 pub models_with_gpu: usize,
280 pub total_agents_with_mappings: usize,
281 pub total_sandbox_profiles: usize,
282 pub capability_counts: HashMap<ModelCapability, usize>,
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use crate::config::{
289 GpuRequirements, Model, ModelAllowListConfig, ModelCapability, ModelProvider,
290 ModelResourceRequirements, SandboxProfile,
291 };
292 use std::collections::HashMap;
293 use std::path::PathBuf;
294
295 fn create_test_model(id: &str, capabilities: Vec<ModelCapability>) -> Model {
296 Model {
297 id: id.to_string(),
298 name: format!("Test Model {}", id),
299 provider: ModelProvider::LocalFile {
300 file_path: PathBuf::from("/tmp/test.gguf"),
301 },
302 capabilities,
303 resource_requirements: ModelResourceRequirements {
304 min_memory_mb: 1024,
305 preferred_cpu_cores: 2.0,
306 gpu_requirements: None,
307 },
308 }
309 }
310
311 fn create_test_model_with_memory(
312 id: &str,
313 capabilities: Vec<ModelCapability>,
314 memory_mb: u64,
315 ) -> Model {
316 Model {
317 id: id.to_string(),
318 name: format!("Test Model {}", id),
319 provider: ModelProvider::LocalFile {
320 file_path: PathBuf::from("/tmp/test.gguf"),
321 },
322 capabilities,
323 resource_requirements: ModelResourceRequirements {
324 min_memory_mb: memory_mb,
325 preferred_cpu_cores: 2.0,
326 gpu_requirements: None,
327 },
328 }
329 }
330
331 fn create_test_model_with_gpu(
332 id: &str,
333 capabilities: Vec<ModelCapability>,
334 gpu_vram_mb: u64,
335 ) -> Model {
336 Model {
337 id: id.to_string(),
338 name: format!("Test Model {}", id),
339 provider: ModelProvider::LocalFile {
340 file_path: PathBuf::from("/tmp/test.gguf"),
341 },
342 capabilities,
343 resource_requirements: ModelResourceRequirements {
344 min_memory_mb: 1024,
345 preferred_cpu_cores: 2.0,
346 gpu_requirements: Some(GpuRequirements {
347 min_vram_mb: gpu_vram_mb,
348 compute_capability: "7.0".to_string(),
349 }),
350 },
351 }
352 }
353
354 fn create_test_slm_config() -> Slm {
355 let mut sandbox_profiles = HashMap::new();
356 sandbox_profiles.insert("secure".to_string(), SandboxProfile::secure_default());
357 sandbox_profiles.insert("standard".to_string(), SandboxProfile::standard_default());
358
359 let models = vec![
360 create_test_model("model1", vec![ModelCapability::TextGeneration]),
361 create_test_model("model2", vec![ModelCapability::CodeGeneration]),
362 ];
363
364 let mut agent_model_maps = HashMap::new();
365 agent_model_maps.insert("agent1".to_string(), vec!["model1".to_string()]);
366
367 Slm {
368 enabled: true,
369 model_allow_lists: ModelAllowListConfig {
370 global_models: models,
371 agent_model_maps,
372 allow_runtime_overrides: false,
373 },
374 sandbox_profiles,
375 default_sandbox_profile: "secure".to_string(),
376 }
377 }
378
379 fn create_complex_slm_config() -> Slm {
380 let mut sandbox_profiles = HashMap::new();
381 sandbox_profiles.insert("secure".to_string(), SandboxProfile::secure_default());
382 sandbox_profiles.insert("standard".to_string(), SandboxProfile::standard_default());
383
384 let models = vec![
385 create_test_model_with_memory(
386 "small_model",
387 vec![ModelCapability::TextGeneration],
388 512,
389 ),
390 create_test_model_with_memory(
391 "medium_model",
392 vec![ModelCapability::TextGeneration, ModelCapability::Reasoning],
393 1024,
394 ),
395 create_test_model_with_memory(
396 "large_model",
397 vec![
398 ModelCapability::TextGeneration,
399 ModelCapability::CodeGeneration,
400 ],
401 2048,
402 ),
403 create_test_model_with_gpu(
404 "gpu_model",
405 vec![ModelCapability::TextGeneration, ModelCapability::Embeddings],
406 4096,
407 ),
408 create_test_model(
409 "multi_cap_model",
410 vec![
411 ModelCapability::TextGeneration,
412 ModelCapability::CodeGeneration,
413 ModelCapability::Reasoning,
414 ModelCapability::ToolUse,
415 ],
416 ),
417 ];
418
419 let mut agent_model_maps = HashMap::new();
420 agent_model_maps.insert(
421 "text_agent".to_string(),
422 vec!["small_model".to_string(), "medium_model".to_string()],
423 );
424 agent_model_maps.insert(
425 "code_agent".to_string(),
426 vec!["large_model".to_string(), "multi_cap_model".to_string()],
427 );
428 agent_model_maps.insert(
429 "restricted_agent".to_string(),
430 vec!["small_model".to_string()],
431 );
432
433 Slm {
434 enabled: true,
435 model_allow_lists: ModelAllowListConfig {
436 global_models: models,
437 agent_model_maps,
438 allow_runtime_overrides: true,
439 },
440 sandbox_profiles,
441 default_sandbox_profile: "secure".to_string(),
442 }
443 }
444
445 #[test]
446 fn test_catalog_creation() {
447 let config = create_test_slm_config();
448 let catalog = ModelCatalog::new(config).unwrap();
449
450 assert_eq!(catalog.list_models().len(), 2);
451 assert!(catalog.get_model("model1").is_some());
452 assert!(catalog.get_model("model2").is_some());
453 assert!(catalog.get_model("nonexistent").is_none());
454 }
455
456 #[test]
457 fn test_catalog_creation_with_complex_config() {
458 let config = create_complex_slm_config();
459 let catalog = ModelCatalog::new(config).unwrap();
460
461 assert_eq!(catalog.list_models().len(), 5);
462 assert!(catalog.allows_runtime_overrides());
463 }
464
465 #[test]
466 fn test_agent_model_access() {
467 let config = create_test_slm_config();
468 let catalog = ModelCatalog::new(config).unwrap();
469
470 let agent1_models = catalog.get_models_for_agent("agent1");
471 assert_eq!(agent1_models.len(), 1);
472 assert_eq!(agent1_models[0].id, "model1");
473
474 let agent2_models = catalog.get_models_for_agent("agent2");
476 assert_eq!(agent2_models.len(), 2);
477 }
478
479 #[test]
480 fn test_agent_model_access_complex() {
481 let config = create_complex_slm_config();
482 let catalog = ModelCatalog::new(config).unwrap();
483
484 let text_agent_models = catalog.get_models_for_agent("text_agent");
486 assert_eq!(text_agent_models.len(), 2);
487 let model_ids: Vec<&str> = text_agent_models.iter().map(|m| m.id.as_str()).collect();
488 assert!(model_ids.contains(&"small_model"));
489 assert!(model_ids.contains(&"medium_model"));
490
491 let code_agent_models = catalog.get_models_for_agent("code_agent");
493 assert_eq!(code_agent_models.len(), 2);
494 let code_model_ids: Vec<&str> = code_agent_models.iter().map(|m| m.id.as_str()).collect();
495 assert!(code_model_ids.contains(&"large_model"));
496 assert!(code_model_ids.contains(&"multi_cap_model"));
497
498 let restricted_models = catalog.get_models_for_agent("restricted_agent");
500 assert_eq!(restricted_models.len(), 1);
501 assert_eq!(restricted_models[0].id, "small_model");
502
503 let unmapped_models = catalog.get_models_for_agent("unmapped_agent");
505 assert_eq!(unmapped_models.len(), 5);
506 }
507
508 #[test]
509 fn test_capability_filtering() {
510 let config = create_test_slm_config();
511 let catalog = ModelCatalog::new(config).unwrap();
512
513 let text_models = catalog.get_models_with_capability(&ModelCapability::TextGeneration);
514 assert_eq!(text_models.len(), 1);
515 assert_eq!(text_models[0].id, "model1");
516
517 let code_models = catalog.get_models_with_capability(&ModelCapability::CodeGeneration);
518 assert_eq!(code_models.len(), 1);
519 assert_eq!(code_models[0].id, "model2");
520
521 let embedding_models = catalog.get_models_with_capability(&ModelCapability::Embeddings);
523 assert_eq!(embedding_models.len(), 0);
524 }
525
526 #[test]
527 fn test_capability_filtering_complex() {
528 let config = create_complex_slm_config();
529 let catalog = ModelCatalog::new(config).unwrap();
530
531 let text_models = catalog.get_models_with_capability(&ModelCapability::TextGeneration);
533 assert_eq!(text_models.len(), 5); let code_models = catalog.get_models_with_capability(&ModelCapability::CodeGeneration);
537 assert_eq!(code_models.len(), 2); let reasoning_models = catalog.get_models_with_capability(&ModelCapability::Reasoning);
541 assert_eq!(reasoning_models.len(), 2); let tool_models = catalog.get_models_with_capability(&ModelCapability::ToolUse);
545 assert_eq!(tool_models.len(), 1); let embedding_models = catalog.get_models_with_capability(&ModelCapability::Embeddings);
549 assert_eq!(embedding_models.len(), 1); }
551
552 #[test]
553 fn test_sandbox_profile_access() {
554 let config = create_test_slm_config();
555 let catalog = ModelCatalog::new(config).unwrap();
556
557 let default_profile = catalog.get_default_sandbox_profile();
559 assert!(default_profile.is_some());
560
561 let secure_profile = catalog.get_sandbox_profile("secure");
563 assert!(secure_profile.is_some());
564
565 let standard_profile = catalog.get_sandbox_profile("standard");
566 assert!(standard_profile.is_some());
567
568 let nonexistent_profile = catalog.get_sandbox_profile("nonexistent");
569 assert!(nonexistent_profile.is_none());
570
571 let all_profiles = catalog.list_sandbox_profiles();
573 assert_eq!(all_profiles.len(), 2);
574 }
575
576 #[test]
577 fn test_model_requirements_access() {
578 let config = create_complex_slm_config();
579 let catalog = ModelCatalog::new(config).unwrap();
580
581 let small_model_req = catalog.get_model_requirements("small_model");
583 assert!(small_model_req.is_some());
584 assert_eq!(small_model_req.unwrap().min_memory_mb, 512);
585
586 let gpu_model_req = catalog.get_model_requirements("gpu_model");
587 assert!(gpu_model_req.is_some());
588 assert!(gpu_model_req.unwrap().gpu_requirements.is_some());
589
590 let nonexistent_req = catalog.get_model_requirements("nonexistent");
592 assert!(nonexistent_req.is_none());
593 }
594
595 #[test]
596 fn test_find_best_model_for_requirements() {
597 let config = create_complex_slm_config();
598 let catalog = ModelCatalog::new(config).unwrap();
599
600 let text_model = catalog.find_best_model_for_requirements(
602 &[ModelCapability::TextGeneration],
603 None,
604 None,
605 );
606 assert!(text_model.is_some());
607 assert_eq!(text_model.unwrap().id, "small_model");
609
610 let code_model = catalog.find_best_model_for_requirements(
612 &[ModelCapability::CodeGeneration],
613 None,
614 None,
615 );
616 assert!(code_model.is_some());
617 assert_eq!(code_model.unwrap().id, "multi_cap_model");
619
620 let multi_cap_model = catalog.find_best_model_for_requirements(
622 &[
623 ModelCapability::TextGeneration,
624 ModelCapability::Reasoning,
625 ModelCapability::ToolUse,
626 ],
627 None,
628 None,
629 );
630 assert!(multi_cap_model.is_some());
631 assert_eq!(multi_cap_model.unwrap().id, "multi_cap_model");
633
634 let constrained_model = catalog.find_best_model_for_requirements(
636 &[ModelCapability::TextGeneration],
637 Some(1000), None,
639 );
640 assert!(constrained_model.is_some());
641 assert_eq!(constrained_model.unwrap().id, "small_model");
642
643 let no_model = catalog.find_best_model_for_requirements(
645 &[ModelCapability::TextGeneration],
646 Some(100), None,
648 );
649 assert!(no_model.is_none());
650 }
651
652 #[test]
653 fn test_find_best_model_for_agent() {
654 let config = create_complex_slm_config();
655 let catalog = ModelCatalog::new(config).unwrap();
656
657 let text_agent_model = catalog.find_best_model_for_requirements(
659 &[ModelCapability::TextGeneration],
660 None,
661 Some("text_agent"),
662 );
663 assert!(text_agent_model.is_some());
664 assert_eq!(text_agent_model.unwrap().id, "small_model");
666
667 let code_agent_model = catalog.find_best_model_for_requirements(
669 &[ModelCapability::CodeGeneration],
670 None,
671 Some("code_agent"),
672 );
673 assert!(code_agent_model.is_some());
674 assert_eq!(code_agent_model.unwrap().id, "multi_cap_model");
676
677 let restricted_model = catalog.find_best_model_for_requirements(
679 &[ModelCapability::TextGeneration],
680 None,
681 Some("restricted_agent"),
682 );
683 assert!(restricted_model.is_some());
684 assert_eq!(restricted_model.unwrap().id, "small_model");
685
686 let impossible_model = catalog.find_best_model_for_requirements(
688 &[ModelCapability::CodeGeneration],
689 None,
690 Some("restricted_agent"), );
692 assert!(impossible_model.is_none());
693 }
694
695 #[test]
696 fn test_validate_model_access() {
697 let config = create_complex_slm_config();
698 let catalog = ModelCatalog::new(config).unwrap();
699
700 let valid_access = catalog.validate_model_access("small_model", "text_agent");
702 assert!(valid_access.is_ok());
703
704 let invalid_access = catalog.validate_model_access("large_model", "text_agent");
706 assert!(invalid_access.is_err());
707 if let Err(ModelCatalogError::InvalidConfig { reason }) = invalid_access {
708 assert!(reason.contains("does not have access"));
709 }
710
711 let nonexistent_access = catalog.validate_model_access("nonexistent_model", "text_agent");
713 assert!(nonexistent_access.is_err());
714 if let Err(ModelCatalogError::ModelNotFound { id }) = nonexistent_access {
715 assert_eq!(id, "nonexistent_model");
716 }
717
718 let unmapped_access = catalog.validate_model_access("large_model", "unmapped_agent");
720 assert!(unmapped_access.is_ok());
721 }
722
723 #[test]
724 fn test_catalog_statistics() {
725 let config = create_complex_slm_config();
726 let catalog = ModelCatalog::new(config).unwrap();
727
728 let stats = catalog.get_statistics();
729
730 assert_eq!(stats.total_models, 5);
731 assert_eq!(stats.models_with_gpu, 1); assert_eq!(stats.total_agents_with_mappings, 3); assert_eq!(stats.total_sandbox_profiles, 2);
734
735 assert_eq!(stats.capability_counts[&ModelCapability::TextGeneration], 5); assert_eq!(stats.capability_counts[&ModelCapability::CodeGeneration], 2); assert_eq!(stats.capability_counts[&ModelCapability::Reasoning], 2); assert_eq!(stats.capability_counts[&ModelCapability::ToolUse], 1); assert_eq!(stats.capability_counts[&ModelCapability::Embeddings], 1); }
742
743 #[test]
744 fn test_validation_errors() {
745 let mut config = create_test_slm_config();
746
747 config.default_sandbox_profile = "nonexistent".to_string();
749 let result = ModelCatalog::new(config);
750 assert!(matches!(
751 result,
752 Err(ModelCatalogError::SandboxProfileNotFound { .. })
753 ));
754 }
755
756 #[test]
757 fn test_validation_duplicate_model_ids() {
758 let mut config = create_test_slm_config();
759
760 config
762 .model_allow_lists
763 .global_models
764 .push(create_test_model(
765 "model1",
766 vec![ModelCapability::Reasoning],
767 ));
768
769 let result = ModelCatalog::new(config);
770 assert!(matches!(
771 result,
772 Err(ModelCatalogError::InvalidConfig { .. })
773 ));
774 }
775
776 #[test]
777 fn test_validation_invalid_agent_model_mapping() {
778 let mut config = create_test_slm_config();
779
780 config.model_allow_lists.agent_model_maps.insert(
782 "invalid_agent".to_string(),
783 vec!["nonexistent_model".to_string()],
784 );
785
786 let result = ModelCatalog::new(config);
787 assert!(matches!(
788 result,
789 Err(ModelCatalogError::InvalidConfig { .. })
790 ));
791 }
792
793 #[test]
794 fn test_empty_catalog() {
795 let mut config = create_test_slm_config();
796 config.model_allow_lists.global_models.clear();
797 config.model_allow_lists.agent_model_maps.clear();
798
799 let catalog = ModelCatalog::new(config).unwrap();
800
801 assert_eq!(catalog.list_models().len(), 0);
802 assert_eq!(catalog.get_models_for_agent("any_agent").len(), 0);
803 assert_eq!(
804 catalog
805 .get_models_with_capability(&ModelCapability::TextGeneration)
806 .len(),
807 0
808 );
809
810 let stats = catalog.get_statistics();
811 assert_eq!(stats.total_models, 0);
812 assert_eq!(stats.models_with_gpu, 0);
813 assert_eq!(stats.total_agents_with_mappings, 0);
814 }
815
816 #[test]
817 fn test_model_provider_types() {
818 let local_model = Model {
819 id: "local".to_string(),
820 name: "Local Model".to_string(),
821 provider: ModelProvider::LocalFile {
822 file_path: PathBuf::from("/models/local.gguf"),
823 },
824 capabilities: vec![ModelCapability::TextGeneration],
825 resource_requirements: ModelResourceRequirements {
826 min_memory_mb: 1024,
827 preferred_cpu_cores: 2.0,
828 gpu_requirements: None,
829 },
830 };
831
832 let hf_model = Model {
833 id: "huggingface".to_string(),
834 name: "HuggingFace Model".to_string(),
835 provider: ModelProvider::HuggingFace {
836 model_path: "microsoft/DialoGPT-medium".to_string(),
837 },
838 capabilities: vec![ModelCapability::TextGeneration],
839 resource_requirements: ModelResourceRequirements {
840 min_memory_mb: 2048,
841 preferred_cpu_cores: 4.0,
842 gpu_requirements: Some(GpuRequirements {
843 min_vram_mb: 4096,
844 compute_capability: "7.0".to_string(),
845 }),
846 },
847 };
848
849 let openai_model = Model {
850 id: "openai".to_string(),
851 name: "OpenAI Model".to_string(),
852 provider: ModelProvider::OpenAI {
853 model_name: "gpt-3.5-turbo".to_string(),
854 },
855 capabilities: vec![ModelCapability::TextGeneration, ModelCapability::Reasoning],
856 resource_requirements: ModelResourceRequirements {
857 min_memory_mb: 0, preferred_cpu_cores: 0.0,
859 gpu_requirements: None,
860 },
861 };
862
863 let mut config = create_test_slm_config();
864 config.model_allow_lists.global_models = vec![local_model, hf_model, openai_model];
865 config.model_allow_lists.agent_model_maps.clear();
866
867 let catalog = ModelCatalog::new(config).unwrap();
868 assert_eq!(catalog.list_models().len(), 3);
869
870 assert!(catalog.get_model("local").is_some());
872 assert!(catalog.get_model("huggingface").is_some());
873 assert!(catalog.get_model("openai").is_some());
874 }
875
876 #[test]
877 fn test_runtime_overrides_setting() {
878 let mut config = create_test_slm_config();
879 config.model_allow_lists.allow_runtime_overrides = true;
880
881 let catalog = ModelCatalog::new(config).unwrap();
882 assert!(catalog.allows_runtime_overrides());
883
884 let mut config_no_overrides = create_test_slm_config();
885 config_no_overrides
886 .model_allow_lists
887 .allow_runtime_overrides = false;
888
889 let catalog_no_overrides = ModelCatalog::new(config_no_overrides).unwrap();
890 assert!(!catalog_no_overrides.allows_runtime_overrides());
891 }
892}