Skip to main content

venice_e2ee_proxy/
config.rs

1//! Configuration loading and validation.
2//!
3//! This module provides a typed representation of the proxy configuration,
4//! default values, validation, and redacted handling for the Venice API key.
5
6use std::{path::Path, time::Duration};
7
8use axum::http::HeaderName;
9use figment::{
10    Figment,
11    providers::{Env, Format, Toml},
12};
13use secrecy::{ExposeSecret, SecretString};
14use serde::{Deserialize, Deserializer, de};
15use thiserror::Error;
16use tracing_subscriber::EnvFilter;
17
18/// Top-level proxy configuration.
19#[derive(Debug, Clone, Default, Deserialize)]
20#[serde(default, deny_unknown_fields)]
21pub struct ProxyConfig {
22    pub server: ServerConfig,
23    pub logging: LoggingConfig,
24    pub venice: VeniceConfig,
25    pub keys: KeysConfig,
26    pub session: SessionConfig,
27    pub attestation: AttestationConfig,
28    pub e2ee: E2eeConfig,
29    pub tools: ToolsConfig,
30}
31
32impl ProxyConfig {
33    /// Prefix used when loading configuration overrides from environment variables.
34    pub const ENV_PREFIX: &'static str = "VENICE_E2EE_PROXY__";
35
36    /// Loads configuration from a TOML file with environment overrides.
37    pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
38        Self::from_figment(
39            Figment::new()
40                .merge(Toml::file(path.as_ref()))
41                .merge(Self::env_provider()),
42        )
43    }
44
45    /// Parses TOML configuration and validates the result.
46    pub fn from_toml_str(contents: &str) -> Result<Self, ConfigError> {
47        Self::from_figment(Figment::new().merge(Toml::string(contents)))
48    }
49
50    /// Builds the environment provider used to overlay nested config values.
51    fn env_provider() -> Env {
52        Env::prefixed(Self::ENV_PREFIX).split("__")
53    }
54
55    /// Extracts proxy configuration from a Figment provider and validates it before returning.
56    fn from_figment(figment: Figment) -> Result<Self, ConfigError> {
57        let config: Self = figment.extract()?;
58        config.validate()?;
59        Ok(config)
60    }
61
62    /// Validates a fully materialized configuration.
63    pub fn validate(&self) -> Result<(), ConfigError> {
64        validate_non_empty("server.host", &self.server.host)?;
65        validate_non_empty("logging.level", &self.logging.level)?;
66        validate_env_filter("logging.level", &self.logging.level)?;
67        validate_http_url("venice.base_url", &self.venice.base_url, false)?;
68        validate_duration_non_zero("venice.request_timeout", self.venice.request_timeout)?;
69
70        validate_duration_non_zero("session.idle_ttl", self.session.idle_ttl)?;
71        validate_duration_non_zero("session.max_ttl", self.session.max_ttl)?;
72        if self.session.idle_ttl > self.session.max_ttl {
73            return Err(ConfigError::invalid(
74                "session.idle_ttl",
75                "must be less than or equal to session.max_ttl",
76            ));
77        }
78        if self.session.max_requests == 0 {
79            return Err(ConfigError::invalid(
80                "session.max_requests",
81                "must be greater than zero",
82            ));
83        }
84        validate_header_name(
85            "session.headers.incoming_session_id",
86            &self.session.headers.incoming_session_id,
87        )?;
88
89        validate_http_url("attestation.pccs_url", &self.attestation.pccs_url, true)?;
90
91        validate_non_empty("e2ee.hkdf_info", &self.e2ee.hkdf_info)?;
92
93        if self.tools.tool_call_max_bytes == 0 {
94            return Err(ConfigError::invalid(
95                "tools.tool_call_max_bytes",
96                "must be greater than zero",
97            ));
98        }
99        validate_duration_non_zero(
100            "tools.tool_call_marker_timeout",
101            self.tools.tool_call_marker_timeout,
102        )?;
103
104        Ok(())
105    }
106
107    /// Returns the configured Venice API key.
108    pub fn venice_api_key(&self) -> Result<&SecretString, ConfigError> {
109        if self.venice.api_key.expose_secret().trim().is_empty() {
110            return Err(ConfigError::MissingApiKey);
111        }
112        Ok(&self.venice.api_key)
113    }
114}
115
116/// HTTP listener configuration for the local proxy server.
117#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
118#[serde(default, deny_unknown_fields)]
119pub struct ServerConfig {
120    pub host: String,
121    pub port: u16,
122}
123
124impl Default for ServerConfig {
125    /// Returns the default listener binding used when server config is omitted.
126    fn default() -> Self {
127        Self {
128            host: "0.0.0.0".to_owned(),
129            port: 8080,
130        }
131    }
132}
133
134/// Tracing configuration for proxy logs.
135#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
136#[serde(default, deny_unknown_fields)]
137pub struct LoggingConfig {
138    /// Tracing filter directive. Accepts simple levels like `info` or full
139    /// `tracing_subscriber::EnvFilter` directives like
140    /// `venice_e2ee_proxy=debug,tower_http=warn`.
141    pub level: String,
142}
143
144impl Default for LoggingConfig {
145    /// Returns the default tracing filter for proxy logs.
146    fn default() -> Self {
147        Self {
148            level: "info".to_owned(),
149        }
150    }
151}
152
153/// Venice upstream API client configuration.
154#[derive(Debug, Clone, Deserialize)]
155#[serde(default, deny_unknown_fields)]
156pub struct VeniceConfig {
157    pub base_url: String,
158    pub api_key: SecretString,
159    #[serde(deserialize_with = "deserialize_duration")]
160    pub request_timeout: Duration,
161}
162
163impl Default for VeniceConfig {
164    /// Returns default Venice API endpoint and timeout settings with an empty API key.
165    fn default() -> Self {
166        Self {
167            base_url: "https://api.venice.ai/api/v1".to_owned(),
168            api_key: SecretString::default(),
169            request_timeout: Duration::from_secs(30),
170        }
171    }
172}
173
174/// Proxy instance key generation configuration.
175#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
176#[serde(default, deny_unknown_fields)]
177pub struct KeysConfig {
178    pub generate_proxy_instance_key_on_startup: bool,
179}
180
181impl Default for KeysConfig {
182    /// Returns the default key policy that generates a proxy instance key at startup.
183    fn default() -> Self {
184        Self {
185            generate_proxy_instance_key_on_startup: true,
186        }
187    }
188}
189
190/// Session lifetime, reuse, and identifier-resolution configuration.
191#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
192#[serde(default, deny_unknown_fields)]
193pub struct SessionConfig {
194    #[serde(deserialize_with = "deserialize_duration")]
195    pub idle_ttl: Duration,
196    #[serde(deserialize_with = "deserialize_duration")]
197    pub max_ttl: Duration,
198    pub max_requests: u64,
199    pub fallback_scope: SessionFallbackScope,
200    pub headers: SessionHeadersConfig,
201}
202
203impl Default for SessionConfig {
204    /// Returns the default session TTLs, request budget, fallback behavior, and headers.
205    fn default() -> Self {
206        Self {
207            idle_ttl: Duration::from_secs(600),
208            max_ttl: Duration::from_secs(1_800),
209            max_requests: 100,
210            fallback_scope: SessionFallbackScope::Request,
211            headers: SessionHeadersConfig::default(),
212        }
213    }
214}
215
216/// Header names used to resolve stable agent session identifiers.
217#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
218#[serde(default, deny_unknown_fields)]
219pub struct SessionHeadersConfig {
220    pub incoming_session_id: String,
221}
222
223impl Default for SessionHeadersConfig {
224    /// Returns the default incoming session-id header.
225    fn default() -> Self {
226        Self {
227            incoming_session_id: "X-Venice-Proxy-Session-Id".to_owned(),
228        }
229    }
230}
231
232/// Fallback strategy used when a request does not include a session identifier.
233#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
234#[serde(rename_all = "snake_case")]
235pub enum SessionFallbackScope {
236    Agent,
237    #[default]
238    Request,
239    Disabled,
240}
241
242/// Attestation verification policy for Venice model-key evidence.
243#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
244#[serde(default, deny_unknown_fields)]
245pub struct AttestationConfig {
246    pub mode: AttestationMode,
247    pub require_tdx: bool,
248    pub require_nvidia: NvidiaRequirement,
249    pub allow_debug: bool,
250    pub pccs_url: String,
251}
252
253impl Default for AttestationConfig {
254    /// Returns the default attestation policy used when attestation requirements are not configured.
255    fn default() -> Self {
256        Self {
257            mode: AttestationMode::Independent,
258            require_tdx: false,
259            require_nvidia: NvidiaRequirement::Never,
260            allow_debug: false,
261            pccs_url: String::new(),
262        }
263    }
264}
265
266/// Attestation strategy exposed in proxy metadata and config.
267#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
268#[serde(rename_all = "snake_case")]
269pub enum AttestationMode {
270    #[default]
271    Independent,
272}
273
274impl AttestationMode {
275    /// Returns the lowercase metadata/header value for this attestation mode.
276    pub fn as_str(self) -> &'static str {
277        match self {
278            Self::Independent => "independent",
279        }
280    }
281}
282
283/// Policy for how NVIDIA attestation payloads are required or ignored.
284#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
285#[serde(rename_all = "snake_case")]
286pub enum NvidiaRequirement {
287    Required,
288    WhenPresent,
289    #[default]
290    Never,
291}
292
293/// E2EE codec configuration for request encryption and response decryption.
294#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
295#[serde(default, deny_unknown_fields)]
296pub struct E2eeConfig {
297    pub hkdf_info: String,
298    pub require_encrypted_response_content: bool,
299}
300
301impl Default for E2eeConfig {
302    /// Returns the default HKDF context and encrypted-response policy.
303    fn default() -> Self {
304        Self {
305            hkdf_info: "ecdsa_encryption".to_owned(),
306            require_encrypted_response_content: true,
307        }
308    }
309}
310
311/// Tool-call emulation configuration for OpenAI-style function calls.
312#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
313#[serde(default, deny_unknown_fields)]
314pub struct ToolsConfig {
315    pub enabled: bool,
316    pub mode: ToolMode,
317    pub max_retries: u32,
318    pub tool_call_max_bytes: usize,
319    #[serde(deserialize_with = "deserialize_duration")]
320    pub tool_call_marker_timeout: Duration,
321    pub validate_json_schema: bool,
322}
323
324impl Default for ToolsConfig {
325    /// Returns the default tool-emulation limits and retry policy.
326    fn default() -> Self {
327        Self {
328            enabled: true,
329            mode: ToolMode::Emulated,
330            max_retries: 2,
331            tool_call_max_bytes: 65_536,
332            tool_call_marker_timeout: Duration::from_secs(30),
333            validate_json_schema: true,
334        }
335    }
336}
337
338/// Tool handling mode used by the proxy.
339#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
340#[serde(rename_all = "snake_case")]
341pub enum ToolMode {
342    #[default]
343    Emulated,
344    None,
345}
346
347impl ToolMode {
348    /// Returns the lowercase metadata/header value for this tool mode.
349    pub fn as_str(self) -> &'static str {
350        match self {
351            Self::Emulated => "emulated",
352            Self::None => "none",
353        }
354    }
355}
356
357/// Errors returned while loading or validating proxy configuration.
358#[derive(Debug, Error)]
359pub enum ConfigError {
360    #[error("failed to load config: {0}")]
361    Figment(#[source] Box<figment::Error>),
362    #[error("Venice API key is not configured")]
363    MissingApiKey,
364    #[error("invalid config value for {field}: {message}")]
365    InvalidValue {
366        field: &'static str,
367        message: String,
368    },
369}
370
371impl From<figment::Error> for ConfigError {
372    /// Converts Figment extraction failures into configuration errors.
373    fn from(error: figment::Error) -> Self {
374        Self::Figment(Box::new(error))
375    }
376}
377
378impl ConfigError {
379    /// Creates an invalid-value error for a named configuration field.
380    fn invalid(field: &'static str, message: impl Into<String>) -> Self {
381        Self::InvalidValue {
382            field,
383            message: message.into(),
384        }
385    }
386}
387
388/// Validates that a string configuration field contains non-whitespace text.
389fn validate_non_empty(field: &'static str, value: &str) -> Result<(), ConfigError> {
390    if value.trim().is_empty() {
391        return Err(ConfigError::invalid(field, "must not be empty"));
392    }
393    Ok(())
394}
395
396/// Validates that a string configuration field is an HTTP(S) URL, optionally allowing empty values.
397fn validate_http_url(
398    field: &'static str,
399    value: &str,
400    allow_empty: bool,
401) -> Result<(), ConfigError> {
402    let value = value.trim();
403    if value.is_empty() {
404        if allow_empty {
405            return Ok(());
406        }
407        return Err(ConfigError::invalid(field, "must not be empty"));
408    }
409
410    if !(value.starts_with("https://") || value.starts_with("http://")) {
411        return Err(ConfigError::invalid(
412            field,
413            "must start with http:// or https://",
414        ));
415    }
416
417    Ok(())
418}
419
420/// Validates that a string configuration field can be used as an HTTP header name.
421fn validate_header_name(field: &'static str, value: &str) -> Result<(), ConfigError> {
422    validate_non_empty(field, value)?;
423    HeaderName::from_bytes(value.as_bytes())
424        .map_err(|_| ConfigError::invalid(field, "must be a valid HTTP header name"))?;
425    Ok(())
426}
427
428/// Validates that a duration configuration field is greater than zero.
429fn validate_duration_non_zero(field: &'static str, value: Duration) -> Result<(), ConfigError> {
430    if value == Duration::ZERO {
431        return Err(ConfigError::invalid(field, "must be greater than zero"));
432    }
433    Ok(())
434}
435
436/// Deserializes human-readable duration strings into [`Duration`] values.
437fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
438where
439    D: Deserializer<'de>,
440{
441    let value = String::deserialize(deserializer)?;
442    humantime::parse_duration(&value).map_err(de::Error::custom)
443}
444
445/// Validates that a string configuration field is accepted by `tracing_subscriber::EnvFilter`.
446fn validate_env_filter(field: &'static str, value: &str) -> Result<(), ConfigError> {
447    let value = value.trim();
448    if value.is_empty() {
449        return Ok(());
450    }
451
452    EnvFilter::try_new(value).map_err(|source| {
453        ConfigError::invalid(
454            field,
455            format!("must be a valid tracing env filter: {source}"),
456        )
457    })?;
458    Ok(())
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    fn assert_default_config_values(config: &ProxyConfig) {
466        assert_eq!(config.server.host, "0.0.0.0");
467        assert_eq!(config.server.port, 8080);
468        assert_eq!(config.logging.level, "info");
469        assert_eq!(config.venice.base_url, "https://api.venice.ai/api/v1");
470        assert_eq!(config.venice.api_key.expose_secret(), "");
471        assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
472        assert!(config.keys.generate_proxy_instance_key_on_startup);
473        assert_eq!(config.session.idle_ttl, Duration::from_secs(600));
474        assert_eq!(config.session.max_ttl, Duration::from_secs(1_800));
475        assert_eq!(config.session.max_requests, 100);
476        assert_eq!(config.session.fallback_scope, SessionFallbackScope::Request);
477        assert_eq!(
478            config.session.headers.incoming_session_id,
479            "X-Venice-Proxy-Session-Id"
480        );
481        assert_eq!(config.attestation.mode, AttestationMode::Independent);
482        assert!(!config.attestation.require_tdx);
483        assert_eq!(config.attestation.require_nvidia, NvidiaRequirement::Never);
484        assert!(!config.attestation.allow_debug);
485        assert_eq!(config.attestation.pccs_url, "");
486        assert_eq!(config.e2ee.hkdf_info, "ecdsa_encryption");
487        assert!(config.e2ee.require_encrypted_response_content);
488        assert!(config.tools.enabled);
489        assert_eq!(config.tools.mode, ToolMode::Emulated);
490        assert_eq!(config.tools.max_retries, 2);
491        assert_eq!(config.tools.tool_call_max_bytes, 65_536);
492        assert_eq!(
493            config.tools.tool_call_marker_timeout,
494            Duration::from_secs(30)
495        );
496        assert!(config.tools.validate_json_schema);
497
498        config.validate().expect("default config is valid");
499    }
500
501    #[test]
502    fn default_config_matches_expected_values() {
503        let config = ProxyConfig::default();
504
505        assert_default_config_values(&config);
506    }
507
508    #[test]
509    fn checked_in_default_config_matches_code_defaults() {
510        let config = ProxyConfig::from_toml_str(include_str!("../config/default.toml"))
511            .expect("checked-in default config should load");
512
513        assert_default_config_values(&config);
514    }
515
516    #[test]
517    fn toml_config_applies_defaults_for_missing_sections() {
518        let config = ProxyConfig::from_toml_str(
519            r#"
520            [server]
521            host = "0.0.0.0"
522            port = 8080
523
524            [tools]
525            enabled = false
526            mode = "none"
527            "#,
528        )
529        .expect("partial config should load with defaults");
530
531        assert_eq!(config.server.host, "0.0.0.0");
532        assert_eq!(config.server.port, 8080);
533        assert_eq!(config.logging.level, "info");
534        assert_eq!(config.venice.api_key.expose_secret(), "");
535        assert_eq!(config.venice.request_timeout, Duration::from_secs(30));
536        assert!(!config.tools.enabled);
537        assert_eq!(config.tools.mode, ToolMode::None);
538        assert_eq!(config.tools.tool_call_max_bytes, 65_536);
539    }
540
541    #[test]
542    fn validation_rejects_invalid_values() {
543        let err = ProxyConfig::from_toml_str(
544            r#"
545            [venice]
546            base_url = "not-valid-url"
547            "#,
548        )
549        .expect_err("invalid base URL should be rejected");
550
551        assert!(matches!(
552            err,
553            ConfigError::InvalidValue {
554                field: "venice.base_url",
555                ..
556            }
557        ));
558
559        let err = ProxyConfig::from_toml_str(
560            r#"
561            [venice]
562            request_timeout = "0s"
563            "#,
564        )
565        .expect_err("zero Venice timeout should be rejected");
566
567        assert!(matches!(
568            err,
569            ConfigError::InvalidValue {
570                field: "venice.request_timeout",
571                ..
572            }
573        ));
574    }
575
576    #[test]
577    fn logging_config_accepts_level_or_env_filter_and_rejects_invalid_filters() {
578        let config = ProxyConfig::from_toml_str(
579            r#"
580            [logging]
581            level = "debug"
582            "#,
583        )
584        .expect("logging level config should load");
585        assert_eq!(config.logging.level, "debug");
586
587        let config = ProxyConfig::from_toml_str(
588            r#"
589            [logging]
590            level = "venice_e2ee_proxy=debug,tower_http=warn"
591            "#,
592        )
593        .expect("logging env filter config should load");
594        assert_eq!(
595            config.logging.level,
596            "venice_e2ee_proxy=debug,tower_http=warn"
597        );
598
599        for level in ["", "   "] {
600            let err = ProxyConfig::from_toml_str(&format!(
601                r#"
602                [logging]
603                level = {level:?}
604                "#
605            ))
606            .expect_err("empty logging level should be rejected");
607            assert!(matches!(
608                err,
609                ConfigError::InvalidValue {
610                    field: "logging.level",
611                    ..
612                }
613            ));
614        }
615
616        let err = ProxyConfig::from_toml_str(
617            r#"
618            [logging]
619            level = "venice_e2ee_proxy=[debug"
620            "#,
621        )
622        .expect_err("invalid tracing env filter should be rejected");
623        assert!(matches!(
624            err,
625            ConfigError::InvalidValue {
626                field: "logging.level",
627                ..
628            }
629        ));
630    }
631
632    #[test]
633    fn removed_tool_marker_options_are_rejected_as_unknown_fields() {
634        let err = ProxyConfig::from_toml_str("[tools]\nmarker_start = \"<tool_call>\"\n")
635            .expect_err("removed tools.marker_start option should be rejected");
636        assert!(matches!(err, ConfigError::Figment(_)));
637    }
638
639    #[test]
640    fn missing_api_key_is_reported() {
641        let config = ProxyConfig::default();
642        let err = config
643            .venice_api_key()
644            .expect_err("missing API key should be reported");
645
646        assert!(matches!(err, ConfigError::MissingApiKey));
647        assert_eq!(err.to_string(), "Venice API key is not configured");
648    }
649
650    #[test]
651    fn api_key_debug_output_is_redacted() {
652        let config = ProxyConfig::from_toml_str(
653            r#"
654            [venice]
655            api_key = "super-secret-test-key"
656            "#,
657        )
658        .expect("config should load");
659        let key = config.venice_api_key().expect("test key should load");
660
661        assert_eq!(key.expose_secret(), "super-secret-test-key");
662        assert!(!format!("{key:?}").contains("super-secret-test-key"));
663        assert!(!format!("{config:?}").contains("super-secret-test-key"));
664    }
665}