Skip to main content

vv_llm/
settings.rs

1use crate::{BackendType, VvLlmError};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::{collections::HashMap, fs, path::Path};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub struct EndpointConfig {
8    pub id: String,
9    #[serde(default)]
10    pub api_base: Option<String>,
11    #[serde(default)]
12    pub api_key: Option<String>,
13    #[serde(default)]
14    pub organization: Option<String>,
15    #[serde(default)]
16    pub endpoint_type: Option<String>,
17    #[serde(default)]
18    pub region: Option<String>,
19    #[serde(default)]
20    pub is_bedrock: Option<bool>,
21    #[serde(default)]
22    pub is_vertex: Option<bool>,
23    #[serde(default)]
24    pub credentials: Value,
25    #[serde(default)]
26    pub extra: HashMap<String, Value>,
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
30pub struct BackendConfig {
31    #[serde(default)]
32    pub models: HashMap<String, ModelConfig>,
33    #[serde(flatten)]
34    pub extra: HashMap<String, Value>,
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct ModelConfig {
39    pub id: String,
40    #[serde(default)]
41    pub endpoints: Vec<EndpointBinding>,
42    #[serde(default)]
43    pub context_length: Option<u32>,
44    #[serde(default)]
45    pub max_output_tokens: Option<u32>,
46    #[serde(default)]
47    pub function_call_available: Option<bool>,
48    #[serde(default)]
49    pub response_format_available: Option<bool>,
50    #[serde(default)]
51    pub native_multimodal: Option<bool>,
52    #[serde(default)]
53    pub protocol: Option<String>,
54    #[serde(default)]
55    pub request_mapping: Option<Value>,
56    #[serde(default)]
57    pub response_mapping: Option<Value>,
58    #[serde(flatten)]
59    pub extra: HashMap<String, Value>,
60}
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
63pub struct LlmSettings {
64    #[serde(rename = "VERSION", default)]
65    pub version: Option<String>,
66    #[serde(default)]
67    pub endpoints: Vec<EndpointConfig>,
68    #[serde(default)]
69    pub backends: HashMap<String, BackendConfig>,
70    #[serde(default)]
71    pub embedding_backends: HashMap<String, BackendConfig>,
72    #[serde(default)]
73    pub rerank_backends: HashMap<String, BackendConfig>,
74    #[serde(flatten)]
75    pub extra: HashMap<String, Value>,
76}
77
78#[derive(Debug, Clone, PartialEq)]
79pub struct ResolvedModelConfig {
80    pub backend: String,
81    pub model: ModelConfig,
82    pub model_id: String,
83    pub endpoint: EndpointConfig,
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum EndpointBinding {
89    Id(String),
90    Config {
91        endpoint_id: String,
92        #[serde(default)]
93        model_id: Option<String>,
94        #[serde(default)]
95        enabled: Option<bool>,
96        #[serde(flatten)]
97        extra: HashMap<String, Value>,
98    },
99}
100
101impl EndpointBinding {
102    pub fn endpoint_id(&self) -> &str {
103        match self {
104            Self::Id(endpoint_id) => endpoint_id,
105            Self::Config { endpoint_id, .. } => endpoint_id,
106        }
107    }
108
109    pub fn model_id<'a>(&'a self, default_model_id: &'a str) -> &'a str {
110        match self {
111            Self::Id(_) => default_model_id,
112            Self::Config { model_id, .. } => model_id.as_deref().unwrap_or(default_model_id),
113        }
114    }
115
116    pub fn enabled(&self) -> bool {
117        match self {
118            Self::Id(_) => true,
119            Self::Config { enabled, .. } => enabled.unwrap_or(true),
120        }
121    }
122}
123
124impl LlmSettings {
125    pub fn from_json_str(raw: &str) -> Result<Self, VvLlmError> {
126        Ok(serde_json::from_str(raw)?)
127    }
128
129    pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self, VvLlmError> {
130        let raw = fs::read_to_string(path.as_ref())
131            .map_err(|error| VvLlmError::Configuration(error.to_string()))?;
132        Self::from_json_str(&raw)
133    }
134
135    pub fn resolve_chat_model(
136        &self,
137        backend: BackendType,
138        model_id: &str,
139    ) -> Result<ResolvedModelConfig, VvLlmError> {
140        self.resolve_model_in_map(&self.backends, backend.as_str(), model_id)
141    }
142
143    pub fn resolve_embedding_model(
144        &self,
145        backend: &str,
146        model_id: &str,
147    ) -> Result<ResolvedModelConfig, VvLlmError> {
148        self.resolve_model_in_map(&self.embedding_backends, backend, model_id)
149    }
150
151    pub fn resolve_rerank_model(
152        &self,
153        backend: &str,
154        model_id: &str,
155    ) -> Result<ResolvedModelConfig, VvLlmError> {
156        self.resolve_model_in_map(&self.rerank_backends, backend, model_id)
157    }
158
159    fn resolve_model_in_map(
160        &self,
161        map: &HashMap<String, BackendConfig>,
162        backend: &str,
163        model_id: &str,
164    ) -> Result<ResolvedModelConfig, VvLlmError> {
165        let backend_config = map.get(backend).ok_or_else(|| VvLlmError::ModelNotFound {
166            backend: backend.to_string(),
167            model: model_id.to_string(),
168        })?;
169        let model = backend_config
170            .models
171            .get(model_id)
172            .or_else(|| {
173                backend_config
174                    .models
175                    .values()
176                    .find(|model| model.id == model_id)
177            })
178            .ok_or_else(|| VvLlmError::ModelNotFound {
179                backend: backend.to_string(),
180                model: model_id.to_string(),
181            })?;
182        let binding = model
183            .endpoints
184            .iter()
185            .find(|binding| binding.enabled())
186            .ok_or_else(|| {
187                VvLlmError::Configuration(format!("model {model_id} has no enabled endpoints"))
188            })?;
189        let endpoint_id = binding.endpoint_id();
190        let endpoint = self
191            .endpoints
192            .iter()
193            .find(|endpoint| endpoint.id == endpoint_id)
194            .ok_or_else(|| VvLlmError::EndpointNotFound(endpoint_id.to_string()))?;
195
196        Ok(ResolvedModelConfig {
197            backend: backend.to_string(),
198            model: model.clone(),
199            model_id: binding.model_id(&model.id).to_string(),
200            endpoint: endpoint.clone(),
201        })
202    }
203}