synaptic_middleware/
ssrf_guard.rs1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::SynapticError;
7
8use crate::{AgentMiddleware, ToolCallRequest, ToolCaller};
9
10#[derive(Debug, Clone)]
12pub struct SsrfGuardConfig {
13 pub block_private: bool,
15 pub blocklist: HashSet<String>,
17 pub allowlist: HashSet<String>,
19 pub url_keys: Vec<String>,
21}
22
23impl Default for SsrfGuardConfig {
24 fn default() -> Self {
25 Self {
26 block_private: true,
27 blocklist: HashSet::new(),
28 allowlist: HashSet::new(),
29 url_keys: vec![
30 "url".to_string(),
31 "uri".to_string(),
32 "endpoint".to_string(),
33 "base_url".to_string(),
34 "webhook_url".to_string(),
35 ],
36 }
37 }
38}
39
40pub struct SsrfGuardMiddleware {
45 config: SsrfGuardConfig,
46}
47
48impl SsrfGuardMiddleware {
49 pub fn new(config: SsrfGuardConfig) -> Self {
50 Self { config }
51 }
52
53 fn check_url(&self, url: &str) -> Result<(), String> {
55 let host = extract_host(url).ok_or_else(|| format!("invalid URL: {}", url))?;
57
58 if self.config.allowlist.contains(&host) {
60 return Ok(());
61 }
62
63 if self.config.blocklist.contains(&host) {
65 return Err(format!("host '{}' is blocklisted", host));
66 }
67
68 if self.config.block_private {
70 if let Ok(ip) = host.parse::<IpAddr>() {
71 if is_private_ip(&ip) {
72 return Err(format!(
73 "access to private/loopback address {} is blocked",
74 ip
75 ));
76 }
77 }
78
79 let lower = host.to_lowercase();
81 if lower == "localhost"
82 || lower == "0.0.0.0"
83 || lower.ends_with(".local")
84 || lower.ends_with(".internal")
85 || lower == "metadata.google.internal"
86 || lower == "169.254.169.254"
87 {
89 return Err(format!("access to private host '{}' is blocked", host));
90 }
91 }
92
93 Ok(())
94 }
95
96 fn scan_args(&self, args: &Value) -> Result<(), String> {
98 match args {
99 Value::Object(map) => {
100 for (key, value) in map {
101 if self.config.url_keys.iter().any(|k| k == key) {
102 if let Some(url) = value.as_str() {
103 self.check_url(url)?;
104 }
105 }
106 self.scan_args(value)?;
108 }
109 }
110 Value::Array(arr) => {
111 for item in arr {
112 self.scan_args(item)?;
113 }
114 }
115 Value::String(s) => {
116 if (s.starts_with("http://") || s.starts_with("https://")) && s.len() < 2048 {
118 self.check_url(s)?;
119 }
120 }
121 _ => {}
122 }
123 Ok(())
124 }
125}
126
127#[async_trait]
128impl AgentMiddleware for SsrfGuardMiddleware {
129 async fn wrap_tool_call(
130 &self,
131 request: ToolCallRequest,
132 next: &dyn ToolCaller,
133 ) -> Result<Value, SynapticError> {
134 if let Err(reason) = self.scan_args(&request.call.arguments) {
136 return Err(SynapticError::Security(format!(
137 "SSRF blocked: {} (tool: {})",
138 reason, request.call.name
139 )));
140 }
141
142 next.call(request).await
143 }
144}
145
146fn extract_host(url: &str) -> Option<String> {
148 let stripped = url
150 .strip_prefix("https://")
151 .or_else(|| url.strip_prefix("http://"))?;
152 let host_port = stripped.split('/').next()?;
153 let host = host_port.split(':').next()?;
154 if host.is_empty() {
155 None
156 } else {
157 Some(host.to_string())
158 }
159}
160
161fn is_private_ip(ip: &IpAddr) -> bool {
163 match ip {
164 IpAddr::V4(v4) => {
165 v4.is_loopback()
166 || v4.is_private()
167 || v4.is_link_local()
168 || is_cgnat(v4)
169 || v4.is_broadcast()
170 || v4.is_unspecified()
171 }
172 IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || is_v6_private(v6),
173 }
174}
175
176fn is_cgnat(ip: &Ipv4Addr) -> bool {
177 let octets = ip.octets();
179 octets[0] == 100 && (octets[1] & 0xC0) == 64
180}
181
182fn is_v6_private(ip: &Ipv6Addr) -> bool {
183 let segments = ip.segments();
184 (segments[0] & 0xFE00) == 0xFC00
186 || (segments[0] & 0xFFC0) == 0xFE80
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn default_guard() -> SsrfGuardMiddleware {
195 SsrfGuardMiddleware::new(SsrfGuardConfig::default())
196 }
197
198 #[test]
199 fn blocks_localhost() {
200 let guard = default_guard();
201 assert!(guard.check_url("http://localhost/api").is_err());
202 assert!(guard.check_url("http://127.0.0.1/api").is_err());
203 }
204
205 #[test]
206 fn blocks_private_ips() {
207 let guard = default_guard();
208 assert!(guard.check_url("http://192.168.1.1/api").is_err());
209 assert!(guard.check_url("http://10.0.0.1/api").is_err());
210 assert!(guard.check_url("http://172.16.0.1/api").is_err());
211 }
212
213 #[test]
214 fn blocks_aws_metadata() {
215 let guard = default_guard();
216 assert!(guard
217 .check_url("http://169.254.169.254/latest/meta-data/")
218 .is_err());
219 }
220
221 #[test]
222 fn allows_public_urls() {
223 let guard = default_guard();
224 assert!(guard.check_url("https://api.openai.com/v1/chat").is_ok());
225 assert!(guard.check_url("https://example.com").is_ok());
226 }
227
228 #[test]
229 fn allowlist_overrides_private() {
230 let mut config = SsrfGuardConfig::default();
231 config.allowlist.insert("localhost".to_string());
232 let guard = SsrfGuardMiddleware::new(config);
233 assert!(guard.check_url("http://localhost/api").is_ok());
234 }
235
236 #[test]
237 fn blocklist_blocks_public() {
238 let mut config = SsrfGuardConfig::default();
239 config.blocklist.insert("evil.com".to_string());
240 let guard = SsrfGuardMiddleware::new(config);
241 assert!(guard.check_url("https://evil.com/api").is_err());
242 }
243
244 #[test]
245 fn scans_nested_args() {
246 let guard = default_guard();
247 let args = serde_json::json!({
248 "config": {
249 "url": "http://127.0.0.1/steal"
250 }
251 });
252 assert!(guard.scan_args(&args).is_err());
253 }
254
255 #[test]
256 fn extract_host_works() {
257 assert_eq!(
258 extract_host("https://example.com/path"),
259 Some("example.com".to_string())
260 );
261 assert_eq!(
262 extract_host("http://localhost:8080/api"),
263 Some("localhost".to_string())
264 );
265 assert_eq!(extract_host("not-a-url"), None);
266 }
267}