1use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SearchConfig {
13 #[serde(default)]
15 pub backend: BackendConfig,
16
17 #[serde(default)]
19 pub embedding: EmbeddingConfig,
20
21 #[serde(default)]
23 pub retrieval: RetrievalConfig,
24
25 #[serde(default)]
27 pub reranker: RerankerConfig,
28
29 #[serde(default)]
31 pub context: ContextConfig,
32
33 #[serde(default)]
35 pub file: Option<FileConfig>,
36
37 #[serde(default)]
39 pub qdrant: Option<QdrantConfig>,
40
41 #[serde(default)]
43 pub index: IndexConfig,
44
45 #[serde(default)]
47 pub ai_ingestion: AiIngestionConfig,
48}
49
50impl Default for SearchConfig {
51 fn default() -> Self {
52 Self {
53 backend: BackendConfig::default(),
54 embedding: EmbeddingConfig::default(),
55 retrieval: RetrievalConfig::default(),
56 reranker: RerankerConfig::default(),
57 context: ContextConfig::default(),
58 file: None,
59 qdrant: None,
60 index: IndexConfig::default(),
61 ai_ingestion: AiIngestionConfig::default(),
62 }
63 }
64}
65
66impl SearchConfig {
67 pub fn from_toml_file(path: &std::path::Path) -> Result<Self> {
69 let content = std::fs::read_to_string(path)
70 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
71 Self::from_toml(&content)
72 }
73
74 pub fn from_toml(content: &str) -> Result<Self> {
78 let is_wrapped = content.contains("[search]") || content.contains("[search.");
80
81 if is_wrapped {
82 #[derive(Deserialize)]
84 struct Wrapper {
85 #[serde(default)]
86 search: Option<SearchConfig>,
87 }
88
89 let wrapper: Wrapper = toml::from_str(content)
90 .context("Failed to parse TOML config (wrapped format)")?;
91
92 Ok(wrapper.search.unwrap_or_default())
93 } else {
94 toml::from_str::<SearchConfig>(content)
96 .context("Failed to parse TOML config (unwrapped format)")
97 }
98 }
99
100 pub fn with_env_overrides(mut self) -> Self {
102 if let Ok(val) = std::env::var("SKILL_SEARCH_BACKEND") {
104 self.backend.backend_type = val.parse().unwrap_or_default();
105 }
106
107 if let Ok(val) = std::env::var("SKILL_EMBEDDING_PROVIDER") {
109 self.embedding.provider = val;
110 }
111 if let Ok(val) = std::env::var("SKILL_EMBEDDING_MODEL") {
112 self.embedding.model = val;
113 }
114 if let Ok(val) = std::env::var("SKILL_EMBEDDING_DIMENSIONS") {
115 if let Ok(dims) = val.parse() {
116 self.embedding.dimensions = dims;
117 }
118 }
119
120 if let Ok(val) = std::env::var("SKILL_SEARCH_ENABLE_HYBRID") {
122 self.retrieval.enable_hybrid = val.parse().unwrap_or(true);
123 }
124 if let Ok(val) = std::env::var("SKILL_SEARCH_DENSE_WEIGHT") {
125 if let Ok(weight) = val.parse() {
126 self.retrieval.dense_weight = weight;
127 }
128 }
129 if let Ok(val) = std::env::var("SKILL_SEARCH_TOP_K") {
130 if let Ok(k) = val.parse() {
131 self.retrieval.final_k = k;
132 }
133 }
134
135 if let Ok(val) = std::env::var("SKILL_RERANKER_ENABLED") {
137 self.reranker.enabled = val.parse().unwrap_or(false);
138 }
139 if let Ok(val) = std::env::var("SKILL_RERANKER_MODEL") {
140 self.reranker.model = val;
141 }
142
143 if let Ok(val) = std::env::var("SKILL_CONTEXT_MAX_TOKENS") {
145 if let Ok(tokens) = val.parse() {
146 self.context.max_total_tokens = tokens;
147 }
148 }
149
150 if let Ok(url) = std::env::var("QDRANT_URL") {
152 let qdrant = self.qdrant.get_or_insert_with(QdrantConfig::default);
153 qdrant.url = url;
154 }
155 if let Ok(key) = std::env::var("QDRANT_API_KEY") {
156 let qdrant = self.qdrant.get_or_insert_with(QdrantConfig::default);
157 qdrant.api_key = Some(key);
158 }
159
160 if let Ok(val) = std::env::var("SKILL_AI_INGESTION_ENABLED") {
162 self.ai_ingestion.enabled = val.parse().unwrap_or(false);
163 }
164 if let Ok(val) = std::env::var("SKILL_AI_INGESTION_PROVIDER") {
165 self.ai_ingestion.provider = val.parse().unwrap_or_default();
166 }
167 if let Ok(val) = std::env::var("SKILL_AI_INGESTION_MODEL") {
168 self.ai_ingestion.model = val;
169 }
170 if let Ok(val) = std::env::var("SKILL_AI_EXAMPLES_PER_TOOL") {
171 if let Ok(n) = val.parse() {
172 self.ai_ingestion.examples_per_tool = n;
173 }
174 }
175 if let Ok(val) = std::env::var("OLLAMA_HOST") {
176 self.ai_ingestion.ollama.host = val;
177 }
178 if let Ok(_) = std::env::var("OPENAI_API_KEY") {
179 self.ai_ingestion.openai.api_key_env = Some("OPENAI_API_KEY".to_string());
180 }
181 if let Ok(_) = std::env::var("ANTHROPIC_API_KEY") {
182 self.ai_ingestion.anthropic.api_key_env = Some("ANTHROPIC_API_KEY".to_string());
183 }
184
185 self
186 }
187
188 pub fn validate(&self) -> Result<()> {
190 if self.embedding.dimensions == 0 {
192 anyhow::bail!("Embedding dimensions must be > 0");
193 }
194
195 if self.retrieval.enable_hybrid {
197 let total_weight = self.retrieval.dense_weight + self.retrieval.sparse_weight;
198 if (total_weight - 1.0).abs() > 0.01 {
199 anyhow::bail!("Dense and sparse weights should sum to 1.0");
200 }
201 }
202
203 if self.retrieval.final_k > self.retrieval.rerank_k {
205 anyhow::bail!("final_k cannot be greater than rerank_k");
206 }
207 if self.retrieval.rerank_k > self.retrieval.first_stage_k {
208 anyhow::bail!("rerank_k cannot be greater than first_stage_k");
209 }
210
211 if self.context.max_tokens_per_result > self.context.max_total_tokens {
213 anyhow::bail!("max_tokens_per_result cannot exceed max_total_tokens");
214 }
215
216 if matches!(self.backend.backend_type, BackendType::File) {
218 }
221
222 if matches!(self.backend.backend_type, BackendType::Qdrant) {
224 if self.qdrant.is_none() {
225 anyhow::bail!("Qdrant configuration required when backend = 'qdrant'");
226 }
227 }
228
229 if self.ai_ingestion.enabled {
231 if self.ai_ingestion.examples_per_tool == 0 {
232 anyhow::bail!("examples_per_tool must be > 0 when AI ingestion is enabled");
233 }
234 if self.ai_ingestion.timeout_secs == 0 {
235 anyhow::bail!("timeout_secs must be > 0 when AI ingestion is enabled");
236 }
237 }
238
239 Ok(())
240 }
241}
242
243#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
245#[serde(rename_all = "lowercase")]
246pub enum BackendType {
247 #[default]
249 File,
250 InMemory,
252 Qdrant,
254}
255
256impl std::str::FromStr for BackendType {
257 type Err = anyhow::Error;
258
259 fn from_str(s: &str) -> Result<Self, Self::Err> {
260 match s.to_lowercase().as_str() {
261 "file" => Ok(Self::File),
262 "in-memory" | "inmemory" | "memory" => Ok(Self::InMemory),
263 "qdrant" => Ok(Self::Qdrant),
264 _ => anyhow::bail!("Unknown backend type: {}. Options: file, in-memory, qdrant", s),
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct BackendConfig {
272 #[serde(default, rename = "type")]
274 pub backend_type: BackendType,
275}
276
277impl Default for BackendConfig {
278 fn default() -> Self {
279 Self {
280 backend_type: BackendType::default(),
281 }
282 }
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct EmbeddingConfig {
288 #[serde(default = "default_embedding_provider")]
290 pub provider: String,
291
292 #[serde(default = "default_embedding_model")]
294 pub model: String,
295
296 #[serde(default = "default_embedding_dimensions")]
298 pub dimensions: usize,
299
300 #[serde(default = "default_batch_size")]
302 pub batch_size: usize,
303
304 pub openai_api_key: Option<String>,
306
307 pub ollama_host: Option<String>,
309}
310
311fn default_embedding_provider() -> String { "fastembed".to_string() }
312fn default_embedding_model() -> String { "all-minilm".to_string() }
313fn default_embedding_dimensions() -> usize { 384 }
314fn default_batch_size() -> usize { 32 }
315
316impl Default for EmbeddingConfig {
317 fn default() -> Self {
318 Self {
319 provider: default_embedding_provider(),
320 model: default_embedding_model(),
321 dimensions: default_embedding_dimensions(),
322 batch_size: default_batch_size(),
323 openai_api_key: None,
324 ollama_host: None,
325 }
326 }
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct RetrievalConfig {
332 #[serde(default = "default_enable_hybrid")]
334 pub enable_hybrid: bool,
335
336 #[serde(default = "default_dense_weight")]
338 pub dense_weight: f32,
339
340 #[serde(default = "default_sparse_weight")]
342 pub sparse_weight: f32,
343
344 #[serde(default = "default_first_stage_k")]
346 pub first_stage_k: usize,
347
348 #[serde(default = "default_rerank_k")]
350 pub rerank_k: usize,
351
352 #[serde(default = "default_final_k")]
354 pub final_k: usize,
355
356 #[serde(default)]
358 pub fusion_method: FusionMethod,
359
360 #[serde(default = "default_rrf_k")]
362 pub rrf_k: f32,
363}
364
365fn default_enable_hybrid() -> bool { true }
366fn default_dense_weight() -> f32 { 0.7 }
367fn default_sparse_weight() -> f32 { 0.3 }
368fn default_first_stage_k() -> usize { 100 }
369fn default_rerank_k() -> usize { 20 }
370fn default_final_k() -> usize { 5 }
371fn default_rrf_k() -> f32 { 60.0 }
372
373impl Default for RetrievalConfig {
374 fn default() -> Self {
375 Self {
376 enable_hybrid: default_enable_hybrid(),
377 dense_weight: default_dense_weight(),
378 sparse_weight: default_sparse_weight(),
379 first_stage_k: default_first_stage_k(),
380 rerank_k: default_rerank_k(),
381 final_k: default_final_k(),
382 fusion_method: FusionMethod::default(),
383 rrf_k: default_rrf_k(),
384 }
385 }
386}
387
388#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
390#[serde(rename_all = "snake_case")]
391pub enum FusionMethod {
392 #[default]
394 ReciprocalRank,
395 WeightedSum,
397 MaxScore,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct RerankerConfig {
404 #[serde(default)]
406 pub enabled: bool,
407
408 #[serde(default = "default_reranker_provider")]
410 pub provider: String,
411
412 #[serde(default = "default_reranker_model")]
414 pub model: String,
415
416 #[serde(default = "default_max_rerank_docs")]
418 pub max_documents: usize,
419
420 pub cohere_api_key: Option<String>,
422}
423
424fn default_reranker_provider() -> String { "fastembed".to_string() }
425fn default_reranker_model() -> String { "bge-reranker-base".to_string() }
426fn default_max_rerank_docs() -> usize { 50 }
427
428impl Default for RerankerConfig {
429 fn default() -> Self {
430 Self {
431 enabled: false,
432 provider: default_reranker_provider(),
433 model: default_reranker_model(),
434 max_documents: default_max_rerank_docs(),
435 cohere_api_key: None,
436 }
437 }
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
442pub struct ContextConfig {
443 #[serde(default = "default_max_tokens_per_result")]
445 pub max_tokens_per_result: usize,
446
447 #[serde(default = "default_max_total_tokens")]
449 pub max_total_tokens: usize,
450
451 #[serde(default)]
453 pub include_examples: bool,
454
455 #[serde(default)]
457 pub compression: CompressionStrategy,
458}
459
460fn default_max_tokens_per_result() -> usize { 200 }
461fn default_max_total_tokens() -> usize { 800 }
462
463impl Default for ContextConfig {
464 fn default() -> Self {
465 Self {
466 max_tokens_per_result: default_max_tokens_per_result(),
467 max_total_tokens: default_max_total_tokens(),
468 include_examples: false,
469 compression: CompressionStrategy::default(),
470 }
471 }
472}
473
474#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
476#[serde(rename_all = "lowercase")]
477pub enum CompressionStrategy {
478 Extractive,
480 #[default]
482 Template,
483 Progressive,
485 None,
487}
488
489#[derive(Debug, Clone, Serialize, Deserialize)]
491pub struct FileConfig {
492 pub storage_path: Option<PathBuf>,
494
495 #[serde(default)]
497 pub distance_metric: crate::vector_store::DistanceMetric,
498}
499
500impl Default for FileConfig {
501 fn default() -> Self {
502 Self {
503 storage_path: None,
504 distance_metric: crate::vector_store::DistanceMetric::Cosine,
505 }
506 }
507}
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
511pub struct QdrantConfig {
512 #[serde(default = "default_qdrant_url")]
514 pub url: String,
515
516 pub api_key: Option<String>,
518
519 #[serde(default = "default_collection_name")]
521 pub collection: String,
522
523 #[serde(default)]
525 pub tls: bool,
526}
527
528fn default_qdrant_url() -> String { "http://localhost:6334".to_string() }
529fn default_collection_name() -> String { "skill-tools".to_string() }
530
531impl Default for QdrantConfig {
532 fn default() -> Self {
533 Self {
534 url: default_qdrant_url(),
535 api_key: None,
536 collection: default_collection_name(),
537 tls: false,
538 }
539 }
540}
541
542#[derive(Debug, Clone, Serialize, Deserialize)]
544pub struct IndexConfig {
545 pub path: Option<PathBuf>,
547
548 #[serde(default = "default_index_on_startup")]
550 pub index_on_startup: bool,
551
552 #[serde(default)]
554 pub watch_for_changes: bool,
555}
556
557fn default_index_on_startup() -> bool { true }
558
559impl Default for IndexConfig {
560 fn default() -> Self {
561 Self {
562 path: None,
563 index_on_startup: default_index_on_startup(),
564 watch_for_changes: false,
565 }
566 }
567}
568
569#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
575#[serde(rename_all = "lowercase")]
576pub enum AiProvider {
577 #[default]
579 Ollama,
580 OpenAi,
582 Anthropic,
584}
585
586impl std::str::FromStr for AiProvider {
587 type Err = anyhow::Error;
588
589 fn from_str(s: &str) -> Result<Self, Self::Err> {
590 match s.to_lowercase().as_str() {
591 "ollama" => Ok(Self::Ollama),
592 "openai" => Ok(Self::OpenAi),
593 "anthropic" | "claude" => Ok(Self::Anthropic),
594 _ => anyhow::bail!("Unknown AI provider: {}. Options: ollama, openai, anthropic", s),
595 }
596 }
597}
598
599impl std::fmt::Display for AiProvider {
600 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601 match self {
602 AiProvider::Ollama => write!(f, "ollama"),
603 AiProvider::OpenAi => write!(f, "openai"),
604 AiProvider::Anthropic => write!(f, "anthropic"),
605 }
606 }
607}
608
609#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct AiIngestionConfig {
612 #[serde(default)]
614 pub enabled: bool,
615
616 #[serde(default = "default_examples_per_tool")]
618 pub examples_per_tool: usize,
619
620 #[serde(default)]
622 pub provider: AiProvider,
623
624 #[serde(default = "default_ai_model")]
626 pub model: String,
627
628 #[serde(default = "default_validate_examples")]
630 pub validate_examples: bool,
631
632 #[serde(default = "default_stream_progress")]
634 pub stream_progress: bool,
635
636 #[serde(default = "default_cache_examples")]
638 pub cache_examples: bool,
639
640 #[serde(default = "default_timeout_secs")]
642 pub timeout_secs: u64,
643
644 #[serde(default)]
646 pub ollama: OllamaLlmConfig,
647
648 #[serde(default)]
650 pub openai: OpenAiLlmConfig,
651
652 #[serde(default)]
654 pub anthropic: AnthropicLlmConfig,
655}
656
657fn default_examples_per_tool() -> usize { 5 }
658fn default_ai_model() -> String { "llama3.2".to_string() }
659fn default_validate_examples() -> bool { true }
660fn default_stream_progress() -> bool { true }
661fn default_cache_examples() -> bool { true }
662fn default_timeout_secs() -> u64 { 30 }
663
664impl Default for AiIngestionConfig {
665 fn default() -> Self {
666 Self {
667 enabled: false,
668 examples_per_tool: default_examples_per_tool(),
669 provider: AiProvider::default(),
670 model: default_ai_model(),
671 validate_examples: default_validate_examples(),
672 stream_progress: default_stream_progress(),
673 cache_examples: default_cache_examples(),
674 timeout_secs: default_timeout_secs(),
675 ollama: OllamaLlmConfig::default(),
676 openai: OpenAiLlmConfig::default(),
677 anthropic: AnthropicLlmConfig::default(),
678 }
679 }
680}
681
682impl AiIngestionConfig {
683 pub fn get_model(&self) -> &str {
685 if !self.model.is_empty() {
686 return &self.model;
687 }
688 match self.provider {
689 AiProvider::Ollama => &self.ollama.model,
690 AiProvider::OpenAi => &self.openai.model,
691 AiProvider::Anthropic => &self.anthropic.model,
692 }
693 }
694}
695
696#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct OllamaLlmConfig {
699 #[serde(default = "default_ollama_host")]
701 pub host: String,
702
703 #[serde(default = "default_ollama_model")]
705 pub model: String,
706}
707
708fn default_ollama_host() -> String { "http://localhost:11434".to_string() }
709fn default_ollama_model() -> String { "llama3.2".to_string() }
710
711impl Default for OllamaLlmConfig {
712 fn default() -> Self {
713 Self {
714 host: default_ollama_host(),
715 model: default_ollama_model(),
716 }
717 }
718}
719
720#[derive(Debug, Clone, Serialize, Deserialize)]
722pub struct OpenAiLlmConfig {
723 #[serde(default)]
725 pub api_key_env: Option<String>,
726
727 #[serde(default = "default_openai_llm_model")]
729 pub model: String,
730
731 #[serde(default = "default_openai_max_tokens")]
733 pub max_tokens: u32,
734
735 #[serde(default = "default_temperature")]
737 pub temperature: f32,
738}
739
740fn default_openai_llm_model() -> String { "gpt-4o-mini".to_string() }
741fn default_openai_max_tokens() -> u32 { 2048 }
742fn default_temperature() -> f32 { 0.7 }
743
744impl Default for OpenAiLlmConfig {
745 fn default() -> Self {
746 Self {
747 api_key_env: None,
748 model: default_openai_llm_model(),
749 max_tokens: default_openai_max_tokens(),
750 temperature: default_temperature(),
751 }
752 }
753}
754
755#[derive(Debug, Clone, Serialize, Deserialize)]
757pub struct AnthropicLlmConfig {
758 #[serde(default)]
760 pub api_key_env: Option<String>,
761
762 #[serde(default = "default_anthropic_model")]
764 pub model: String,
765
766 #[serde(default = "default_anthropic_max_tokens")]
768 pub max_tokens: u32,
769
770 #[serde(default = "default_temperature")]
772 pub temperature: f32,
773}
774
775fn default_anthropic_model() -> String { "claude-3-haiku-20240307".to_string() }
776fn default_anthropic_max_tokens() -> u32 { 2048 }
777
778impl Default for AnthropicLlmConfig {
779 fn default() -> Self {
780 Self {
781 api_key_env: None,
782 model: default_anthropic_model(),
783 max_tokens: default_anthropic_max_tokens(),
784 temperature: default_temperature(),
785 }
786 }
787}
788
789#[cfg(test)]
790mod tests {
791 use super::*;
792
793 #[test]
794 fn test_default_config() {
795 let config = SearchConfig::default();
796 assert!(config.validate().is_ok());
797 assert_eq!(config.embedding.provider, "fastembed");
798 assert_eq!(config.embedding.dimensions, 384);
799 assert!(config.retrieval.enable_hybrid);
800 assert!(!config.reranker.enabled);
801 }
802
803 #[test]
804 fn test_parse_toml() {
805 let toml = r#"
806[search]
807backend = { type = "qdrant" }
808
809[search.embedding]
810provider = "openai"
811model = "text-embedding-3-small"
812dimensions = 1536
813
814[search.retrieval]
815enable_hybrid = true
816dense_weight = 0.8
817sparse_weight = 0.2
818final_k = 10
819
820[search.reranker]
821enabled = true
822model = "bge-reranker-large"
823
824[search.context]
825max_total_tokens = 1000
826compression = "progressive"
827
828[search.qdrant]
829url = "http://qdrant:6334"
830collection = "my-tools"
831"#;
832
833 let config = SearchConfig::from_toml(toml).unwrap();
834
835 assert!(matches!(config.backend.backend_type, BackendType::Qdrant));
836 assert_eq!(config.embedding.provider, "openai");
837 assert_eq!(config.embedding.dimensions, 1536);
838 assert!((config.retrieval.dense_weight - 0.8).abs() < 0.001);
839 assert_eq!(config.retrieval.final_k, 10);
840 assert!(config.reranker.enabled);
841 assert_eq!(config.reranker.model, "bge-reranker-large");
842 assert!(matches!(config.context.compression, CompressionStrategy::Progressive));
843 assert_eq!(config.qdrant.as_ref().unwrap().url, "http://qdrant:6334");
844 }
845
846 #[test]
847 fn test_validation_weights() {
848 let mut config = SearchConfig::default();
849 config.retrieval.dense_weight = 0.5;
850 config.retrieval.sparse_weight = 0.3; assert!(config.validate().is_err());
853 }
854
855 #[test]
856 fn test_validation_k_values() {
857 let mut config = SearchConfig::default();
858 config.retrieval.final_k = 50;
859 config.retrieval.rerank_k = 20; assert!(config.validate().is_err());
862 }
863
864 #[test]
865 fn test_validation_qdrant_required() {
866 let mut config = SearchConfig::default();
867 config.backend.backend_type = BackendType::Qdrant;
868 config.qdrant = None;
869
870 assert!(config.validate().is_err());
871 }
872
873 #[test]
874 fn test_backend_type_from_str() {
875 assert!(matches!("in-memory".parse::<BackendType>().unwrap(), BackendType::InMemory));
876 assert!(matches!("inmemory".parse::<BackendType>().unwrap(), BackendType::InMemory));
877 assert!(matches!("qdrant".parse::<BackendType>().unwrap(), BackendType::Qdrant));
878 assert!("invalid".parse::<BackendType>().is_err());
879 }
880
881 #[test]
882 fn test_env_overrides() {
883 std::env::set_var("SKILL_SEARCH_BACKEND", "qdrant");
884 std::env::set_var("SKILL_EMBEDDING_DIMENSIONS", "768");
885 std::env::set_var("SKILL_RERANKER_ENABLED", "true");
886 std::env::set_var("QDRANT_URL", "http://custom:6334");
887
888 let config = SearchConfig::default().with_env_overrides();
889
890 assert!(matches!(config.backend.backend_type, BackendType::Qdrant));
891 assert_eq!(config.embedding.dimensions, 768);
892 assert!(config.reranker.enabled);
893 assert_eq!(config.qdrant.as_ref().unwrap().url, "http://custom:6334");
894
895 std::env::remove_var("SKILL_SEARCH_BACKEND");
897 std::env::remove_var("SKILL_EMBEDDING_DIMENSIONS");
898 std::env::remove_var("SKILL_RERANKER_ENABLED");
899 std::env::remove_var("QDRANT_URL");
900 }
901
902 #[test]
903 fn test_minimal_toml() {
904 let toml = r#"
905[search]
906"#;
907
908 let config = SearchConfig::from_toml(toml).unwrap();
909 assert!(config.validate().is_ok());
910 }
911
912 #[test]
913 fn test_empty_file() {
914 let toml = "";
915 let config = SearchConfig::from_toml(toml).unwrap();
916 assert!(config.validate().is_ok());
917 }
918
919 #[test]
920 fn test_ai_ingestion_defaults() {
921 let config = AiIngestionConfig::default();
922 assert!(!config.enabled);
923 assert_eq!(config.examples_per_tool, 5);
924 assert!(matches!(config.provider, AiProvider::Ollama));
925 assert_eq!(config.model, "llama3.2");
926 assert!(config.validate_examples);
927 assert!(config.stream_progress);
928 assert!(config.cache_examples);
929 assert_eq!(config.timeout_secs, 30);
930 }
931
932 #[test]
933 fn test_ai_provider_from_str() {
934 assert!(matches!("ollama".parse::<AiProvider>().unwrap(), AiProvider::Ollama));
935 assert!(matches!("openai".parse::<AiProvider>().unwrap(), AiProvider::OpenAi));
936 assert!(matches!("anthropic".parse::<AiProvider>().unwrap(), AiProvider::Anthropic));
937 assert!(matches!("claude".parse::<AiProvider>().unwrap(), AiProvider::Anthropic));
938 assert!("invalid".parse::<AiProvider>().is_err());
939 }
940
941 #[test]
942 fn test_ai_ingestion_toml_parsing() {
943 let toml = r#"
944[ai_ingestion]
945enabled = true
946examples_per_tool = 3
947provider = "openai"
948model = "gpt-4o"
949validate_examples = false
950stream_progress = true
951timeout_secs = 60
952
953[ai_ingestion.openai]
954model = "gpt-4o-mini"
955max_tokens = 4096
956temperature = 0.5
957"#;
958
959 let config: SearchConfig = toml::from_str(toml).unwrap();
960 assert!(config.ai_ingestion.enabled);
961 assert_eq!(config.ai_ingestion.examples_per_tool, 3);
962 assert!(matches!(config.ai_ingestion.provider, AiProvider::OpenAi));
963 assert_eq!(config.ai_ingestion.model, "gpt-4o");
964 assert!(!config.ai_ingestion.validate_examples);
965 assert_eq!(config.ai_ingestion.timeout_secs, 60);
966 assert_eq!(config.ai_ingestion.openai.model, "gpt-4o-mini");
967 assert_eq!(config.ai_ingestion.openai.max_tokens, 4096);
968 assert!((config.ai_ingestion.openai.temperature - 0.5).abs() < 0.01);
969 }
970
971 #[test]
972 fn test_ai_ingestion_validation() {
973 let mut config = SearchConfig::default();
974 config.ai_ingestion.enabled = true;
975 config.ai_ingestion.examples_per_tool = 0;
976
977 assert!(config.validate().is_err());
978
979 config.ai_ingestion.examples_per_tool = 5;
980 config.ai_ingestion.timeout_secs = 0;
981
982 assert!(config.validate().is_err());
983 }
984
985 #[test]
986 fn test_ai_ingestion_get_model() {
987 let mut config = AiIngestionConfig::default();
988
989 config.model = String::new();
991 config.provider = AiProvider::Ollama;
992 assert_eq!(config.get_model(), "llama3.2");
993
994 config.provider = AiProvider::OpenAi;
995 assert_eq!(config.get_model(), "gpt-4o-mini");
996
997 config.provider = AiProvider::Anthropic;
998 assert_eq!(config.get_model(), "claude-3-haiku-20240307");
999
1000 config.model = "custom-model".to_string();
1002 assert_eq!(config.get_model(), "custom-model");
1003 }
1004}