zeph_core/
provider_factory.rs1use zeph_llm::any::AnyProvider;
12use zeph_llm::claude::ClaudeProvider;
13use zeph_llm::compatible::CompatibleProvider;
14use zeph_llm::gemini::GeminiProvider;
15use zeph_llm::http::llm_client;
16use zeph_llm::ollama::OllamaProvider;
17use zeph_llm::openai::OpenAiProvider;
18
19use crate::agent::state::ProviderConfigSnapshot;
20use crate::config::{Config, ProviderEntry, ProviderKind};
21
22#[derive(Debug, thiserror::Error)]
29pub enum BootstrapError {
30 #[error("config error: {0}")]
32 Config(#[from] crate::config::ConfigError),
33 #[error("provider error: {0}")]
35 Provider(String),
36 #[error("memory error: {0}")]
38 Memory(String),
39 #[error("vault init error: {0}")]
41 VaultInit(crate::vault::AgeVaultError),
42 #[error("I/O error: {0}")]
44 Io(#[from] std::io::Error),
45}
46
47pub fn build_provider_for_switch(
58 entry: &ProviderEntry,
59 snapshot: &ProviderConfigSnapshot,
60) -> Result<AnyProvider, BootstrapError> {
61 use zeph_common::secret::Secret;
62 let mut config = Config::default();
66 config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
67 config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
68 config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
69 config.secrets.compatible_api_keys = snapshot
70 .compatible_api_keys
71 .iter()
72 .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
73 .collect();
74 config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
75 config
76 .llm
77 .embedding_model
78 .clone_from(&snapshot.embedding_model);
79 build_provider_from_entry(entry, &config)
80}
81
82#[allow(clippy::too_many_lines)]
92pub fn build_provider_from_entry(
93 entry: &ProviderEntry,
94 config: &Config,
95) -> Result<AnyProvider, BootstrapError> {
96 match entry.provider_type {
97 ProviderKind::Ollama => {
98 let base_url = entry
99 .base_url
100 .as_deref()
101 .unwrap_or("http://localhost:11434");
102 let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
103 let embed = entry
104 .embedding_model
105 .clone()
106 .unwrap_or_else(|| config.llm.embedding_model.clone());
107 let mut provider = OllamaProvider::new(base_url, model, embed);
108 if let Some(ref vm) = entry.vision_model {
109 provider = provider.with_vision_model(vm.clone());
110 }
111 if config.mcp.forward_output_schema {
112 tracing::debug!(
113 "mcp.forward_output_schema is enabled but Ollama does not support \
114 output schema forwarding; setting ignored for this provider"
115 );
116 }
117 Ok(AnyProvider::Ollama(provider))
118 }
119 ProviderKind::Claude => {
120 let api_key = config
121 .secrets
122 .claude_api_key
123 .as_ref()
124 .ok_or_else(|| {
125 BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into())
126 })?
127 .expose()
128 .to_owned();
129 let model = entry
130 .model
131 .clone()
132 .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
133 let max_tokens = entry.max_tokens.unwrap_or(4096);
134 let provider = ClaudeProvider::new(api_key, model, max_tokens)
135 .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
136 .with_extended_context(entry.enable_extended_context)
137 .with_thinking_opt(entry.thinking.clone())
138 .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
139 .with_server_compaction(entry.server_compaction)
140 .with_prompt_cache_ttl(entry.prompt_cache_ttl)
141 .with_output_schema_forwarding(
142 config.mcp.forward_output_schema,
143 config.mcp.output_schema_hint_bytes,
144 config.mcp.max_description_bytes,
145 );
146 tracing::info!(
147 forward = config.mcp.forward_output_schema,
148 "mcp.output_schema.forwarding_configured"
149 );
150 Ok(AnyProvider::Claude(provider))
151 }
152 ProviderKind::OpenAi => {
153 let api_key = config
154 .secrets
155 .openai_api_key
156 .as_ref()
157 .ok_or_else(|| {
158 BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into())
159 })?
160 .expose()
161 .to_owned();
162 let base_url = entry
163 .base_url
164 .clone()
165 .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
166 let model = entry
167 .model
168 .clone()
169 .unwrap_or_else(|| "gpt-4o-mini".to_owned());
170 let max_tokens = entry.max_tokens.unwrap_or(4096);
171 Ok(AnyProvider::OpenAi(
172 OpenAiProvider::new(
173 api_key,
174 base_url,
175 model,
176 max_tokens,
177 entry.embedding_model.clone(),
178 entry.reasoning_effort.clone(),
179 )
180 .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
181 .with_output_schema_forwarding(
182 config.mcp.forward_output_schema,
183 config.mcp.output_schema_hint_bytes,
184 config.mcp.max_description_bytes,
185 ),
186 ))
187 }
188 ProviderKind::Gemini => {
189 let api_key = config
190 .secrets
191 .gemini_api_key
192 .as_ref()
193 .ok_or_else(|| {
194 BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into())
195 })?
196 .expose()
197 .to_owned();
198 let model = entry
199 .model
200 .clone()
201 .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
202 let max_tokens = entry.max_tokens.unwrap_or(8192);
203 let base_url = entry
204 .base_url
205 .clone()
206 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
207 let mut provider = GeminiProvider::new(api_key, model, max_tokens)
208 .with_base_url(base_url)
209 .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
210 if let Some(ref em) = entry.embedding_model {
211 provider = provider.with_embedding_model(em.clone());
212 }
213 if let Some(level) = entry.thinking_level {
214 provider = provider.with_thinking_level(level);
215 }
216 if let Some(budget) = entry.thinking_budget {
217 provider = provider
218 .with_thinking_budget(budget)
219 .map_err(|e| BootstrapError::Provider(e.to_string()))?;
220 }
221 if let Some(include) = entry.include_thoughts {
222 provider = provider.with_include_thoughts(include);
223 }
224 if config.mcp.forward_output_schema {
225 tracing::debug!(
226 "mcp.forward_output_schema is enabled but Gemini does not support \
227 output schema forwarding; setting ignored for this provider"
228 );
229 }
230 Ok(AnyProvider::Gemini(provider))
231 }
232 ProviderKind::Compatible => {
233 let name = entry.name.as_deref().ok_or_else(|| {
234 BootstrapError::Provider(
235 "compatible provider requires 'name' field in [[llm.providers]]".into(),
236 )
237 })?;
238 let base_url = entry.base_url.clone().ok_or_else(|| {
239 BootstrapError::Provider(format!(
240 "compatible provider '{name}' requires 'base_url'"
241 ))
242 })?;
243 let model = entry.model.clone().unwrap_or_default();
244 let api_key = entry.api_key.clone().unwrap_or_else(|| {
245 config
246 .secrets
247 .compatible_api_keys
248 .get(name)
249 .map(|s| s.expose().to_owned())
250 .unwrap_or_default()
251 });
252 let max_tokens = entry.max_tokens.unwrap_or(4096);
253 let provider = CompatibleProvider::new(
254 name.to_owned(),
255 api_key,
256 base_url,
257 model,
258 max_tokens,
259 entry.embedding_model.clone(),
260 )
261 .with_output_schema_forwarding(
262 config.mcp.forward_output_schema,
263 config.mcp.output_schema_hint_bytes,
264 config.mcp.max_description_bytes,
265 );
266 tracing::info!(
267 forward = config.mcp.forward_output_schema,
268 provider = name,
269 "mcp.output_schema.forwarding_configured"
270 );
271 Ok(AnyProvider::Compatible(provider))
272 }
273 #[cfg(feature = "candle")]
274 ProviderKind::Candle => {
275 let candle = entry.candle.as_ref().ok_or_else(|| {
276 BootstrapError::Provider(
277 "candle provider requires 'candle' section in [[llm.providers]]".into(),
278 )
279 })?;
280 let source = match candle.source.as_str() {
281 "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
282 path: std::path::PathBuf::from(&candle.local_path),
283 },
284 _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
285 repo_id: entry
286 .model
287 .clone()
288 .unwrap_or_else(|| config.llm.effective_model().to_owned()),
289 filename: candle.filename.clone(),
290 },
291 };
292 let template =
293 zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
294 let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
295 temperature: candle.generation.temperature,
296 top_p: candle.generation.top_p,
297 top_k: candle.generation.top_k,
298 max_tokens: candle.generation.capped_max_tokens(),
299 seed: candle.generation.seed,
300 repeat_penalty: candle.generation.repeat_penalty,
301 repeat_last_n: candle.generation.repeat_last_n,
302 };
303 let device = select_device(&candle.device)?;
304 let inference_timeout =
307 std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
308 zeph_llm::candle_provider::CandleProvider::new_with_timeout(
309 &source,
310 template,
311 gen_config,
312 candle.embedding_repo.as_deref(),
313 candle.hf_token.as_deref(),
314 device,
315 inference_timeout,
316 )
317 .map(AnyProvider::Candle)
318 .map_err(|e| BootstrapError::Provider(e.to_string()))
319 }
320 #[cfg(not(feature = "candle"))]
321 ProviderKind::Candle => Err(BootstrapError::Provider(
322 "candle feature is not enabled".into(),
323 )),
324 }
325}
326
327#[cfg(feature = "candle")]
338pub fn select_device(
339 preference: &str,
340) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
341 match preference {
342 "metal" => {
343 #[cfg(feature = "metal")]
344 return zeph_llm::candle_provider::Device::new_metal(0)
345 .map_err(|e| BootstrapError::Provider(e.to_string()));
346 #[cfg(not(feature = "metal"))]
347 return Err(BootstrapError::Provider(
348 "candle compiled without metal feature".into(),
349 ));
350 }
351 "cuda" => {
352 #[cfg(feature = "cuda")]
353 return zeph_llm::candle_provider::Device::new_cuda(0)
354 .map_err(|e| BootstrapError::Provider(e.to_string()));
355 #[cfg(not(feature = "cuda"))]
356 return Err(BootstrapError::Provider(
357 "candle compiled without cuda feature".into(),
358 ));
359 }
360 "auto" => {
361 #[cfg(feature = "metal")]
362 if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
363 return Ok(device);
364 }
365 #[cfg(feature = "cuda")]
366 if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
367 return Ok(device);
368 }
369 Ok(zeph_llm::candle_provider::Device::Cpu)
370 }
371 _ => Ok(zeph_llm::candle_provider::Device::Cpu),
372 }
373}
374
375#[must_use]
382pub fn effective_embedding_model(config: &Config) -> String {
383 if let Some(m) = config
385 .llm
386 .providers
387 .iter()
388 .find(|e| e.embed)
389 .and_then(|e| e.embedding_model.as_ref())
390 {
391 return m.clone();
392 }
393 if let Some(m) = config
395 .llm
396 .providers
397 .first()
398 .and_then(|e| e.embedding_model.as_ref())
399 {
400 return m.clone();
401 }
402 config.llm.embedding_model.clone()
403}
404
405#[cfg(test)]
406mod tests {
407 #[cfg(feature = "candle")]
408 use super::select_device;
409
410 #[cfg(feature = "candle")]
411 #[test]
412 fn select_device_cpu_default() {
413 let device = select_device("cpu").unwrap();
414 assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
415 }
416
417 #[cfg(feature = "candle")]
418 #[test]
419 fn select_device_unknown_defaults_to_cpu() {
420 let device = select_device("unknown").unwrap();
421 assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
422 }
423
424 #[cfg(all(feature = "candle", not(feature = "metal")))]
425 #[test]
426 fn select_device_metal_without_feature_errors() {
427 let result = select_device("metal");
428 assert!(result.is_err());
429 assert!(result.unwrap_err().to_string().contains("metal feature"));
430 }
431
432 #[cfg(all(feature = "candle", not(feature = "cuda")))]
433 #[test]
434 fn select_device_cuda_without_feature_errors() {
435 let result = select_device("cuda");
436 assert!(result.is_err());
437 assert!(result.unwrap_err().to_string().contains("cuda feature"));
438 }
439
440 #[cfg(feature = "candle")]
441 #[test]
442 fn select_device_auto_fallback() {
443 let device = select_device("auto").unwrap();
444 assert!(matches!(
445 device,
446 zeph_llm::candle_provider::Device::Cpu
447 | zeph_llm::candle_provider::Device::Cuda(_)
448 | zeph_llm::candle_provider::Device::Metal(_)
449 ));
450 }
451}