1use std::env;
2use std::fs;
3use std::path::Path;
4
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AgnoError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10pub struct ServerConfig {
11 pub host: String,
12 pub port: u16,
13 #[serde(default = "default_tls")]
14 pub tls_enabled: bool,
15}
16
17fn default_tls() -> bool {
18 false
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub struct SecurityConfig {
23 #[serde(default)]
24 pub allowed_origins: Vec<String>,
25 #[serde(default)]
26 pub allowed_tenants: Vec<String>,
27 #[serde(default = "default_encryption_required")]
28 pub encryption_required: bool,
29}
30
31impl Default for SecurityConfig {
32 fn default() -> Self {
33 Self {
34 allowed_origins: Vec::new(),
35 allowed_tenants: Vec::new(),
36 encryption_required: default_encryption_required(),
37 }
38 }
39}
40
41fn default_encryption_required() -> bool {
42 true
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct TelemetryConfig {
47 #[serde(default = "default_sample_rate")]
48 pub sample_rate: f32,
49 #[serde(default)]
50 pub endpoint: Option<String>,
51 #[serde(default = "default_retention_hours")]
52 pub retention_hours: u32,
53}
54
55impl Default for TelemetryConfig {
56 fn default() -> Self {
57 Self {
58 sample_rate: default_sample_rate(),
59 endpoint: None,
60 retention_hours: default_retention_hours(),
61 }
62 }
63}
64
65fn default_sample_rate() -> f32 {
66 1.0
67}
68
69fn default_retention_hours() -> u32 {
70 72
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
74pub struct DeploymentConfig {
75 #[serde(default = "default_replicas")]
76 pub replicas: u16,
77 #[serde(default = "default_max_concurrency")]
78 pub max_concurrency: u32,
79 #[serde(default)]
80 pub autoscale: bool,
81 #[serde(default)]
82 pub container_image: Option<String>,
83}
84
85impl Default for DeploymentConfig {
86 fn default() -> Self {
87 Self {
88 replicas: default_replicas(),
89 max_concurrency: default_max_concurrency(),
90 autoscale: false,
91 container_image: None,
92 }
93 }
94}
95
96fn default_replicas() -> u16 {
97 1
98}
99
100fn default_max_concurrency() -> u32 {
101 32
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
105pub struct ModelConfig {
106 pub provider: String,
107 pub model: String,
108 #[serde(default)]
109 pub api_key: Option<String>,
110 #[serde(default)]
111 pub base_url: Option<String>,
112 #[serde(default)]
113 pub organization: Option<String>,
114 #[serde(default)]
115 pub stream: bool,
116 #[serde(default)]
117 pub openai: ProviderConfig,
118 #[serde(default)]
119 pub anthropic: ProviderConfig,
120 #[serde(default)]
121 pub gemini: ProviderConfig,
122 #[serde(default)]
123 pub cohere: ProviderConfig,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
127pub struct ProviderConfig {
128 #[serde(default)]
129 pub api_key: Option<String>,
130 #[serde(default)]
131 pub endpoint: Option<String>,
132 #[serde(default)]
133 pub organization: Option<String>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum StorageBackend {
139 File,
140 Sqlite,
141}
142
143impl Default for StorageBackend {
144 fn default() -> Self {
145 StorageBackend::File
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
150pub struct StorageConfig {
151 #[serde(default)]
152 pub backend: StorageBackend,
153 #[serde(default = "default_storage_path")]
154 pub file_path: String,
155 #[serde(default)]
156 pub database_url: Option<String>,
157}
158
159impl Default for StorageConfig {
160 fn default() -> Self {
161 Self {
162 backend: StorageBackend::default(),
163 file_path: default_storage_path(),
164 database_url: None,
165 }
166 }
167}
168
169fn default_storage_path() -> String {
170 "conversation.jsonl".into()
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
174pub struct AppConfig {
175 pub server: ServerConfig,
176 #[serde(default)]
177 pub security: SecurityConfig,
178 #[serde(default)]
179 pub telemetry: TelemetryConfig,
180 #[serde(default)]
181 pub deployment: DeploymentConfig,
182 pub model: ModelConfig,
183 #[serde(default)]
184 pub storage: StorageConfig,
185}
186
187impl Default for AppConfig {
188 fn default() -> Self {
189 Self {
190 server: ServerConfig {
191 host: "0.0.0.0".into(),
192 port: 8080,
193 tls_enabled: default_tls(),
194 },
195 security: SecurityConfig {
196 allowed_origins: vec![],
197 allowed_tenants: vec![],
198 encryption_required: default_encryption_required(),
199 },
200 telemetry: TelemetryConfig {
201 sample_rate: default_sample_rate(),
202 endpoint: None,
203 retention_hours: default_retention_hours(),
204 },
205 deployment: DeploymentConfig {
206 replicas: default_replicas(),
207 max_concurrency: default_max_concurrency(),
208 autoscale: false,
209 container_image: None,
210 },
211 model: ModelConfig {
212 provider: "stub".into(),
213 model: "stub-model".into(),
214 api_key: None,
215 base_url: None,
216 organization: None,
217 stream: false,
218 openai: ProviderConfig::default(),
219 anthropic: ProviderConfig::default(),
220 gemini: ProviderConfig::default(),
221 cohere: ProviderConfig::default(),
222 },
223 storage: StorageConfig::default(),
224 }
225 }
226}
227
228impl AppConfig {
229 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
230 let raw = fs::read_to_string(path)?;
231 let cfg: Self = toml::from_str(&raw)
232 .map_err(|err| AgnoError::Protocol(format!("Failed to parse configuration: {err}")))?;
233 Ok(cfg)
234 }
235
236 pub fn from_env_or_file(path: impl AsRef<Path>) -> Result<Self> {
237 let mut cfg = Self::from_file(path)?;
238 if let Ok(host) = env::var("AGNO_HOST") {
239 cfg.server.host = host;
240 }
241 if let Ok(port) = env::var("AGNO_PORT") {
242 if let Ok(parsed) = port.parse::<u16>() {
243 cfg.server.port = parsed;
244 }
245 }
246 if let Ok(key) = env::var("AGNO_API_KEY") {
247 cfg.model.api_key = Some(key);
248 }
249 if let Ok(openai_key) = env::var("AGNO_OPENAI_API_KEY") {
250 cfg.model.openai.api_key = Some(openai_key);
251 }
252 if let Ok(openai_endpoint) = env::var("AGNO_OPENAI_ENDPOINT") {
253 cfg.model.openai.endpoint = Some(openai_endpoint);
254 }
255 if let Ok(openai_org) = env::var("AGNO_OPENAI_ORG") {
256 cfg.model.openai.organization = Some(openai_org);
257 }
258 if let Ok(anthropic_key) = env::var("AGNO_ANTHROPIC_API_KEY") {
259 cfg.model.anthropic.api_key = Some(anthropic_key);
260 }
261 if let Ok(anthropic_endpoint) = env::var("AGNO_ANTHROPIC_ENDPOINT") {
262 cfg.model.anthropic.endpoint = Some(anthropic_endpoint);
263 }
264 if let Ok(gemini_key) = env::var("AGNO_GEMINI_API_KEY") {
265 cfg.model.gemini.api_key = Some(gemini_key);
266 }
267 if let Ok(gemini_endpoint) = env::var("AGNO_GEMINI_ENDPOINT") {
268 cfg.model.gemini.endpoint = Some(gemini_endpoint);
269 }
270 if let Ok(cohere_key) = env::var("AGNO_COHERE_API_KEY") {
271 cfg.model.cohere.api_key = Some(cohere_key);
272 }
273 if let Ok(cohere_endpoint) = env::var("AGNO_COHERE_ENDPOINT") {
274 cfg.model.cohere.endpoint = Some(cohere_endpoint);
275 }
276 if let Ok(stream) = env::var("AGNO_STREAMING") {
277 if let Ok(parsed) = stream.parse::<bool>() {
278 cfg.model.stream = parsed;
279 }
280 }
281 if let Ok(sample) = env::var("AGNO_TELEMETRY_SAMPLE") {
282 if let Ok(parsed) = sample.parse::<f32>() {
283 cfg.telemetry.sample_rate = parsed.clamp(0.01, 1.0);
284 }
285 }
286 if let Ok(backend) = env::var("AGNO_STORAGE_BACKEND") {
287 cfg.storage.backend = match backend.to_ascii_lowercase().as_str() {
288 "sqlite" => StorageBackend::Sqlite,
289 _ => StorageBackend::File,
290 };
291 }
292 if let Ok(path) = env::var("AGNO_STORAGE_PATH") {
293 cfg.storage.file_path = path;
294 }
295 if let Ok(url) = env::var("AGNO_DATABASE_URL") {
296 cfg.storage.database_url = Some(url);
297 }
298 Ok(cfg)
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use std::env;
306 use std::io::Write;
307 use tempfile::NamedTempFile;
308
309 #[test]
310 fn loads_and_overrides() {
311 let mut file = NamedTempFile::new().unwrap();
312 writeln!(
313 file,
314 "[server]\nhost='127.0.0.1'\nport=9000\n[model]\nprovider='openai'\nmodel='gpt-4'"
315 )
316 .unwrap();
317
318 env::set_var("AGNO_PORT", "9100");
319 let cfg = AppConfig::from_env_or_file(file.path()).unwrap();
320
321 assert_eq!(cfg.server.port, 9100);
322 assert_eq!(cfg.server.host, "127.0.0.1");
323 assert_eq!(cfg.model.provider, "openai");
324 env::remove_var("AGNO_PORT");
325 }
326
327 #[test]
328 fn overrides_storage_backend() {
329 let mut file = NamedTempFile::new().unwrap();
330 writeln!(
331 file,
332 "[server]\nhost='127.0.0.1'\nport=9000\n[model]\nprovider='openai'\nmodel='gpt-4'\n[storage]\nbackend='file'\nfile_path='transcript.jsonl'"
333 )
334 .unwrap();
335
336 env::set_var("AGNO_STORAGE_BACKEND", "sqlite");
337 env::set_var("AGNO_DATABASE_URL", "sqlite::memory:");
338 let cfg = AppConfig::from_env_or_file(file.path()).unwrap();
339
340 assert_eq!(cfg.storage.backend, StorageBackend::Sqlite);
341 assert_eq!(
342 cfg.storage.database_url,
343 Some("sqlite::memory:".to_string())
344 );
345
346 env::remove_var("AGNO_STORAGE_BACKEND");
347 env::remove_var("AGNO_DATABASE_URL");
348 }
349}