1use zeph_llm::any::AnyProvider;
12use zeph_llm::claude::ClaudeProvider;
13use zeph_llm::compatible::CompatibleProvider;
14use zeph_llm::gemini::GeminiProvider;
15use zeph_llm::http::llm_client;
16use zeph_llm::ollama::OllamaProvider;
17use zeph_llm::openai::OpenAiProvider;
18
19use crate::agent::state::ProviderConfigSnapshot;
20use crate::config::{Config, ProviderEntry, ProviderKind};
21
22#[derive(Debug, thiserror::Error)]
29pub enum BootstrapError {
30 #[error("config error: {0}")]
32 Config(#[from] crate::config::ConfigError),
33 #[error("provider error: {0}")]
35 Provider(String),
36 #[error("memory error: {0}")]
38 Memory(String),
39 #[error("vault init error: {0}")]
41 VaultInit(crate::vault::AgeVaultError),
42 #[error("I/O error: {0}")]
44 Io(#[from] std::io::Error),
45}
46
47pub fn build_provider_for_switch(
58 entry: &ProviderEntry,
59 snapshot: &ProviderConfigSnapshot,
60) -> Result<AnyProvider, BootstrapError> {
61 use zeph_common::secret::Secret;
62 let mut config = Config::default();
66 config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
67 config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
68 config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
69 config.secrets.compatible_api_keys = snapshot
70 .compatible_api_keys
71 .iter()
72 .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
73 .collect();
74 config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
75 config
76 .llm
77 .embedding_model
78 .clone_from(&snapshot.embedding_model);
79 build_provider_from_entry(entry, &config)
80}
81
82pub fn build_provider_from_entry(
92 entry: &ProviderEntry,
93 config: &Config,
94) -> Result<AnyProvider, BootstrapError> {
95 match entry.provider_type {
96 ProviderKind::Ollama => Ok(build_ollama_provider(entry, config)),
97 ProviderKind::Claude => build_claude_provider(entry, config),
98 ProviderKind::OpenAi => build_openai_provider(entry, config),
99 ProviderKind::Gemini => build_gemini_provider(entry, config),
100 ProviderKind::Compatible => build_compatible_provider(entry, config),
101 #[cfg(feature = "candle")]
102 ProviderKind::Candle => build_candle_provider(entry, config),
103 #[cfg(not(feature = "candle"))]
104 ProviderKind::Candle => Err(BootstrapError::Provider(
105 "candle feature is not enabled".into(),
106 )),
107 }
108}
109
110fn build_ollama_provider(entry: &ProviderEntry, config: &Config) -> AnyProvider {
111 let base_url = entry
112 .base_url
113 .as_deref()
114 .unwrap_or("http://localhost:11434");
115 let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
116 let embed = entry
117 .embedding_model
118 .clone()
119 .unwrap_or_else(|| config.llm.embedding_model.clone());
120 let mut provider = OllamaProvider::new(base_url, model, embed);
121 if let Some(ref vm) = entry.vision_model {
122 provider = provider.with_vision_model(vm.clone());
123 }
124 if config.mcp.forward_output_schema {
125 tracing::debug!(
126 "mcp.forward_output_schema is enabled but Ollama does not support \
127 output schema forwarding; setting ignored for this provider"
128 );
129 }
130 AnyProvider::Ollama(provider)
131}
132
133fn build_claude_provider(
134 entry: &ProviderEntry,
135 config: &Config,
136) -> Result<AnyProvider, BootstrapError> {
137 let api_key = config
138 .secrets
139 .claude_api_key
140 .as_ref()
141 .ok_or_else(|| BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into()))?
142 .expose()
143 .to_owned();
144 let model = entry
145 .model
146 .clone()
147 .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
148 let max_tokens = entry.max_tokens.unwrap_or(4096);
149 let provider = ClaudeProvider::new(api_key, model, max_tokens)
150 .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
151 .with_extended_context(entry.enable_extended_context)
152 .with_thinking_opt(entry.thinking.clone())
153 .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
154 .with_server_compaction(entry.server_compaction)
155 .with_prompt_cache_ttl(entry.prompt_cache_ttl)
156 .with_output_schema_forwarding(
157 config.mcp.forward_output_schema,
158 config.mcp.output_schema_hint_bytes,
159 config.mcp.max_description_bytes,
160 );
161 tracing::info!(
162 forward = config.mcp.forward_output_schema,
163 "mcp.output_schema.forwarding_configured"
164 );
165 Ok(AnyProvider::Claude(provider))
166}
167
168fn build_openai_provider(
169 entry: &ProviderEntry,
170 config: &Config,
171) -> Result<AnyProvider, BootstrapError> {
172 let api_key = config
173 .secrets
174 .openai_api_key
175 .as_ref()
176 .ok_or_else(|| BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into()))?
177 .expose()
178 .to_owned();
179 let base_url = entry
180 .base_url
181 .clone()
182 .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
183 let model = entry
184 .model
185 .clone()
186 .unwrap_or_else(|| "gpt-4o-mini".to_owned());
187 let max_tokens = entry.max_tokens.unwrap_or(4096);
188 Ok(AnyProvider::OpenAi(
189 OpenAiProvider::new(
190 api_key,
191 base_url,
192 model,
193 max_tokens,
194 entry.embedding_model.clone(),
195 entry.reasoning_effort.clone(),
196 )
197 .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
198 .with_output_schema_forwarding(
199 config.mcp.forward_output_schema,
200 config.mcp.output_schema_hint_bytes,
201 config.mcp.max_description_bytes,
202 ),
203 ))
204}
205
206fn build_gemini_provider(
207 entry: &ProviderEntry,
208 config: &Config,
209) -> Result<AnyProvider, BootstrapError> {
210 let api_key = config
211 .secrets
212 .gemini_api_key
213 .as_ref()
214 .ok_or_else(|| BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into()))?
215 .expose()
216 .to_owned();
217 let model = entry
218 .model
219 .clone()
220 .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
221 let max_tokens = entry.max_tokens.unwrap_or(8192);
222 let base_url = entry
223 .base_url
224 .clone()
225 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
226 let mut provider = GeminiProvider::new(api_key, model, max_tokens)
227 .with_base_url(base_url)
228 .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
229 if let Some(ref em) = entry.embedding_model {
230 provider = provider.with_embedding_model(em.clone());
231 }
232 if let Some(level) = entry.thinking_level {
233 provider = provider.with_thinking_level(level);
234 }
235 if let Some(budget) = entry.thinking_budget {
236 provider = provider
237 .with_thinking_budget(budget)
238 .map_err(|e| BootstrapError::Provider(e.to_string()))?;
239 }
240 if let Some(include) = entry.include_thoughts {
241 provider = provider.with_include_thoughts(include);
242 }
243 if config.mcp.forward_output_schema {
244 tracing::debug!(
245 "mcp.forward_output_schema is enabled but Gemini does not support \
246 output schema forwarding; setting ignored for this provider"
247 );
248 }
249 Ok(AnyProvider::Gemini(provider))
250}
251
252fn build_compatible_provider(
253 entry: &ProviderEntry,
254 config: &Config,
255) -> Result<AnyProvider, BootstrapError> {
256 let name = entry.name.as_deref().ok_or_else(|| {
257 BootstrapError::Provider(
258 "compatible provider requires 'name' field in [[llm.providers]]".into(),
259 )
260 })?;
261 let base_url = entry.base_url.clone().ok_or_else(|| {
262 BootstrapError::Provider(format!("compatible provider '{name}' requires 'base_url'"))
263 })?;
264 let model = entry.model.clone().unwrap_or_default();
265 let api_key = entry.api_key.clone().unwrap_or_else(|| {
266 config
267 .secrets
268 .compatible_api_keys
269 .get(name)
270 .map(|s| s.expose().to_owned())
271 .unwrap_or_default()
272 });
273 let max_tokens = entry.max_tokens.unwrap_or(4096);
274 let provider = CompatibleProvider::new(
275 name.to_owned(),
276 api_key,
277 base_url,
278 model,
279 max_tokens,
280 entry.embedding_model.clone(),
281 )
282 .with_output_schema_forwarding(
283 config.mcp.forward_output_schema,
284 config.mcp.output_schema_hint_bytes,
285 config.mcp.max_description_bytes,
286 );
287 tracing::info!(
288 forward = config.mcp.forward_output_schema,
289 provider = name,
290 "mcp.output_schema.forwarding_configured"
291 );
292 Ok(AnyProvider::Compatible(provider))
293}
294
295#[cfg(feature = "candle")]
296fn build_candle_provider(
297 entry: &ProviderEntry,
298 config: &Config,
299) -> Result<AnyProvider, BootstrapError> {
300 let candle = entry.candle.as_ref().ok_or_else(|| {
301 BootstrapError::Provider(
302 "candle provider requires 'candle' section in [[llm.providers]]".into(),
303 )
304 })?;
305 let source = match candle.source.as_str() {
306 "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
307 path: std::path::PathBuf::from(&candle.local_path),
308 },
309 _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
310 repo_id: entry
311 .model
312 .clone()
313 .unwrap_or_else(|| config.llm.effective_model().to_owned()),
314 filename: candle.filename.clone(),
315 },
316 };
317 let template =
318 zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
319 let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
320 temperature: candle.generation.temperature,
321 top_p: candle.generation.top_p,
322 top_k: candle.generation.top_k,
323 max_tokens: candle.generation.capped_max_tokens(),
324 seed: candle.generation.seed,
325 repeat_penalty: candle.generation.repeat_penalty,
326 repeat_last_n: candle.generation.repeat_last_n,
327 };
328 let device = select_device(&candle.device)?;
329 let inference_timeout = std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
332 zeph_llm::candle_provider::CandleProvider::new_with_timeout(
333 &source,
334 template,
335 gen_config,
336 candle.embedding_repo.as_deref(),
337 candle.hf_token.as_deref(),
338 device,
339 inference_timeout,
340 )
341 .map(AnyProvider::Candle)
342 .map_err(|e| BootstrapError::Provider(e.to_string()))
343}
344
345#[cfg(feature = "candle")]
356pub fn select_device(
357 preference: &str,
358) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
359 match preference {
360 "metal" => {
361 #[cfg(feature = "metal")]
362 return zeph_llm::candle_provider::Device::new_metal(0)
363 .map_err(|e| BootstrapError::Provider(e.to_string()));
364 #[cfg(not(feature = "metal"))]
365 return Err(BootstrapError::Provider(
366 "candle compiled without metal feature".into(),
367 ));
368 }
369 "cuda" => {
370 #[cfg(feature = "cuda")]
371 return zeph_llm::candle_provider::Device::new_cuda(0)
372 .map_err(|e| BootstrapError::Provider(e.to_string()));
373 #[cfg(not(feature = "cuda"))]
374 return Err(BootstrapError::Provider(
375 "candle compiled without cuda feature".into(),
376 ));
377 }
378 "auto" => {
379 #[cfg(feature = "metal")]
380 if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
381 return Ok(device);
382 }
383 #[cfg(feature = "cuda")]
384 if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
385 return Ok(device);
386 }
387 Ok(zeph_llm::candle_provider::Device::Cpu)
388 }
389 _ => Ok(zeph_llm::candle_provider::Device::Cpu),
390 }
391}
392
393#[must_use]
400pub fn effective_embedding_model(config: &Config) -> String {
401 if let Some(m) = config
403 .llm
404 .providers
405 .iter()
406 .find(|e| e.embed)
407 .and_then(|e| e.embedding_model.as_ref())
408 {
409 return m.clone();
410 }
411 if let Some(m) = config
413 .llm
414 .providers
415 .first()
416 .and_then(|e| e.embedding_model.as_ref())
417 {
418 return m.clone();
419 }
420 config.llm.embedding_model.clone()
421}
422
423#[must_use]
433pub fn stable_skill_embedding_model(config: &Config) -> String {
434 let embed_entry = config.llm.providers.iter().find(|e| e.embed).or_else(|| {
436 config
437 .llm
438 .providers
439 .iter()
440 .find(|e| e.embedding_model.is_some())
441 });
442
443 if let Some(entry) = embed_entry {
444 if let Some(em) = entry.embedding_model.as_ref().filter(|s| !s.is_empty()) {
446 return em.clone();
447 }
448 if let Some(m) = entry.model.as_ref().filter(|s| !s.is_empty()) {
449 return m.clone();
450 }
451 }
452
453 effective_embedding_model(config)
455}
456
457#[cfg(test)]
458mod tests {
459 #[cfg(feature = "candle")]
460 use super::select_device;
461
462 #[cfg(feature = "candle")]
463 #[test]
464 fn select_device_cpu_default() {
465 let device = select_device("cpu").unwrap();
466 assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
467 }
468
469 #[cfg(feature = "candle")]
470 #[test]
471 fn select_device_unknown_defaults_to_cpu() {
472 let device = select_device("unknown").unwrap();
473 assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
474 }
475
476 #[cfg(all(feature = "candle", not(feature = "metal")))]
477 #[test]
478 fn select_device_metal_without_feature_errors() {
479 let result = select_device("metal");
480 assert!(result.is_err());
481 assert!(result.unwrap_err().to_string().contains("metal feature"));
482 }
483
484 #[cfg(all(feature = "candle", not(feature = "cuda")))]
485 #[test]
486 fn select_device_cuda_without_feature_errors() {
487 let result = select_device("cuda");
488 assert!(result.is_err());
489 assert!(result.unwrap_err().to_string().contains("cuda feature"));
490 }
491
492 #[cfg(feature = "candle")]
493 #[test]
494 fn select_device_auto_fallback() {
495 let device = select_device("auto").unwrap();
496 assert!(matches!(
497 device,
498 zeph_llm::candle_provider::Device::Cpu
499 | zeph_llm::candle_provider::Device::Cuda(_)
500 | zeph_llm::candle_provider::Device::Metal(_)
501 ));
502 }
503
504 use super::{effective_embedding_model, stable_skill_embedding_model};
505 use crate::config::{Config, ProviderKind};
506 use zeph_config::providers::ProviderEntry;
507
508 fn make_provider_entry(
509 embed: bool,
510 model: Option<&str>,
511 embedding_model: Option<&str>,
512 ) -> ProviderEntry {
513 ProviderEntry {
514 provider_type: ProviderKind::Ollama,
515 embed,
516 model: model.map(str::to_owned),
517 embedding_model: embedding_model.map(str::to_owned),
518 ..ProviderEntry::default()
519 }
520 }
521
522 #[test]
523 fn stable_skill_embedding_model_prefers_embedding_model_field() {
524 let mut config = Config::default();
525 config.llm.providers = vec![make_provider_entry(
526 true,
527 Some("chat-model"),
528 Some("embed-v2"),
529 )];
530 assert_eq!(stable_skill_embedding_model(&config), "embed-v2");
531 }
532
533 #[test]
534 fn stable_skill_embedding_model_falls_back_to_model_field() {
535 let mut config = Config::default();
536 config.llm.providers = vec![make_provider_entry(
537 true,
538 Some("nomic-embed-text-v2-moe:latest"),
539 None,
540 )];
541 assert_eq!(
542 stable_skill_embedding_model(&config),
543 "nomic-embed-text-v2-moe:latest"
544 );
545 }
546
547 #[test]
548 fn stable_skill_embedding_model_finds_embed_flag_entry() {
549 let mut config = Config::default();
550 config.llm.providers = vec![
551 make_provider_entry(false, Some("chat-model"), None),
552 make_provider_entry(true, Some("embed-model"), Some("text-embed-3")),
553 ];
554 assert_eq!(stable_skill_embedding_model(&config), "text-embed-3");
555 }
556
557 #[test]
558 fn stable_skill_embedding_model_falls_back_to_effective_when_no_embed_entry() {
559 let mut config = Config::default();
560 config.llm.embedding_model = "global-embed-model".to_owned();
561 config.llm.providers = vec![make_provider_entry(false, Some("chat"), None)];
563 assert_eq!(
564 stable_skill_embedding_model(&config),
565 effective_embedding_model(&config)
566 );
567 }
568}