Skip to main content

zeph_core/
provider_factory.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Pure provider factory helpers: build `AnyProvider` instances from config entries.
5//!
6//! This module contains configuration-to-provider transformation functions that are
7//! used by internal `zeph-core` subsystems (skills, tools, autodream, session config).
8//! They are intentionally separated from bootstrap orchestration logic so that provider
9//! construction can be reasoned about and tested independently of startup sequencing.
10
11use zeph_llm::any::AnyProvider;
12use zeph_llm::claude::ClaudeProvider;
13#[cfg(feature = "cocoon")]
14use zeph_llm::cocoon::{CocoonClient, CocoonProvider};
15use zeph_llm::compatible::CompatibleProvider;
16use zeph_llm::gemini::GeminiProvider;
17#[cfg(feature = "gonka")]
18use zeph_llm::gonka::endpoints::{EndpointPool, GonkaEndpoint};
19#[cfg(feature = "gonka")]
20use zeph_llm::gonka::{GonkaProvider, RequestSigner};
21use zeph_llm::http::llm_client;
22use zeph_llm::ollama::OllamaProvider;
23use zeph_llm::openai::OpenAiProvider;
24#[cfg(feature = "gonka")]
25use zeroize::Zeroizing;
26
27use crate::agent::state::ProviderConfigSnapshot;
28use crate::config::{Config, ProviderEntry, ProviderKind};
29
30/// Error type for provider construction failures.
31///
32/// String-based variants flatten the error chain intentionally: bootstrap errors are
33/// terminal (the application exits), so downcasting is not needed at this stage.
34/// If a future phase requires programmatic retry on specific failures, expand these
35/// variants into typed sub-errors.
36#[derive(Debug, thiserror::Error)]
37pub enum BootstrapError {
38    /// Configuration validation failed.
39    #[error("config error: {0}")]
40    Config(#[from] crate::config::ConfigError),
41    /// Provider construction failed (missing secrets, unsupported kind, etc.).
42    #[error("provider error: {0}")]
43    Provider(String),
44    /// Memory subsystem initialization failed.
45    #[error("memory error: {0}")]
46    Memory(String),
47    /// Age vault initialization failed.
48    #[error("vault init error: {0}")]
49    VaultInit(crate::vault::AgeVaultError),
50    /// I/O error during bootstrap.
51    #[error("I/O error: {0}")]
52    Io(#[from] std::io::Error),
53}
54
55/// Build an `AnyProvider` from a `ProviderEntry` using a runtime config snapshot.
56///
57/// Called by the `/provider <name>` slash command to switch providers at runtime without
58/// requiring the full `Config`. Router and Orchestrator provider kinds are not supported
59/// for runtime switching — they require the full provider pool to be re-initialized.
60///
61/// # Errors
62///
63/// Returns `BootstrapError::Provider` when the provider kind is unsupported for runtime
64/// switching, a required secret is missing, or the entry is misconfigured.
65pub fn build_provider_for_switch(
66    entry: &ProviderEntry,
67    snapshot: &ProviderConfigSnapshot,
68) -> Result<AnyProvider, BootstrapError> {
69    use zeph_common::secret::Secret;
70    // Reconstruct a minimal Config from the snapshot so we can reuse build_provider_from_entry.
71    // Only fields read by build_provider_from_entry are populated; everything else uses defaults.
72    // Secrets are stored as plain strings in the snapshot because Secret does not implement Clone.
73    let mut config = Config::default();
74    config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
75    config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
76    config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
77    config.secrets.compatible_api_keys = snapshot
78        .compatible_api_keys
79        .iter()
80        .map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
81        .collect();
82    config.secrets.gonka_private_key = snapshot
83        .gonka_private_key
84        .as_ref()
85        .map(|z| Secret::new(z.as_str()));
86    config.secrets.gonka_address = snapshot.gonka_address.as_deref().map(Secret::new);
87    config.secrets.cocoon_access_hash = snapshot.cocoon_access_hash.as_deref().map(Secret::new);
88    config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
89    config
90        .llm
91        .embedding_model
92        .clone_from(&snapshot.embedding_model);
93    build_provider_from_entry(entry, &config)
94}
95
96/// Build an `AnyProvider` from a unified `ProviderEntry` (new `[[llm.providers]]` format).
97///
98/// All provider-specific fields come from `entry`; the global `config` is used only for
99/// secrets and timeout settings.
100///
101/// # Errors
102///
103/// Returns `BootstrapError::Provider` when a required secret is missing or an entry is
104/// misconfigured (e.g. compatible provider without a name).
105pub fn build_provider_from_entry(
106    entry: &ProviderEntry,
107    config: &Config,
108) -> Result<AnyProvider, BootstrapError> {
109    match entry.provider_type {
110        ProviderKind::Ollama => Ok(build_ollama_provider(entry, config)),
111        ProviderKind::Claude => build_claude_provider(entry, config),
112        ProviderKind::OpenAi => build_openai_provider(entry, config),
113        ProviderKind::Gemini => build_gemini_provider(entry, config),
114        ProviderKind::Compatible => build_compatible_provider(entry, config),
115        #[cfg(feature = "candle")]
116        ProviderKind::Candle => build_candle_provider(entry, config),
117        #[cfg(not(feature = "candle"))]
118        ProviderKind::Candle => Err(BootstrapError::Provider(
119            "candle feature is not enabled".into(),
120        )),
121        #[cfg(feature = "gonka")]
122        ProviderKind::Gonka => build_gonka_provider(entry, config),
123        #[cfg(not(feature = "gonka"))]
124        ProviderKind::Gonka => Err(BootstrapError::Provider(
125            "gonka feature is not enabled; rebuild with --features gonka".into(),
126        )),
127        #[cfg(feature = "cocoon")]
128        ProviderKind::Cocoon => build_cocoon_provider(entry, config),
129        #[cfg(not(feature = "cocoon"))]
130        ProviderKind::Cocoon => Err(BootstrapError::Provider(
131            "cocoon feature is not enabled; rebuild with --features cocoon".into(),
132        )),
133        _ => Err(BootstrapError::Provider(format!(
134            "unknown provider kind: {:?}",
135            entry.provider_type
136        ))),
137    }
138}
139
140fn build_ollama_provider(entry: &ProviderEntry, config: &Config) -> AnyProvider {
141    let base_url = entry
142        .base_url
143        .as_deref()
144        .unwrap_or("http://localhost:11434");
145    let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
146    let embed = entry
147        .embedding_model
148        .clone()
149        .unwrap_or_else(|| config.llm.embedding_model.clone());
150    let mut provider = OllamaProvider::new(base_url, model, embed);
151    if let Some(ref vm) = entry.vision_model {
152        provider = provider.with_vision_model(vm.clone());
153    }
154    if config.mcp.forward_output_schema {
155        tracing::debug!(
156            "mcp.forward_output_schema is enabled but Ollama does not support \
157             output schema forwarding; setting ignored for this provider"
158        );
159    }
160    AnyProvider::Ollama(provider)
161}
162
163fn build_claude_provider(
164    entry: &ProviderEntry,
165    config: &Config,
166) -> Result<AnyProvider, BootstrapError> {
167    let api_key = config
168        .secrets
169        .claude_api_key
170        .as_ref()
171        .ok_or_else(|| BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into()))?
172        .expose()
173        .to_owned();
174    let model = entry
175        .model
176        .clone()
177        .unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
178    let max_tokens = entry.max_tokens.unwrap_or(4096);
179    let provider = ClaudeProvider::new(api_key, model, max_tokens)
180        .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
181        .with_extended_context(entry.enable_extended_context)
182        .with_thinking_opt(entry.thinking.clone())
183        .map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
184        .with_server_compaction(entry.server_compaction)
185        .with_prompt_cache_ttl(entry.prompt_cache_ttl)
186        .with_output_schema_forwarding(
187            config.mcp.forward_output_schema,
188            config.mcp.output_schema_hint_bytes,
189            config.mcp.max_description_bytes,
190        );
191    tracing::info!(
192        forward = config.mcp.forward_output_schema,
193        "mcp.output_schema.forwarding_configured"
194    );
195    Ok(AnyProvider::Claude(provider))
196}
197
198fn build_openai_provider(
199    entry: &ProviderEntry,
200    config: &Config,
201) -> Result<AnyProvider, BootstrapError> {
202    let api_key = config
203        .secrets
204        .openai_api_key
205        .as_ref()
206        .ok_or_else(|| BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into()))?
207        .expose()
208        .to_owned();
209    let base_url = entry
210        .base_url
211        .clone()
212        .unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
213    let model = entry
214        .model
215        .clone()
216        .unwrap_or_else(|| "gpt-4o-mini".to_owned());
217    let max_tokens = entry.max_tokens.unwrap_or(4096);
218    Ok(AnyProvider::OpenAi(
219        OpenAiProvider::new(zeph_llm::OpenAiConfig {
220            api_key,
221            base_url,
222            model,
223            max_tokens,
224            embedding_model: entry.embedding_model.clone(),
225            reasoning_effort: entry.reasoning_effort.clone(),
226        })
227        .with_client(llm_client(config.timeouts.llm_request_timeout_secs))
228        .with_output_schema_forwarding(
229            config.mcp.forward_output_schema,
230            config.mcp.output_schema_hint_bytes,
231            config.mcp.max_description_bytes,
232        ),
233    ))
234}
235
236fn build_gemini_provider(
237    entry: &ProviderEntry,
238    config: &Config,
239) -> Result<AnyProvider, BootstrapError> {
240    let api_key = config
241        .secrets
242        .gemini_api_key
243        .as_ref()
244        .ok_or_else(|| BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into()))?
245        .expose()
246        .to_owned();
247    let model = entry
248        .model
249        .clone()
250        .unwrap_or_else(|| "gemini-2.0-flash".to_owned());
251    let max_tokens = entry.max_tokens.unwrap_or(8192);
252    let base_url = entry
253        .base_url
254        .clone()
255        .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
256    let mut provider = GeminiProvider::new(api_key, model, max_tokens)
257        .with_base_url(base_url)
258        .with_client(llm_client(config.timeouts.llm_request_timeout_secs));
259    if let Some(ref em) = entry.embedding_model {
260        provider = provider.with_embedding_model(em.clone());
261    }
262    if let Some(level) = entry.thinking_level {
263        provider = provider.with_thinking_level(level);
264    }
265    if let Some(budget) = entry.thinking_budget {
266        provider = provider
267            .with_thinking_budget(budget)
268            .map_err(|e| BootstrapError::Provider(e.to_string()))?;
269    }
270    if let Some(include) = entry.include_thoughts {
271        provider = provider.with_include_thoughts(include);
272    }
273    if config.mcp.forward_output_schema {
274        tracing::debug!(
275            "mcp.forward_output_schema is enabled but Gemini does not support \
276             output schema forwarding; setting ignored for this provider"
277        );
278    }
279    Ok(AnyProvider::Gemini(provider))
280}
281
282fn build_compatible_provider(
283    entry: &ProviderEntry,
284    config: &Config,
285) -> Result<AnyProvider, BootstrapError> {
286    let name = entry.name.as_deref().ok_or_else(|| {
287        BootstrapError::Provider(
288            "compatible provider requires 'name' field in [[llm.providers]]".into(),
289        )
290    })?;
291    let base_url = entry.base_url.clone().ok_or_else(|| {
292        BootstrapError::Provider(format!("compatible provider '{name}' requires 'base_url'"))
293    })?;
294    let model = entry.model.clone().unwrap_or_default();
295    let api_key = entry.api_key.clone().unwrap_or_else(|| {
296        config
297            .secrets
298            .compatible_api_keys
299            .get(name)
300            .map(|s| s.expose().to_owned())
301            .unwrap_or_default()
302    });
303    let max_tokens = entry.max_tokens.unwrap_or(4096);
304    let provider = CompatibleProvider::new(zeph_llm::CompatibleConfig {
305        provider_name: name.to_owned(),
306        api_key,
307        base_url,
308        model,
309        max_tokens,
310        embedding_model: entry.embedding_model.clone(),
311    })
312    .with_output_schema_forwarding(
313        config.mcp.forward_output_schema,
314        config.mcp.output_schema_hint_bytes,
315        config.mcp.max_description_bytes,
316    );
317    tracing::info!(
318        forward = config.mcp.forward_output_schema,
319        provider = name,
320        "mcp.output_schema.forwarding_configured"
321    );
322    Ok(AnyProvider::Compatible(provider))
323}
324
325#[cfg(feature = "gonka")]
326fn build_gonka_provider(
327    entry: &ProviderEntry,
328    config: &Config,
329) -> Result<AnyProvider, BootstrapError> {
330    let _span = tracing::info_span!("core.provider_factory.build_gonka").entered();
331
332    let private_key_hex: Zeroizing<String> = Zeroizing::new(
333        config
334            .secrets
335            .gonka_private_key
336            .as_ref()
337            .ok_or_else(|| {
338                BootstrapError::Provider(
339                    "ZEPH_GONKA_PRIVATE_KEY not found in vault; set it with: zeph vault set ZEPH_GONKA_PRIVATE_KEY <hex>".into(),
340                )
341            })?
342            .expose()
343            .to_owned(),
344    );
345
346    let chain_prefix = entry.effective_gonka_chain_prefix().to_owned();
347    let signer = RequestSigner::from_hex(&private_key_hex, &chain_prefix)
348        .map_err(|e| BootstrapError::Provider(format!("invalid Gonka private key: {e}")))?;
349
350    if let Some(ref configured_address) = config.secrets.gonka_address {
351        let configured = configured_address.expose().to_lowercase();
352        let derived = signer.address().to_lowercase();
353        if configured != derived {
354            return Err(BootstrapError::Provider(format!(
355                "ZEPH_GONKA_ADDRESS does not match address derived from private key \
356                 (configured: {configured}, derived: {derived})"
357            )));
358        }
359    } else {
360        tracing::info!(
361            address = signer.address(),
362            "Gonka: using address derived from private key (ZEPH_GONKA_ADDRESS not set)"
363        );
364    }
365
366    if entry.gonka_nodes.is_empty() {
367        return Err(BootstrapError::Provider(
368            "Gonka provider entry must have at least one node in gonka_nodes".into(),
369        ));
370    }
371
372    let endpoints: Vec<GonkaEndpoint> = entry
373        .gonka_nodes
374        .iter()
375        .map(|n| GonkaEndpoint {
376            base_url: n.url.clone(),
377            address: n.address.clone(),
378        })
379        .collect();
380
381    let pool = EndpointPool::new(endpoints).map_err(|e| {
382        BootstrapError::Provider(format!("failed to build Gonka endpoint pool: {e}"))
383    })?;
384
385    let model = entry.model.clone().unwrap_or_else(|| "gpt-4o".to_owned());
386    let max_tokens = entry.max_tokens.unwrap_or(4096);
387    let timeout = std::time::Duration::from_secs(config.timeouts.llm_request_timeout_secs);
388
389    let provider = GonkaProvider::new(zeph_llm::gonka::GonkaConfig {
390        signer: std::sync::Arc::new(signer),
391        pool: std::sync::Arc::new(pool),
392        model,
393        max_tokens,
394        embedding_model: entry.embedding_model.clone(),
395        timeout,
396    });
397
398    Ok(AnyProvider::Gonka(provider))
399}
400
401/// Build a [`CocoonProvider`] from a `[[llm.providers]]` entry.
402///
403/// Resolves the access hash from the age vault when `cocoon_access_hash` is `Some(_)` in the
404/// entry. If the vault key is absent an explicit, actionable error is returned.
405///
406/// # Errors
407///
408/// Returns [`BootstrapError::Provider`] when the vault key `ZEPH_COCOON_ACCESS_HASH` is
409/// expected (field is `Some`) but not present in the resolved secrets.
410#[cfg(feature = "cocoon")]
411fn build_cocoon_provider(
412    entry: &ProviderEntry,
413    config: &Config,
414) -> Result<AnyProvider, BootstrapError> {
415    let _span = tracing::info_span!("core.provider_factory.build_cocoon").entered();
416
417    let base_url = entry
418        .cocoon_client_url
419        .as_deref()
420        .unwrap_or("http://localhost:10000");
421
422    // Validate URL at construction time (MINOR-3): warn if not localhost.
423    if !base_url.starts_with("http://localhost")
424        && !base_url.starts_with("http://127.0.0.1")
425        && !base_url.starts_with("http://[::1]")
426        && !base_url.starts_with("https://localhost")
427        && !base_url.starts_with("https://127.0.0.1")
428        && !base_url.starts_with("https://[::1]")
429    {
430        tracing::warn!(
431            url = base_url,
432            "cocoon_client_url points to a non-localhost host; \
433             ensure this is intentional (expected sidecar on localhost)"
434        );
435    }
436
437    if entry
438        .cocoon_access_hash
439        .as_deref()
440        .is_some_and(|v| !v.is_empty())
441    {
442        tracing::warn!(
443            "cocoon_access_hash in config file appears to contain a raw value; \
444             this field should be empty — the actual hash must be stored in the vault: \
445             zeph vault set ZEPH_COCOON_ACCESS_HASH <hash>"
446        );
447    }
448
449    let access_hash = if entry.cocoon_access_hash.is_some() {
450        let hash = config
451            .secrets
452            .cocoon_access_hash
453            .as_ref()
454            .ok_or_else(|| {
455                BootstrapError::Provider(
456                    "ZEPH_COCOON_ACCESS_HASH not found in vault; set it with: \
457                     zeph vault set ZEPH_COCOON_ACCESS_HASH <hash>"
458                        .into(),
459                )
460            })?
461            .expose()
462            .to_owned();
463        Some(hash)
464    } else {
465        None
466    };
467
468    let timeout = std::time::Duration::from_secs(config.timeouts.llm_request_timeout_secs);
469    let client = std::sync::Arc::new(CocoonClient::new(base_url, access_hash, timeout));
470
471    if entry.cocoon_health_check {
472        let client_clone = std::sync::Arc::clone(&client);
473        // Fire-and-forget: intentional. The health check is advisory-only; a failure
474        // does not block provider construction.
475        drop(tokio::spawn(async move {
476            match client_clone.health_check().await {
477                Ok(h) => {
478                    tracing::info!(
479                        proxy_connected = h.proxy_connected,
480                        worker_count = h.worker_count,
481                        "cocoon sidecar health check passed"
482                    );
483                }
484                Err(e) => {
485                    tracing::warn!(
486                        error = %e,
487                        "cocoon sidecar health check failed; \
488                         inference requests will return LlmError::Unavailable until the sidecar is running"
489                    );
490                }
491            }
492        }));
493    }
494
495    let model = entry
496        .model
497        .clone()
498        .unwrap_or_else(|| "Qwen/Qwen3-0.6B".to_owned());
499    let max_tokens = entry.max_tokens.unwrap_or(4096);
500    let provider = CocoonProvider::new(model, max_tokens, entry.embedding_model.clone(), client);
501
502    Ok(AnyProvider::Cocoon(provider))
503}
504
505#[cfg(feature = "candle")]
506fn build_candle_provider(
507    entry: &ProviderEntry,
508    config: &Config,
509) -> Result<AnyProvider, BootstrapError> {
510    let candle = entry.candle.as_ref().ok_or_else(|| {
511        BootstrapError::Provider(
512            "candle provider requires 'candle' section in [[llm.providers]]".into(),
513        )
514    })?;
515    let source = match candle.source.as_str() {
516        "local" => zeph_llm::candle_provider::loader::ModelSource::Local {
517            path: std::path::PathBuf::from(&candle.local_path),
518        },
519        _ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
520            repo_id: entry
521                .model
522                .clone()
523                .unwrap_or_else(|| config.llm.effective_model().to_owned()),
524            filename: candle.filename.clone(),
525        },
526    };
527    let template =
528        zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
529    let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
530        temperature: candle.generation.temperature,
531        top_p: candle.generation.top_p,
532        top_k: candle.generation.top_k,
533        max_tokens: candle.generation.capped_max_tokens(),
534        seed: candle.generation.seed,
535        repeat_penalty: candle.generation.repeat_penalty,
536        repeat_last_n: candle.generation.repeat_last_n,
537    };
538    let device = select_device(&candle.device)?;
539    // Floor at 1s so that inference_timeout_secs = 0 does not cause every request to
540    // immediately time out.
541    let inference_timeout = std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
542    zeph_llm::candle_provider::CandleProvider::new_with_timeout(
543        &source,
544        template,
545        gen_config,
546        candle.embedding_repo.as_deref(),
547        candle.hf_token.as_deref(),
548        device,
549        inference_timeout,
550    )
551    .map(AnyProvider::Candle)
552    .map_err(|e| BootstrapError::Provider(e.to_string()))
553}
554
555/// Select the candle compute device based on a string preference.
556///
557/// Resolution order: `"metal"` → Metal GPU (requires `metal` feature),
558/// `"cuda"` → CUDA GPU (requires `cuda` feature), `"auto"` → best available,
559/// anything else → CPU.
560///
561/// # Errors
562///
563/// Returns `BootstrapError::Provider` when the requested device is not available (e.g.
564/// `"metal"` requested but compiled without the `metal` feature).
565#[cfg(feature = "candle")]
566pub fn select_device(
567    preference: &str,
568) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
569    match preference {
570        "metal" => {
571            #[cfg(feature = "metal")]
572            return zeph_llm::candle_provider::Device::new_metal(0)
573                .map_err(|e| BootstrapError::Provider(e.to_string()));
574            #[cfg(not(feature = "metal"))]
575            return Err(BootstrapError::Provider(
576                "candle compiled without metal feature".into(),
577            ));
578        }
579        "cuda" => {
580            #[cfg(feature = "cuda")]
581            return zeph_llm::candle_provider::Device::new_cuda(0)
582                .map_err(|e| BootstrapError::Provider(e.to_string()));
583            #[cfg(not(feature = "cuda"))]
584            return Err(BootstrapError::Provider(
585                "candle compiled without cuda feature".into(),
586            ));
587        }
588        "auto" => {
589            #[cfg(feature = "metal")]
590            if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
591                return Ok(device);
592            }
593            #[cfg(feature = "cuda")]
594            if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
595                return Ok(device);
596            }
597            Ok(zeph_llm::candle_provider::Device::Cpu)
598        }
599        _ => Ok(zeph_llm::candle_provider::Device::Cpu),
600    }
601}
602
603/// Determine the effective embedding model name for the memory subsystem.
604///
605/// Resolution order:
606/// 1. `embedding_model` from the `[[llm.providers]]` entry marked `embed = true`
607/// 2. `embedding_model` from the first entry in `[[llm.providers]]`
608/// 3. `[llm] embedding_model` global fallback
609#[must_use]
610pub fn effective_embedding_model(config: &Config) -> String {
611    // Prefer a dedicated embed provider.
612    if let Some(m) = config
613        .llm
614        .providers
615        .iter()
616        .find(|e| e.embed)
617        .and_then(|e| e.embedding_model.as_ref())
618    {
619        return m.clone();
620    }
621    // Fall back to the first provider's embedding model.
622    if let Some(m) = config
623        .llm
624        .providers
625        .first()
626        .and_then(|e| e.embedding_model.as_ref())
627    {
628        return m.clone();
629    }
630    config.llm.embedding_model.clone()
631}
632
633/// Resolve the stable embedding model name for skill-matcher collection versioning.
634///
635/// This uses the same entry resolution as the embedding provider itself: the entry
636/// with `embed = true`, preferring its `embedding_model` field and falling back to
637/// its `model` field. Using the actual provider's model name prevents the
638/// `model_has_changed` check in [`zeph_memory::embedding_registry`] from triggering
639/// false positives that would rebuild the `zeph_skills` collection on every startup.
640///
641/// Falls back to [`effective_embedding_model`] when no dedicated embed entry exists.
642#[must_use]
643pub fn stable_skill_embedding_model(config: &Config) -> String {
644    // Find the dedicated embed entry (same lookup as `create_embedding_provider`).
645    let embed_entry = config.llm.providers.iter().find(|e| e.embed).or_else(|| {
646        config
647            .llm
648            .providers
649            .iter()
650            .find(|e| e.embedding_model.is_some())
651    });
652
653    if let Some(entry) = embed_entry {
654        // Prefer the explicit `embedding_model` field; fall back to the `model` field.
655        if let Some(em) = entry.embedding_model.as_ref().filter(|s| !s.is_empty()) {
656            return em.clone();
657        }
658        if let Some(m) = entry.model.as_ref().filter(|s| !s.is_empty()) {
659            return m.clone();
660        }
661    }
662
663    // No dedicated embed entry — fall back to the general embedding model resolution.
664    effective_embedding_model(config)
665}
666
667#[cfg(test)]
668mod tests {
669    #[cfg(feature = "candle")]
670    use super::select_device;
671
672    #[cfg(feature = "candle")]
673    #[test]
674    fn select_device_cpu_default() {
675        let device = select_device("cpu").unwrap();
676        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
677    }
678
679    #[cfg(feature = "candle")]
680    #[test]
681    fn select_device_unknown_defaults_to_cpu() {
682        let device = select_device("unknown").unwrap();
683        assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
684    }
685
686    #[cfg(all(feature = "candle", not(feature = "metal")))]
687    #[test]
688    fn select_device_metal_without_feature_errors() {
689        let result = select_device("metal");
690        assert!(result.is_err());
691        assert!(result.unwrap_err().to_string().contains("metal feature"));
692    }
693
694    #[cfg(all(feature = "candle", not(feature = "cuda")))]
695    #[test]
696    fn select_device_cuda_without_feature_errors() {
697        let result = select_device("cuda");
698        assert!(result.is_err());
699        assert!(result.unwrap_err().to_string().contains("cuda feature"));
700    }
701
702    #[cfg(feature = "candle")]
703    #[test]
704    fn select_device_auto_fallback() {
705        let device = select_device("auto").unwrap();
706        assert!(matches!(
707            device,
708            zeph_llm::candle_provider::Device::Cpu
709                | zeph_llm::candle_provider::Device::Cuda(_)
710                | zeph_llm::candle_provider::Device::Metal(_)
711        ));
712    }
713
714    #[cfg(any(feature = "gonka", feature = "cocoon"))]
715    use super::build_provider_from_entry;
716    use super::{effective_embedding_model, stable_skill_embedding_model};
717    use crate::config::{Config, ProviderKind};
718    use zeph_config::providers::ProviderEntry;
719
720    #[cfg(feature = "gonka")]
721    mod gonka_tests {
722        use super::*;
723        use zeph_common::secret::Secret;
724        use zeph_config::GonkaNode;
725        use zeph_llm::LlmProvider;
726
727        fn gonka_entry_with_nodes(nodes: Vec<GonkaNode>) -> ProviderEntry {
728            ProviderEntry {
729                provider_type: ProviderKind::Gonka,
730                name: Some("gonka".into()),
731                model: Some("gpt-4o".into()),
732                gonka_nodes: nodes,
733                ..ProviderEntry::default()
734            }
735        }
736
737        fn valid_nodes() -> Vec<GonkaNode> {
738            vec![GonkaNode {
739                url: "https://node1.gonka.ai".into(),
740                address: "gonka1w508d6qejxtdg4y5r3zarvary0c5xw7k2gsyg6".into(),
741                name: Some("node1".into()),
742            }]
743        }
744
745        const VALID_PRIV_KEY: &str =
746            "0000000000000000000000000000000000000000000000000000000000000001";
747
748        #[test]
749        fn build_gonka_provider_missing_key_returns_error() {
750            let entry = gonka_entry_with_nodes(valid_nodes());
751            let config = Config::default();
752            let result = build_provider_from_entry(&entry, &config);
753            assert!(result.is_err());
754            let msg = result.unwrap_err().to_string();
755            assert!(
756                msg.contains("ZEPH_GONKA_PRIVATE_KEY"),
757                "error must mention missing key: {msg}"
758            );
759        }
760
761        #[test]
762        fn build_gonka_provider_empty_nodes_returns_error() {
763            let entry = gonka_entry_with_nodes(vec![]);
764            let mut config = Config::default();
765            config.secrets.gonka_private_key = Some(Secret::new(VALID_PRIV_KEY));
766            let result = build_provider_from_entry(&entry, &config);
767            assert!(result.is_err());
768            let msg = result.unwrap_err().to_string();
769            assert!(
770                msg.contains("gonka_nodes") || msg.contains("node"),
771                "error must mention empty nodes: {msg}"
772            );
773        }
774
775        #[test]
776        fn build_gonka_provider_address_mismatch_returns_error() {
777            let entry = gonka_entry_with_nodes(valid_nodes());
778            let mut config = Config::default();
779            config.secrets.gonka_private_key = Some(Secret::new(VALID_PRIV_KEY));
780            config.secrets.gonka_address =
781                Some(Secret::new("gonka1wrongaddress000000000000000000000000000"));
782            let result = build_provider_from_entry(&entry, &config);
783            assert!(result.is_err());
784            let msg = result.unwrap_err().to_string();
785            assert!(
786                msg.contains("does not match"),
787                "error must mention address mismatch: {msg}"
788            );
789        }
790
791        #[test]
792        fn build_gonka_provider_happy_path() {
793            let entry = gonka_entry_with_nodes(valid_nodes());
794            let mut config = Config::default();
795            config.secrets.gonka_private_key = Some(Secret::new(VALID_PRIV_KEY));
796            let result = build_provider_from_entry(&entry, &config);
797            assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
798            let provider = result.unwrap();
799            assert_eq!(provider.name(), "gonka");
800        }
801    }
802
803    fn make_provider_entry(
804        embed: bool,
805        model: Option<&str>,
806        embedding_model: Option<&str>,
807    ) -> ProviderEntry {
808        ProviderEntry {
809            provider_type: ProviderKind::Ollama,
810            embed,
811            model: model.map(str::to_owned),
812            embedding_model: embedding_model.map(str::to_owned),
813            ..ProviderEntry::default()
814        }
815    }
816
817    #[test]
818    fn stable_skill_embedding_model_prefers_embedding_model_field() {
819        let mut config = Config::default();
820        config.llm.providers = vec![make_provider_entry(
821            true,
822            Some("chat-model"),
823            Some("embed-v2"),
824        )];
825        assert_eq!(stable_skill_embedding_model(&config), "embed-v2");
826    }
827
828    #[test]
829    fn stable_skill_embedding_model_falls_back_to_model_field() {
830        let mut config = Config::default();
831        config.llm.providers = vec![make_provider_entry(
832            true,
833            Some("nomic-embed-text-v2-moe:latest"),
834            None,
835        )];
836        assert_eq!(
837            stable_skill_embedding_model(&config),
838            "nomic-embed-text-v2-moe:latest"
839        );
840    }
841
842    #[test]
843    fn stable_skill_embedding_model_finds_embed_flag_entry() {
844        let mut config = Config::default();
845        config.llm.providers = vec![
846            make_provider_entry(false, Some("chat-model"), None),
847            make_provider_entry(true, Some("embed-model"), Some("text-embed-3")),
848        ];
849        assert_eq!(stable_skill_embedding_model(&config), "text-embed-3");
850    }
851
852    #[test]
853    fn stable_skill_embedding_model_falls_back_to_effective_when_no_embed_entry() {
854        let mut config = Config::default();
855        config.llm.embedding_model = "global-embed-model".to_owned();
856        // No embed=true entry, no embedding_model field set — falls back to effective_embedding_model.
857        config.llm.providers = vec![make_provider_entry(false, Some("chat"), None)];
858        assert_eq!(
859            stable_skill_embedding_model(&config),
860            effective_embedding_model(&config)
861        );
862    }
863
864    #[cfg(feature = "cocoon")]
865    mod cocoon_tests {
866        use super::*;
867
868        fn cocoon_entry(access_hash: Option<&str>) -> ProviderEntry {
869            ProviderEntry {
870                provider_type: ProviderKind::Cocoon,
871                name: Some("cocoon".into()),
872                model: Some("Qwen/Qwen3-0.6B".into()),
873                cocoon_client_url: Some("http://localhost:10000".into()),
874                cocoon_access_hash: access_hash.map(str::to_owned),
875                cocoon_health_check: false,
876                ..ProviderEntry::default()
877            }
878        }
879
880        /// `cocoon_access_hash = Some("")` sentinel with no vault key must return an error.
881        #[test]
882        fn cocoon_access_hash_gate_vault_miss_errors() {
883            let entry = cocoon_entry(Some(""));
884            let config = Config::default(); // secrets.cocoon_access_hash = None
885            let result = build_provider_from_entry(&entry, &config);
886            assert!(
887                result.is_err(),
888                "expected error when vault key is absent but sentinel is set"
889            );
890            let err_str = result.unwrap_err().to_string();
891            assert!(
892                err_str.contains("ZEPH_COCOON_ACCESS_HASH"),
893                "error should mention the vault key: {err_str}"
894            );
895        }
896
897        /// `cocoon_access_hash = None` must succeed without touching the vault (health check off).
898        #[test]
899        fn cocoon_no_access_hash_gate_succeeds_without_vault() {
900            let entry = cocoon_entry(None);
901            let config = Config::default();
902            let result = build_provider_from_entry(&entry, &config);
903            assert!(
904                result.is_ok(),
905                "expected success when no access hash requested: {:?}",
906                result.err()
907            );
908        }
909    }
910}