Skip to main content

shuru_proxy/
config.rs

1use std::collections::HashMap;
2
3/// Configuration for the proxy engine.
4#[derive(Debug, Clone, Default)]
5pub struct ProxyConfig {
6    /// Secrets to inject. Key is the env var name visible to the guest.
7    /// The guest gets a random placeholder token; the proxy substitutes
8    /// the real value only when the request targets an allowed host.
9    pub secrets: HashMap<String, SecretConfig>,
10    /// Network access rules.
11    pub network: NetworkConfig,
12}
13
14/// A secret that the proxy injects into HTTP requests.
15#[derive(Debug, Clone)]
16pub struct SecretConfig {
17    /// Host environment variable to read the real value from.
18    pub from: String,
19    /// Domain patterns where this secret may be sent (e.g., "api.openai.com").
20    /// The proxy only substitutes the placeholder on requests to these hosts.
21    pub hosts: Vec<String>,
22    /// If set, use this value directly instead of reading from the host env var.
23    pub value: Option<String>,
24}
25
26/// Network access policy.
27#[derive(Debug, Clone, Default)]
28pub struct NetworkConfig {
29    /// Allowed domain patterns. Empty = allow all.
30    /// Supports wildcards: "*.openai.com", "registry.npmjs.org".
31    pub allow: Vec<String>,
32}
33
34impl ProxyConfig {
35    /// Check if a domain is allowed by the network policy.
36    /// Empty allowlist means all domains are allowed.
37    pub fn is_domain_allowed(&self, domain: &str) -> bool {
38        if self.network.allow.is_empty() {
39            return true;
40        }
41        self.network
42            .allow
43            .iter()
44            .any(|pattern| domain_matches(pattern, domain))
45    }
46
47    /// Get all secret placeholder→real value mappings for a given domain.
48    pub fn secrets_for_domain(
49        &self,
50        domain: &str,
51        placeholders: &HashMap<String, String>,
52    ) -> Vec<(String, String)> {
53        let mut substitutions = Vec::new();
54        for (name, secret) in &self.secrets {
55            if secret
56                .hosts
57                .iter()
58                .any(|pattern| domain_matches(pattern, domain))
59            {
60                if let Some(placeholder) = placeholders.get(name) {
61                    let real_value = secret
62                        .value
63                        .clone()
64                        .or_else(|| std::env::var(&secret.from).ok());
65                    if let Some(real_value) = real_value {
66                        substitutions.push((placeholder.clone(), real_value));
67                    }
68                }
69            }
70        }
71        substitutions
72    }
73}
74
75/// Simple wildcard domain matching.
76/// "*.example.com" matches "api.example.com" but not "example.com".
77/// "example.com" matches exactly "example.com".
78fn domain_matches(pattern: &str, domain: &str) -> bool {
79    if let Some(suffix) = pattern.strip_prefix("*.") {
80        domain.ends_with(suffix) && domain.len() > suffix.len() && domain.as_bytes()[domain.len() - suffix.len() - 1] == b'.'
81    } else {
82        pattern == domain
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn test_domain_matching() {
92        assert!(domain_matches("example.com", "example.com"));
93        assert!(!domain_matches("example.com", "api.example.com"));
94        assert!(domain_matches("*.example.com", "api.example.com"));
95        assert!(domain_matches("*.example.com", "deep.api.example.com"));
96        assert!(!domain_matches("*.example.com", "example.com"));
97        assert!(!domain_matches("*.example.com", "notexample.com"));
98    }
99}