Skip to main content

venice_e2ee_proxy/
config.rs

1//! Configuration loading and validation.
2//!
3//! This module provides a typed representation of the proxy configuration,
4//! default values, validation, and redacted handling for the Venice API key.
5
6use std::{path::Path, time::Duration};
7
8use axum::http::HeaderName;
9use figment::{
10    Figment,
11    providers::{Env, Format, Toml},
12};
13use secrecy::{ExposeSecret, SecretString};
14use serde::{Deserialize, Deserializer, de};
15use thiserror::Error;
16use tracing_subscriber::EnvFilter;
17
18/// Top-level proxy configuration.
19#[derive(Debug, Clone, Default, Deserialize)]
20#[serde(default, deny_unknown_fields)]
21pub struct ProxyConfig {
22    pub server: ServerConfig,
23    pub logging: LoggingConfig,
24    pub venice: VeniceConfig,
25    pub keys: KeysConfig,
26    pub session: SessionConfig,
27    pub attestation: AttestationConfig,
28    pub e2ee: E2eeConfig,
29    pub tools: ToolsConfig,
30}
31
32impl ProxyConfig {
33    /// Prefix used when loading configuration overrides from environment variables.
34    pub const ENV_PREFIX: &'static str = "VENICE_E2EE_PROXY__";
35
36    /// Loads configuration from a TOML file with environment overrides.
37    pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
38        Self::from_figment(
39            Figment::new()
40                .merge(Toml::file(path.as_ref()))
41                .merge(Self::env_provider()),
42        )
43    }
44
45    /// Parses TOML configuration and validates the result.
46    pub fn from_toml_str(contents: &str) -> Result<Self, ConfigError> {
47        Self::from_figment(Figment::new().merge(Toml::string(contents)))
48    }
49
50    /// Builds the environment provider used to overlay nested config values.
51    fn env_provider() -> Env {
52        Env::prefixed(Self::ENV_PREFIX).split("__")
53    }
54
55    /// Extracts proxy configuration from a Figment provider and validates it before returning.
56    fn from_figment(figment: Figment) -> Result<Self, ConfigError> {
57        let config: Self = figment.extract()?;
58        config.validate()?;
59        Ok(config)
60    }
61
62    /// Validates a fully materialized configuration.
63    pub fn validate(&self) -> Result<(), ConfigError> {
64        validate_non_empty("server.host", &self.server.host)?;
65        validate_non_empty("logging.level", &self.logging.level)?;
66        validate_env_filter("logging.level", &self.logging.level)?;
67        validate_http_url("venice.base_url", &self.venice.base_url, false)?;
68        validate_duration_non_zero("venice.request_timeout", self.venice.request_timeout)?;
69
70        validate_duration_non_zero("session.idle_ttl", self.session.idle_ttl)?;
71        validate_duration_non_zero("session.max_ttl", self.session.max_ttl)?;
72        if self.session.idle_ttl > self.session.max_ttl {
73            return Err(ConfigError::invalid(
74                "session.idle_ttl",
75                "must be less than or equal to session.max_ttl",
76            ));
77        }
78        if self.session.max_requests == 0 {
79            return Err(ConfigError::invalid(
80                "session.max_requests",
81                "must be greater than zero",
82            ));
83        }
84        validate_header_name("session.headers.preferred", &self.session.headers.preferred)?;
85        validate_header_name(
86            "session.headers.open_webui",
87            &self.session.headers.open_webui,
88        )?;
89
90        validate_http_url("attestation.pccs_url", &self.attestation.pccs_url, true)?;
91
92        validate_non_empty("e2ee.hkdf_info", &self.e2ee.hkdf_info)?;
93
94        if self.tools.tool_call_max_bytes == 0 {
95            return Err(ConfigError::invalid(
96                "tools.tool_call_max_bytes",
97                "must be greater than zero",
98            ));
99        }
100        validate_duration_non_zero(
101            "tools.tool_call_marker_timeout",
102            self.tools.tool_call_marker_timeout,
103        )?;
104
105        Ok(())
106    }
107
108    /// Returns the configured Venice API key.
109    pub fn venice_api_key(&self) -> Result<&SecretString, ConfigError> {
110        if self.venice.api_key.expose_secret().trim().is_empty() {
111            return Err(ConfigError::MissingApiKey);
112        }
113        Ok(&self.venice.api_key)
114    }
115}
116
117/// HTTP listener configuration for the local proxy server.
118#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
119#[serde(default, deny_unknown_fields)]
120pub struct ServerConfig {
121    pub host: String,
122    pub port: u16,
123}
124
125impl Default for ServerConfig {
126    /// Returns the default listener binding used when server config is omitted.
127    fn default() -> Self {
128        Self {
129            host: "0.0.0.0".to_owned(),
130            port: 8080,
131        }
132    }
133}
134
135/// Tracing configuration for proxy logs.
136#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
137#[serde(default, deny_unknown_fields)]
138pub struct LoggingConfig {
139    /// Tracing filter directive. Accepts simple levels like `info` or full
140    /// `tracing_subscriber::EnvFilter` directives like
141    /// `venice_e2ee_proxy=debug,tower_http=warn`.
142    pub level: String,
143}
144
145impl Default for LoggingConfig {
146    /// Returns the default tracing filter for proxy logs.
147    fn default() -> Self {
148        Self {
149            level: "info".to_owned(),
150        }
151    }
152}
153
154/// Venice upstream API client configuration.
155#[derive(Debug, Clone, Deserialize)]
156#[serde(default, deny_unknown_fields)]
157pub struct VeniceConfig {
158    pub base_url: String,
159    pub api_key: SecretString,
160    #[serde(deserialize_with = "deserialize_duration")]
161    pub request_timeout: Duration,
162}
163
164impl Default for VeniceConfig {
165    /// Returns default Venice API endpoint and timeout settings with an empty API key.
166    fn default() -> Self {
167        Self {
168            base_url: "https://api.venice.ai/api/v1".to_owned(),
169            api_key: SecretString::default(),
170            request_timeout: Duration::from_secs(30),
171        }
172    }
173}
174
175/// Proxy instance key generation configuration.
176#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
177#[serde(default, deny_unknown_fields)]
178pub struct KeysConfig {
179    pub generate_proxy_instance_key_on_startup: bool,
180}
181
182impl Default for KeysConfig {
183    /// Returns the default key policy that generates a proxy instance key at startup.
184    fn default() -> Self {
185        Self {
186            generate_proxy_instance_key_on_startup: true,
187        }
188    }
189}
190
191/// Session lifetime, reuse, and identifier-resolution configuration.
192#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
193#[serde(default, deny_unknown_fields)]
194pub struct SessionConfig {
195    #[serde(deserialize_with = "deserialize_duration")]
196    pub idle_ttl: Duration,
197    #[serde(deserialize_with = "deserialize_duration")]
198    pub max_ttl: Duration,
199    pub max_requests: u64,
200    pub fallback_scope: SessionFallbackScope,
201    pub headers: SessionHeadersConfig,
202}
203
204impl Default for SessionConfig {
205    /// Returns the default session TTLs, request budget, fallback behavior, and headers.
206    fn default() -> Self {
207        Self {
208            idle_ttl: Duration::from_secs(600),
209            max_ttl: Duration::from_secs(1_800),
210            max_requests: 100,
211            fallback_scope: SessionFallbackScope::Request,
212            headers: SessionHeadersConfig::default(),
213        }
214    }
215}
216
217/// Header names used to resolve stable agent session identifiers.
218#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
219#[serde(default, deny_unknown_fields)]
220pub struct SessionHeadersConfig {
221    pub preferred: String,
222    pub open_webui: String,
223}
224
225impl Default for SessionHeadersConfig {
226    /// Returns the default preferred and Open WebUI compatibility session headers.
227    fn default() -> Self {
228        Self {
229            preferred: "X-Venice-Proxy-Session-Id".to_owned(),
230            open_webui: "X-OpenWebUI-Chat-Id".to_owned(),
231        }
232    }
233}
234
235/// Fallback strategy used when a request does not include a session identifier.
236#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
237#[serde(rename_all = "snake_case")]
238pub enum SessionFallbackScope {
239    Agent,
240    #[default]
241    Request,
242    Disabled,
243}
244
245/// Attestation verification policy for Venice model-key evidence.
246#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
247#[serde(default, deny_unknown_fields)]
248pub struct AttestationConfig {
249    pub mode: AttestationMode,
250    pub require_tdx: bool,
251    pub require_nvidia: NvidiaRequirement,
252    pub allow_debug: bool,
253    pub pccs_url: String,
254}
255
256impl Default for AttestationConfig {
257    /// Returns the default attestation policy used when attestation requirements are not configured.
258    fn default() -> Self {
259        Self {
260            mode: AttestationMode::Independent,
261            require_tdx: false,
262            require_nvidia: NvidiaRequirement::Never,
263            allow_debug: false,
264            pccs_url: String::new(),
265        }
266    }
267}
268
269/// Attestation strategy exposed in proxy metadata and config.
270#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
271#[serde(rename_all = "snake_case")]
272pub enum AttestationMode {
273    #[default]
274    Independent,
275}
276
277impl AttestationMode {
278    /// Returns the lowercase metadata/header value for this attestation mode.
279    pub fn as_str(self) -> &'static str {
280        match self {
281            Self::Independent => "independent",
282        }
283    }
284}
285
286/// Policy for how NVIDIA attestation payloads are required or ignored.
287#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
288#[serde(rename_all = "snake_case")]
289pub enum NvidiaRequirement {
290    Required,
291    WhenPresent,
292    #[default]
293    Never,
294}
295
296/// E2EE codec configuration for request encryption and response decryption.
297#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
298#[serde(default, deny_unknown_fields)]
299pub struct E2eeConfig {
300    pub hkdf_info: String,
301    pub require_encrypted_response_content: bool,
302}
303
304impl Default for E2eeConfig {
305    /// Returns the default HKDF context and encrypted-response policy.
306    fn default() -> Self {
307        Self {
308            hkdf_info: "ecdsa_encryption".to_owned(),
309            require_encrypted_response_content: true,
310        }
311    }
312}
313
314/// Tool-call emulation configuration for OpenAI-style function calls.
315#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
316#[serde(default, deny_unknown_fields)]
317pub struct ToolsConfig {
318    pub enabled: bool,
319    pub mode: ToolMode,
320    pub max_retries: u32,
321    pub tool_call_max_bytes: usize,
322    #[serde(deserialize_with = "deserialize_duration")]
323    pub tool_call_marker_timeout: Duration,
324    pub validate_json_schema: bool,
325}
326
327impl Default for ToolsConfig {
328    /// Returns the default tool-emulation limits and retry policy.
329    fn default() -> Self {
330        Self {
331            enabled: true,
332            mode: ToolMode::Emulated,
333            max_retries: 2,
334            tool_call_max_bytes: 65_536,
335            tool_call_marker_timeout: Duration::from_secs(30),
336            validate_json_schema: true,
337        }
338    }
339}
340
341/// Tool handling mode used by the proxy.
342#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
343#[serde(rename_all = "snake_case")]
344pub enum ToolMode {
345    #[default]
346    Emulated,
347    None,
348}
349
350impl ToolMode {
351    /// Returns the lowercase metadata/header value for this tool mode.
352    pub fn as_str(self) -> &'static str {
353        match self {
354            Self::Emulated => "emulated",
355            Self::None => "none",
356        }
357    }
358}
359
360/// Errors returned while loading or validating proxy configuration.
361#[derive(Debug, Error)]
362pub enum ConfigError {
363    #[error("failed to load config: {0}")]
364    Figment(#[source] Box<figment::Error>),
365    #[error("Venice API key is not configured")]
366    MissingApiKey,
367    #[error("invalid config value for {field}: {message}")]
368    InvalidValue {
369        field: &'static str,
370        message: String,
371    },
372}
373
374impl From<figment::Error> for ConfigError {
375    /// Converts Figment extraction failures into configuration errors.
376    fn from(error: figment::Error) -> Self {
377        Self::Figment(Box::new(error))
378    }
379}
380
381impl ConfigError {
382    /// Creates an invalid-value error for a named configuration field.
383    fn invalid(field: &'static str, message: impl Into<String>) -> Self {
384        Self::InvalidValue {
385            field,
386            message: message.into(),
387        }
388    }
389}
390
391/// Validates that a string configuration field contains non-whitespace text.
392fn validate_non_empty(field: &'static str, value: &str) -> Result<(), ConfigError> {
393    if value.trim().is_empty() {
394        return Err(ConfigError::invalid(field, "must not be empty"));
395    }
396    Ok(())
397}
398
399/// Validates that a string configuration field is an HTTP(S) URL, optionally allowing empty values.
400fn validate_http_url(
401    field: &'static str,
402    value: &str,
403    allow_empty: bool,
404) -> Result<(), ConfigError> {
405    let value = value.trim();
406    if value.is_empty() {
407        if allow_empty {
408            return Ok(());
409        }
410        return Err(ConfigError::invalid(field, "must not be empty"));
411    }
412
413    if !(value.starts_with("https://") || value.starts_with("http://")) {
414        return Err(ConfigError::invalid(
415            field,
416            "must start with http:// or https://",
417        ));
418    }
419
420    Ok(())
421}
422
423/// Validates that a string configuration field can be used as an HTTP header name.
424fn validate_header_name(field: &'static str, value: &str) -> Result<(), ConfigError> {
425    validate_non_empty(field, value)?;
426    HeaderName::from_bytes(value.as_bytes())
427        .map_err(|_| ConfigError::invalid(field, "must be a valid HTTP header name"))?;
428    Ok(())
429}
430
431/// Validates that a duration configuration field is greater than zero.
432fn validate_duration_non_zero(field: &'static str, value: Duration) -> Result<(), ConfigError> {
433    if value == Duration::ZERO {
434        return Err(ConfigError::invalid(field, "must be greater than zero"));
435    }
436    Ok(())
437}
438
439/// Deserializes human-readable duration strings into [`Duration`] values.
440fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
441where
442    D: Deserializer<'de>,
443{
444    let value = String::deserialize(deserializer)?;
445    humantime::parse_duration(&value).map_err(de::Error::custom)
446}
447
448/// Validates that a string configuration field is accepted by `tracing_subscriber::EnvFilter`.
449fn validate_env_filter(field: &'static str, value: &str) -> Result<(), ConfigError> {
450    let value = value.trim();
451    if value.is_empty() {
452        return Ok(());
453    }
454
455    EnvFilter::try_new(value).map_err(|source| {
456        ConfigError::invalid(
457            field,
458            format!("must be a valid tracing env filter: {source}"),
459        )
460    })?;
461    Ok(())
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    fn assert_default_config_values(config: &ProxyConfig) {
469        assert_eq!(config.server.host, "0.0.0.0");
470        assert_eq!(config.server.port, 8080);
471        assert_eq!(config.logging.level, "info");
472        assert_eq!(config.venice.base_url, "https://api.venice.ai/api/v1");
473        assert_eq!(config.venice.api_key.expose_secret(), "");
474        assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
475        assert!(config.keys.generate_proxy_instance_key_on_startup);
476        assert_eq!(config.session.idle_ttl, Duration::from_secs(600));
477        assert_eq!(config.session.max_ttl, Duration::from_secs(1_800));
478        assert_eq!(config.session.max_requests, 100);
479        assert_eq!(config.session.fallback_scope, SessionFallbackScope::Request);
480        assert_eq!(
481            config.session.headers.preferred,
482            "X-Venice-Proxy-Session-Id"
483        );
484        assert_eq!(config.session.headers.open_webui, "X-OpenWebUI-Chat-Id");
485        assert_eq!(config.attestation.mode, AttestationMode::Independent);
486        assert!(!config.attestation.require_tdx);
487        assert_eq!(config.attestation.require_nvidia, NvidiaRequirement::Never);
488        assert!(!config.attestation.allow_debug);
489        assert_eq!(config.attestation.pccs_url, "");
490        assert_eq!(config.e2ee.hkdf_info, "ecdsa_encryption");
491        assert!(config.e2ee.require_encrypted_response_content);
492        assert!(config.tools.enabled);
493        assert_eq!(config.tools.mode, ToolMode::Emulated);
494        assert_eq!(config.tools.max_retries, 2);
495        assert_eq!(config.tools.tool_call_max_bytes, 65_536);
496        assert_eq!(
497            config.tools.tool_call_marker_timeout,
498            Duration::from_secs(30)
499        );
500        assert!(config.tools.validate_json_schema);
501
502        config.validate().expect("default config is valid");
503    }
504
505    #[test]
506    fn default_config_matches_expected_values() {
507        let config = ProxyConfig::default();
508
509        assert_default_config_values(&config);
510    }
511
512    #[test]
513    fn checked_in_default_config_matches_code_defaults() {
514        let config = ProxyConfig::from_toml_str(include_str!("../config/default.toml"))
515            .expect("checked-in default config should load");
516
517        assert_default_config_values(&config);
518    }
519
520    #[test]
521    fn toml_config_applies_defaults_for_missing_sections() {
522        let config = ProxyConfig::from_toml_str(
523            r#"
524            [server]
525            host = "0.0.0.0"
526            port = 8080
527
528            [tools]
529            enabled = false
530            mode = "none"
531            "#,
532        )
533        .expect("partial config should load with defaults");
534
535        assert_eq!(config.server.host, "0.0.0.0");
536        assert_eq!(config.server.port, 8080);
537        assert_eq!(config.logging.level, "info");
538        assert_eq!(config.venice.api_key.expose_secret(), "");
539        assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
540        assert!(!config.tools.enabled);
541        assert_eq!(config.tools.mode, ToolMode::None);
542        assert_eq!(config.tools.tool_call_max_bytes, 65_536);
543    }
544
545    #[test]
546    fn validation_rejects_invalid_values() {
547        let err = ProxyConfig::from_toml_str(
548            r#"
549            [venice]
550            base_url = "not-valid-url"
551            "#,
552        )
553        .expect_err("invalid base URL should be rejected");
554
555        assert!(matches!(
556            err,
557            ConfigError::InvalidValue {
558                field: "venice.base_url",
559                ..
560            }
561        ));
562
563        let err = ProxyConfig::from_toml_str(
564            r#"
565            [venice]
566            request_timeout = "0s"
567            "#,
568        )
569        .expect_err("zero Venice timeout should be rejected");
570
571        assert!(matches!(
572            err,
573            ConfigError::InvalidValue {
574                field: "venice.request_timeout",
575                ..
576            }
577        ));
578    }
579
580    #[test]
581    fn logging_config_accepts_level_or_env_filter_and_rejects_invalid_filters() {
582        let config = ProxyConfig::from_toml_str(
583            r#"
584            [logging]
585            level = "debug"
586            "#,
587        )
588        .expect("logging level config should load");
589        assert_eq!(config.logging.level, "debug");
590
591        let config = ProxyConfig::from_toml_str(
592            r#"
593            [logging]
594            level = "venice_e2ee_proxy=debug,tower_http=warn"
595            "#,
596        )
597        .expect("logging env filter config should load");
598        assert_eq!(
599            config.logging.level,
600            "venice_e2ee_proxy=debug,tower_http=warn"
601        );
602
603        for level in ["", "   "] {
604            let err = ProxyConfig::from_toml_str(&format!(
605                r#"
606                [logging]
607                level = {level:?}
608                "#
609            ))
610            .expect_err("empty logging level should be rejected");
611            assert!(matches!(
612                err,
613                ConfigError::InvalidValue {
614                    field: "logging.level",
615                    ..
616                }
617            ));
618        }
619
620        let err = ProxyConfig::from_toml_str(
621            r#"
622            [logging]
623            level = "venice_e2ee_proxy=[debug"
624            "#,
625        )
626        .expect_err("invalid tracing env filter should be rejected");
627        assert!(matches!(
628            err,
629            ConfigError::InvalidValue {
630                field: "logging.level",
631                ..
632            }
633        ));
634    }
635
636    #[test]
637    fn removed_tool_marker_options_are_rejected_as_unknown_fields() {
638        let err = ProxyConfig::from_toml_str("[tools]\nmarker_start = \"<tool_call>\"\n")
639            .expect_err("removed tools.marker_start option should be rejected");
640        assert!(matches!(err, ConfigError::Figment(_)));
641    }
642
643    #[test]
644    fn missing_api_key_is_reported() {
645        let config = ProxyConfig::default();
646        let err = config
647            .venice_api_key()
648            .expect_err("missing API key should be reported");
649
650        assert!(matches!(err, ConfigError::MissingApiKey));
651        assert_eq!(err.to_string(), "Venice API key is not configured");
652    }
653
654    #[test]
655    fn api_key_debug_output_is_redacted() {
656        let config = ProxyConfig::from_toml_str(
657            r#"
658            [venice]
659            api_key = "super-secret-test-key"
660            "#,
661        )
662        .expect("config should load");
663        let key = config.venice_api_key().expect("test key should load");
664
665        assert_eq!(key.expose_secret(), "super-secret-test-key");
666        assert!(!format!("{key:?}").contains("super-secret-test-key"));
667        assert!(!format!("{config:?}").contains("super-secret-test-key"));
668    }
669}