Skip to main content

walrus_model/
config.rs

1//! Provider configuration.
2//!
3//! `ProviderConfig` for remote providers — kind inferred from model name
4//! prefix via `kind()`. Local models use the built-in registry instead.
5//! `Loader` selects which mistralrs builder to use for local models.
6
7use anyhow::{Result, bail};
8use compact_str::CompactString;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11
12/// API protocol standard for remote providers.
13///
14/// Only two wire formats exist: OpenAI-compatible and Anthropic.
15/// Defaults to `OpenAI` when omitted in config.
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum ApiStandard {
19    /// OpenAI-compatible chat completions API (covers DeepSeek, Grok, Qwen, Kimi, Ollama, etc.).
20    #[default]
21    OpenAI,
22    /// Anthropic Messages API.
23    Anthropic,
24}
25
26/// Remote provider configuration.
27///
28/// Any model name is valid — the `standard` field (or auto-detection from
29/// `base_url`) determines which API protocol to use. Local models are handled
30/// by the built-in registry, not by this config.
31#[derive(Debug, Serialize, Deserialize, Clone)]
32pub struct ProviderConfig {
33    /// Model identifier sent to the remote API.
34    pub model: CompactString,
35    /// API key for remote providers. Supports `${ENV_VAR}` expansion at the
36    /// daemon layer.
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub api_key: Option<String>,
39    /// Base URL for the remote provider endpoint.
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub base_url: Option<String>,
42    /// API protocol standard. Defaults to OpenAI if omitted.
43    #[serde(default)]
44    pub standard: ApiStandard,
45}
46
47impl ProviderConfig {
48    /// Resolve the effective API standard.
49    ///
50    /// Returns `Anthropic` if the field is explicitly set to `Anthropic`,
51    /// or if `base_url` contains "anthropic". Otherwise `OpenAI`.
52    pub fn effective_standard(&self) -> ApiStandard {
53        if self.standard == ApiStandard::Anthropic {
54            return ApiStandard::Anthropic;
55        }
56        if let Some(url) = &self.base_url
57            && url.contains("anthropic")
58        {
59            return ApiStandard::Anthropic;
60        }
61        ApiStandard::OpenAI
62    }
63
64    /// Validate field combinations.
65    ///
66    /// Called on startup and on provider add/reload.
67    pub fn validate(&self) -> Result<()> {
68        if self.model.is_empty() {
69            bail!("model is required");
70        }
71        // Remote providers: api_key is required unless base_url is set
72        // (e.g. Ollama which is keyless with a local base_url).
73        if self.api_key.is_none() && self.base_url.is_none() {
74            bail!(
75                "remote provider '{}' requires api_key or base_url",
76                self.model
77            );
78        }
79        Ok(())
80    }
81}
82
83/// Selects which mistralrs model builder to use for local inference.
84///
85/// Defaults to `Text` when omitted in config.
86#[derive(Debug, Serialize, Deserialize, Clone, Copy, Default)]
87#[serde(rename_all = "snake_case")]
88pub enum Loader {
89    /// `TextModelBuilder` — standard text models.
90    #[default]
91    Text,
92    /// `GgufModelBuilder` — GGUF quantized models.
93    Gguf,
94    /// `VisionModelBuilder` — vision-language models.
95    Vision,
96}
97
98/// Custom HuggingFace model configuration for local inference.
99///
100/// Allows users to run any HuggingFace model not in the built-in registry.
101/// Memory and quantization are auto-selected at runtime based on system RAM.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct HfModelConfig {
104    /// HuggingFace repository ID (e.g. `"myorg/MyModel-7B-GGUF"`).
105    pub model_id: String,
106    /// Which mistralrs builder to use.
107    #[serde(default)]
108    pub loader: Loader,
109    /// Explicit GGUF filename (required for GGUF models without a standard
110    /// naming convention).
111    #[serde(default, skip_serializing_if = "Option::is_none")]
112    pub gguf_file: Option<String>,
113    /// Custom chat template override.
114    #[serde(default, skip_serializing_if = "Option::is_none")]
115    pub chat_template: Option<String>,
116}
117
118/// Model provider configuration for the daemon.
119///
120/// Remote providers are configured in `providers`. Local models come from the
121/// built-in registry or from user-defined `models`. The active model name
122/// lives in `[walrus].model`.
123#[derive(Debug, Clone, Serialize, Deserialize, Default)]
124pub struct ModelConfig {
125    /// Optional embedding model
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub embedding: Option<CompactString>,
128    /// Remote providers (local models come from the built-in registry)
129    #[serde(default)]
130    pub providers: BTreeMap<CompactString, ProviderConfig>,
131    /// Custom HuggingFace models for local inference
132    #[serde(default)]
133    pub models: BTreeMap<CompactString, HfModelConfig>,
134}
135
136// MemoryThreshold is generated by build.rs from registry TOML files.
137include!(concat!(env!("OUT_DIR"), "/quantization.rs"));