skill_runtime/
search_config.rs

1//! Configuration schema for RAG search pipeline
2//!
3//! Provides comprehensive configuration for embedding providers, vector backends,
4//! hybrid retrieval, reranking, and context compression.
5
6use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10/// Root search configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SearchConfig {
13    /// Vector store backend
14    #[serde(default)]
15    pub backend: BackendConfig,
16
17    /// Embedding configuration
18    #[serde(default)]
19    pub embedding: EmbeddingConfig,
20
21    /// Retrieval configuration
22    #[serde(default)]
23    pub retrieval: RetrievalConfig,
24
25    /// Reranker configuration
26    #[serde(default)]
27    pub reranker: RerankerConfig,
28
29    /// Context compression configuration
30    #[serde(default)]
31    pub context: ContextConfig,
32
33    /// File-based vector store configuration (if backend = "file")
34    #[serde(default)]
35    pub file: Option<FileConfig>,
36
37    /// Qdrant-specific configuration (if backend = "qdrant")
38    #[serde(default)]
39    pub qdrant: Option<QdrantConfig>,
40
41    /// Index configuration
42    #[serde(default)]
43    pub index: IndexConfig,
44
45    /// AI-powered example generation during ingestion
46    #[serde(default)]
47    pub ai_ingestion: AiIngestionConfig,
48}
49
50impl Default for SearchConfig {
51    fn default() -> Self {
52        Self {
53            backend: BackendConfig::default(),
54            embedding: EmbeddingConfig::default(),
55            retrieval: RetrievalConfig::default(),
56            reranker: RerankerConfig::default(),
57            context: ContextConfig::default(),
58            file: None,
59            qdrant: None,
60            index: IndexConfig::default(),
61            ai_ingestion: AiIngestionConfig::default(),
62        }
63    }
64}
65
66impl SearchConfig {
67    /// Load config from TOML file
68    pub fn from_toml_file(path: &std::path::Path) -> Result<Self> {
69        let content = std::fs::read_to_string(path)
70            .with_context(|| format!("Failed to read config file: {}", path.display()))?;
71        Self::from_toml(&content)
72    }
73
74    /// Parse from TOML string
75    ///
76    /// Supports both wrapped format (with `[search]` section) and unwrapped format.
77    pub fn from_toml(content: &str) -> Result<Self> {
78        // Check if the TOML uses wrapped format (has [search] section)
79        let is_wrapped = content.contains("[search]") || content.contains("[search.");
80
81        if is_wrapped {
82            // Wrapped format (with [search] section)
83            #[derive(Deserialize)]
84            struct Wrapper {
85                #[serde(default)]
86                search: Option<SearchConfig>,
87            }
88
89            let wrapper: Wrapper = toml::from_str(content)
90                .context("Failed to parse TOML config (wrapped format)")?;
91
92            Ok(wrapper.search.unwrap_or_default())
93        } else {
94            // Unwrapped format (direct sections like [embedding], [backend], etc.)
95            toml::from_str::<SearchConfig>(content)
96                .context("Failed to parse TOML config (unwrapped format)")
97        }
98    }
99
100    /// Apply environment variable overrides
101    pub fn with_env_overrides(mut self) -> Self {
102        // Backend
103        if let Ok(val) = std::env::var("SKILL_SEARCH_BACKEND") {
104            self.backend.backend_type = val.parse().unwrap_or_default();
105        }
106
107        // Embedding
108        if let Ok(val) = std::env::var("SKILL_EMBEDDING_PROVIDER") {
109            self.embedding.provider = val;
110        }
111        if let Ok(val) = std::env::var("SKILL_EMBEDDING_MODEL") {
112            self.embedding.model = val;
113        }
114        if let Ok(val) = std::env::var("SKILL_EMBEDDING_DIMENSIONS") {
115            if let Ok(dims) = val.parse() {
116                self.embedding.dimensions = dims;
117            }
118        }
119
120        // Retrieval
121        if let Ok(val) = std::env::var("SKILL_SEARCH_ENABLE_HYBRID") {
122            self.retrieval.enable_hybrid = val.parse().unwrap_or(true);
123        }
124        if let Ok(val) = std::env::var("SKILL_SEARCH_DENSE_WEIGHT") {
125            if let Ok(weight) = val.parse() {
126                self.retrieval.dense_weight = weight;
127            }
128        }
129        if let Ok(val) = std::env::var("SKILL_SEARCH_TOP_K") {
130            if let Ok(k) = val.parse() {
131                self.retrieval.final_k = k;
132            }
133        }
134
135        // Reranker
136        if let Ok(val) = std::env::var("SKILL_RERANKER_ENABLED") {
137            self.reranker.enabled = val.parse().unwrap_or(false);
138        }
139        if let Ok(val) = std::env::var("SKILL_RERANKER_MODEL") {
140            self.reranker.model = val;
141        }
142
143        // Context
144        if let Ok(val) = std::env::var("SKILL_CONTEXT_MAX_TOKENS") {
145            if let Ok(tokens) = val.parse() {
146                self.context.max_total_tokens = tokens;
147            }
148        }
149
150        // Qdrant
151        if let Ok(url) = std::env::var("QDRANT_URL") {
152            let qdrant = self.qdrant.get_or_insert_with(QdrantConfig::default);
153            qdrant.url = url;
154        }
155        if let Ok(key) = std::env::var("QDRANT_API_KEY") {
156            let qdrant = self.qdrant.get_or_insert_with(QdrantConfig::default);
157            qdrant.api_key = Some(key);
158        }
159
160        // AI Ingestion
161        if let Ok(val) = std::env::var("SKILL_AI_INGESTION_ENABLED") {
162            self.ai_ingestion.enabled = val.parse().unwrap_or(false);
163        }
164        if let Ok(val) = std::env::var("SKILL_AI_INGESTION_PROVIDER") {
165            self.ai_ingestion.provider = val.parse().unwrap_or_default();
166        }
167        if let Ok(val) = std::env::var("SKILL_AI_INGESTION_MODEL") {
168            self.ai_ingestion.model = val;
169        }
170        if let Ok(val) = std::env::var("SKILL_AI_EXAMPLES_PER_TOOL") {
171            if let Ok(n) = val.parse() {
172                self.ai_ingestion.examples_per_tool = n;
173            }
174        }
175        if let Ok(val) = std::env::var("OLLAMA_HOST") {
176            self.ai_ingestion.ollama.host = val;
177        }
178        if let Ok(_) = std::env::var("OPENAI_API_KEY") {
179            self.ai_ingestion.openai.api_key_env = Some("OPENAI_API_KEY".to_string());
180        }
181        if let Ok(_) = std::env::var("ANTHROPIC_API_KEY") {
182            self.ai_ingestion.anthropic.api_key_env = Some("ANTHROPIC_API_KEY".to_string());
183        }
184
185        self
186    }
187
188    /// Validate configuration
189    pub fn validate(&self) -> Result<()> {
190        // Validate embedding dimensions
191        if self.embedding.dimensions == 0 {
192            anyhow::bail!("Embedding dimensions must be > 0");
193        }
194
195        // Validate retrieval weights
196        if self.retrieval.enable_hybrid {
197            let total_weight = self.retrieval.dense_weight + self.retrieval.sparse_weight;
198            if (total_weight - 1.0).abs() > 0.01 {
199                anyhow::bail!("Dense and sparse weights should sum to 1.0");
200            }
201        }
202
203        // Validate retrieval k values
204        if self.retrieval.final_k > self.retrieval.rerank_k {
205            anyhow::bail!("final_k cannot be greater than rerank_k");
206        }
207        if self.retrieval.rerank_k > self.retrieval.first_stage_k {
208            anyhow::bail!("rerank_k cannot be greater than first_stage_k");
209        }
210
211        // Validate context tokens
212        if self.context.max_tokens_per_result > self.context.max_total_tokens {
213            anyhow::bail!("max_tokens_per_result cannot exceed max_total_tokens");
214        }
215
216        // Validate File config if using File backend
217        if matches!(self.backend.backend_type, BackendType::File) {
218            // File config is optional (uses default ~/.skill-engine/vectors/store.bin)
219            // No validation needed
220        }
221
222        // Validate Qdrant config if using Qdrant backend
223        if matches!(self.backend.backend_type, BackendType::Qdrant) {
224            if self.qdrant.is_none() {
225                anyhow::bail!("Qdrant configuration required when backend = 'qdrant'");
226            }
227        }
228
229        // Validate AI ingestion config
230        if self.ai_ingestion.enabled {
231            if self.ai_ingestion.examples_per_tool == 0 {
232                anyhow::bail!("examples_per_tool must be > 0 when AI ingestion is enabled");
233            }
234            if self.ai_ingestion.timeout_secs == 0 {
235                anyhow::bail!("timeout_secs must be > 0 when AI ingestion is enabled");
236            }
237        }
238
239        Ok(())
240    }
241}
242
243/// Vector store backend type
244#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
245#[serde(rename_all = "lowercase")]
246pub enum BackendType {
247    /// File-based vector store (default) - persistent local storage
248    #[default]
249    File,
250    /// In-memory vector store - fast but no persistence
251    InMemory,
252    /// Qdrant vector database - production-grade with Docker
253    Qdrant,
254}
255
256impl std::str::FromStr for BackendType {
257    type Err = anyhow::Error;
258
259    fn from_str(s: &str) -> Result<Self, Self::Err> {
260        match s.to_lowercase().as_str() {
261            "file" => Ok(Self::File),
262            "in-memory" | "inmemory" | "memory" => Ok(Self::InMemory),
263            "qdrant" => Ok(Self::Qdrant),
264            _ => anyhow::bail!("Unknown backend type: {}. Options: file, in-memory, qdrant", s),
265        }
266    }
267}
268
269/// Backend configuration
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct BackendConfig {
272    /// Backend type
273    #[serde(default, rename = "type")]
274    pub backend_type: BackendType,
275}
276
277impl Default for BackendConfig {
278    fn default() -> Self {
279        Self {
280            backend_type: BackendType::default(),
281        }
282    }
283}
284
285/// Embedding configuration
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct EmbeddingConfig {
288    /// Embedding provider (fastembed, openai, ollama)
289    #[serde(default = "default_embedding_provider")]
290    pub provider: String,
291
292    /// Model name
293    #[serde(default = "default_embedding_model")]
294    pub model: String,
295
296    /// Embedding dimensions
297    #[serde(default = "default_embedding_dimensions")]
298    pub dimensions: usize,
299
300    /// Batch size for embedding generation
301    #[serde(default = "default_batch_size")]
302    pub batch_size: usize,
303
304    /// OpenAI API key (if provider = "openai")
305    pub openai_api_key: Option<String>,
306
307    /// Ollama host (if provider = "ollama")
308    pub ollama_host: Option<String>,
309}
310
311fn default_embedding_provider() -> String { "fastembed".to_string() }
312fn default_embedding_model() -> String { "all-minilm".to_string() }
313fn default_embedding_dimensions() -> usize { 384 }
314fn default_batch_size() -> usize { 32 }
315
316impl Default for EmbeddingConfig {
317    fn default() -> Self {
318        Self {
319            provider: default_embedding_provider(),
320            model: default_embedding_model(),
321            dimensions: default_embedding_dimensions(),
322            batch_size: default_batch_size(),
323            openai_api_key: None,
324            ollama_host: None,
325        }
326    }
327}
328
329/// Retrieval configuration
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct RetrievalConfig {
332    /// Enable hybrid (dense + sparse) search
333    #[serde(default = "default_enable_hybrid")]
334    pub enable_hybrid: bool,
335
336    /// Weight for dense (vector) search
337    #[serde(default = "default_dense_weight")]
338    pub dense_weight: f32,
339
340    /// Weight for sparse (BM25) search
341    #[serde(default = "default_sparse_weight")]
342    pub sparse_weight: f32,
343
344    /// Number of results for first stage retrieval
345    #[serde(default = "default_first_stage_k")]
346    pub first_stage_k: usize,
347
348    /// Number of results to rerank
349    #[serde(default = "default_rerank_k")]
350    pub rerank_k: usize,
351
352    /// Final number of results to return
353    #[serde(default = "default_final_k")]
354    pub final_k: usize,
355
356    /// Fusion method for hybrid search
357    #[serde(default)]
358    pub fusion_method: FusionMethod,
359
360    /// RRF k parameter (for reciprocal rank fusion)
361    #[serde(default = "default_rrf_k")]
362    pub rrf_k: f32,
363}
364
365fn default_enable_hybrid() -> bool { true }
366fn default_dense_weight() -> f32 { 0.7 }
367fn default_sparse_weight() -> f32 { 0.3 }
368fn default_first_stage_k() -> usize { 100 }
369fn default_rerank_k() -> usize { 20 }
370fn default_final_k() -> usize { 5 }
371fn default_rrf_k() -> f32 { 60.0 }
372
373impl Default for RetrievalConfig {
374    fn default() -> Self {
375        Self {
376            enable_hybrid: default_enable_hybrid(),
377            dense_weight: default_dense_weight(),
378            sparse_weight: default_sparse_weight(),
379            first_stage_k: default_first_stage_k(),
380            rerank_k: default_rerank_k(),
381            final_k: default_final_k(),
382            fusion_method: FusionMethod::default(),
383            rrf_k: default_rrf_k(),
384        }
385    }
386}
387
388/// Fusion method for combining search results
389#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
390#[serde(rename_all = "snake_case")]
391pub enum FusionMethod {
392    /// Reciprocal Rank Fusion (default)
393    #[default]
394    ReciprocalRank,
395    /// Weighted sum of normalized scores
396    WeightedSum,
397    /// Take maximum score
398    MaxScore,
399}
400
401/// Reranker configuration
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct RerankerConfig {
404    /// Enable reranking
405    #[serde(default)]
406    pub enabled: bool,
407
408    /// Reranker provider (fastembed, cohere)
409    #[serde(default = "default_reranker_provider")]
410    pub provider: String,
411
412    /// Reranker model
413    #[serde(default = "default_reranker_model")]
414    pub model: String,
415
416    /// Maximum documents to rerank
417    #[serde(default = "default_max_rerank_docs")]
418    pub max_documents: usize,
419
420    /// Cohere API key (if provider = "cohere")
421    pub cohere_api_key: Option<String>,
422}
423
424fn default_reranker_provider() -> String { "fastembed".to_string() }
425fn default_reranker_model() -> String { "bge-reranker-base".to_string() }
426fn default_max_rerank_docs() -> usize { 50 }
427
428impl Default for RerankerConfig {
429    fn default() -> Self {
430        Self {
431            enabled: false,
432            provider: default_reranker_provider(),
433            model: default_reranker_model(),
434            max_documents: default_max_rerank_docs(),
435            cohere_api_key: None,
436        }
437    }
438}
439
440/// Context compression configuration
441#[derive(Debug, Clone, Serialize, Deserialize)]
442pub struct ContextConfig {
443    /// Maximum tokens per result
444    #[serde(default = "default_max_tokens_per_result")]
445    pub max_tokens_per_result: usize,
446
447    /// Maximum total tokens
448    #[serde(default = "default_max_total_tokens")]
449    pub max_total_tokens: usize,
450
451    /// Include code examples in output
452    #[serde(default)]
453    pub include_examples: bool,
454
455    /// Compression strategy
456    #[serde(default)]
457    pub compression: CompressionStrategy,
458}
459
460fn default_max_tokens_per_result() -> usize { 200 }
461fn default_max_total_tokens() -> usize { 800 }
462
463impl Default for ContextConfig {
464    fn default() -> Self {
465        Self {
466            max_tokens_per_result: default_max_tokens_per_result(),
467            max_total_tokens: default_max_total_tokens(),
468            include_examples: false,
469            compression: CompressionStrategy::default(),
470        }
471    }
472}
473
474/// Compression strategy
475#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
476#[serde(rename_all = "lowercase")]
477pub enum CompressionStrategy {
478    /// Keep first sentence + parameters
479    Extractive,
480    /// Template-based structured format (default)
481    #[default]
482    Template,
483    /// Progressive detail based on rank
484    Progressive,
485    /// No compression
486    None,
487}
488
489/// File-based vector store configuration
490#[derive(Debug, Clone, Serialize, Deserialize)]
491pub struct FileConfig {
492    /// Custom storage directory (defaults to ~/.skill-engine/vectors/store.bin)
493    pub storage_path: Option<PathBuf>,
494
495    /// Distance metric for similarity calculation
496    #[serde(default)]
497    pub distance_metric: crate::vector_store::DistanceMetric,
498}
499
500impl Default for FileConfig {
501    fn default() -> Self {
502        Self {
503            storage_path: None,
504            distance_metric: crate::vector_store::DistanceMetric::Cosine,
505        }
506    }
507}
508
509/// Qdrant-specific configuration
510#[derive(Debug, Clone, Serialize, Deserialize)]
511pub struct QdrantConfig {
512    /// Qdrant URL
513    #[serde(default = "default_qdrant_url")]
514    pub url: String,
515
516    /// API key (optional, for Qdrant Cloud)
517    pub api_key: Option<String>,
518
519    /// Collection name
520    #[serde(default = "default_collection_name")]
521    pub collection: String,
522
523    /// Enable TLS
524    #[serde(default)]
525    pub tls: bool,
526}
527
528fn default_qdrant_url() -> String { "http://localhost:6334".to_string() }
529fn default_collection_name() -> String { "skill-tools".to_string() }
530
531impl Default for QdrantConfig {
532    fn default() -> Self {
533        Self {
534            url: default_qdrant_url(),
535            api_key: None,
536            collection: default_collection_name(),
537            tls: false,
538        }
539    }
540}
541
542/// Index configuration
543#[derive(Debug, Clone, Serialize, Deserialize)]
544pub struct IndexConfig {
545    /// Index directory path
546    pub path: Option<PathBuf>,
547
548    /// Index on startup
549    #[serde(default = "default_index_on_startup")]
550    pub index_on_startup: bool,
551
552    /// Watch for skill changes
553    #[serde(default)]
554    pub watch_for_changes: bool,
555}
556
557fn default_index_on_startup() -> bool { true }
558
559impl Default for IndexConfig {
560    fn default() -> Self {
561        Self {
562            path: None,
563            index_on_startup: default_index_on_startup(),
564            watch_for_changes: false,
565        }
566    }
567}
568
569// =============================================================================
570// AI Ingestion Configuration
571// =============================================================================
572
573/// LLM provider for AI-powered example generation
574#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
575#[serde(rename_all = "lowercase")]
576pub enum AiProvider {
577    /// Ollama for local inference (default)
578    #[default]
579    Ollama,
580    /// OpenAI API
581    OpenAi,
582    /// Anthropic Claude API
583    Anthropic,
584}
585
586impl std::str::FromStr for AiProvider {
587    type Err = anyhow::Error;
588
589    fn from_str(s: &str) -> Result<Self, Self::Err> {
590        match s.to_lowercase().as_str() {
591            "ollama" => Ok(Self::Ollama),
592            "openai" => Ok(Self::OpenAi),
593            "anthropic" | "claude" => Ok(Self::Anthropic),
594            _ => anyhow::bail!("Unknown AI provider: {}. Options: ollama, openai, anthropic", s),
595        }
596    }
597}
598
599impl std::fmt::Display for AiProvider {
600    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601        match self {
602            AiProvider::Ollama => write!(f, "ollama"),
603            AiProvider::OpenAi => write!(f, "openai"),
604            AiProvider::Anthropic => write!(f, "anthropic"),
605        }
606    }
607}
608
609/// AI-powered ingestion configuration
610#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct AiIngestionConfig {
612    /// Enable AI example generation during skill indexing
613    #[serde(default)]
614    pub enabled: bool,
615
616    /// Number of examples to generate per tool
617    #[serde(default = "default_examples_per_tool")]
618    pub examples_per_tool: usize,
619
620    /// LLM provider for generation
621    #[serde(default)]
622    pub provider: AiProvider,
623
624    /// Model name (provider-specific)
625    #[serde(default = "default_ai_model")]
626    pub model: String,
627
628    /// Validate generated examples against tool schema
629    #[serde(default = "default_validate_examples")]
630    pub validate_examples: bool,
631
632    /// Stream generation progress to terminal/MCP
633    #[serde(default = "default_stream_progress")]
634    pub stream_progress: bool,
635
636    /// Cache generated examples (skip regeneration if tool unchanged)
637    #[serde(default = "default_cache_examples")]
638    pub cache_examples: bool,
639
640    /// Timeout per tool generation in seconds
641    #[serde(default = "default_timeout_secs")]
642    pub timeout_secs: u64,
643
644    /// Ollama-specific configuration
645    #[serde(default)]
646    pub ollama: OllamaLlmConfig,
647
648    /// OpenAI-specific configuration
649    #[serde(default)]
650    pub openai: OpenAiLlmConfig,
651
652    /// Anthropic-specific configuration
653    #[serde(default)]
654    pub anthropic: AnthropicLlmConfig,
655}
656
657fn default_examples_per_tool() -> usize { 5 }
658fn default_ai_model() -> String { "llama3.2".to_string() }
659fn default_validate_examples() -> bool { true }
660fn default_stream_progress() -> bool { true }
661fn default_cache_examples() -> bool { true }
662fn default_timeout_secs() -> u64 { 30 }
663
664impl Default for AiIngestionConfig {
665    fn default() -> Self {
666        Self {
667            enabled: false,
668            examples_per_tool: default_examples_per_tool(),
669            provider: AiProvider::default(),
670            model: default_ai_model(),
671            validate_examples: default_validate_examples(),
672            stream_progress: default_stream_progress(),
673            cache_examples: default_cache_examples(),
674            timeout_secs: default_timeout_secs(),
675            ollama: OllamaLlmConfig::default(),
676            openai: OpenAiLlmConfig::default(),
677            anthropic: AnthropicLlmConfig::default(),
678        }
679    }
680}
681
682impl AiIngestionConfig {
683    /// Get the model name for the current provider
684    pub fn get_model(&self) -> &str {
685        if !self.model.is_empty() {
686            return &self.model;
687        }
688        match self.provider {
689            AiProvider::Ollama => &self.ollama.model,
690            AiProvider::OpenAi => &self.openai.model,
691            AiProvider::Anthropic => &self.anthropic.model,
692        }
693    }
694}
695
696/// Ollama LLM configuration
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct OllamaLlmConfig {
699    /// Ollama API host
700    #[serde(default = "default_ollama_host")]
701    pub host: String,
702
703    /// Model to use (if not set in parent config)
704    #[serde(default = "default_ollama_model")]
705    pub model: String,
706}
707
708fn default_ollama_host() -> String { "http://localhost:11434".to_string() }
709fn default_ollama_model() -> String { "llama3.2".to_string() }
710
711impl Default for OllamaLlmConfig {
712    fn default() -> Self {
713        Self {
714            host: default_ollama_host(),
715            model: default_ollama_model(),
716        }
717    }
718}
719
720/// OpenAI LLM configuration
721#[derive(Debug, Clone, Serialize, Deserialize)]
722pub struct OpenAiLlmConfig {
723    /// API key environment variable name (default: OPENAI_API_KEY)
724    #[serde(default)]
725    pub api_key_env: Option<String>,
726
727    /// Model to use (if not set in parent config)
728    #[serde(default = "default_openai_llm_model")]
729    pub model: String,
730
731    /// Max tokens for completion
732    #[serde(default = "default_openai_max_tokens")]
733    pub max_tokens: u32,
734
735    /// Temperature for generation
736    #[serde(default = "default_temperature")]
737    pub temperature: f32,
738}
739
740fn default_openai_llm_model() -> String { "gpt-4o-mini".to_string() }
741fn default_openai_max_tokens() -> u32 { 2048 }
742fn default_temperature() -> f32 { 0.7 }
743
744impl Default for OpenAiLlmConfig {
745    fn default() -> Self {
746        Self {
747            api_key_env: None,
748            model: default_openai_llm_model(),
749            max_tokens: default_openai_max_tokens(),
750            temperature: default_temperature(),
751        }
752    }
753}
754
755/// Anthropic Claude LLM configuration
756#[derive(Debug, Clone, Serialize, Deserialize)]
757pub struct AnthropicLlmConfig {
758    /// API key environment variable name (default: ANTHROPIC_API_KEY)
759    #[serde(default)]
760    pub api_key_env: Option<String>,
761
762    /// Model to use (if not set in parent config)
763    #[serde(default = "default_anthropic_model")]
764    pub model: String,
765
766    /// Max tokens for completion
767    #[serde(default = "default_anthropic_max_tokens")]
768    pub max_tokens: u32,
769
770    /// Temperature for generation
771    #[serde(default = "default_temperature")]
772    pub temperature: f32,
773}
774
775fn default_anthropic_model() -> String { "claude-3-haiku-20240307".to_string() }
776fn default_anthropic_max_tokens() -> u32 { 2048 }
777
778impl Default for AnthropicLlmConfig {
779    fn default() -> Self {
780        Self {
781            api_key_env: None,
782            model: default_anthropic_model(),
783            max_tokens: default_anthropic_max_tokens(),
784            temperature: default_temperature(),
785        }
786    }
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792
793    #[test]
794    fn test_default_config() {
795        let config = SearchConfig::default();
796        assert!(config.validate().is_ok());
797        assert_eq!(config.embedding.provider, "fastembed");
798        assert_eq!(config.embedding.dimensions, 384);
799        assert!(config.retrieval.enable_hybrid);
800        assert!(!config.reranker.enabled);
801    }
802
803    #[test]
804    fn test_parse_toml() {
805        let toml = r#"
806[search]
807backend = { type = "qdrant" }
808
809[search.embedding]
810provider = "openai"
811model = "text-embedding-3-small"
812dimensions = 1536
813
814[search.retrieval]
815enable_hybrid = true
816dense_weight = 0.8
817sparse_weight = 0.2
818final_k = 10
819
820[search.reranker]
821enabled = true
822model = "bge-reranker-large"
823
824[search.context]
825max_total_tokens = 1000
826compression = "progressive"
827
828[search.qdrant]
829url = "http://qdrant:6334"
830collection = "my-tools"
831"#;
832
833        let config = SearchConfig::from_toml(toml).unwrap();
834
835        assert!(matches!(config.backend.backend_type, BackendType::Qdrant));
836        assert_eq!(config.embedding.provider, "openai");
837        assert_eq!(config.embedding.dimensions, 1536);
838        assert!((config.retrieval.dense_weight - 0.8).abs() < 0.001);
839        assert_eq!(config.retrieval.final_k, 10);
840        assert!(config.reranker.enabled);
841        assert_eq!(config.reranker.model, "bge-reranker-large");
842        assert!(matches!(config.context.compression, CompressionStrategy::Progressive));
843        assert_eq!(config.qdrant.as_ref().unwrap().url, "http://qdrant:6334");
844    }
845
846    #[test]
847    fn test_validation_weights() {
848        let mut config = SearchConfig::default();
849        config.retrieval.dense_weight = 0.5;
850        config.retrieval.sparse_weight = 0.3; // Sum is 0.8, not 1.0
851
852        assert!(config.validate().is_err());
853    }
854
855    #[test]
856    fn test_validation_k_values() {
857        let mut config = SearchConfig::default();
858        config.retrieval.final_k = 50;
859        config.retrieval.rerank_k = 20; // final_k > rerank_k
860
861        assert!(config.validate().is_err());
862    }
863
864    #[test]
865    fn test_validation_qdrant_required() {
866        let mut config = SearchConfig::default();
867        config.backend.backend_type = BackendType::Qdrant;
868        config.qdrant = None;
869
870        assert!(config.validate().is_err());
871    }
872
873    #[test]
874    fn test_backend_type_from_str() {
875        assert!(matches!("in-memory".parse::<BackendType>().unwrap(), BackendType::InMemory));
876        assert!(matches!("inmemory".parse::<BackendType>().unwrap(), BackendType::InMemory));
877        assert!(matches!("qdrant".parse::<BackendType>().unwrap(), BackendType::Qdrant));
878        assert!("invalid".parse::<BackendType>().is_err());
879    }
880
881    #[test]
882    fn test_env_overrides() {
883        std::env::set_var("SKILL_SEARCH_BACKEND", "qdrant");
884        std::env::set_var("SKILL_EMBEDDING_DIMENSIONS", "768");
885        std::env::set_var("SKILL_RERANKER_ENABLED", "true");
886        std::env::set_var("QDRANT_URL", "http://custom:6334");
887
888        let config = SearchConfig::default().with_env_overrides();
889
890        assert!(matches!(config.backend.backend_type, BackendType::Qdrant));
891        assert_eq!(config.embedding.dimensions, 768);
892        assert!(config.reranker.enabled);
893        assert_eq!(config.qdrant.as_ref().unwrap().url, "http://custom:6334");
894
895        // Clean up
896        std::env::remove_var("SKILL_SEARCH_BACKEND");
897        std::env::remove_var("SKILL_EMBEDDING_DIMENSIONS");
898        std::env::remove_var("SKILL_RERANKER_ENABLED");
899        std::env::remove_var("QDRANT_URL");
900    }
901
902    #[test]
903    fn test_minimal_toml() {
904        let toml = r#"
905[search]
906"#;
907
908        let config = SearchConfig::from_toml(toml).unwrap();
909        assert!(config.validate().is_ok());
910    }
911
912    #[test]
913    fn test_empty_file() {
914        let toml = "";
915        let config = SearchConfig::from_toml(toml).unwrap();
916        assert!(config.validate().is_ok());
917    }
918
919    #[test]
920    fn test_ai_ingestion_defaults() {
921        let config = AiIngestionConfig::default();
922        assert!(!config.enabled);
923        assert_eq!(config.examples_per_tool, 5);
924        assert!(matches!(config.provider, AiProvider::Ollama));
925        assert_eq!(config.model, "llama3.2");
926        assert!(config.validate_examples);
927        assert!(config.stream_progress);
928        assert!(config.cache_examples);
929        assert_eq!(config.timeout_secs, 30);
930    }
931
932    #[test]
933    fn test_ai_provider_from_str() {
934        assert!(matches!("ollama".parse::<AiProvider>().unwrap(), AiProvider::Ollama));
935        assert!(matches!("openai".parse::<AiProvider>().unwrap(), AiProvider::OpenAi));
936        assert!(matches!("anthropic".parse::<AiProvider>().unwrap(), AiProvider::Anthropic));
937        assert!(matches!("claude".parse::<AiProvider>().unwrap(), AiProvider::Anthropic));
938        assert!("invalid".parse::<AiProvider>().is_err());
939    }
940
941    #[test]
942    fn test_ai_ingestion_toml_parsing() {
943        let toml = r#"
944[ai_ingestion]
945enabled = true
946examples_per_tool = 3
947provider = "openai"
948model = "gpt-4o"
949validate_examples = false
950stream_progress = true
951timeout_secs = 60
952
953[ai_ingestion.openai]
954model = "gpt-4o-mini"
955max_tokens = 4096
956temperature = 0.5
957"#;
958
959        let config: SearchConfig = toml::from_str(toml).unwrap();
960        assert!(config.ai_ingestion.enabled);
961        assert_eq!(config.ai_ingestion.examples_per_tool, 3);
962        assert!(matches!(config.ai_ingestion.provider, AiProvider::OpenAi));
963        assert_eq!(config.ai_ingestion.model, "gpt-4o");
964        assert!(!config.ai_ingestion.validate_examples);
965        assert_eq!(config.ai_ingestion.timeout_secs, 60);
966        assert_eq!(config.ai_ingestion.openai.model, "gpt-4o-mini");
967        assert_eq!(config.ai_ingestion.openai.max_tokens, 4096);
968        assert!((config.ai_ingestion.openai.temperature - 0.5).abs() < 0.01);
969    }
970
971    #[test]
972    fn test_ai_ingestion_validation() {
973        let mut config = SearchConfig::default();
974        config.ai_ingestion.enabled = true;
975        config.ai_ingestion.examples_per_tool = 0;
976
977        assert!(config.validate().is_err());
978
979        config.ai_ingestion.examples_per_tool = 5;
980        config.ai_ingestion.timeout_secs = 0;
981
982        assert!(config.validate().is_err());
983    }
984
985    #[test]
986    fn test_ai_ingestion_get_model() {
987        let mut config = AiIngestionConfig::default();
988
989        // Default model from provider config
990        config.model = String::new();
991        config.provider = AiProvider::Ollama;
992        assert_eq!(config.get_model(), "llama3.2");
993
994        config.provider = AiProvider::OpenAi;
995        assert_eq!(config.get_model(), "gpt-4o-mini");
996
997        config.provider = AiProvider::Anthropic;
998        assert_eq!(config.get_model(), "claude-3-haiku-20240307");
999
1000        // Override with explicit model
1001        config.model = "custom-model".to_string();
1002        assert_eq!(config.get_model(), "custom-model");
1003    }
1004}