Skip to main content

zeph_core/
provider_factory.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Pure provider factory helpers: build `AnyProvider` instances from config entries.
5//!
6//! This module contains configuration-to-provider transformation functions that are
7//! used by internal `zeph-core` subsystems (skills, tools, autodream, session config).
8//! They are intentionally separated from bootstrap orchestration logic so that provider
9//! construction can be reasoned about and tested independently of startup sequencing.
10
11use zeph_llm::any::AnyProvider;
12use zeph_llm::claude::ClaudeProvider;
13use zeph_llm::compatible::CompatibleProvider;
14use zeph_llm::gemini::GeminiProvider;
15use zeph_llm::http::llm_client;
16use zeph_llm::ollama::OllamaProvider;
17use zeph_llm::openai::OpenAiProvider;
18
19use crate::agent::state::ProviderConfigSnapshot;
20use crate::config::{Config, ProviderEntry, ProviderKind};
21
22/// Error type for provider construction failures.
23///
24/// String-based variants flatten the error chain intentionally: bootstrap errors are
25/// terminal (the application exits), so downcasting is not needed at this stage.
26/// If a future phase requires programmatic retry on specific failures, expand these
27/// variants into typed sub-errors.
28#[derive(Debug, thiserror::Error)]
29pub enum BootstrapError {
30    /// Configuration validation failed.
31    #[error("config error: {0}")]
32    Config(#[from] crate::config::ConfigError),
33    /// Provider construction failed (missing secrets, unsupported kind, etc.).
34    #[error("provider error: {0}")]
35    Provider(String),
36    /// Memory subsystem initialization failed.
37    #[error("memory error: {0}")]
38    Memory(String),
39    /// Age vault initialization failed.
40    #[error("vault init error: {0}")]
41    VaultInit(crate::vault::AgeVaultError),
42    /// I/O error during bootstrap.
43    #[error("I/O error: {0}")]
44    Io(#[from] std::io::Error),
45}
46
47/// Build an `AnyProvider` from a `ProviderEntry` using a runtime config snapshot.
48///
49/// Called by the `/provider <name>` slash command to switch providers at runtime without
50/// requiring the full `Config`. Router and Orchestrator provider kinds are not supported
51/// for runtime switching — they require the full provider pool to be re-initialized.
52///
53/// # Errors
54///
55/// Returns `BootstrapError::Provider` when the provider kind is unsupported for runtime
56/// switching, a required secret is missing, or the entry is misconfigured.
57pub fn build_provider_for_switch(
58    entry: &ProviderEntry,
59    snapshot: &ProviderConfigSnapshot,
60) -> Result<AnyProvider, BootstrapError> {
61    use zeph_common::secret::Secret;
62    // Reconstruct a minimal Config from the snapshot so we can reuse build_provider_from_entry.
63    // Only fields read by build_provider_from_entry are populated; everything else uses defaults.
64    // Secrets are stored as plain strings in the snapshot because Secret does not implement Clone.
65    let mut config = Config::default();
66    config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
67    config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
68    config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
69    config.secrets.compatible_api_keys = snapshot
70        .compatible_api_keys
71        .iter()
72        .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
73        .collect();
74    config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
75    config
76        .llm
77        .embedding_model
78        .clone_from(&snapshot.embedding_model);
79    build_provider_from_entry(entry, &config)
80}
81
82/// Build an `AnyProvider` from a unified `ProviderEntry` (new `[[llm.providers]]` format).
83///
84/// All provider-specific fields come from `entry`; the global `config` is used only for
85/// secrets and timeout settings.
86///
87/// # Errors
88///
89/// Returns `BootstrapError::Provider` when a required secret is missing or an entry is
90/// misconfigured (e.g. compatible provider without a name).
91#[allow(clippy::too_many_lines)]
92pub fn build_provider_from_entry(
93    entry: &ProviderEntry,
94    config: &Config,
95) -> Result<AnyProvider, BootstrapError> {
96    match entry.provider_type {
97        ProviderKind::Ollama => {
98            let base_url = entry
99                .base_url
100                .as_deref()
101                .unwrap_or("http://localhost:11434");
102            let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
103            let embed = entry
104                .embedding_model
105                .clone()
106                .unwrap_or_else(|| config.llm.embedding_model.clone());
107            let mut provider = OllamaProvider::new(base_url, model, embed);
108            if let Some(ref vm) = entry.vision_model {
109                provider = provider.with_vision_model(vm.clone());
110            }
111            if config.mcp.forward_output_schema {
112                tracing::debug!(
113                    "mcp.forward_output_schema is enabled but Ollama does not support \
114                     output schema forwarding; setting ignored for this provider"
115                );
116            }
117            Ok(AnyProvider::Ollama(provider))
118        }
119        ProviderKind::Claude => {
120            let api_key = config
121                .secrets
122                .claude_api_key
123                .as_ref()
124                .ok_or_else(|| {
125                    BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into())
126                })?
127                .expose()
128                .to_owned();
129            let model = entry
130                .model
131                .clone()
132                .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
133            let max_tokens = entry.max_tokens.unwrap_or(4096);
134            let provider = ClaudeProvider::new(api_key, model, max_tokens)
135                .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
136                .with_extended_context(entry.enable_extended_context)
137                .with_thinking_opt(entry.thinking.clone())
138                .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
139                .with_server_compaction(entry.server_compaction)
140                .with_prompt_cache_ttl(entry.prompt_cache_ttl)
141                .with_output_schema_forwarding(
142                    config.mcp.forward_output_schema,
143                    config.mcp.output_schema_hint_bytes,
144                    config.mcp.max_description_bytes,
145                );
146            tracing::info!(
147                forward = config.mcp.forward_output_schema,
148                "mcp.output_schema.forwarding_configured"
149            );
150            Ok(AnyProvider::Claude(provider))
151        }
152        ProviderKind::OpenAi => {
153            let api_key = config
154                .secrets
155                .openai_api_key
156                .as_ref()
157                .ok_or_else(|| {
158                    BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into())
159                })?
160                .expose()
161                .to_owned();
162            let base_url = entry
163                .base_url
164                .clone()
165                .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
166            let model = entry
167                .model
168                .clone()
169                .unwrap_or_else(|| "gpt-4o-mini".to_owned());
170            let max_tokens = entry.max_tokens.unwrap_or(4096);
171            Ok(AnyProvider::OpenAi(
172                OpenAiProvider::new(
173                    api_key,
174                    base_url,
175                    model,
176                    max_tokens,
177                    entry.embedding_model.clone(),
178                    entry.reasoning_effort.clone(),
179                )
180                .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
181                .with_output_schema_forwarding(
182                    config.mcp.forward_output_schema,
183                    config.mcp.output_schema_hint_bytes,
184                    config.mcp.max_description_bytes,
185                ),
186            ))
187        }
188        ProviderKind::Gemini => {
189            let api_key = config
190                .secrets
191                .gemini_api_key
192                .as_ref()
193                .ok_or_else(|| {
194                    BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into())
195                })?
196                .expose()
197                .to_owned();
198            let model = entry
199                .model
200                .clone()
201                .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
202            let max_tokens = entry.max_tokens.unwrap_or(8192);
203            let base_url = entry
204                .base_url
205                .clone()
206                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
207            let mut provider = GeminiProvider::new(api_key, model, max_tokens)
208                .with_base_url(base_url)
209                .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
210            if let Some(ref em) = entry.embedding_model {
211                provider = provider.with_embedding_model(em.clone());
212            }
213            if let Some(level) = entry.thinking_level {
214                provider = provider.with_thinking_level(level);
215            }
216            if let Some(budget) = entry.thinking_budget {
217                provider = provider
218                    .with_thinking_budget(budget)
219                    .map_err(|e| BootstrapError::Provider(e.to_string()))?;
220            }
221            if let Some(include) = entry.include_thoughts {
222                provider = provider.with_include_thoughts(include);
223            }
224            if config.mcp.forward_output_schema {
225                tracing::debug!(
226                    "mcp.forward_output_schema is enabled but Gemini does not support \
227                     output schema forwarding; setting ignored for this provider"
228                );
229            }
230            Ok(AnyProvider::Gemini(provider))
231        }
232        ProviderKind::Compatible => {
233            let name = entry.name.as_deref().ok_or_else(|| {
234                BootstrapError::Provider(
235                    "compatible provider requires 'name' field in [[llm.providers]]".into(),
236                )
237            })?;
238            let base_url = entry.base_url.clone().ok_or_else(|| {
239                BootstrapError::Provider(format!(
240                    "compatible provider '{name}' requires 'base_url'"
241                ))
242            })?;
243            let model = entry.model.clone().unwrap_or_default();
244            let api_key = entry.api_key.clone().unwrap_or_else(|| {
245                config
246                    .secrets
247                    .compatible_api_keys
248                    .get(name)
249                    .map(|s| s.expose().to_owned())
250                    .unwrap_or_default()
251            });
252            let max_tokens = entry.max_tokens.unwrap_or(4096);
253            let provider = CompatibleProvider::new(
254                name.to_owned(),
255                api_key,
256                base_url,
257                model,
258                max_tokens,
259                entry.embedding_model.clone(),
260            )
261            .with_output_schema_forwarding(
262                config.mcp.forward_output_schema,
263                config.mcp.output_schema_hint_bytes,
264                config.mcp.max_description_bytes,
265            );
266            tracing::info!(
267                forward = config.mcp.forward_output_schema,
268                provider = name,
269                "mcp.output_schema.forwarding_configured"
270            );
271            Ok(AnyProvider::Compatible(provider))
272        }
273        #[cfg(feature = "candle")]
274        ProviderKind::Candle => {
275            let candle = entry.candle.as_ref().ok_or_else(|| {
276                BootstrapError::Provider(
277                    "candle provider requires 'candle' section in [[llm.providers]]".into(),
278                )
279            })?;
280            let source = match candle.source.as_str() {
281                "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
282                    path: std::path::PathBuf::from(&candle.local_path),
283                },
284                _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
285                    repo_id: entry
286                        .model
287                        .clone()
288                        .unwrap_or_else(|| config.llm.effective_model().to_owned()),
289                    filename: candle.filename.clone(),
290                },
291            };
292            let template =
293                zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
294            let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
295                temperature: candle.generation.temperature,
296                top_p: candle.generation.top_p,
297                top_k: candle.generation.top_k,
298                max_tokens: candle.generation.capped_max_tokens(),
299                seed: candle.generation.seed,
300                repeat_penalty: candle.generation.repeat_penalty,
301                repeat_last_n: candle.generation.repeat_last_n,
302            };
303            let device = select_device(&candle.device)?;
304            // Floor at 1s so that inference_timeout_secs = 0 does not cause every request to
305            // immediately time out.
306            let inference_timeout =
307                std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
308            zeph_llm::candle_provider::CandleProvider::new_with_timeout(
309                &source,
310                template,
311                gen_config,
312                candle.embedding_repo.as_deref(),
313                candle.hf_token.as_deref(),
314                device,
315                inference_timeout,
316            )
317            .map(AnyProvider::Candle)
318            .map_err(|e| BootstrapError::Provider(e.to_string()))
319        }
320        #[cfg(not(feature = "candle"))]
321        ProviderKind::Candle => Err(BootstrapError::Provider(
322            "candle feature is not enabled".into(),
323        )),
324    }
325}
326
327/// Select the candle compute device based on a string preference.
328///
329/// Resolution order: `"metal"` → Metal GPU (requires `metal` feature),
330/// `"cuda"` → CUDA GPU (requires `cuda` feature), `"auto"` → best available,
331/// anything else → CPU.
332///
333/// # Errors
334///
335/// Returns `BootstrapError::Provider` when the requested device is not available (e.g.
336/// `"metal"` requested but compiled without the `metal` feature).
337#[cfg(feature = "candle")]
338pub fn select_device(
339    preference: &str,
340) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
341    match preference {
342        "metal" => {
343            #[cfg(feature = "metal")]
344            return zeph_llm::candle_provider::Device::new_metal(0)
345                .map_err(|e| BootstrapError::Provider(e.to_string()));
346            #[cfg(not(feature = "metal"))]
347            return Err(BootstrapError::Provider(
348                "candle compiled without metal feature".into(),
349            ));
350        }
351        "cuda" => {
352            #[cfg(feature = "cuda")]
353            return zeph_llm::candle_provider::Device::new_cuda(0)
354                .map_err(|e| BootstrapError::Provider(e.to_string()));
355            #[cfg(not(feature = "cuda"))]
356            return Err(BootstrapError::Provider(
357                "candle compiled without cuda feature".into(),
358            ));
359        }
360        "auto" => {
361            #[cfg(feature = "metal")]
362            if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
363                return Ok(device);
364            }
365            #[cfg(feature = "cuda")]
366            if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
367                return Ok(device);
368            }
369            Ok(zeph_llm::candle_provider::Device::Cpu)
370        }
371        _ => Ok(zeph_llm::candle_provider::Device::Cpu),
372    }
373}
374
375/// Determine the effective embedding model name for the memory subsystem.
376///
377/// Resolution order:
378/// 1. `embedding_model` from the `[[llm.providers]]` entry marked `embed = true`
379/// 2. `embedding_model` from the first entry in `[[llm.providers]]`
380/// 3. `[llm] embedding_model` global fallback
381#[must_use]
382pub fn effective_embedding_model(config: &Config) -> String {
383    // Prefer a dedicated embed provider.
384    if let Some(m) = config
385        .llm
386        .providers
387        .iter()
388        .find(|e| e.embed)
389        .and_then(|e| e.embedding_model.as_ref())
390    {
391        return m.clone();
392    }
393    // Fall back to the first provider's embedding model.
394    if let Some(m) = config
395        .llm
396        .providers
397        .first()
398        .and_then(|e| e.embedding_model.as_ref())
399    {
400        return m.clone();
401    }
402    config.llm.embedding_model.clone()
403}
404
405#[cfg(test)]
406mod tests {
407    #[cfg(feature = "candle")]
408    use super::select_device;
409
410    #[cfg(feature = "candle")]
411    #[test]
412    fn select_device_cpu_default() {
413        let device = select_device("cpu").unwrap();
414        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
415    }
416
417    #[cfg(feature = "candle")]
418    #[test]
419    fn select_device_unknown_defaults_to_cpu() {
420        let device = select_device("unknown").unwrap();
421        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
422    }
423
424    #[cfg(all(feature = "candle", not(feature = "metal")))]
425    #[test]
426    fn select_device_metal_without_feature_errors() {
427        let result = select_device("metal");
428        assert!(result.is_err());
429        assert!(result.unwrap_err().to_string().contains("metal feature"));
430    }
431
432    #[cfg(all(feature = "candle", not(feature = "cuda")))]
433    #[test]
434    fn select_device_cuda_without_feature_errors() {
435        let result = select_device("cuda");
436        assert!(result.is_err());
437        assert!(result.unwrap_err().to_string().contains("cuda feature"));
438    }
439
440    #[cfg(feature = "candle")]
441    #[test]
442    fn select_device_auto_fallback() {
443        let device = select_device("auto").unwrap();
444        assert!(matches!(
445            device,
446            zeph_llm::candle_provider::Device::Cpu
447                | zeph_llm::candle_provider::Device::Cuda(_)
448                | zeph_llm::candle_provider::Device::Metal(_)
449        ));
450    }
451}