Skip to main content

zeph_core/
provider_factory.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Pure provider factory helpers: build `AnyProvider` instances from config entries.
5//!
6//! This module contains configuration-to-provider transformation functions that are
7//! used by internal `zeph-core` subsystems (skills, tools, autodream, session config).
8//! They are intentionally separated from bootstrap orchestration logic so that provider
9//! construction can be reasoned about and tested independently of startup sequencing.
10
11use zeph_llm::any::AnyProvider;
12use zeph_llm::claude::ClaudeProvider;
13use zeph_llm::compatible::CompatibleProvider;
14use zeph_llm::gemini::GeminiProvider;
15use zeph_llm::http::llm_client;
16use zeph_llm::ollama::OllamaProvider;
17use zeph_llm::openai::OpenAiProvider;
18
19use crate::agent::state::ProviderConfigSnapshot;
20use crate::config::{Config, ProviderEntry, ProviderKind};
21
22/// Error type for provider construction failures.
23///
24/// String-based variants flatten the error chain intentionally: bootstrap errors are
25/// terminal (the application exits), so downcasting is not needed at this stage.
26/// If a future phase requires programmatic retry on specific failures, expand these
27/// variants into typed sub-errors.
28#[derive(Debug, thiserror::Error)]
29pub enum BootstrapError {
30    /// Configuration validation failed.
31    #[error("config error: {0}")]
32    Config(#[from] crate::config::ConfigError),
33    /// Provider construction failed (missing secrets, unsupported kind, etc.).
34    #[error("provider error: {0}")]
35    Provider(String),
36    /// Memory subsystem initialization failed.
37    #[error("memory error: {0}")]
38    Memory(String),
39    /// Age vault initialization failed.
40    #[error("vault init error: {0}")]
41    VaultInit(crate::vault::AgeVaultError),
42    /// I/O error during bootstrap.
43    #[error("I/O error: {0}")]
44    Io(#[from] std::io::Error),
45}
46
47/// Build an `AnyProvider` from a `ProviderEntry` using a runtime config snapshot.
48///
49/// Called by the `/provider <name>` slash command to switch providers at runtime without
50/// requiring the full `Config`. Router and Orchestrator provider kinds are not supported
51/// for runtime switching — they require the full provider pool to be re-initialized.
52///
53/// # Errors
54///
55/// Returns `BootstrapError::Provider` when the provider kind is unsupported for runtime
56/// switching, a required secret is missing, or the entry is misconfigured.
57pub fn build_provider_for_switch(
58    entry: &ProviderEntry,
59    snapshot: &ProviderConfigSnapshot,
60) -> Result<AnyProvider, BootstrapError> {
61    use zeph_common::secret::Secret;
62    // Reconstruct a minimal Config from the snapshot so we can reuse build_provider_from_entry.
63    // Only fields read by build_provider_from_entry are populated; everything else uses defaults.
64    // Secrets are stored as plain strings in the snapshot because Secret does not implement Clone.
65    let mut config = Config::default();
66    config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
67    config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
68    config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
69    config.secrets.compatible_api_keys = snapshot
70        .compatible_api_keys
71        .iter()
72        .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
73        .collect();
74    config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
75    config
76        .llm
77        .embedding_model
78        .clone_from(&snapshot.embedding_model);
79    build_provider_from_entry(entry, &config)
80}
81
82/// Build an `AnyProvider` from a unified `ProviderEntry` (new `[[llm.providers]]` format).
83///
84/// All provider-specific fields come from `entry`; the global `config` is used only for
85/// secrets and timeout settings.
86///
87/// # Errors
88///
89/// Returns `BootstrapError::Provider` when a required secret is missing or an entry is
90/// misconfigured (e.g. compatible provider without a name).
91pub fn build_provider_from_entry(
92    entry: &ProviderEntry,
93    config: &Config,
94) -> Result<AnyProvider, BootstrapError> {
95    match entry.provider_type {
96        ProviderKind::Ollama => Ok(build_ollama_provider(entry, config)),
97        ProviderKind::Claude => build_claude_provider(entry, config),
98        ProviderKind::OpenAi => build_openai_provider(entry, config),
99        ProviderKind::Gemini => build_gemini_provider(entry, config),
100        ProviderKind::Compatible => build_compatible_provider(entry, config),
101        #[cfg(feature = "candle")]
102        ProviderKind::Candle => build_candle_provider(entry, config),
103        #[cfg(not(feature = "candle"))]
104        ProviderKind::Candle => Err(BootstrapError::Provider(
105            "candle feature is not enabled".into(),
106        )),
107    }
108}
109
110fn build_ollama_provider(entry: &ProviderEntry, config: &Config) -> AnyProvider {
111    let base_url = entry
112        .base_url
113        .as_deref()
114        .unwrap_or("http://localhost:11434");
115    let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
116    let embed = entry
117        .embedding_model
118        .clone()
119        .unwrap_or_else(|| config.llm.embedding_model.clone());
120    let mut provider = OllamaProvider::new(base_url, model, embed);
121    if let Some(ref vm) = entry.vision_model {
122        provider = provider.with_vision_model(vm.clone());
123    }
124    if config.mcp.forward_output_schema {
125        tracing::debug!(
126            "mcp.forward_output_schema is enabled but Ollama does not support \
127             output schema forwarding; setting ignored for this provider"
128        );
129    }
130    AnyProvider::Ollama(provider)
131}
132
133fn build_claude_provider(
134    entry: &ProviderEntry,
135    config: &Config,
136) -> Result<AnyProvider, BootstrapError> {
137    let api_key = config
138        .secrets
139        .claude_api_key
140        .as_ref()
141        .ok_or_else(|| BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into()))?
142        .expose()
143        .to_owned();
144    let model = entry
145        .model
146        .clone()
147        .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
148    let max_tokens = entry.max_tokens.unwrap_or(4096);
149    let provider = ClaudeProvider::new(api_key, model, max_tokens)
150        .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
151        .with_extended_context(entry.enable_extended_context)
152        .with_thinking_opt(entry.thinking.clone())
153        .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
154        .with_server_compaction(entry.server_compaction)
155        .with_prompt_cache_ttl(entry.prompt_cache_ttl)
156        .with_output_schema_forwarding(
157            config.mcp.forward_output_schema,
158            config.mcp.output_schema_hint_bytes,
159            config.mcp.max_description_bytes,
160        );
161    tracing::info!(
162        forward = config.mcp.forward_output_schema,
163        "mcp.output_schema.forwarding_configured"
164    );
165    Ok(AnyProvider::Claude(provider))
166}
167
168fn build_openai_provider(
169    entry: &ProviderEntry,
170    config: &Config,
171) -> Result<AnyProvider, BootstrapError> {
172    let api_key = config
173        .secrets
174        .openai_api_key
175        .as_ref()
176        .ok_or_else(|| BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into()))?
177        .expose()
178        .to_owned();
179    let base_url = entry
180        .base_url
181        .clone()
182        .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
183    let model = entry
184        .model
185        .clone()
186        .unwrap_or_else(|| "gpt-4o-mini".to_owned());
187    let max_tokens = entry.max_tokens.unwrap_or(4096);
188    Ok(AnyProvider::OpenAi(
189        OpenAiProvider::new(
190            api_key,
191            base_url,
192            model,
193            max_tokens,
194            entry.embedding_model.clone(),
195            entry.reasoning_effort.clone(),
196        )
197        .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
198        .with_output_schema_forwarding(
199            config.mcp.forward_output_schema,
200            config.mcp.output_schema_hint_bytes,
201            config.mcp.max_description_bytes,
202        ),
203    ))
204}
205
206fn build_gemini_provider(
207    entry: &ProviderEntry,
208    config: &Config,
209) -> Result<AnyProvider, BootstrapError> {
210    let api_key = config
211        .secrets
212        .gemini_api_key
213        .as_ref()
214        .ok_or_else(|| BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into()))?
215        .expose()
216        .to_owned();
217    let model = entry
218        .model
219        .clone()
220        .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
221    let max_tokens = entry.max_tokens.unwrap_or(8192);
222    let base_url = entry
223        .base_url
224        .clone()
225        .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
226    let mut provider = GeminiProvider::new(api_key, model, max_tokens)
227        .with_base_url(base_url)
228        .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
229    if let Some(ref em) = entry.embedding_model {
230        provider = provider.with_embedding_model(em.clone());
231    }
232    if let Some(level) = entry.thinking_level {
233        provider = provider.with_thinking_level(level);
234    }
235    if let Some(budget) = entry.thinking_budget {
236        provider = provider
237            .with_thinking_budget(budget)
238            .map_err(|e| BootstrapError::Provider(e.to_string()))?;
239    }
240    if let Some(include) = entry.include_thoughts {
241        provider = provider.with_include_thoughts(include);
242    }
243    if config.mcp.forward_output_schema {
244        tracing::debug!(
245            "mcp.forward_output_schema is enabled but Gemini does not support \
246             output schema forwarding; setting ignored for this provider"
247        );
248    }
249    Ok(AnyProvider::Gemini(provider))
250}
251
252fn build_compatible_provider(
253    entry: &ProviderEntry,
254    config: &Config,
255) -> Result<AnyProvider, BootstrapError> {
256    let name = entry.name.as_deref().ok_or_else(|| {
257        BootstrapError::Provider(
258            "compatible provider requires 'name' field in [[llm.providers]]".into(),
259        )
260    })?;
261    let base_url = entry.base_url.clone().ok_or_else(|| {
262        BootstrapError::Provider(format!("compatible provider '{name}' requires 'base_url'"))
263    })?;
264    let model = entry.model.clone().unwrap_or_default();
265    let api_key = entry.api_key.clone().unwrap_or_else(|| {
266        config
267            .secrets
268            .compatible_api_keys
269            .get(name)
270            .map(|s| s.expose().to_owned())
271            .unwrap_or_default()
272    });
273    let max_tokens = entry.max_tokens.unwrap_or(4096);
274    let provider = CompatibleProvider::new(
275        name.to_owned(),
276        api_key,
277        base_url,
278        model,
279        max_tokens,
280        entry.embedding_model.clone(),
281    )
282    .with_output_schema_forwarding(
283        config.mcp.forward_output_schema,
284        config.mcp.output_schema_hint_bytes,
285        config.mcp.max_description_bytes,
286    );
287    tracing::info!(
288        forward = config.mcp.forward_output_schema,
289        provider = name,
290        "mcp.output_schema.forwarding_configured"
291    );
292    Ok(AnyProvider::Compatible(provider))
293}
294
295#[cfg(feature = "candle")]
296fn build_candle_provider(
297    entry: &ProviderEntry,
298    config: &Config,
299) -> Result<AnyProvider, BootstrapError> {
300    let candle = entry.candle.as_ref().ok_or_else(|| {
301        BootstrapError::Provider(
302            "candle provider requires 'candle' section in [[llm.providers]]".into(),
303        )
304    })?;
305    let source = match candle.source.as_str() {
306        "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
307            path: std::path::PathBuf::from(&candle.local_path),
308        },
309        _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
310            repo_id: entry
311                .model
312                .clone()
313                .unwrap_or_else(|| config.llm.effective_model().to_owned()),
314            filename: candle.filename.clone(),
315        },
316    };
317    let template =
318        zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
319    let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
320        temperature: candle.generation.temperature,
321        top_p: candle.generation.top_p,
322        top_k: candle.generation.top_k,
323        max_tokens: candle.generation.capped_max_tokens(),
324        seed: candle.generation.seed,
325        repeat_penalty: candle.generation.repeat_penalty,
326        repeat_last_n: candle.generation.repeat_last_n,
327    };
328    let device = select_device(&candle.device)?;
329    // Floor at 1s so that inference_timeout_secs = 0 does not cause every request to
330    // immediately time out.
331    let inference_timeout = std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
332    zeph_llm::candle_provider::CandleProvider::new_with_timeout(
333        &source,
334        template,
335        gen_config,
336        candle.embedding_repo.as_deref(),
337        candle.hf_token.as_deref(),
338        device,
339        inference_timeout,
340    )
341    .map(AnyProvider::Candle)
342    .map_err(|e| BootstrapError::Provider(e.to_string()))
343}
344
345/// Select the candle compute device based on a string preference.
346///
347/// Resolution order: `"metal"` → Metal GPU (requires `metal` feature),
348/// `"cuda"` → CUDA GPU (requires `cuda` feature), `"auto"` → best available,
349/// anything else → CPU.
350///
351/// # Errors
352///
353/// Returns `BootstrapError::Provider` when the requested device is not available (e.g.
354/// `"metal"` requested but compiled without the `metal` feature).
355#[cfg(feature = "candle")]
356pub fn select_device(
357    preference: &str,
358) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
359    match preference {
360        "metal" => {
361            #[cfg(feature = "metal")]
362            return zeph_llm::candle_provider::Device::new_metal(0)
363                .map_err(|e| BootstrapError::Provider(e.to_string()));
364            #[cfg(not(feature = "metal"))]
365            return Err(BootstrapError::Provider(
366                "candle compiled without metal feature".into(),
367            ));
368        }
369        "cuda" => {
370            #[cfg(feature = "cuda")]
371            return zeph_llm::candle_provider::Device::new_cuda(0)
372                .map_err(|e| BootstrapError::Provider(e.to_string()));
373            #[cfg(not(feature = "cuda"))]
374            return Err(BootstrapError::Provider(
375                "candle compiled without cuda feature".into(),
376            ));
377        }
378        "auto" => {
379            #[cfg(feature = "metal")]
380            if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
381                return Ok(device);
382            }
383            #[cfg(feature = "cuda")]
384            if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
385                return Ok(device);
386            }
387            Ok(zeph_llm::candle_provider::Device::Cpu)
388        }
389        _ => Ok(zeph_llm::candle_provider::Device::Cpu),
390    }
391}
392
393/// Determine the effective embedding model name for the memory subsystem.
394///
395/// Resolution order:
396/// 1. `embedding_model` from the `[[llm.providers]]` entry marked `embed = true`
397/// 2. `embedding_model` from the first entry in `[[llm.providers]]`
398/// 3. `[llm] embedding_model` global fallback
399#[must_use]
400pub fn effective_embedding_model(config: &Config) -> String {
401    // Prefer a dedicated embed provider.
402    if let Some(m) = config
403        .llm
404        .providers
405        .iter()
406        .find(|e| e.embed)
407        .and_then(|e| e.embedding_model.as_ref())
408    {
409        return m.clone();
410    }
411    // Fall back to the first provider's embedding model.
412    if let Some(m) = config
413        .llm
414        .providers
415        .first()
416        .and_then(|e| e.embedding_model.as_ref())
417    {
418        return m.clone();
419    }
420    config.llm.embedding_model.clone()
421}
422
423/// Resolve the stable embedding model name for skill-matcher collection versioning.
424///
425/// This uses the same entry resolution as the embedding provider itself: the entry
426/// with `embed = true`, preferring its `embedding_model` field and falling back to
427/// its `model` field. Using the actual provider's model name prevents the
428/// `model_has_changed` check in [`zeph_memory::embedding_registry`] from triggering
429/// false positives that would rebuild the `zeph_skills` collection on every startup.
430///
431/// Falls back to [`effective_embedding_model`] when no dedicated embed entry exists.
432#[must_use]
433pub fn stable_skill_embedding_model(config: &Config) -> String {
434    // Find the dedicated embed entry (same lookup as `create_embedding_provider`).
435    let embed_entry = config.llm.providers.iter().find(|e| e.embed).or_else(|| {
436        config
437            .llm
438            .providers
439            .iter()
440            .find(|e| e.embedding_model.is_some())
441    });
442
443    if let Some(entry) = embed_entry {
444        // Prefer the explicit `embedding_model` field; fall back to the `model` field.
445        if let Some(em) = entry.embedding_model.as_ref().filter(|s| !s.is_empty()) {
446            return em.clone();
447        }
448        if let Some(m) = entry.model.as_ref().filter(|s| !s.is_empty()) {
449            return m.clone();
450        }
451    }
452
453    // No dedicated embed entry — fall back to the general embedding model resolution.
454    effective_embedding_model(config)
455}
456
457#[cfg(test)]
458mod tests {
459    #[cfg(feature = "candle")]
460    use super::select_device;
461
462    #[cfg(feature = "candle")]
463    #[test]
464    fn select_device_cpu_default() {
465        let device = select_device("cpu").unwrap();
466        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
467    }
468
469    #[cfg(feature = "candle")]
470    #[test]
471    fn select_device_unknown_defaults_to_cpu() {
472        let device = select_device("unknown").unwrap();
473        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
474    }
475
476    #[cfg(all(feature = "candle", not(feature = "metal")))]
477    #[test]
478    fn select_device_metal_without_feature_errors() {
479        let result = select_device("metal");
480        assert!(result.is_err());
481        assert!(result.unwrap_err().to_string().contains("metal feature"));
482    }
483
484    #[cfg(all(feature = "candle", not(feature = "cuda")))]
485    #[test]
486    fn select_device_cuda_without_feature_errors() {
487        let result = select_device("cuda");
488        assert!(result.is_err());
489        assert!(result.unwrap_err().to_string().contains("cuda feature"));
490    }
491
492    #[cfg(feature = "candle")]
493    #[test]
494    fn select_device_auto_fallback() {
495        let device = select_device("auto").unwrap();
496        assert!(matches!(
497            device,
498            zeph_llm::candle_provider::Device::Cpu
499                | zeph_llm::candle_provider::Device::Cuda(_)
500                | zeph_llm::candle_provider::Device::Metal(_)
501        ));
502    }
503
504    use super::{effective_embedding_model, stable_skill_embedding_model};
505    use crate::config::{Config, ProviderKind};
506    use zeph_config::providers::ProviderEntry;
507
508    fn make_provider_entry(
509        embed: bool,
510        model: Option<&str>,
511        embedding_model: Option<&str>,
512    ) -> ProviderEntry {
513        ProviderEntry {
514            provider_type: ProviderKind::Ollama,
515            embed,
516            model: model.map(str::to_owned),
517            embedding_model: embedding_model.map(str::to_owned),
518            ..ProviderEntry::default()
519        }
520    }
521
522    #[test]
523    fn stable_skill_embedding_model_prefers_embedding_model_field() {
524        let mut config = Config::default();
525        config.llm.providers = vec![make_provider_entry(
526            true,
527            Some("chat-model"),
528            Some("embed-v2"),
529        )];
530        assert_eq!(stable_skill_embedding_model(&config), "embed-v2");
531    }
532
533    #[test]
534    fn stable_skill_embedding_model_falls_back_to_model_field() {
535        let mut config = Config::default();
536        config.llm.providers = vec![make_provider_entry(
537            true,
538            Some("nomic-embed-text-v2-moe:latest"),
539            None,
540        )];
541        assert_eq!(
542            stable_skill_embedding_model(&config),
543            "nomic-embed-text-v2-moe:latest"
544        );
545    }
546
547    #[test]
548    fn stable_skill_embedding_model_finds_embed_flag_entry() {
549        let mut config = Config::default();
550        config.llm.providers = vec![
551            make_provider_entry(false, Some("chat-model"), None),
552            make_provider_entry(true, Some("embed-model"), Some("text-embed-3")),
553        ];
554        assert_eq!(stable_skill_embedding_model(&config), "text-embed-3");
555    }
556
557    #[test]
558    fn stable_skill_embedding_model_falls_back_to_effective_when_no_embed_entry() {
559        let mut config = Config::default();
560        config.llm.embedding_model = "global-embed-model".to_owned();
561        // No embed=true entry, no embedding_model field set — falls back to effective_embedding_model.
562        config.llm.providers = vec![make_provider_entry(false, Some("chat"), None)];
563        assert_eq!(
564            stable_skill_embedding_model(&config),
565            effective_embedding_model(&config)
566        );
567    }
568}