zeph_core/
provider_factory.rs1use zeph_llm::any::AnyProvider;
12use zeph_llm::claude::ClaudeProvider;
13use zeph_llm::compatible::CompatibleProvider;
14use zeph_llm::gemini::GeminiProvider;
15use zeph_llm::http::llm_client;
16use zeph_llm::ollama::OllamaProvider;
17use zeph_llm::openai::OpenAiProvider;
18
19use crate::agent::state::ProviderConfigSnapshot;
20use crate::config::{Config, ProviderEntry, ProviderKind};
21
22#[derive(Debug, thiserror::Error)]
29pub enum BootstrapError {
30 #[error("config error: {0}")]
32 Config(#[from] crate::config::ConfigError),
33 #[error("provider error: {0}")]
35 Provider(String),
36 #[error("memory error: {0}")]
38 Memory(String),
39 #[error("vault init error: {0}")]
41 VaultInit(crate::vault::AgeVaultError),
42 #[error("I/O error: {0}")]
44 Io(#[from] std::io::Error),
45}
46
47pub fn build_provider_for_switch(
58 entry: &ProviderEntry,
59 snapshot: &ProviderConfigSnapshot,
60) -> Result<AnyProvider, BootstrapError> {
61 use zeph_common::secret::Secret;
62 let mut config = Config::default();
66 config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
67 config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
68 config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
69 config.secrets.compatible_api_keys = snapshot
70 .compatible_api_keys
71 .iter()
72 .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
73 .collect();
74 config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
75 config
76 .llm
77 .embedding_model
78 .clone_from(&snapshot.embedding_model);
79 build_provider_from_entry(entry, &config)
80}
81
82#[allow(clippy::too_many_lines)]
92pub fn build_provider_from_entry(
93 entry: &ProviderEntry,
94 config: &Config,
95) -> Result<AnyProvider, BootstrapError> {
96 match entry.provider_type {
97 ProviderKind::Ollama => {
98 let base_url = entry
99 .base_url
100 .as_deref()
101 .unwrap_or("http://localhost:11434");
102 let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
103 let embed = entry
104 .embedding_model
105 .clone()
106 .unwrap_or_else(|| config.llm.embedding_model.clone());
107 let mut provider = OllamaProvider::new(base_url, model, embed);
108 if let Some(ref vm) = entry.vision_model {
109 provider = provider.with_vision_model(vm.clone());
110 }
111 Ok(AnyProvider::Ollama(provider))
112 }
113 ProviderKind::Claude => {
114 let api_key = config
115 .secrets
116 .claude_api_key
117 .as_ref()
118 .ok_or_else(|| {
119 BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into())
120 })?
121 .expose()
122 .to_owned();
123 let model = entry
124 .model
125 .clone()
126 .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
127 let max_tokens = entry.max_tokens.unwrap_or(4096);
128 let provider = ClaudeProvider::new(api_key, model, max_tokens)
129 .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
130 .with_extended_context(entry.enable_extended_context)
131 .with_thinking_opt(entry.thinking.clone())
132 .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
133 .with_server_compaction(entry.server_compaction);
134 Ok(AnyProvider::Claude(provider))
135 }
136 ProviderKind::OpenAi => {
137 let api_key = config
138 .secrets
139 .openai_api_key
140 .as_ref()
141 .ok_or_else(|| {
142 BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into())
143 })?
144 .expose()
145 .to_owned();
146 let base_url = entry
147 .base_url
148 .clone()
149 .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
150 let model = entry
151 .model
152 .clone()
153 .unwrap_or_else(|| "gpt-4o-mini".to_owned());
154 let max_tokens = entry.max_tokens.unwrap_or(4096);
155 Ok(AnyProvider::OpenAi(
156 OpenAiProvider::new(
157 api_key,
158 base_url,
159 model,
160 max_tokens,
161 entry.embedding_model.clone(),
162 entry.reasoning_effort.clone(),
163 )
164 .with_client(llm_client(config.timeouts.llm_request_timeout_secs)),
165 ))
166 }
167 ProviderKind::Gemini => {
168 let api_key = config
169 .secrets
170 .gemini_api_key
171 .as_ref()
172 .ok_or_else(|| {
173 BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into())
174 })?
175 .expose()
176 .to_owned();
177 let model = entry
178 .model
179 .clone()
180 .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
181 let max_tokens = entry.max_tokens.unwrap_or(8192);
182 let base_url = entry
183 .base_url
184 .clone()
185 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
186 let mut provider = GeminiProvider::new(api_key, model, max_tokens)
187 .with_base_url(base_url)
188 .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
189 if let Some(ref em) = entry.embedding_model {
190 provider = provider.with_embedding_model(em.clone());
191 }
192 if let Some(level) = entry.thinking_level {
193 provider = provider.with_thinking_level(level);
194 }
195 if let Some(budget) = entry.thinking_budget {
196 provider = provider
197 .with_thinking_budget(budget)
198 .map_err(|e| BootstrapError::Provider(e.to_string()))?;
199 }
200 if let Some(include) = entry.include_thoughts {
201 provider = provider.with_include_thoughts(include);
202 }
203 Ok(AnyProvider::Gemini(provider))
204 }
205 ProviderKind::Compatible => {
206 let name = entry.name.as_deref().ok_or_else(|| {
207 BootstrapError::Provider(
208 "compatible provider requires 'name' field in [[llm.providers]]".into(),
209 )
210 })?;
211 let base_url = entry.base_url.clone().ok_or_else(|| {
212 BootstrapError::Provider(format!(
213 "compatible provider '{name}' requires 'base_url'"
214 ))
215 })?;
216 let model = entry.model.clone().unwrap_or_default();
217 let api_key = entry.api_key.clone().unwrap_or_else(|| {
218 config
219 .secrets
220 .compatible_api_keys
221 .get(name)
222 .map(|s| s.expose().to_owned())
223 .unwrap_or_default()
224 });
225 let max_tokens = entry.max_tokens.unwrap_or(4096);
226 Ok(AnyProvider::Compatible(CompatibleProvider::new(
227 name.to_owned(),
228 api_key,
229 base_url,
230 model,
231 max_tokens,
232 entry.embedding_model.clone(),
233 )))
234 }
235 #[cfg(feature = "candle")]
236 ProviderKind::Candle => {
237 let candle = entry.candle.as_ref().ok_or_else(|| {
238 BootstrapError::Provider(
239 "candle provider requires 'candle' section in [[llm.providers]]".into(),
240 )
241 })?;
242 let source = match candle.source.as_str() {
243 "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
244 path: std::path::PathBuf::from(&candle.local_path),
245 },
246 _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
247 repo_id: entry
248 .model
249 .clone()
250 .unwrap_or_else(|| config.llm.effective_model().to_owned()),
251 filename: candle.filename.clone(),
252 },
253 };
254 let template =
255 zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
256 let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
257 temperature: candle.generation.temperature,
258 top_p: candle.generation.top_p,
259 top_k: candle.generation.top_k,
260 max_tokens: candle.generation.capped_max_tokens(),
261 seed: candle.generation.seed,
262 repeat_penalty: candle.generation.repeat_penalty,
263 repeat_last_n: candle.generation.repeat_last_n,
264 };
265 let device = select_device(&candle.device)?;
266 let inference_timeout =
269 std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
270 zeph_llm::candle_provider::CandleProvider::new_with_timeout(
271 &source,
272 template,
273 gen_config,
274 candle.embedding_repo.as_deref(),
275 candle.hf_token.as_deref(),
276 device,
277 inference_timeout,
278 )
279 .map(AnyProvider::Candle)
280 .map_err(|e| BootstrapError::Provider(e.to_string()))
281 }
282 #[cfg(not(feature = "candle"))]
283 ProviderKind::Candle => Err(BootstrapError::Provider(
284 "candle feature is not enabled".into(),
285 )),
286 }
287}
288
289#[cfg(feature = "candle")]
300pub fn select_device(
301 preference: &str,
302) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
303 match preference {
304 "metal" => {
305 #[cfg(feature = "metal")]
306 return zeph_llm::candle_provider::Device::new_metal(0)
307 .map_err(|e| BootstrapError::Provider(e.to_string()));
308 #[cfg(not(feature = "metal"))]
309 return Err(BootstrapError::Provider(
310 "candle compiled without metal feature".into(),
311 ));
312 }
313 "cuda" => {
314 #[cfg(feature = "cuda")]
315 return zeph_llm::candle_provider::Device::new_cuda(0)
316 .map_err(|e| BootstrapError::Provider(e.to_string()));
317 #[cfg(not(feature = "cuda"))]
318 return Err(BootstrapError::Provider(
319 "candle compiled without cuda feature".into(),
320 ));
321 }
322 "auto" => {
323 #[cfg(feature = "metal")]
324 if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
325 return Ok(device);
326 }
327 #[cfg(feature = "cuda")]
328 if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
329 return Ok(device);
330 }
331 Ok(zeph_llm::candle_provider::Device::Cpu)
332 }
333 _ => Ok(zeph_llm::candle_provider::Device::Cpu),
334 }
335}
336
337#[must_use]
344pub fn effective_embedding_model(config: &Config) -> String {
345 if let Some(m) = config
347 .llm
348 .providers
349 .iter()
350 .find(|e| e.embed)
351 .and_then(|e| e.embedding_model.as_ref())
352 {
353 return m.clone();
354 }
355 if let Some(m) = config
357 .llm
358 .providers
359 .first()
360 .and_then(|e| e.embedding_model.as_ref())
361 {
362 return m.clone();
363 }
364 config.llm.embedding_model.clone()
365}
366
367#[cfg(test)]
368mod tests {
369 #[cfg(feature = "candle")]
370 use super::select_device;
371
372 #[cfg(feature = "candle")]
373 #[test]
374 fn select_device_cpu_default() {
375 let device = select_device("cpu").unwrap();
376 assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
377 }
378
379 #[cfg(feature = "candle")]
380 #[test]
381 fn select_device_unknown_defaults_to_cpu() {
382 let device = select_device("unknown").unwrap();
383 assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
384 }
385
386 #[cfg(all(feature = "candle", not(feature = "metal")))]
387 #[test]
388 fn select_device_metal_without_feature_errors() {
389 let result = select_device("metal");
390 assert!(result.is_err());
391 assert!(result.unwrap_err().to_string().contains("metal feature"));
392 }
393
394 #[cfg(all(feature = "candle", not(feature = "cuda")))]
395 #[test]
396 fn select_device_cuda_without_feature_errors() {
397 let result = select_device("cuda");
398 assert!(result.is_err());
399 assert!(result.unwrap_err().to_string().contains("cuda feature"));
400 }
401
402 #[cfg(feature = "candle")]
403 #[test]
404 fn select_device_auto_fallback() {
405 let device = select_device("auto").unwrap();
406 assert!(matches!(
407 device,
408 zeph_llm::candle_provider::Device::Cpu
409 | zeph_llm::candle_provider::Device::Cuda(_)
410 | zeph_llm::candle_provider::Device::Metal(_)
411 ));
412 }
413}