1use zeph_llm::any::AnyProvider;
5use zeph_llm::router::triage::{ComplexityTier, TriageRouter};
6
7#[derive(Debug, thiserror::Error)]
14pub enum BootstrapError {
15 #[error("config error: {0}")]
16 Config(#[from] crate::config::ConfigError),
17 #[error("provider error: {0}")]
18 Provider(String),
19 #[error("memory error: {0}")]
20 Memory(String),
21 #[error("vault init error: {0}")]
22 VaultInit(crate::vault::AgeVaultError),
23 #[error("I/O error: {0}")]
24 Io(#[from] std::io::Error),
25}
26use zeph_llm::claude::ClaudeProvider;
27use zeph_llm::compatible::CompatibleProvider;
28use zeph_llm::gemini::GeminiProvider;
29use zeph_llm::http::llm_client;
30use zeph_llm::ollama::OllamaProvider;
31use zeph_llm::openai::OpenAiProvider;
32use zeph_llm::router::cascade::ClassifierMode;
33use zeph_llm::router::{BanditRouterConfig, CascadeRouterConfig, RouterProvider};
34
35use crate::agent::state::ProviderConfigSnapshot;
36use crate::config::{Config, LlmRoutingStrategy, ProviderEntry, ProviderKind};
37
38pub fn create_provider(config: &Config) -> Result<AnyProvider, BootstrapError> {
39 create_provider_from_pool(config)
40}
41
42fn build_cascade_router_config(
43 cascade_cfg: &crate::config::CascadeConfig,
44 config: &Config,
45) -> CascadeRouterConfig {
46 let classifier_mode = match cascade_cfg.classifier_mode {
47 crate::config::CascadeClassifierMode::Heuristic => ClassifierMode::Heuristic,
48 crate::config::CascadeClassifierMode::Judge => ClassifierMode::Judge,
49 };
50 let raw_threshold = cascade_cfg.quality_threshold;
52 let quality_threshold = if raw_threshold.is_finite() {
53 raw_threshold.clamp(0.0, 1.0)
54 } else {
55 tracing::warn!(
56 raw_threshold,
57 "cascade quality_threshold is non-finite, defaulting to 0.5"
58 );
59 0.5
60 };
61 if (quality_threshold - raw_threshold).abs() > f64::EPSILON {
62 tracing::warn!(
63 raw_threshold,
64 clamped = quality_threshold,
65 "cascade quality_threshold out of range [0.0, 1.0], clamped"
66 );
67 }
68 let window_size = cascade_cfg.window_size.max(1);
70 if window_size != cascade_cfg.window_size {
71 tracing::warn!(
72 raw = cascade_cfg.window_size,
73 "cascade window_size=0 is invalid, clamped to 1"
74 );
75 }
76 let summary_provider = if classifier_mode == ClassifierMode::Judge {
78 if let Some(model_spec) = config.llm.summary_model.as_deref() {
79 match create_summary_provider(model_spec, config) {
80 Ok(p) => Some(p),
81 Err(e) => {
82 tracing::warn!(
83 error = %e,
84 "cascade: failed to build judge provider, falling back to heuristic"
85 );
86 None
87 }
88 }
89 } else {
90 tracing::warn!(
91 "cascade: classifier_mode=judge requires [llm] summary_model to \
92 be configured; falling back to heuristic"
93 );
94 None
95 }
96 } else {
97 None
98 };
99 CascadeRouterConfig {
100 quality_threshold,
101 max_escalations: cascade_cfg.max_escalations,
102 classifier_mode,
103 window_size,
104 max_cascade_tokens: cascade_cfg.max_cascade_tokens,
105 summary_provider,
106 cost_tiers: cascade_cfg.cost_tiers.clone(),
107 }
108}
109
110pub fn create_named_provider(name: &str, config: &Config) -> Result<AnyProvider, BootstrapError> {
114 let entry = config
115 .llm
116 .providers
117 .iter()
118 .find(|e| e.effective_name() == name || e.provider_type.as_str() == name)
119 .ok_or_else(|| {
120 BootstrapError::Provider(format!("provider '{name}' not found in [[llm.providers]]"))
121 })?;
122 build_provider_from_entry(entry, config)
123}
124
125pub fn create_summary_provider(
132 model_spec: &str,
133 config: &Config,
134) -> Result<AnyProvider, BootstrapError> {
135 if let Some(entry) = config
137 .llm
138 .providers
139 .iter()
140 .find(|e| e.effective_name() == model_spec || e.provider_type.as_str() == model_spec)
141 {
142 return build_provider_from_entry(entry, config);
143 }
144
145 if let Some(((_, model), entry)) = model_spec.split_once('/').and_then(|(b, m)| {
147 config
148 .llm
149 .providers
150 .iter()
151 .find(|e| e.provider_type.as_str() == b || e.effective_name() == b)
152 .map(|e| ((b, m), e))
153 }) {
154 let mut cloned = entry.clone();
155 cloned.model = Some(model.to_owned());
156 cloned.max_tokens = Some(cloned.max_tokens.unwrap_or(4096).min(4096));
158 return build_provider_from_entry(&cloned, config);
159 }
160
161 Err(BootstrapError::Provider(format!(
162 "summary_model '{model_spec}' not found in [[llm.providers]]. \
163 Use a provider name or 'type/model' shorthand (e.g. 'ollama/qwen3:1.7b')."
164 )))
165}
166
167#[cfg(feature = "candle")]
168pub fn select_device(
169 preference: &str,
170) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
171 match preference {
172 "metal" => {
173 #[cfg(feature = "metal")]
174 return zeph_llm::candle_provider::Device::new_metal(0)
175 .map_err(|e| BootstrapError::Provider(e.to_string()));
176 #[cfg(not(feature = "metal"))]
177 return Err(BootstrapError::Provider(
178 "candle compiled without metal feature".into(),
179 ));
180 }
181 "cuda" => {
182 #[cfg(feature = "cuda")]
183 return zeph_llm::candle_provider::Device::new_cuda(0)
184 .map_err(|e| BootstrapError::Provider(e.to_string()));
185 #[cfg(not(feature = "cuda"))]
186 return Err(BootstrapError::Provider(
187 "candle compiled without cuda feature".into(),
188 ));
189 }
190 "auto" => {
191 #[cfg(feature = "metal")]
192 if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
193 return Ok(device);
194 }
195 #[cfg(feature = "cuda")]
196 if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
197 return Ok(device);
198 }
199 Ok(zeph_llm::candle_provider::Device::Cpu)
200 }
201 _ => Ok(zeph_llm::candle_provider::Device::Cpu),
202 }
203}
204
205#[cfg(feature = "candle")]
206fn build_candle_provider(
207 source: &zeph_llm::candle_provider::loader::ModelSource,
208 candle_cfg: &crate::config::CandleConfig,
209 device_pref: &str,
210) -> Result<AnyProvider, BootstrapError> {
211 let template =
212 zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle_cfg.chat_template);
213 let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
214 temperature: candle_cfg.generation.temperature,
215 top_p: candle_cfg.generation.top_p,
216 top_k: candle_cfg.generation.top_k,
217 max_tokens: candle_cfg.generation.capped_max_tokens(),
218 seed: candle_cfg.generation.seed,
219 repeat_penalty: candle_cfg.generation.repeat_penalty,
220 repeat_last_n: candle_cfg.generation.repeat_last_n,
221 };
222 let device = select_device(device_pref)?;
223 zeph_llm::candle_provider::CandleProvider::new(
224 source,
225 template,
226 gen_config,
227 candle_cfg.embedding_repo.as_deref(),
228 candle_cfg.hf_token.as_deref(),
229 device,
230 )
231 .map(AnyProvider::Candle)
232 .map_err(|e| BootstrapError::Provider(e.to_string()))
233}
234
235pub fn build_provider_for_switch(
246 entry: &ProviderEntry,
247 snapshot: &ProviderConfigSnapshot,
248) -> Result<AnyProvider, BootstrapError> {
249 use zeph_common::secret::Secret;
250 let mut config = Config::default();
254 config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
255 config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
256 config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
257 config.secrets.compatible_api_keys = snapshot
258 .compatible_api_keys
259 .iter()
260 .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
261 .collect();
262 config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
263 config
264 .llm
265 .embedding_model
266 .clone_from(&snapshot.embedding_model);
267 build_provider_from_entry(entry, &config)
268}
269
270#[allow(clippy::too_many_lines)]
280pub fn build_provider_from_entry(
281 entry: &ProviderEntry,
282 config: &Config,
283) -> Result<AnyProvider, BootstrapError> {
284 match entry.provider_type {
285 ProviderKind::Ollama => {
286 let base_url = entry
287 .base_url
288 .as_deref()
289 .unwrap_or("http://localhost:11434");
290 let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
291 let embed = entry
292 .embedding_model
293 .clone()
294 .unwrap_or_else(|| config.llm.embedding_model.clone());
295 let tool_use = entry.tool_use;
296 let mut provider = OllamaProvider::new(base_url, model, embed).with_tool_use(tool_use);
297 if let Some(ref vm) = entry.vision_model {
298 provider = provider.with_vision_model(vm.clone());
299 }
300 Ok(AnyProvider::Ollama(provider))
301 }
302 ProviderKind::Claude => {
303 let api_key = config
304 .secrets
305 .claude_api_key
306 .as_ref()
307 .ok_or_else(|| {
308 BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into())
309 })?
310 .expose()
311 .to_owned();
312 let model = entry
313 .model
314 .clone()
315 .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
316 let max_tokens = entry.max_tokens.unwrap_or(4096);
317 let provider = ClaudeProvider::new(api_key, model, max_tokens)
318 .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
319 .with_extended_context(entry.enable_extended_context)
320 .with_thinking_opt(entry.thinking.clone())
321 .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
322 .with_server_compaction(entry.server_compaction);
323 Ok(AnyProvider::Claude(provider))
324 }
325 ProviderKind::OpenAi => {
326 let api_key = config
327 .secrets
328 .openai_api_key
329 .as_ref()
330 .ok_or_else(|| {
331 BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into())
332 })?
333 .expose()
334 .to_owned();
335 let base_url = entry
336 .base_url
337 .clone()
338 .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
339 let model = entry
340 .model
341 .clone()
342 .unwrap_or_else(|| "gpt-4o-mini".to_owned());
343 let max_tokens = entry.max_tokens.unwrap_or(4096);
344 Ok(AnyProvider::OpenAi(
345 OpenAiProvider::new(
346 api_key,
347 base_url,
348 model,
349 max_tokens,
350 entry.embedding_model.clone(),
351 entry.reasoning_effort.clone(),
352 )
353 .with_client(llm_client(config.timeouts.llm_request_timeout_secs)),
354 ))
355 }
356 ProviderKind::Gemini => {
357 let api_key = config
358 .secrets
359 .gemini_api_key
360 .as_ref()
361 .ok_or_else(|| {
362 BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into())
363 })?
364 .expose()
365 .to_owned();
366 let model = entry
367 .model
368 .clone()
369 .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
370 let max_tokens = entry.max_tokens.unwrap_or(8192);
371 let base_url = entry
372 .base_url
373 .clone()
374 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
375 let mut provider = GeminiProvider::new(api_key, model, max_tokens)
376 .with_base_url(base_url)
377 .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
378 if let Some(ref em) = entry.embedding_model {
379 provider = provider.with_embedding_model(em.clone());
380 }
381 if let Some(level) = entry.thinking_level {
382 provider = provider.with_thinking_level(level);
383 }
384 if let Some(budget) = entry.thinking_budget {
385 provider = provider
386 .with_thinking_budget(budget)
387 .map_err(|e| BootstrapError::Provider(e.to_string()))?;
388 }
389 if let Some(include) = entry.include_thoughts {
390 provider = provider.with_include_thoughts(include);
391 }
392 Ok(AnyProvider::Gemini(provider))
393 }
394 ProviderKind::Compatible => {
395 let name = entry.name.as_deref().ok_or_else(|| {
396 BootstrapError::Provider(
397 "compatible provider requires 'name' field in [[llm.providers]]".into(),
398 )
399 })?;
400 let base_url = entry.base_url.clone().ok_or_else(|| {
401 BootstrapError::Provider(format!(
402 "compatible provider '{name}' requires 'base_url'"
403 ))
404 })?;
405 let model = entry.model.clone().unwrap_or_default();
406 let api_key = entry.api_key.clone().unwrap_or_else(|| {
407 config
408 .secrets
409 .compatible_api_keys
410 .get(name)
411 .map(|s| s.expose().to_owned())
412 .unwrap_or_default()
413 });
414 let max_tokens = entry.max_tokens.unwrap_or(4096);
415 Ok(AnyProvider::Compatible(CompatibleProvider::new(
416 name.to_owned(),
417 api_key,
418 base_url,
419 model,
420 max_tokens,
421 entry.embedding_model.clone(),
422 )))
423 }
424 #[cfg(feature = "candle")]
425 ProviderKind::Candle => {
426 let candle = entry.candle.as_ref().ok_or_else(|| {
427 BootstrapError::Provider(
428 "candle provider requires 'candle' section in [[llm.providers]]".into(),
429 )
430 })?;
431 let source = match candle.source.as_str() {
432 "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
433 path: std::path::PathBuf::from(&candle.local_path),
434 },
435 _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
436 repo_id: entry
437 .model
438 .clone()
439 .unwrap_or_else(|| config.llm.effective_model().to_owned()),
440 filename: candle.filename.clone(),
441 },
442 };
443 let candle_cfg_adapter = crate::config::CandleConfig {
444 source: candle.source.clone(),
445 local_path: candle.local_path.clone(),
446 filename: candle.filename.clone(),
447 chat_template: candle.chat_template.clone(),
448 device: candle.device.clone(),
449 embedding_repo: candle.embedding_repo.clone(),
450 hf_token: candle.hf_token.clone(),
451 generation: candle.generation.clone(),
452 };
453 build_candle_provider(&source, &candle_cfg_adapter, &candle.device)
454 }
455 #[cfg(not(feature = "candle"))]
456 ProviderKind::Candle => Err(BootstrapError::Provider(
457 "candle feature is not enabled".into(),
458 )),
459 }
460}
461
462#[allow(clippy::too_many_lines)]
469fn create_provider_from_pool(config: &Config) -> Result<AnyProvider, BootstrapError> {
470 let pool = &config.llm.providers;
471
472 if pool.is_empty() {
474 let base_url = config.llm.effective_base_url();
475 let model = config.llm.effective_model();
476 let embed = &config.llm.embedding_model;
477 return Ok(AnyProvider::Ollama(OllamaProvider::new(
478 base_url,
479 model.to_owned(),
480 embed.clone(),
481 )));
482 }
483
484 match config.llm.routing {
485 LlmRoutingStrategy::None => build_single_provider_from_pool(pool, config),
486 LlmRoutingStrategy::Ema => {
487 let providers = build_all_pool_providers(pool, config)?;
488 let raw_alpha = config.llm.router_ema_alpha;
489 let alpha = raw_alpha.clamp(f64::MIN_POSITIVE, 1.0);
490 if (alpha - raw_alpha).abs() > f64::EPSILON {
491 tracing::warn!(
492 raw_alpha,
493 clamped = alpha,
494 "router_ema_alpha out of range [MIN_POSITIVE, 1.0], clamped"
495 );
496 }
497 Ok(AnyProvider::Router(Box::new(
498 RouterProvider::new(providers).with_ema(alpha, config.llm.router_reorder_interval),
499 )))
500 }
501 LlmRoutingStrategy::Thompson => {
502 let providers = build_all_pool_providers(pool, config)?;
503 let state_path = config
504 .llm
505 .router
506 .as_ref()
507 .and_then(|r| r.thompson_state_path.as_deref())
508 .map(std::path::Path::new);
509 Ok(AnyProvider::Router(Box::new(
510 RouterProvider::new(providers).with_thompson(state_path),
511 )))
512 }
513 LlmRoutingStrategy::Cascade => {
514 let providers = build_all_pool_providers(pool, config)?;
515 let cascade_cfg = config
516 .llm
517 .router
518 .as_ref()
519 .and_then(|r| r.cascade.clone())
520 .unwrap_or_default();
521 let router_cascade_cfg = build_cascade_router_config(&cascade_cfg, config);
522 Ok(AnyProvider::Router(Box::new(
523 RouterProvider::new(providers).with_cascade(router_cascade_cfg),
524 )))
525 }
526 LlmRoutingStrategy::Bandit => {
527 let providers = build_all_pool_providers(pool, config)?;
528 let bandit_cfg = config
529 .llm
530 .router
531 .as_ref()
532 .and_then(|r| r.bandit.clone())
533 .unwrap_or_default();
534 let state_path = bandit_cfg.state_path.as_deref().map(std::path::Path::new);
535 let router_bandit_cfg = BanditRouterConfig {
536 alpha: bandit_cfg.alpha,
537 dim: bandit_cfg.dim,
538 cost_weight: bandit_cfg.cost_weight.clamp(0.0, 1.0),
539 decay_factor: bandit_cfg.decay_factor,
540 warmup_queries: 0, embedding_timeout_ms: bandit_cfg.embedding_timeout_ms,
542 cache_size: bandit_cfg.cache_size,
543 memory_confidence_threshold: bandit_cfg.memory_confidence_threshold.clamp(0.0, 1.0),
544 };
545 let embed_provider = if bandit_cfg.embedding_provider.is_empty() {
547 None
548 } else if let Some(entry) = pool
549 .iter()
550 .find(|e| e.effective_name() == bandit_cfg.embedding_provider)
551 {
552 match build_provider_from_entry(entry, config) {
553 Ok(p) => Some(p),
554 Err(e) => {
555 tracing::warn!(
556 provider = %bandit_cfg.embedding_provider,
557 error = %e,
558 "bandit: embedding provider failed to init, bandit will use Thompson fallback"
559 );
560 None
561 }
562 }
563 } else {
564 tracing::warn!(
565 provider = %bandit_cfg.embedding_provider,
566 "bandit: embedding_provider not found in [[llm.providers]], \
567 bandit will use Thompson fallback"
568 );
569 None
570 };
571 Ok(AnyProvider::Router(Box::new(
572 RouterProvider::new(providers).with_bandit(
573 router_bandit_cfg,
574 state_path,
575 embed_provider,
576 ),
577 )))
578 }
579 LlmRoutingStrategy::Task => {
580 tracing::warn!(
582 "routing = \"task\" is not yet implemented; \
583 falling back to single provider from pool"
584 );
585 build_single_provider_from_pool(pool, config)
586 }
587 LlmRoutingStrategy::Triage => build_triage_provider(pool, config),
588 }
589}
590
591fn build_all_pool_providers(
594 pool: &[ProviderEntry],
595 config: &Config,
596) -> Result<Vec<AnyProvider>, BootstrapError> {
597 let mut providers = Vec::new();
598 for entry in pool {
599 match build_provider_from_entry(entry, config) {
600 Ok(p) => providers.push(p),
601 Err(e) => {
602 tracing::warn!(
603 provider = entry.name.as_deref().unwrap_or("?"),
604 error = %e,
605 "skipping pool provider during routing initialization"
606 );
607 }
608 }
609 }
610 if providers.is_empty() {
611 return Err(BootstrapError::Provider(
612 "routing enabled but no providers in [[llm.providers]] could be initialized".into(),
613 ));
614 }
615 Ok(providers)
616}
617
618fn build_triage_provider(
624 pool: &[crate::config::ProviderEntry],
625 config: &crate::config::Config,
626) -> Result<AnyProvider, BootstrapError> {
627 let cr = config.llm.complexity_routing.as_ref().ok_or_else(|| {
628 BootstrapError::Provider(
629 "routing = \"triage\" requires [llm.complexity_routing] section".into(),
630 )
631 })?;
632
633 let default_triage_name = pool
635 .first()
636 .map(crate::config::ProviderEntry::effective_name)
637 .unwrap_or_default();
638 let triage_prov_name = cr
639 .triage_provider
640 .as_deref()
641 .unwrap_or(default_triage_name.as_str());
642 let triage_provider = create_named_provider(triage_prov_name, config).map_err(|e| {
643 BootstrapError::Provider(format!(
644 "triage_provider '{triage_prov_name}' not found in [[llm.providers]]: {e}"
645 ))
646 })?;
647
648 let tier_config: [(ComplexityTier, Option<&str>); 4] = [
650 (ComplexityTier::Simple, cr.tiers.simple.as_deref()),
651 (ComplexityTier::Medium, cr.tiers.medium.as_deref()),
652 (ComplexityTier::Complex, cr.tiers.complex.as_deref()),
653 (ComplexityTier::Expert, cr.tiers.expert.as_deref()),
654 ];
655
656 let mut tier_providers: Vec<(ComplexityTier, AnyProvider)> = Vec::new();
660 let mut tier_config_names: Vec<&str> = Vec::new();
661 for (tier, maybe_name) in &tier_config {
662 let Some(name) = maybe_name else { continue };
663 match create_named_provider(name, config) {
664 Ok(p) => {
665 tier_providers.push((*tier, p));
666 tier_config_names.push(name);
667 }
668 Err(e) => {
669 tracing::warn!(
670 tier = tier.as_str(),
671 provider = name,
672 error = %e,
673 "triage: skipping tier provider (not found in pool)"
674 );
675 }
676 }
677 }
678
679 if tier_providers.is_empty() {
680 tracing::warn!(
682 "triage routing: no tier providers configured, \
683 falling back to single provider"
684 );
685 return build_single_provider_from_pool(pool, config);
686 }
687
688 if cr.bypass_single_provider
690 && let Some(first_name) = tier_config_names
691 .first()
692 .copied()
693 .filter(|&n| tier_config_names.iter().all(|m| *m == n))
694 {
695 tracing::debug!(
696 provider = first_name,
697 "triage routing: all tiers map to same config entry, bypassing triage"
698 );
699 return build_single_provider_from_pool(pool, config);
700 }
701
702 let router = TriageRouter::new(
703 triage_provider,
704 tier_providers,
705 cr.triage_timeout_secs,
706 cr.max_triage_tokens,
707 );
708 Ok(AnyProvider::Triage(Box::new(router)))
709}
710
711fn build_single_provider_from_pool(
713 pool: &[ProviderEntry],
714 config: &Config,
715) -> Result<AnyProvider, BootstrapError> {
716 let primary_idx = pool.iter().position(|e| e.default).unwrap_or(0);
717 let primary = &pool[primary_idx];
718 match build_provider_from_entry(primary, config) {
719 Ok(p) => Ok(p),
720 Err(e) => {
721 let name = primary.name.as_deref().unwrap_or("primary");
722 tracing::warn!(provider = name, error = %e, "primary provider failed, trying next");
723 for (i, entry) in pool.iter().enumerate() {
724 if i == primary_idx {
725 continue;
726 }
727 match build_provider_from_entry(entry, config) {
728 Ok(p) => return Ok(p),
729 Err(e2) => {
730 tracing::warn!(
731 provider = entry.name.as_deref().unwrap_or("?"),
732 error = %e2,
733 "fallback provider failed"
734 );
735 }
736 }
737 }
738 Err(BootstrapError::Provider(format!(
739 "all providers in [[llm.providers]] failed to initialize; first error: {e}"
740 )))
741 }
742 }
743}