1use crate::{BackendType, VvLlmError};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::{collections::HashMap, fs, path::Path};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub struct EndpointConfig {
8 pub id: String,
9 #[serde(default)]
10 pub api_base: Option<String>,
11 #[serde(default)]
12 pub api_key: Option<String>,
13 #[serde(default)]
14 pub organization: Option<String>,
15 #[serde(default)]
16 pub endpoint_type: Option<String>,
17 #[serde(default)]
18 pub region: Option<String>,
19 #[serde(default)]
20 pub is_bedrock: Option<bool>,
21 #[serde(default)]
22 pub is_vertex: Option<bool>,
23 #[serde(default)]
24 pub credentials: Value,
25 #[serde(default)]
26 pub extra: HashMap<String, Value>,
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
30pub struct BackendConfig {
31 #[serde(default)]
32 pub models: HashMap<String, ModelConfig>,
33 #[serde(flatten)]
34 pub extra: HashMap<String, Value>,
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct ModelConfig {
39 pub id: String,
40 #[serde(default)]
41 pub endpoints: Vec<EndpointBinding>,
42 #[serde(default)]
43 pub context_length: Option<u32>,
44 #[serde(default)]
45 pub max_output_tokens: Option<u32>,
46 #[serde(default)]
47 pub function_call_available: Option<bool>,
48 #[serde(default)]
49 pub response_format_available: Option<bool>,
50 #[serde(default)]
51 pub native_multimodal: Option<bool>,
52 #[serde(default)]
53 pub protocol: Option<String>,
54 #[serde(default)]
55 pub request_mapping: Option<Value>,
56 #[serde(default)]
57 pub response_mapping: Option<Value>,
58 #[serde(flatten)]
59 pub extra: HashMap<String, Value>,
60}
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
63pub struct LlmSettings {
64 #[serde(rename = "VERSION", default)]
65 pub version: Option<String>,
66 #[serde(default)]
67 pub endpoints: Vec<EndpointConfig>,
68 #[serde(default)]
69 pub backends: HashMap<String, BackendConfig>,
70 #[serde(default)]
71 pub embedding_backends: HashMap<String, BackendConfig>,
72 #[serde(default)]
73 pub rerank_backends: HashMap<String, BackendConfig>,
74 #[serde(flatten)]
75 pub extra: HashMap<String, Value>,
76}
77
78#[derive(Debug, Clone, PartialEq)]
79pub struct ResolvedModelConfig {
80 pub backend: String,
81 pub model: ModelConfig,
82 pub model_id: String,
83 pub endpoint: EndpointConfig,
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum EndpointBinding {
89 Id(String),
90 Config {
91 endpoint_id: String,
92 #[serde(default)]
93 model_id: Option<String>,
94 #[serde(default)]
95 enabled: Option<bool>,
96 #[serde(flatten)]
97 extra: HashMap<String, Value>,
98 },
99}
100
101impl EndpointBinding {
102 pub fn endpoint_id(&self) -> &str {
103 match self {
104 Self::Id(endpoint_id) => endpoint_id,
105 Self::Config { endpoint_id, .. } => endpoint_id,
106 }
107 }
108
109 pub fn model_id<'a>(&'a self, default_model_id: &'a str) -> &'a str {
110 match self {
111 Self::Id(_) => default_model_id,
112 Self::Config { model_id, .. } => model_id.as_deref().unwrap_or(default_model_id),
113 }
114 }
115
116 pub fn enabled(&self) -> bool {
117 match self {
118 Self::Id(_) => true,
119 Self::Config { enabled, .. } => enabled.unwrap_or(true),
120 }
121 }
122}
123
124impl LlmSettings {
125 pub fn from_json_str(raw: &str) -> Result<Self, VvLlmError> {
126 Ok(serde_json::from_str(raw)?)
127 }
128
129 pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self, VvLlmError> {
130 let raw = fs::read_to_string(path.as_ref())
131 .map_err(|error| VvLlmError::Configuration(error.to_string()))?;
132 Self::from_json_str(&raw)
133 }
134
135 pub fn resolve_chat_model(
136 &self,
137 backend: BackendType,
138 model_id: &str,
139 ) -> Result<ResolvedModelConfig, VvLlmError> {
140 self.resolve_model_in_map(&self.backends, backend.as_str(), model_id)
141 }
142
143 pub fn resolve_embedding_model(
144 &self,
145 backend: &str,
146 model_id: &str,
147 ) -> Result<ResolvedModelConfig, VvLlmError> {
148 self.resolve_model_in_map(&self.embedding_backends, backend, model_id)
149 }
150
151 pub fn resolve_rerank_model(
152 &self,
153 backend: &str,
154 model_id: &str,
155 ) -> Result<ResolvedModelConfig, VvLlmError> {
156 self.resolve_model_in_map(&self.rerank_backends, backend, model_id)
157 }
158
159 fn resolve_model_in_map(
160 &self,
161 map: &HashMap<String, BackendConfig>,
162 backend: &str,
163 model_id: &str,
164 ) -> Result<ResolvedModelConfig, VvLlmError> {
165 let backend_config = map.get(backend).ok_or_else(|| VvLlmError::ModelNotFound {
166 backend: backend.to_string(),
167 model: model_id.to_string(),
168 })?;
169 let model = backend_config
170 .models
171 .get(model_id)
172 .or_else(|| {
173 backend_config
174 .models
175 .values()
176 .find(|model| model.id == model_id)
177 })
178 .ok_or_else(|| VvLlmError::ModelNotFound {
179 backend: backend.to_string(),
180 model: model_id.to_string(),
181 })?;
182 let binding = model
183 .endpoints
184 .iter()
185 .find(|binding| binding.enabled())
186 .ok_or_else(|| {
187 VvLlmError::Configuration(format!("model {model_id} has no enabled endpoints"))
188 })?;
189 let endpoint_id = binding.endpoint_id();
190 let endpoint = self
191 .endpoints
192 .iter()
193 .find(|endpoint| endpoint.id == endpoint_id)
194 .ok_or_else(|| VvLlmError::EndpointNotFound(endpoint_id.to_string()))?;
195
196 Ok(ResolvedModelConfig {
197 backend: backend.to_string(),
198 model: model.clone(),
199 model_id: binding.model_id(&model.id).to_string(),
200 endpoint: endpoint.clone(),
201 })
202 }
203}