1use std::{path::Path, time::Duration};
7
8use axum::http::HeaderName;
9use figment::{
10 Figment,
11 providers::{Env, Format, Toml},
12};
13use secrecy::{ExposeSecret, SecretString};
14use serde::{Deserialize, Deserializer, de};
15use thiserror::Error;
16use tracing_subscriber::EnvFilter;
17
18#[derive(Debug, Clone, Default, Deserialize)]
20#[serde(default, deny_unknown_fields)]
21pub struct ProxyConfig {
22 pub server: ServerConfig,
23 pub logging: LoggingConfig,
24 pub venice: VeniceConfig,
25 pub keys: KeysConfig,
26 pub session: SessionConfig,
27 pub attestation: AttestationConfig,
28 pub e2ee: E2eeConfig,
29 pub tools: ToolsConfig,
30}
31
32impl ProxyConfig {
33 pub const ENV_PREFIX: &'static str = "VENICE_E2EE_PROXY__";
35
36 pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
38 Self::from_figment(
39 Figment::new()
40 .merge(Toml::file(path.as_ref()))
41 .merge(Self::env_provider()),
42 )
43 }
44
45 pub fn from_toml_str(contents: &str) -> Result<Self, ConfigError> {
47 Self::from_figment(Figment::new().merge(Toml::string(contents)))
48 }
49
50 fn env_provider() -> Env {
52 Env::prefixed(Self::ENV_PREFIX).split("__")
53 }
54
55 fn from_figment(figment: Figment) -> Result<Self, ConfigError> {
57 let config: Self = figment.extract()?;
58 config.validate()?;
59 Ok(config)
60 }
61
62 pub fn validate(&self) -> Result<(), ConfigError> {
64 validate_non_empty("server.host", &self.server.host)?;
65 validate_non_empty("logging.level", &self.logging.level)?;
66 validate_env_filter("logging.level", &self.logging.level)?;
67 validate_http_url("venice.base_url", &self.venice.base_url, false)?;
68 validate_duration_non_zero("venice.request_timeout", self.venice.request_timeout)?;
69
70 validate_duration_non_zero("session.idle_ttl", self.session.idle_ttl)?;
71 validate_duration_non_zero("session.max_ttl", self.session.max_ttl)?;
72 if self.session.idle_ttl > self.session.max_ttl {
73 return Err(ConfigError::invalid(
74 "session.idle_ttl",
75 "must be less than or equal to session.max_ttl",
76 ));
77 }
78 if self.session.max_requests == 0 {
79 return Err(ConfigError::invalid(
80 "session.max_requests",
81 "must be greater than zero",
82 ));
83 }
84 validate_header_name("session.headers.preferred", &self.session.headers.preferred)?;
85 validate_header_name(
86 "session.headers.open_webui",
87 &self.session.headers.open_webui,
88 )?;
89
90 validate_http_url("attestation.pccs_url", &self.attestation.pccs_url, true)?;
91
92 validate_non_empty("e2ee.hkdf_info", &self.e2ee.hkdf_info)?;
93
94 if self.tools.tool_call_max_bytes == 0 {
95 return Err(ConfigError::invalid(
96 "tools.tool_call_max_bytes",
97 "must be greater than zero",
98 ));
99 }
100 validate_duration_non_zero(
101 "tools.tool_call_marker_timeout",
102 self.tools.tool_call_marker_timeout,
103 )?;
104
105 Ok(())
106 }
107
108 pub fn venice_api_key(&self) -> Result<&SecretString, ConfigError> {
110 if self.venice.api_key.expose_secret().trim().is_empty() {
111 return Err(ConfigError::MissingApiKey);
112 }
113 Ok(&self.venice.api_key)
114 }
115}
116
117#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
119#[serde(default, deny_unknown_fields)]
120pub struct ServerConfig {
121 pub host: String,
122 pub port: u16,
123}
124
125impl Default for ServerConfig {
126 fn default() -> Self {
128 Self {
129 host: "0.0.0.0".to_owned(),
130 port: 8080,
131 }
132 }
133}
134
135#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
137#[serde(default, deny_unknown_fields)]
138pub struct LoggingConfig {
139 pub level: String,
143}
144
145impl Default for LoggingConfig {
146 fn default() -> Self {
148 Self {
149 level: "info".to_owned(),
150 }
151 }
152}
153
154#[derive(Debug, Clone, Deserialize)]
156#[serde(default, deny_unknown_fields)]
157pub struct VeniceConfig {
158 pub base_url: String,
159 pub api_key: SecretString,
160 #[serde(deserialize_with = "deserialize_duration")]
161 pub request_timeout: Duration,
162}
163
164impl Default for VeniceConfig {
165 fn default() -> Self {
167 Self {
168 base_url: "https://api.venice.ai/api/v1".to_owned(),
169 api_key: SecretString::default(),
170 request_timeout: Duration::from_secs(30),
171 }
172 }
173}
174
175#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
177#[serde(default, deny_unknown_fields)]
178pub struct KeysConfig {
179 pub generate_proxy_instance_key_on_startup: bool,
180}
181
182impl Default for KeysConfig {
183 fn default() -> Self {
185 Self {
186 generate_proxy_instance_key_on_startup: true,
187 }
188 }
189}
190
191#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
193#[serde(default, deny_unknown_fields)]
194pub struct SessionConfig {
195 #[serde(deserialize_with = "deserialize_duration")]
196 pub idle_ttl: Duration,
197 #[serde(deserialize_with = "deserialize_duration")]
198 pub max_ttl: Duration,
199 pub max_requests: u64,
200 pub fallback_scope: SessionFallbackScope,
201 pub headers: SessionHeadersConfig,
202}
203
204impl Default for SessionConfig {
205 fn default() -> Self {
207 Self {
208 idle_ttl: Duration::from_secs(600),
209 max_ttl: Duration::from_secs(1_800),
210 max_requests: 100,
211 fallback_scope: SessionFallbackScope::Request,
212 headers: SessionHeadersConfig::default(),
213 }
214 }
215}
216
217#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
219#[serde(default, deny_unknown_fields)]
220pub struct SessionHeadersConfig {
221 pub preferred: String,
222 pub open_webui: String,
223}
224
225impl Default for SessionHeadersConfig {
226 fn default() -> Self {
228 Self {
229 preferred: "X-Venice-Proxy-Session-Id".to_owned(),
230 open_webui: "X-OpenWebUI-Chat-Id".to_owned(),
231 }
232 }
233}
234
235#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
237#[serde(rename_all = "snake_case")]
238pub enum SessionFallbackScope {
239 Agent,
240 #[default]
241 Request,
242 Disabled,
243}
244
245#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
247#[serde(default, deny_unknown_fields)]
248pub struct AttestationConfig {
249 pub mode: AttestationMode,
250 pub require_tdx: bool,
251 pub require_nvidia: NvidiaRequirement,
252 pub allow_debug: bool,
253 pub pccs_url: String,
254}
255
256impl Default for AttestationConfig {
257 fn default() -> Self {
259 Self {
260 mode: AttestationMode::Independent,
261 require_tdx: false,
262 require_nvidia: NvidiaRequirement::Never,
263 allow_debug: false,
264 pccs_url: String::new(),
265 }
266 }
267}
268
269#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
271#[serde(rename_all = "snake_case")]
272pub enum AttestationMode {
273 #[default]
274 Independent,
275}
276
277impl AttestationMode {
278 pub fn as_str(self) -> &'static str {
280 match self {
281 Self::Independent => "independent",
282 }
283 }
284}
285
286#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
288#[serde(rename_all = "snake_case")]
289pub enum NvidiaRequirement {
290 Required,
291 WhenPresent,
292 #[default]
293 Never,
294}
295
296#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
298#[serde(default, deny_unknown_fields)]
299pub struct E2eeConfig {
300 pub hkdf_info: String,
301 pub require_encrypted_response_content: bool,
302}
303
304impl Default for E2eeConfig {
305 fn default() -> Self {
307 Self {
308 hkdf_info: "ecdsa_encryption".to_owned(),
309 require_encrypted_response_content: true,
310 }
311 }
312}
313
314#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
316#[serde(default, deny_unknown_fields)]
317pub struct ToolsConfig {
318 pub enabled: bool,
319 pub mode: ToolMode,
320 pub max_retries: u32,
321 pub tool_call_max_bytes: usize,
322 #[serde(deserialize_with = "deserialize_duration")]
323 pub tool_call_marker_timeout: Duration,
324 pub validate_json_schema: bool,
325}
326
327impl Default for ToolsConfig {
328 fn default() -> Self {
330 Self {
331 enabled: true,
332 mode: ToolMode::Emulated,
333 max_retries: 2,
334 tool_call_max_bytes: 65_536,
335 tool_call_marker_timeout: Duration::from_secs(30),
336 validate_json_schema: true,
337 }
338 }
339}
340
341#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
343#[serde(rename_all = "snake_case")]
344pub enum ToolMode {
345 #[default]
346 Emulated,
347 None,
348}
349
350impl ToolMode {
351 pub fn as_str(self) -> &'static str {
353 match self {
354 Self::Emulated => "emulated",
355 Self::None => "none",
356 }
357 }
358}
359
360#[derive(Debug, Error)]
362pub enum ConfigError {
363 #[error("failed to load config: {0}")]
364 Figment(#[source] Box<figment::Error>),
365 #[error("Venice API key is not configured")]
366 MissingApiKey,
367 #[error("invalid config value for {field}: {message}")]
368 InvalidValue {
369 field: &'static str,
370 message: String,
371 },
372}
373
374impl From<figment::Error> for ConfigError {
375 fn from(error: figment::Error) -> Self {
377 Self::Figment(Box::new(error))
378 }
379}
380
381impl ConfigError {
382 fn invalid(field: &'static str, message: impl Into<String>) -> Self {
384 Self::InvalidValue {
385 field,
386 message: message.into(),
387 }
388 }
389}
390
391fn validate_non_empty(field: &'static str, value: &str) -> Result<(), ConfigError> {
393 if value.trim().is_empty() {
394 return Err(ConfigError::invalid(field, "must not be empty"));
395 }
396 Ok(())
397}
398
399fn validate_http_url(
401 field: &'static str,
402 value: &str,
403 allow_empty: bool,
404) -> Result<(), ConfigError> {
405 let value = value.trim();
406 if value.is_empty() {
407 if allow_empty {
408 return Ok(());
409 }
410 return Err(ConfigError::invalid(field, "must not be empty"));
411 }
412
413 if !(value.starts_with("https://") || value.starts_with("http://")) {
414 return Err(ConfigError::invalid(
415 field,
416 "must start with http:// or https://",
417 ));
418 }
419
420 Ok(())
421}
422
423fn validate_header_name(field: &'static str, value: &str) -> Result<(), ConfigError> {
425 validate_non_empty(field, value)?;
426 HeaderName::from_bytes(value.as_bytes())
427 .map_err(|_| ConfigError::invalid(field, "must be a valid HTTP header name"))?;
428 Ok(())
429}
430
431fn validate_duration_non_zero(field: &'static str, value: Duration) -> Result<(), ConfigError> {
433 if value == Duration::ZERO {
434 return Err(ConfigError::invalid(field, "must be greater than zero"));
435 }
436 Ok(())
437}
438
439fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
441where
442 D: Deserializer<'de>,
443{
444 let value = String::deserialize(deserializer)?;
445 humantime::parse_duration(&value).map_err(de::Error::custom)
446}
447
448fn validate_env_filter(field: &'static str, value: &str) -> Result<(), ConfigError> {
450 let value = value.trim();
451 if value.is_empty() {
452 return Ok(());
453 }
454
455 EnvFilter::try_new(value).map_err(|source| {
456 ConfigError::invalid(
457 field,
458 format!("must be a valid tracing env filter: {source}"),
459 )
460 })?;
461 Ok(())
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 fn assert_default_config_values(config: &ProxyConfig) {
469 assert_eq!(config.server.host, "0.0.0.0");
470 assert_eq!(config.server.port, 8080);
471 assert_eq!(config.logging.level, "info");
472 assert_eq!(config.venice.base_url, "https://api.venice.ai/api/v1");
473 assert_eq!(config.venice.api_key.expose_secret(), "");
474 assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
475 assert!(config.keys.generate_proxy_instance_key_on_startup);
476 assert_eq!(config.session.idle_ttl, Duration::from_secs(600));
477 assert_eq!(config.session.max_ttl, Duration::from_secs(1_800));
478 assert_eq!(config.session.max_requests, 100);
479 assert_eq!(config.session.fallback_scope, SessionFallbackScope::Request);
480 assert_eq!(
481 config.session.headers.preferred,
482 "X-Venice-Proxy-Session-Id"
483 );
484 assert_eq!(config.session.headers.open_webui, "X-OpenWebUI-Chat-Id");
485 assert_eq!(config.attestation.mode, AttestationMode::Independent);
486 assert!(!config.attestation.require_tdx);
487 assert_eq!(config.attestation.require_nvidia, NvidiaRequirement::Never);
488 assert!(!config.attestation.allow_debug);
489 assert_eq!(config.attestation.pccs_url, "");
490 assert_eq!(config.e2ee.hkdf_info, "ecdsa_encryption");
491 assert!(config.e2ee.require_encrypted_response_content);
492 assert!(config.tools.enabled);
493 assert_eq!(config.tools.mode, ToolMode::Emulated);
494 assert_eq!(config.tools.max_retries, 2);
495 assert_eq!(config.tools.tool_call_max_bytes, 65_536);
496 assert_eq!(
497 config.tools.tool_call_marker_timeout,
498 Duration::from_secs(30)
499 );
500 assert!(config.tools.validate_json_schema);
501
502 config.validate().expect("default config is valid");
503 }
504
505 #[test]
506 fn default_config_matches_expected_values() {
507 let config = ProxyConfig::default();
508
509 assert_default_config_values(&config);
510 }
511
512 #[test]
513 fn checked_in_default_config_matches_code_defaults() {
514 let config = ProxyConfig::from_toml_str(include_str!("../config/default.toml"))
515 .expect("checked-in default config should load");
516
517 assert_default_config_values(&config);
518 }
519
520 #[test]
521 fn toml_config_applies_defaults_for_missing_sections() {
522 let config = ProxyConfig::from_toml_str(
523 r#"
524 [server]
525 host = "0.0.0.0"
526 port = 8080
527
528 [tools]
529 enabled = false
530 mode = "none"
531 "#,
532 )
533 .expect("partial config should load with defaults");
534
535 assert_eq!(config.server.host, "0.0.0.0");
536 assert_eq!(config.server.port, 8080);
537 assert_eq!(config.logging.level, "info");
538 assert_eq!(config.venice.api_key.expose_secret(), "");
539 assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
540 assert!(!config.tools.enabled);
541 assert_eq!(config.tools.mode, ToolMode::None);
542 assert_eq!(config.tools.tool_call_max_bytes, 65_536);
543 }
544
545 #[test]
546 fn validation_rejects_invalid_values() {
547 let err = ProxyConfig::from_toml_str(
548 r#"
549 [venice]
550 base_url = "not-valid-url"
551 "#,
552 )
553 .expect_err("invalid base URL should be rejected");
554
555 assert!(matches!(
556 err,
557 ConfigError::InvalidValue {
558 field: "venice.base_url",
559 ..
560 }
561 ));
562
563 let err = ProxyConfig::from_toml_str(
564 r#"
565 [venice]
566 request_timeout = "0s"
567 "#,
568 )
569 .expect_err("zero Venice timeout should be rejected");
570
571 assert!(matches!(
572 err,
573 ConfigError::InvalidValue {
574 field: "venice.request_timeout",
575 ..
576 }
577 ));
578 }
579
580 #[test]
581 fn logging_config_accepts_level_or_env_filter_and_rejects_invalid_filters() {
582 let config = ProxyConfig::from_toml_str(
583 r#"
584 [logging]
585 level = "debug"
586 "#,
587 )
588 .expect("logging level config should load");
589 assert_eq!(config.logging.level, "debug");
590
591 let config = ProxyConfig::from_toml_str(
592 r#"
593 [logging]
594 level = "venice_e2ee_proxy=debug,tower_http=warn"
595 "#,
596 )
597 .expect("logging env filter config should load");
598 assert_eq!(
599 config.logging.level,
600 "venice_e2ee_proxy=debug,tower_http=warn"
601 );
602
603 for level in ["", " "] {
604 let err = ProxyConfig::from_toml_str(&format!(
605 r#"
606 [logging]
607 level = {level:?}
608 "#
609 ))
610 .expect_err("empty logging level should be rejected");
611 assert!(matches!(
612 err,
613 ConfigError::InvalidValue {
614 field: "logging.level",
615 ..
616 }
617 ));
618 }
619
620 let err = ProxyConfig::from_toml_str(
621 r#"
622 [logging]
623 level = "venice_e2ee_proxy=[debug"
624 "#,
625 )
626 .expect_err("invalid tracing env filter should be rejected");
627 assert!(matches!(
628 err,
629 ConfigError::InvalidValue {
630 field: "logging.level",
631 ..
632 }
633 ));
634 }
635
636 #[test]
637 fn removed_tool_marker_options_are_rejected_as_unknown_fields() {
638 let err = ProxyConfig::from_toml_str("[tools]\nmarker_start = \"<tool_call>\"\n")
639 .expect_err("removed tools.marker_start option should be rejected");
640 assert!(matches!(err, ConfigError::Figment(_)));
641 }
642
643 #[test]
644 fn missing_api_key_is_reported() {
645 let config = ProxyConfig::default();
646 let err = config
647 .venice_api_key()
648 .expect_err("missing API key should be reported");
649
650 assert!(matches!(err, ConfigError::MissingApiKey));
651 assert_eq!(err.to_string(), "Venice API key is not configured");
652 }
653
654 #[test]
655 fn api_key_debug_output_is_redacted() {
656 let config = ProxyConfig::from_toml_str(
657 r#"
658 [venice]
659 api_key = "super-secret-test-key"
660 "#,
661 )
662 .expect("config should load");
663 let key = config.venice_api_key().expect("test key should load");
664
665 assert_eq!(key.expose_secret(), "super-secret-test-key");
666 assert!(!format!("{key:?}").contains("super-secret-test-key"));
667 assert!(!format!("{config:?}").contains("super-secret-test-key"));
668 }
669}