Skip to main content

zeph_core/
provider_factory.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Pure provider factory helpers: build `AnyProvider` instances from config entries.
5//!
6//! This module contains configuration-to-provider transformation functions that are
7//! used by internal `zeph-core` subsystems (skills, tools, autodream, session config).
8//! They are intentionally separated from bootstrap orchestration logic so that provider
9//! construction can be reasoned about and tested independently of startup sequencing.
10
11use zeph_llm::any::AnyProvider;
12use zeph_llm::claude::ClaudeProvider;
13use zeph_llm::compatible::CompatibleProvider;
14use zeph_llm::gemini::GeminiProvider;
15use zeph_llm::http::llm_client;
16use zeph_llm::ollama::OllamaProvider;
17use zeph_llm::openai::OpenAiProvider;
18
19use crate::agent::state::ProviderConfigSnapshot;
20use crate::config::{Config, ProviderEntry, ProviderKind};
21
22/// Error type for provider construction failures.
23///
24/// String-based variants flatten the error chain intentionally: bootstrap errors are
25/// terminal (the application exits), so downcasting is not needed at this stage.
26/// If a future phase requires programmatic retry on specific failures, expand these
27/// variants into typed sub-errors.
28#[derive(Debug, thiserror::Error)]
29pub enum BootstrapError {
30    /// Configuration validation failed.
31    #[error("config error: {0}")]
32    Config(#[from] crate::config::ConfigError),
33    /// Provider construction failed (missing secrets, unsupported kind, etc.).
34    #[error("provider error: {0}")]
35    Provider(String),
36    /// Memory subsystem initialization failed.
37    #[error("memory error: {0}")]
38    Memory(String),
39    /// Age vault initialization failed.
40    #[error("vault init error: {0}")]
41    VaultInit(crate::vault::AgeVaultError),
42    /// I/O error during bootstrap.
43    #[error("I/O error: {0}")]
44    Io(#[from] std::io::Error),
45}
46
47/// Build an `AnyProvider` from a `ProviderEntry` using a runtime config snapshot.
48///
49/// Called by the `/provider <name>` slash command to switch providers at runtime without
50/// requiring the full `Config`. Router and Orchestrator provider kinds are not supported
51/// for runtime switching — they require the full provider pool to be re-initialized.
52///
53/// # Errors
54///
55/// Returns `BootstrapError::Provider` when the provider kind is unsupported for runtime
56/// switching, a required secret is missing, or the entry is misconfigured.
57pub fn build_provider_for_switch(
58    entry: &ProviderEntry,
59    snapshot: &ProviderConfigSnapshot,
60) -> Result<AnyProvider, BootstrapError> {
61    use zeph_common::secret::Secret;
62    // Reconstruct a minimal Config from the snapshot so we can reuse build_provider_from_entry.
63    // Only fields read by build_provider_from_entry are populated; everything else uses defaults.
64    // Secrets are stored as plain strings in the snapshot because Secret does not implement Clone.
65    let mut config = Config::default();
66    config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
67    config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
68    config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
69    config.secrets.compatible_api_keys = snapshot
70        .compatible_api_keys
71        .iter()
72        .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
73        .collect();
74    config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
75    config
76        .llm
77        .embedding_model
78        .clone_from(&snapshot.embedding_model);
79    build_provider_from_entry(entry, &config)
80}
81
82/// Build an `AnyProvider` from a unified `ProviderEntry` (new `[[llm.providers]]` format).
83///
84/// All provider-specific fields come from `entry`; the global `config` is used only for
85/// secrets and timeout settings.
86///
87/// # Errors
88///
89/// Returns `BootstrapError::Provider` when a required secret is missing or an entry is
90/// misconfigured (e.g. compatible provider without a name).
91#[allow(clippy::too_many_lines)]
92pub fn build_provider_from_entry(
93    entry: &ProviderEntry,
94    config: &Config,
95) -> Result<AnyProvider, BootstrapError> {
96    match entry.provider_type {
97        ProviderKind::Ollama => {
98            let base_url = entry
99                .base_url
100                .as_deref()
101                .unwrap_or("http://localhost:11434");
102            let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
103            let embed = entry
104                .embedding_model
105                .clone()
106                .unwrap_or_else(|| config.llm.embedding_model.clone());
107            let mut provider = OllamaProvider::new(base_url, model, embed);
108            if let Some(ref vm) = entry.vision_model {
109                provider = provider.with_vision_model(vm.clone());
110            }
111            Ok(AnyProvider::Ollama(provider))
112        }
113        ProviderKind::Claude => {
114            let api_key = config
115                .secrets
116                .claude_api_key
117                .as_ref()
118                .ok_or_else(|| {
119                    BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into())
120                })?
121                .expose()
122                .to_owned();
123            let model = entry
124                .model
125                .clone()
126                .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
127            let max_tokens = entry.max_tokens.unwrap_or(4096);
128            let provider = ClaudeProvider::new(api_key, model, max_tokens)
129                .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
130                .with_extended_context(entry.enable_extended_context)
131                .with_thinking_opt(entry.thinking.clone())
132                .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
133                .with_server_compaction(entry.server_compaction);
134            Ok(AnyProvider::Claude(provider))
135        }
136        ProviderKind::OpenAi => {
137            let api_key = config
138                .secrets
139                .openai_api_key
140                .as_ref()
141                .ok_or_else(|| {
142                    BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into())
143                })?
144                .expose()
145                .to_owned();
146            let base_url = entry
147                .base_url
148                .clone()
149                .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
150            let model = entry
151                .model
152                .clone()
153                .unwrap_or_else(|| "gpt-4o-mini".to_owned());
154            let max_tokens = entry.max_tokens.unwrap_or(4096);
155            Ok(AnyProvider::OpenAi(
156                OpenAiProvider::new(
157                    api_key,
158                    base_url,
159                    model,
160                    max_tokens,
161                    entry.embedding_model.clone(),
162                    entry.reasoning_effort.clone(),
163                )
164                .with_client(llm_client(config.timeouts.llm_request_timeout_secs)),
165            ))
166        }
167        ProviderKind::Gemini => {
168            let api_key = config
169                .secrets
170                .gemini_api_key
171                .as_ref()
172                .ok_or_else(|| {
173                    BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into())
174                })?
175                .expose()
176                .to_owned();
177            let model = entry
178                .model
179                .clone()
180                .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
181            let max_tokens = entry.max_tokens.unwrap_or(8192);
182            let base_url = entry
183                .base_url
184                .clone()
185                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
186            let mut provider = GeminiProvider::new(api_key, model, max_tokens)
187                .with_base_url(base_url)
188                .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
189            if let Some(ref em) = entry.embedding_model {
190                provider = provider.with_embedding_model(em.clone());
191            }
192            if let Some(level) = entry.thinking_level {
193                provider = provider.with_thinking_level(level);
194            }
195            if let Some(budget) = entry.thinking_budget {
196                provider = provider
197                    .with_thinking_budget(budget)
198                    .map_err(|e| BootstrapError::Provider(e.to_string()))?;
199            }
200            if let Some(include) = entry.include_thoughts {
201                provider = provider.with_include_thoughts(include);
202            }
203            Ok(AnyProvider::Gemini(provider))
204        }
205        ProviderKind::Compatible => {
206            let name = entry.name.as_deref().ok_or_else(|| {
207                BootstrapError::Provider(
208                    "compatible provider requires 'name' field in [[llm.providers]]".into(),
209                )
210            })?;
211            let base_url = entry.base_url.clone().ok_or_else(|| {
212                BootstrapError::Provider(format!(
213                    "compatible provider '{name}' requires 'base_url'"
214                ))
215            })?;
216            let model = entry.model.clone().unwrap_or_default();
217            let api_key = entry.api_key.clone().unwrap_or_else(|| {
218                config
219                    .secrets
220                    .compatible_api_keys
221                    .get(name)
222                    .map(|s| s.expose().to_owned())
223                    .unwrap_or_default()
224            });
225            let max_tokens = entry.max_tokens.unwrap_or(4096);
226            Ok(AnyProvider::Compatible(CompatibleProvider::new(
227                name.to_owned(),
228                api_key,
229                base_url,
230                model,
231                max_tokens,
232                entry.embedding_model.clone(),
233            )))
234        }
235        #[cfg(feature = "candle")]
236        ProviderKind::Candle => {
237            let candle = entry.candle.as_ref().ok_or_else(|| {
238                BootstrapError::Provider(
239                    "candle provider requires 'candle' section in [[llm.providers]]".into(),
240                )
241            })?;
242            let source = match candle.source.as_str() {
243                "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
244                    path: std::path::PathBuf::from(&candle.local_path),
245                },
246                _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
247                    repo_id: entry
248                        .model
249                        .clone()
250                        .unwrap_or_else(|| config.llm.effective_model().to_owned()),
251                    filename: candle.filename.clone(),
252                },
253            };
254            let template =
255                zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
256            let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
257                temperature: candle.generation.temperature,
258                top_p: candle.generation.top_p,
259                top_k: candle.generation.top_k,
260                max_tokens: candle.generation.capped_max_tokens(),
261                seed: candle.generation.seed,
262                repeat_penalty: candle.generation.repeat_penalty,
263                repeat_last_n: candle.generation.repeat_last_n,
264            };
265            let device = select_device(&candle.device)?;
266            // Floor at 1s so that inference_timeout_secs = 0 does not cause every request to
267            // immediately time out.
268            let inference_timeout =
269                std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
270            zeph_llm::candle_provider::CandleProvider::new_with_timeout(
271                &source,
272                template,
273                gen_config,
274                candle.embedding_repo.as_deref(),
275                candle.hf_token.as_deref(),
276                device,
277                inference_timeout,
278            )
279            .map(AnyProvider::Candle)
280            .map_err(|e| BootstrapError::Provider(e.to_string()))
281        }
282        #[cfg(not(feature = "candle"))]
283        ProviderKind::Candle => Err(BootstrapError::Provider(
284            "candle feature is not enabled".into(),
285        )),
286    }
287}
288
289/// Select the candle compute device based on a string preference.
290///
291/// Resolution order: `"metal"` → Metal GPU (requires `metal` feature),
292/// `"cuda"` → CUDA GPU (requires `cuda` feature), `"auto"` → best available,
293/// anything else → CPU.
294///
295/// # Errors
296///
297/// Returns `BootstrapError::Provider` when the requested device is not available (e.g.
298/// `"metal"` requested but compiled without the `metal` feature).
299#[cfg(feature = "candle")]
300pub fn select_device(
301    preference: &str,
302) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
303    match preference {
304        "metal" => {
305            #[cfg(feature = "metal")]
306            return zeph_llm::candle_provider::Device::new_metal(0)
307                .map_err(|e| BootstrapError::Provider(e.to_string()));
308            #[cfg(not(feature = "metal"))]
309            return Err(BootstrapError::Provider(
310                "candle compiled without metal feature".into(),
311            ));
312        }
313        "cuda" => {
314            #[cfg(feature = "cuda")]
315            return zeph_llm::candle_provider::Device::new_cuda(0)
316                .map_err(|e| BootstrapError::Provider(e.to_string()));
317            #[cfg(not(feature = "cuda"))]
318            return Err(BootstrapError::Provider(
319                "candle compiled without cuda feature".into(),
320            ));
321        }
322        "auto" => {
323            #[cfg(feature = "metal")]
324            if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
325                return Ok(device);
326            }
327            #[cfg(feature = "cuda")]
328            if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
329                return Ok(device);
330            }
331            Ok(zeph_llm::candle_provider::Device::Cpu)
332        }
333        _ => Ok(zeph_llm::candle_provider::Device::Cpu),
334    }
335}
336
337/// Determine the effective embedding model name for the memory subsystem.
338///
339/// Resolution order:
340/// 1. `embedding_model` from the `[[llm.providers]]` entry marked `embed = true`
341/// 2. `embedding_model` from the first entry in `[[llm.providers]]`
342/// 3. `[llm] embedding_model` global fallback
343#[must_use]
344pub fn effective_embedding_model(config: &Config) -> String {
345    // Prefer a dedicated embed provider.
346    if let Some(m) = config
347        .llm
348        .providers
349        .iter()
350        .find(|e| e.embed)
351        .and_then(|e| e.embedding_model.as_ref())
352    {
353        return m.clone();
354    }
355    // Fall back to the first provider's embedding model.
356    if let Some(m) = config
357        .llm
358        .providers
359        .first()
360        .and_then(|e| e.embedding_model.as_ref())
361    {
362        return m.clone();
363    }
364    config.llm.embedding_model.clone()
365}
366
367#[cfg(test)]
368mod tests {
369    #[cfg(feature = "candle")]
370    use super::select_device;
371
372    #[cfg(feature = "candle")]
373    #[test]
374    fn select_device_cpu_default() {
375        let device = select_device("cpu").unwrap();
376        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
377    }
378
379    #[cfg(feature = "candle")]
380    #[test]
381    fn select_device_unknown_defaults_to_cpu() {
382        let device = select_device("unknown").unwrap();
383        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
384    }
385
386    #[cfg(all(feature = "candle", not(feature = "metal")))]
387    #[test]
388    fn select_device_metal_without_feature_errors() {
389        let result = select_device("metal");
390        assert!(result.is_err());
391        assert!(result.unwrap_err().to_string().contains("metal feature"));
392    }
393
394    #[cfg(all(feature = "candle", not(feature = "cuda")))]
395    #[test]
396    fn select_device_cuda_without_feature_errors() {
397        let result = select_device("cuda");
398        assert!(result.is_err());
399        assert!(result.unwrap_err().to_string().contains("cuda feature"));
400    }
401
402    #[cfg(feature = "candle")]
403    #[test]
404    fn select_device_auto_fallback() {
405        let device = select_device("auto").unwrap();
406        assert!(matches!(
407            device,
408            zeph_llm::candle_provider::Device::Cpu
409                | zeph_llm::candle_provider::Device::Cuda(_)
410                | zeph_llm::candle_provider::Device::Metal(_)
411        ));
412    }
413}