1use std::{path::Path, time::Duration};
7
8use axum::http::HeaderName;
9use figment::{
10 Figment,
11 providers::{Env, Format, Toml},
12};
13use secrecy::{ExposeSecret, SecretString};
14use serde::{Deserialize, Deserializer, de};
15use thiserror::Error;
16use tracing_subscriber::EnvFilter;
17
18#[derive(Debug, Clone, Default, Deserialize)]
20#[serde(default, deny_unknown_fields)]
21pub struct ProxyConfig {
22 pub server: ServerConfig,
23 pub logging: LoggingConfig,
24 pub venice: VeniceConfig,
25 pub keys: KeysConfig,
26 pub session: SessionConfig,
27 pub attestation: AttestationConfig,
28 pub e2ee: E2eeConfig,
29 pub tools: ToolsConfig,
30}
31
32impl ProxyConfig {
33 pub const ENV_PREFIX: &'static str = "VENICE_E2EE_PROXY__";
35
36 pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
38 Self::from_figment(
39 Figment::new()
40 .merge(Toml::file(path.as_ref()))
41 .merge(Self::env_provider()),
42 )
43 }
44
45 pub fn from_toml_str(contents: &str) -> Result<Self, ConfigError> {
47 Self::from_figment(Figment::new().merge(Toml::string(contents)))
48 }
49
50 fn env_provider() -> Env {
52 Env::prefixed(Self::ENV_PREFIX).split("__")
53 }
54
55 fn from_figment(figment: Figment) -> Result<Self, ConfigError> {
57 let config: Self = figment.extract()?;
58 config.validate()?;
59 Ok(config)
60 }
61
62 pub fn validate(&self) -> Result<(), ConfigError> {
64 validate_non_empty("server.host", &self.server.host)?;
65 validate_non_empty("logging.level", &self.logging.level)?;
66 validate_env_filter("logging.level", &self.logging.level)?;
67 validate_http_url("venice.base_url", &self.venice.base_url, false)?;
68 validate_duration_non_zero("venice.request_timeout", self.venice.request_timeout)?;
69
70 validate_duration_non_zero("session.idle_ttl", self.session.idle_ttl)?;
71 validate_duration_non_zero("session.max_ttl", self.session.max_ttl)?;
72 if self.session.idle_ttl > self.session.max_ttl {
73 return Err(ConfigError::invalid(
74 "session.idle_ttl",
75 "must be less than or equal to session.max_ttl",
76 ));
77 }
78 if self.session.max_requests == 0 {
79 return Err(ConfigError::invalid(
80 "session.max_requests",
81 "must be greater than zero",
82 ));
83 }
84 validate_header_name(
85 "session.headers.incoming_session_id",
86 &self.session.headers.incoming_session_id,
87 )?;
88
89 validate_http_url("attestation.pccs_url", &self.attestation.pccs_url, true)?;
90
91 validate_non_empty("e2ee.hkdf_info", &self.e2ee.hkdf_info)?;
92
93 if self.tools.tool_call_max_bytes == 0 {
94 return Err(ConfigError::invalid(
95 "tools.tool_call_max_bytes",
96 "must be greater than zero",
97 ));
98 }
99 validate_duration_non_zero(
100 "tools.tool_call_marker_timeout",
101 self.tools.tool_call_marker_timeout,
102 )?;
103
104 Ok(())
105 }
106
107 pub fn venice_api_key(&self) -> Result<&SecretString, ConfigError> {
109 if self.venice.api_key.expose_secret().trim().is_empty() {
110 return Err(ConfigError::MissingApiKey);
111 }
112 Ok(&self.venice.api_key)
113 }
114}
115
116#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
118#[serde(default, deny_unknown_fields)]
119pub struct ServerConfig {
120 pub host: String,
121 pub port: u16,
122}
123
124impl Default for ServerConfig {
125 fn default() -> Self {
127 Self {
128 host: "0.0.0.0".to_owned(),
129 port: 8080,
130 }
131 }
132}
133
134#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
136#[serde(default, deny_unknown_fields)]
137pub struct LoggingConfig {
138 pub level: String,
142}
143
144impl Default for LoggingConfig {
145 fn default() -> Self {
147 Self {
148 level: "info".to_owned(),
149 }
150 }
151}
152
153#[derive(Debug, Clone, Deserialize)]
155#[serde(default, deny_unknown_fields)]
156pub struct VeniceConfig {
157 pub base_url: String,
158 pub api_key: SecretString,
159 #[serde(deserialize_with = "deserialize_duration")]
160 pub request_timeout: Duration,
161}
162
163impl Default for VeniceConfig {
164 fn default() -> Self {
166 Self {
167 base_url: "https://api.venice.ai/api/v1".to_owned(),
168 api_key: SecretString::default(),
169 request_timeout: Duration::from_secs(30),
170 }
171 }
172}
173
174#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
176#[serde(default, deny_unknown_fields)]
177pub struct KeysConfig {
178 pub generate_proxy_instance_key_on_startup: bool,
179}
180
181impl Default for KeysConfig {
182 fn default() -> Self {
184 Self {
185 generate_proxy_instance_key_on_startup: true,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
192#[serde(default, deny_unknown_fields)]
193pub struct SessionConfig {
194 #[serde(deserialize_with = "deserialize_duration")]
195 pub idle_ttl: Duration,
196 #[serde(deserialize_with = "deserialize_duration")]
197 pub max_ttl: Duration,
198 pub max_requests: u64,
199 pub fallback_scope: SessionFallbackScope,
200 pub headers: SessionHeadersConfig,
201}
202
203impl Default for SessionConfig {
204 fn default() -> Self {
206 Self {
207 idle_ttl: Duration::from_secs(600),
208 max_ttl: Duration::from_secs(1_800),
209 max_requests: 100,
210 fallback_scope: SessionFallbackScope::Request,
211 headers: SessionHeadersConfig::default(),
212 }
213 }
214}
215
216#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
218#[serde(default, deny_unknown_fields)]
219pub struct SessionHeadersConfig {
220 pub incoming_session_id: String,
221}
222
223impl Default for SessionHeadersConfig {
224 fn default() -> Self {
226 Self {
227 incoming_session_id: "X-Venice-Proxy-Session-Id".to_owned(),
228 }
229 }
230}
231
232#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
234#[serde(rename_all = "snake_case")]
235pub enum SessionFallbackScope {
236 Agent,
237 #[default]
238 Request,
239 Disabled,
240}
241
242#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
244#[serde(default, deny_unknown_fields)]
245pub struct AttestationConfig {
246 pub mode: AttestationMode,
247 pub require_tdx: bool,
248 pub require_nvidia: NvidiaRequirement,
249 pub allow_debug: bool,
250 pub pccs_url: String,
251}
252
253impl Default for AttestationConfig {
254 fn default() -> Self {
256 Self {
257 mode: AttestationMode::Independent,
258 require_tdx: false,
259 require_nvidia: NvidiaRequirement::Never,
260 allow_debug: false,
261 pccs_url: String::new(),
262 }
263 }
264}
265
266#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
268#[serde(rename_all = "snake_case")]
269pub enum AttestationMode {
270 #[default]
271 Independent,
272}
273
274impl AttestationMode {
275 pub fn as_str(self) -> &'static str {
277 match self {
278 Self::Independent => "independent",
279 }
280 }
281}
282
283#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
285#[serde(rename_all = "snake_case")]
286pub enum NvidiaRequirement {
287 Required,
288 WhenPresent,
289 #[default]
290 Never,
291}
292
293#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
295#[serde(default, deny_unknown_fields)]
296pub struct E2eeConfig {
297 pub hkdf_info: String,
298 pub require_encrypted_response_content: bool,
299}
300
301impl Default for E2eeConfig {
302 fn default() -> Self {
304 Self {
305 hkdf_info: "ecdsa_encryption".to_owned(),
306 require_encrypted_response_content: true,
307 }
308 }
309}
310
311#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
313#[serde(default, deny_unknown_fields)]
314pub struct ToolsConfig {
315 pub enabled: bool,
316 pub mode: ToolMode,
317 pub max_retries: u32,
318 pub tool_call_max_bytes: usize,
319 #[serde(deserialize_with = "deserialize_duration")]
320 pub tool_call_marker_timeout: Duration,
321 pub validate_json_schema: bool,
322}
323
324impl Default for ToolsConfig {
325 fn default() -> Self {
327 Self {
328 enabled: true,
329 mode: ToolMode::Emulated,
330 max_retries: 2,
331 tool_call_max_bytes: 65_536,
332 tool_call_marker_timeout: Duration::from_secs(30),
333 validate_json_schema: true,
334 }
335 }
336}
337
338#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
340#[serde(rename_all = "snake_case")]
341pub enum ToolMode {
342 #[default]
343 Emulated,
344 None,
345}
346
347impl ToolMode {
348 pub fn as_str(self) -> &'static str {
350 match self {
351 Self::Emulated => "emulated",
352 Self::None => "none",
353 }
354 }
355}
356
357#[derive(Debug, Error)]
359pub enum ConfigError {
360 #[error("failed to load config: {0}")]
361 Figment(#[source] Box<figment::Error>),
362 #[error("Venice API key is not configured")]
363 MissingApiKey,
364 #[error("invalid config value for {field}: {message}")]
365 InvalidValue {
366 field: &'static str,
367 message: String,
368 },
369}
370
371impl From<figment::Error> for ConfigError {
372 fn from(error: figment::Error) -> Self {
374 Self::Figment(Box::new(error))
375 }
376}
377
378impl ConfigError {
379 fn invalid(field: &'static str, message: impl Into<String>) -> Self {
381 Self::InvalidValue {
382 field,
383 message: message.into(),
384 }
385 }
386}
387
388fn validate_non_empty(field: &'static str, value: &str) -> Result<(), ConfigError> {
390 if value.trim().is_empty() {
391 return Err(ConfigError::invalid(field, "must not be empty"));
392 }
393 Ok(())
394}
395
396fn validate_http_url(
398 field: &'static str,
399 value: &str,
400 allow_empty: bool,
401) -> Result<(), ConfigError> {
402 let value = value.trim();
403 if value.is_empty() {
404 if allow_empty {
405 return Ok(());
406 }
407 return Err(ConfigError::invalid(field, "must not be empty"));
408 }
409
410 if !(value.starts_with("https://") || value.starts_with("http://")) {
411 return Err(ConfigError::invalid(
412 field,
413 "must start with http:// or https://",
414 ));
415 }
416
417 Ok(())
418}
419
420fn validate_header_name(field: &'static str, value: &str) -> Result<(), ConfigError> {
422 validate_non_empty(field, value)?;
423 HeaderName::from_bytes(value.as_bytes())
424 .map_err(|_| ConfigError::invalid(field, "must be a valid HTTP header name"))?;
425 Ok(())
426}
427
428fn validate_duration_non_zero(field: &'static str, value: Duration) -> Result<(), ConfigError> {
430 if value == Duration::ZERO {
431 return Err(ConfigError::invalid(field, "must be greater than zero"));
432 }
433 Ok(())
434}
435
436fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
438where
439 D: Deserializer<'de>,
440{
441 let value = String::deserialize(deserializer)?;
442 humantime::parse_duration(&value).map_err(de::Error::custom)
443}
444
445fn validate_env_filter(field: &'static str, value: &str) -> Result<(), ConfigError> {
447 let value = value.trim();
448 if value.is_empty() {
449 return Ok(());
450 }
451
452 EnvFilter::try_new(value).map_err(|source| {
453 ConfigError::invalid(
454 field,
455 format!("must be a valid tracing env filter: {source}"),
456 )
457 })?;
458 Ok(())
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 fn assert_default_config_values(config: &ProxyConfig) {
466 assert_eq!(config.server.host, "0.0.0.0");
467 assert_eq!(config.server.port, 8080);
468 assert_eq!(config.logging.level, "info");
469 assert_eq!(config.venice.base_url, "https://api.venice.ai/api/v1");
470 assert_eq!(config.venice.api_key.expose_secret(), "");
471 assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
472 assert!(config.keys.generate_proxy_instance_key_on_startup);
473 assert_eq!(config.session.idle_ttl, Duration::from_secs(600));
474 assert_eq!(config.session.max_ttl, Duration::from_secs(1_800));
475 assert_eq!(config.session.max_requests, 100);
476 assert_eq!(config.session.fallback_scope, SessionFallbackScope::Request);
477 assert_eq!(
478 config.session.headers.incoming_session_id,
479 "X-Venice-Proxy-Session-Id"
480 );
481 assert_eq!(config.attestation.mode, AttestationMode::Independent);
482 assert!(!config.attestation.require_tdx);
483 assert_eq!(config.attestation.require_nvidia, NvidiaRequirement::Never);
484 assert!(!config.attestation.allow_debug);
485 assert_eq!(config.attestation.pccs_url, "");
486 assert_eq!(config.e2ee.hkdf_info, "ecdsa_encryption");
487 assert!(config.e2ee.require_encrypted_response_content);
488 assert!(config.tools.enabled);
489 assert_eq!(config.tools.mode, ToolMode::Emulated);
490 assert_eq!(config.tools.max_retries, 2);
491 assert_eq!(config.tools.tool_call_max_bytes, 65_536);
492 assert_eq!(
493 config.tools.tool_call_marker_timeout,
494 Duration::from_secs(30)
495 );
496 assert!(config.tools.validate_json_schema);
497
498 config.validate().expect("default config is valid");
499 }
500
501 #[test]
502 fn default_config_matches_expected_values() {
503 let config = ProxyConfig::default();
504
505 assert_default_config_values(&config);
506 }
507
508 #[test]
509 fn checked_in_default_config_matches_code_defaults() {
510 let config = ProxyConfig::from_toml_str(include_str!("../config/default.toml"))
511 .expect("checked-in default config should load");
512
513 assert_default_config_values(&config);
514 }
515
516 #[test]
517 fn toml_config_applies_defaults_for_missing_sections() {
518 let config = ProxyConfig::from_toml_str(
519 r#"
520 [server]
521 host = "0.0.0.0"
522 port = 8080
523
524 [tools]
525 enabled = false
526 mode = "none"
527 "#,
528 )
529 .expect("partial config should load with defaults");
530
531 assert_eq!(config.server.host, "0.0.0.0");
532 assert_eq!(config.server.port, 8080);
533 assert_eq!(config.logging.level, "info");
534 assert_eq!(config.venice.api_key.expose_secret(), "");
535 assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
536 assert!(!config.tools.enabled);
537 assert_eq!(config.tools.mode, ToolMode::None);
538 assert_eq!(config.tools.tool_call_max_bytes, 65_536);
539 }
540
541 #[test]
542 fn validation_rejects_invalid_values() {
543 let err = ProxyConfig::from_toml_str(
544 r#"
545 [venice]
546 base_url = "not-valid-url"
547 "#,
548 )
549 .expect_err("invalid base URL should be rejected");
550
551 assert!(matches!(
552 err,
553 ConfigError::InvalidValue {
554 field: "venice.base_url",
555 ..
556 }
557 ));
558
559 let err = ProxyConfig::from_toml_str(
560 r#"
561 [venice]
562 request_timeout = "0s"
563 "#,
564 )
565 .expect_err("zero Venice timeout should be rejected");
566
567 assert!(matches!(
568 err,
569 ConfigError::InvalidValue {
570 field: "venice.request_timeout",
571 ..
572 }
573 ));
574 }
575
576 #[test]
577 fn logging_config_accepts_level_or_env_filter_and_rejects_invalid_filters() {
578 let config = ProxyConfig::from_toml_str(
579 r#"
580 [logging]
581 level = "debug"
582 "#,
583 )
584 .expect("logging level config should load");
585 assert_eq!(config.logging.level, "debug");
586
587 let config = ProxyConfig::from_toml_str(
588 r#"
589 [logging]
590 level = "venice_e2ee_proxy=debug,tower_http=warn"
591 "#,
592 )
593 .expect("logging env filter config should load");
594 assert_eq!(
595 config.logging.level,
596 "venice_e2ee_proxy=debug,tower_http=warn"
597 );
598
599 for level in ["", " "] {
600 let err = ProxyConfig::from_toml_str(&format!(
601 r#"
602 [logging]
603 level = {level:?}
604 "#
605 ))
606 .expect_err("empty logging level should be rejected");
607 assert!(matches!(
608 err,
609 ConfigError::InvalidValue {
610 field: "logging.level",
611 ..
612 }
613 ));
614 }
615
616 let err = ProxyConfig::from_toml_str(
617 r#"
618 [logging]
619 level = "venice_e2ee_proxy=[debug"
620 "#,
621 )
622 .expect_err("invalid tracing env filter should be rejected");
623 assert!(matches!(
624 err,
625 ConfigError::InvalidValue {
626 field: "logging.level",
627 ..
628 }
629 ));
630 }
631
632 #[test]
633 fn removed_tool_marker_options_are_rejected_as_unknown_fields() {
634 let err = ProxyConfig::from_toml_str("[tools]\nmarker_start = \"<tool_call>\"\n")
635 .expect_err("removed tools.marker_start option should be rejected");
636 assert!(matches!(err, ConfigError::Figment(_)));
637 }
638
639 #[test]
640 fn missing_api_key_is_reported() {
641 let config = ProxyConfig::default();
642 let err = config
643 .venice_api_key()
644 .expect_err("missing API key should be reported");
645
646 assert!(matches!(err, ConfigError::MissingApiKey));
647 assert_eq!(err.to_string(), "Venice API key is not configured");
648 }
649
650 #[test]
651 fn api_key_debug_output_is_redacted() {
652 let config = ProxyConfig::from_toml_str(
653 r#"
654 [venice]
655 api_key = "super-secret-test-key"
656 "#,
657 )
658 .expect("config should load");
659 let key = config.venice_api_key().expect("test key should load");
660
661 assert_eq!(key.expose_secret(), "super-secret-test-key");
662 assert!(!format!("{key:?}").contains("super-secret-test-key"));
663 assert!(!format!("{config:?}").contains("super-secret-test-key"));
664 }
665}