Skip to main content

tirith_core/
policy_client.rs

1use std::fmt;
2use std::time::Duration;
3
4/// Errors that can occur when fetching remote policy.
5#[derive(Debug)]
6pub enum PolicyFetchError {
7    /// Network-level error (DNS, connection refused, timeout, etc.).
8    NetworkError(String),
9    /// Authentication failure (401/403). Always treated as fatal.
10    AuthError(u16),
11    /// Server returned an error status code.
12    ServerError(String),
13    /// Response body could not be read or is not valid YAML.
14    InvalidResponse(String),
15}
16
17impl fmt::Display for PolicyFetchError {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            PolicyFetchError::NetworkError(msg) => write!(f, "network error: {msg}"),
21            PolicyFetchError::AuthError(code) => write!(f, "authentication failed (HTTP {code})"),
22            PolicyFetchError::ServerError(msg) => write!(f, "server error: {msg}"),
23            PolicyFetchError::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
24        }
25    }
26}
27
28/// Fetch remote policy YAML from the policy server.
29///
30/// Uses 5s connect timeout and 10s total timeout. The server endpoint
31/// is `{url}/api/policy/fetch` and requires Bearer token authentication.
32pub fn fetch_remote_policy(url: &str, api_key: &str) -> Result<String, PolicyFetchError> {
33    // SSRF protection: validate the URL before connecting
34    if let Err(reason) = crate::url_validate::validate_server_url(url) {
35        return Err(PolicyFetchError::NetworkError(reason));
36    }
37
38    let client = reqwest::blocking::Client::builder()
39        .connect_timeout(Duration::from_secs(5))
40        .timeout(Duration::from_secs(10))
41        .build()
42        .map_err(|e| PolicyFetchError::NetworkError(e.to_string()))?;
43
44    let endpoint = format!("{}/api/policy/fetch", url.trim_end_matches('/'));
45    let resp = client
46        .get(&endpoint)
47        .header("Authorization", format!("Bearer {api_key}"))
48        .send()
49        .map_err(|e| PolicyFetchError::NetworkError(e.to_string()))?;
50
51    match resp.status().as_u16() {
52        200 => resp
53            .text()
54            .map_err(|e| PolicyFetchError::InvalidResponse(e.to_string())),
55        401 | 403 => Err(PolicyFetchError::AuthError(resp.status().as_u16())),
56        404 => Err(PolicyFetchError::ServerError(
57            "no active policy found".into(),
58        )),
59        s => Err(PolicyFetchError::ServerError(format!(
60            "server returned HTTP {s}"
61        ))),
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_policy_fetch_error_display() {
71        let e = PolicyFetchError::NetworkError("timeout".into());
72        assert_eq!(format!("{e}"), "network error: timeout");
73
74        let e = PolicyFetchError::AuthError(401);
75        assert_eq!(format!("{e}"), "authentication failed (HTTP 401)");
76
77        let e = PolicyFetchError::ServerError("internal error".into());
78        assert_eq!(format!("{e}"), "server error: internal error");
79
80        let e = PolicyFetchError::InvalidResponse("bad body".into());
81        assert_eq!(format!("{e}"), "invalid response: bad body");
82    }
83
84    #[test]
85    fn test_fetch_invalid_url_returns_network_error() {
86        // Non-routable address should fail quickly
87        let result = fetch_remote_policy("http://192.0.2.1:1", "test-key");
88        assert!(result.is_err());
89        match result.unwrap_err() {
90            PolicyFetchError::NetworkError(_) => {} // expected
91            other => panic!("expected NetworkError, got: {other}"),
92        }
93    }
94}