Skip to main content

zagens_runtime_adapters/tools/
network_gate.rs

1//! Shared outbound network policy + SSRF helpers for model-visible tools (D16 E1-a6).
2
3use std::net::IpAddr;
4
5use crate::network_policy::{Decision, NetworkPolicy, NetworkPolicyDecider, host_from_url};
6
7/// Policy gate failure for an outbound network call.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum NetworkGateError {
10    Denied { host: String, tool: String },
11    PromptRequired { host: String, tool: String },
12}
13
14impl NetworkGateError {
15    #[must_use]
16    pub fn denial_message(&self) -> String {
17        match self {
18            Self::Denied { host, .. } => {
19                format!("network call to '{host}' blocked by network policy")
20            }
21            Self::PromptRequired { host, .. } => format!(
22                "network call to '{host}' requires approval; \
23                 re-run after `/network allow {host}` or set network.default = \"allow\" in config"
24            ),
25        }
26    }
27}
28
29/// Evaluate per-domain policy for a known host. No-op when `decider` is `None`.
30pub fn check_host_policy(
31    decider: Option<&NetworkPolicyDecider>,
32    tool_name: &str,
33    host: &str,
34) -> Result<(), NetworkGateError> {
35    let Some(decider) = decider else {
36        return Ok(());
37    };
38    match decider.evaluate(host, tool_name) {
39        Decision::Allow => Ok(()),
40        Decision::Deny => Err(NetworkGateError::Denied {
41            host: host.to_string(),
42            tool: tool_name.to_string(),
43        }),
44        Decision::Prompt => Err(NetworkGateError::PromptRequired {
45            host: host.to_string(),
46            tool: tool_name.to_string(),
47        }),
48    }
49}
50
51/// Extract host from `url` and evaluate policy. Returns the host when present.
52pub fn check_url_policy(
53    decider: Option<&NetworkPolicyDecider>,
54    tool_name: &str,
55    url: &str,
56) -> Result<Option<String>, NetworkGateError> {
57    let Some(host) = host_from_url(url) else {
58        return Ok(None);
59    };
60    check_host_policy(decider, tool_name, &host)?;
61    Ok(Some(host))
62}
63
64/// Evaluate a static [`NetworkPolicy`] (no session cache) for a known host.
65pub fn check_host_with_policy(
66    policy: &NetworkPolicy,
67    tool_name: &str,
68    host: &str,
69) -> Result<(), NetworkGateError> {
70    match policy.decide(host) {
71        Decision::Allow => Ok(()),
72        Decision::Deny => Err(NetworkGateError::Denied {
73            host: host.to_string(),
74            tool: tool_name.to_string(),
75        }),
76        Decision::Prompt => Err(NetworkGateError::PromptRequired {
77            host: host.to_string(),
78            tool: tool_name.to_string(),
79        }),
80    }
81}
82
83/// Decision for `host` against a static policy (no session cache).
84#[must_use]
85pub fn host_policy_decision(policy: &NetworkPolicy, host: &str) -> Decision {
86    policy.decide(host)
87}
88
89/// True when `url` uses http/https (case-insensitive on scheme).
90#[must_use]
91pub fn is_http_url(url: &str) -> bool {
92    let trimmed = url.trim();
93    trimmed.starts_with("http://") || trimmed.starts_with("https://")
94}
95
96/// Check if an IP address is loopback, private, link-local, cloud-metadata,
97/// multicast, or reserved — SSRF prevention for LLM-initiated fetches.
98#[must_use]
99pub fn is_restricted_ip(ip: &IpAddr) -> bool {
100    match ip {
101        IpAddr::V4(v4) => {
102            v4.is_loopback()
103                || v4.is_private()
104                || v4.is_link_local()
105                || v4.is_multicast()
106                || v4.is_broadcast()
107                || v4.is_unspecified()
108                || matches!(v4.octets(), [100, 64..=127, ..])
109                || *ip == IpAddr::V4(std::net::Ipv4Addr::new(169, 254, 169, 254))
110                || matches!(v4.octets(), [198, 18..=19, ..])
111                || v4.octets()[0] >= 240
112        }
113        IpAddr::V6(v6) => {
114            if v6.is_unspecified()
115                || matches!(v6.octets(), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, ..])
116            {
117                return true;
118            }
119            if let Some(v4) = v6.to_ipv4_mapped() {
120                return is_restricted_ip(&IpAddr::V4(v4));
121            }
122            // IPv4-compatible IPv6 (`::/96`, e.g. `::127.0.0.1`) — deprecated
123            // but still parseable; extract the embedded IPv4 and recurse.
124            // Skip `::1` (IPv6 loopback) which has all-zero first 12 octets
125            // but is NOT an IPv4-compatible address.
126            if !v6.is_loopback() {
127                let octets = v6.octets();
128                if octets[..12] == [0u8; 12] {
129                    return is_restricted_ip(&IpAddr::V4(std::net::Ipv4Addr::new(
130                        octets[12], octets[13], octets[14], octets[15],
131                    )));
132                }
133            }
134            v6.is_loopback()
135                || v6.is_multicast()
136                || matches!(v6.segments(), [0xfc00..=0xfdff, ..])
137                || matches!(v6.segments(), [0xfe80..=0xfebf, ..])
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::network_policy::{Decision, NetworkPolicy, NetworkPolicyDecider};
146
147    #[test]
148    fn restricted_ip_detects_loopback() {
149        assert!(is_restricted_ip(&"127.0.0.1".parse().unwrap()));
150        assert!(is_restricted_ip(&"::1".parse().unwrap()));
151    }
152
153    #[test]
154    fn restricted_ip_detects_private_ranges() {
155        assert!(is_restricted_ip(&"10.0.0.1".parse().unwrap()));
156        assert!(is_restricted_ip(&"172.16.0.1".parse().unwrap()));
157        assert!(is_restricted_ip(&"192.168.1.1".parse().unwrap()));
158    }
159
160    #[test]
161    fn restricted_ip_detects_metadata_and_cgnat() {
162        assert!(is_restricted_ip(&"169.254.169.254".parse().unwrap()));
163        assert!(is_restricted_ip(&"100.64.0.1".parse().unwrap()));
164        assert!(!is_restricted_ip(&"100.63.0.1".parse().unwrap()));
165    }
166
167    #[test]
168    fn restricted_ip_detects_link_local() {
169        assert!(is_restricted_ip(&"169.254.1.1".parse().unwrap()));
170    }
171
172    #[test]
173    fn restricted_ip_detects_ipv6_ula() {
174        assert!(is_restricted_ip(&"fc00::1".parse().unwrap()));
175        assert!(is_restricted_ip(&"fd12:3456::1".parse().unwrap()));
176    }
177
178    #[test]
179    fn restricted_ip_detects_unspecified() {
180        assert!(is_restricted_ip(&"::".parse().unwrap()));
181    }
182
183    #[test]
184    fn restricted_ip_allows_public() {
185        assert!(!is_restricted_ip(&"1.1.1.1".parse().unwrap()));
186        assert!(!is_restricted_ip(&"93.184.216.34".parse().unwrap()));
187        assert!(!is_restricted_ip(&"2606:4700::1".parse().unwrap()));
188    }
189
190    #[test]
191    fn restricted_ip_detects_ipv4_mapped_private() {
192        assert!(is_restricted_ip(&"::ffff:10.0.0.1".parse().unwrap()));
193        assert!(is_restricted_ip(&"::ffff:169.254.169.254".parse().unwrap()));
194    }
195
196    #[test]
197    fn restricted_ip_detects_ipv4_compatible_private() {
198        assert!(is_restricted_ip(&"::127.0.0.1".parse().unwrap()));
199        assert!(is_restricted_ip(&"::10.0.0.1".parse().unwrap()));
200        assert!(is_restricted_ip(&"::169.254.169.254".parse().unwrap()));
201    }
202
203    #[test]
204    fn check_url_policy_denies_blocked_host() {
205        let policy = NetworkPolicy {
206            default: Decision::Allow.into(),
207            allow: vec![],
208            deny: vec!["example.com".into()],
209            audit: false,
210        };
211        let decider = NetworkPolicyDecider::with_default_audit(policy);
212        let err = check_url_policy(Some(&decider), "fetch_url", "https://example.com/private")
213            .expect_err("deny");
214        assert!(matches!(err, NetworkGateError::Denied { .. }));
215    }
216
217    #[test]
218    fn check_host_with_policy_denies_blocked_host() {
219        let policy = NetworkPolicy {
220            default: Decision::Allow.into(),
221            allow: vec![],
222            deny: vec!["example.com".into()],
223            audit: false,
224        };
225        let err =
226            check_host_with_policy(&policy, "skills_install", "example.com").expect_err("deny");
227        assert!(matches!(err, NetworkGateError::Denied { .. }));
228    }
229
230    #[test]
231    fn check_host_policy_allows_when_decider_missing() {
232        check_host_policy(None, "fetch_url", "example.com").expect("permissive default");
233    }
234}