1#![allow(dead_code)]
6
7use std::fs::{self, OpenOptions};
44use std::io::Write;
45use std::path::{Path, PathBuf};
46use std::sync::{Arc, Mutex};
47
48use chrono::Utc;
49use serde::{Deserialize, Serialize};
50use thiserror::Error;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum Decision {
55 Allow,
57 Deny,
59 Prompt,
61}
62
63impl Decision {
64 #[must_use]
66 pub fn as_str(self) -> &'static str {
67 match self {
68 Self::Allow => "Allow",
69 Self::Deny => "Deny",
70 Self::Prompt => "Prompt",
71 }
72 }
73
74 #[must_use]
77 pub fn parse(value: &str) -> Self {
78 match value.trim().to_ascii_lowercase().as_str() {
79 "allow" => Self::Allow,
80 "deny" | "block" => Self::Deny,
81 _ => Self::Prompt,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct NetworkPolicy {
92 #[serde(default = "default_decision")]
94 pub default: DecisionToml,
95 #[serde(default)]
97 pub allow: Vec<String>,
98 #[serde(default)]
100 pub deny: Vec<String>,
101 #[serde(default = "default_audit")]
103 pub audit: bool,
104}
105
106fn default_decision() -> DecisionToml {
107 DecisionToml::Prompt
108}
109
110fn default_audit() -> bool {
111 true
112}
113
114impl Default for NetworkPolicy {
115 fn default() -> Self {
116 Self {
117 default: DecisionToml::Prompt,
118 allow: Vec::new(),
119 deny: Vec::new(),
120 audit: true,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
129#[serde(rename_all = "lowercase")]
130pub enum DecisionToml {
131 Allow,
132 Deny,
133 Prompt,
134}
135
136impl From<DecisionToml> for Decision {
137 fn from(value: DecisionToml) -> Self {
138 match value {
139 DecisionToml::Allow => Self::Allow,
140 DecisionToml::Deny => Self::Deny,
141 DecisionToml::Prompt => Self::Prompt,
142 }
143 }
144}
145
146impl From<Decision> for DecisionToml {
147 fn from(value: Decision) -> Self {
148 match value {
149 Decision::Allow => Self::Allow,
150 Decision::Deny => Self::Deny,
151 Decision::Prompt => Self::Prompt,
152 }
153 }
154}
155
156impl NetworkPolicy {
157 #[must_use]
163 pub fn decide(&self, host: &str) -> Decision {
164 let normalized = normalize_host(host);
165 if normalized.is_empty() {
166 return self.default.into();
169 }
170 if self
171 .deny
172 .iter()
173 .any(|entry| host_matches(entry, &normalized))
174 {
175 return Decision::Deny;
176 }
177 if self
178 .allow
179 .iter()
180 .any(|entry| host_matches(entry, &normalized))
181 {
182 return Decision::Allow;
183 }
184 self.default.into()
185 }
186
187 pub fn add_allow(&mut self, host: &str) {
190 let normalized = normalize_host(host);
191 if normalized.is_empty() {
192 return;
193 }
194 if !self
195 .allow
196 .iter()
197 .any(|existing| normalize_host(existing) == normalized)
198 {
199 self.allow.push(normalized);
200 }
201 }
202
203 #[must_use]
205 pub fn audit_enabled(&self) -> bool {
206 self.audit
207 }
208}
209
210fn normalize_host(host: &str) -> String {
216 let trimmed = host.trim().trim_end_matches('.').to_ascii_lowercase();
217 if let Some(rest) = trimmed.strip_prefix("*.") {
218 format!(".{rest}")
219 } else {
220 trimmed
221 }
222}
223
224fn host_matches(entry: &str, normalized_host: &str) -> bool {
226 let entry_norm = normalize_host(entry);
227 if let Some(suffix) = entry_norm.strip_prefix('.') {
228 if suffix.is_empty() {
231 return false;
232 }
233 normalized_host.ends_with(&format!(".{suffix}"))
234 } else {
235 entry_norm == normalized_host
236 }
237}
238
239#[derive(Debug, Clone)]
241pub struct NetworkAuditor {
242 path: PathBuf,
243 enabled: bool,
244}
245
246impl NetworkAuditor {
247 #[must_use]
249 pub fn new(path: PathBuf, enabled: bool) -> Self {
250 Self { path, enabled }
251 }
252
253 #[must_use]
256 pub fn default_path(enabled: bool) -> Option<Self> {
257 Some(Self::new(
258 zagens_config::user_data_path_or_relative("audit.log"),
259 enabled,
260 ))
261 }
262
263 pub fn record(&self, host: &str, tool: &str, decision_label: &str) {
266 if !self.enabled {
267 return;
268 }
269 if let Err(err) = self.try_record(host, tool, decision_label) {
270 eprintln!("network audit write failed: {err}");
271 }
272 }
273
274 fn try_record(&self, host: &str, tool: &str, decision_label: &str) -> std::io::Result<()> {
275 if let Some(parent) = self.path.parent() {
276 fs::create_dir_all(parent)?;
277 }
278 let mut file = OpenOptions::new()
279 .create(true)
280 .append(true)
281 .open(&self.path)?;
282 writeln!(
283 file,
284 "{ts} network {host} {tool} {decision}",
285 ts = Utc::now().to_rfc3339(),
286 host = sanitize_field(host),
287 tool = sanitize_field(tool),
288 decision = decision_label,
289 )
290 }
291
292 #[must_use]
294 pub fn path(&self) -> &Path {
295 &self.path
296 }
297}
298
299fn sanitize_field(s: &str) -> String {
301 s.chars()
302 .map(|c| if c.is_whitespace() { '_' } else { c })
303 .collect()
304}
305
306#[derive(Debug, Default, Clone)]
309pub struct NetworkSessionCache {
310 inner: Arc<Mutex<NetworkSessionCacheInner>>,
311}
312
313#[derive(Debug, Default)]
314struct NetworkSessionCacheInner {
315 approved: std::collections::HashSet<String>,
316 denied: std::collections::HashSet<String>,
317}
318
319impl NetworkSessionCache {
320 #[must_use]
322 pub fn new() -> Self {
323 Self::default()
324 }
325
326 #[must_use]
328 pub fn is_approved(&self, host: &str) -> bool {
329 let normalized = normalize_host(host);
330 self.inner
331 .lock()
332 .map(|guard| guard.approved.contains(&normalized))
333 .unwrap_or(false)
334 }
335
336 #[must_use]
338 pub fn is_denied(&self, host: &str) -> bool {
339 let normalized = normalize_host(host);
340 self.inner
341 .lock()
342 .map(|guard| guard.denied.contains(&normalized))
343 .unwrap_or(false)
344 }
345
346 pub fn approve(&self, host: &str) {
348 let normalized = normalize_host(host);
349 if let Ok(mut guard) = self.inner.lock() {
350 guard.denied.remove(&normalized);
351 guard.approved.insert(normalized);
352 }
353 }
354
355 pub fn deny(&self, host: &str) {
357 let normalized = normalize_host(host);
358 if let Ok(mut guard) = self.inner.lock() {
359 guard.approved.remove(&normalized);
360 guard.denied.insert(normalized);
361 }
362 }
363}
364
365#[derive(Debug, Clone, Error)]
367#[error("network call to '{0}' blocked by network policy")]
368pub struct NetworkDenied(pub String);
369
370impl NetworkDenied {
371 #[must_use]
373 pub fn host(&self) -> &str {
374 &self.0
375 }
376}
377
378#[derive(Debug, Clone)]
383pub struct NetworkPolicyDecider {
384 policy: NetworkPolicy,
385 cache: NetworkSessionCache,
386 auditor: Option<NetworkAuditor>,
387}
388
389impl NetworkPolicyDecider {
390 #[must_use]
392 pub fn new(policy: NetworkPolicy, auditor: Option<NetworkAuditor>) -> Self {
393 Self {
394 policy,
395 cache: NetworkSessionCache::new(),
396 auditor,
397 }
398 }
399
400 #[must_use]
403 pub fn with_default_audit(policy: NetworkPolicy) -> Self {
404 let audit_enabled = policy.audit_enabled();
405 let auditor = if audit_enabled {
406 NetworkAuditor::default_path(true)
407 } else {
408 None
409 };
410 Self::new(policy, auditor)
411 }
412
413 #[must_use]
415 pub fn policy(&self) -> &NetworkPolicy {
416 &self.policy
417 }
418
419 #[must_use]
421 pub fn cache(&self) -> &NetworkSessionCache {
422 &self.cache
423 }
424
425 #[must_use]
431 pub fn evaluate(&self, host: &str, tool: &str) -> Decision {
432 let normalized = normalize_host(host);
433 if normalized.is_empty() {
434 return self.policy.default.into();
435 }
436 if self.cache.is_denied(&normalized) {
437 self.audit_record(&normalized, tool, "Deny");
438 return Decision::Deny;
439 }
440 if self.cache.is_approved(&normalized) {
441 self.audit_record(&normalized, tool, "Allow");
442 return Decision::Allow;
443 }
444 let decision = self.policy.decide(&normalized);
445 match decision {
446 Decision::Allow => self.audit_record(&normalized, tool, "Allow"),
447 Decision::Deny => self.audit_record(&normalized, tool, "Deny"),
448 Decision::Prompt => {}
449 }
450 decision
451 }
452
453 pub fn approve_session(&self, host: &str, tool: &str) {
456 self.cache.approve(host);
457 self.audit_record(host, tool, "Prompt-Approved");
458 }
459
460 pub fn deny_session(&self, host: &str, tool: &str) {
462 self.cache.deny(host);
463 self.audit_record(host, tool, "Prompt-Denied");
464 }
465
466 pub fn approve_persistent(&mut self, host: &str, tool: &str) -> &NetworkPolicy {
470 self.policy.add_allow(host);
471 self.cache.approve(host);
472 self.audit_record(host, tool, "Prompt-Approved");
473 &self.policy
474 }
475
476 fn audit_record(&self, host: &str, tool: &str, label: &str) {
477 if let Some(auditor) = self.auditor.as_ref() {
478 auditor.record(host, tool, label);
479 }
480 }
481}
482
483#[must_use]
486pub fn host_from_url(url: &str) -> Option<String> {
487 let parsed = reqwest::Url::parse(url.trim()).ok()?;
488 parsed.host_str().map(str::to_ascii_lowercase)
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use tempfile::tempdir;
495
496 fn mk(default: Decision, allow: &[&str], deny: &[&str]) -> NetworkPolicy {
497 NetworkPolicy {
498 default: default.into(),
499 allow: allow.iter().map(|s| (*s).to_string()).collect(),
500 deny: deny.iter().map(|s| (*s).to_string()).collect(),
501 audit: false,
502 }
503 }
504
505 #[test]
506 fn exact_match_in_allow_returns_allow() {
507 let p = mk(Decision::Deny, &["api.deepseek.com"], &[]);
508 assert_eq!(p.decide("api.deepseek.com"), Decision::Allow);
509 }
510
511 #[test]
512 fn unknown_host_returns_default() {
513 let p = mk(Decision::Deny, &["api.deepseek.com"], &[]);
514 assert_eq!(p.decide("evil.example.com"), Decision::Deny);
515
516 let p2 = mk(Decision::Prompt, &[], &[]);
517 assert_eq!(p2.decide("anything.example"), Decision::Prompt);
518 }
519
520 #[test]
521 fn deny_wins_precedence() {
522 let p = mk(Decision::Prompt, &["api.example.com"], &["api.example.com"]);
524 assert_eq!(p.decide("api.example.com"), Decision::Deny);
525 }
526
527 #[test]
528 fn deny_wins_with_subdomain_rules() {
529 let p = mk(Decision::Allow, &["api.example.com"], &[".example.com"]);
531 assert_eq!(p.decide("api.example.com"), Decision::Deny);
532 }
533
534 #[test]
535 fn subdomain_wildcard_matches_subdomain_only() {
536 let p = mk(Decision::Deny, &[".example.com"], &[]);
537 assert_eq!(p.decide("api.example.com"), Decision::Allow);
538 assert_eq!(p.decide("a.b.example.com"), Decision::Allow);
539 assert_eq!(p.decide("example.com"), Decision::Deny);
541 }
542
543 #[test]
544 fn star_dot_subdomain_alias_is_accepted() {
545 let p = mk(Decision::Deny, &["*.example.com"], &[]);
546 assert_eq!(p.decide("api.example.com"), Decision::Allow);
547 assert_eq!(p.decide("example.com"), Decision::Deny);
548 }
549
550 #[test]
551 fn host_match_is_case_insensitive() {
552 let p = mk(Decision::Deny, &["API.DeepSeek.com"], &[]);
553 assert_eq!(p.decide("api.deepseek.com"), Decision::Allow);
554 }
555
556 #[test]
557 fn trailing_dot_is_ignored() {
558 let p = mk(Decision::Deny, &["api.deepseek.com"], &[]);
559 assert_eq!(p.decide("api.deepseek.com."), Decision::Allow);
560 }
561
562 #[test]
563 fn empty_host_uses_default() {
564 let p = mk(Decision::Deny, &["api.example.com"], &[]);
565 assert_eq!(p.decide(""), Decision::Deny);
566 assert_eq!(p.decide(" "), Decision::Deny);
567 }
568
569 #[test]
570 fn add_allow_dedupes_case_insensitively() {
571 let mut p = mk(Decision::Deny, &[], &[]);
572 p.add_allow("Example.COM");
573 p.add_allow("example.com");
574 assert_eq!(p.allow.len(), 1);
575 assert_eq!(p.allow[0], "example.com");
576 }
577
578 #[test]
579 fn host_from_url_extracts_host() {
580 assert_eq!(
581 host_from_url("https://api.deepseek.com/health"),
582 Some("api.deepseek.com".to_string())
583 );
584 assert_eq!(
585 host_from_url("http://Example.COM:8080/x"),
586 Some("example.com".to_string())
587 );
588 assert_eq!(host_from_url("not a url"), None);
589 }
590
591 #[test]
592 fn auditor_writes_one_line_per_call() {
593 let dir = tempdir().expect("tempdir");
594 let path = dir.path().join("audit.log");
595 let auditor = NetworkAuditor::new(path.clone(), true);
596 auditor.record("api.example.com", "fetch_url", "Allow");
597 auditor.record("evil.example.com", "fetch_url", "Deny");
598 let body = std::fs::read_to_string(&path).expect("read");
599 let lines: Vec<&str> = body.lines().collect();
600 assert_eq!(lines.len(), 2);
601 for line in &lines {
602 let parts: Vec<&str> = line.split_whitespace().collect();
604 assert!(parts.len() >= 5, "line shape: {line}");
605 assert_eq!(parts[1], "network");
606 }
607 assert!(lines[0].contains("api.example.com"));
608 assert!(lines[0].ends_with("Allow"));
609 assert!(lines[1].contains("evil.example.com"));
610 assert!(lines[1].ends_with("Deny"));
611 }
612
613 #[test]
614 fn auditor_disabled_writes_nothing() {
615 let dir = tempdir().expect("tempdir");
616 let path = dir.path().join("audit.log");
617 let auditor = NetworkAuditor::new(path.clone(), false);
618 auditor.record("api.example.com", "fetch_url", "Allow");
619 assert!(!path.exists() || std::fs::read_to_string(&path).unwrap().is_empty());
620 }
621
622 #[test]
623 fn session_cache_short_circuits_evaluate() {
624 let policy = mk(Decision::Prompt, &[], &[]);
625 let decider = NetworkPolicyDecider::new(policy, None);
626 assert_eq!(
628 decider.evaluate("api.example.com", "fetch_url"),
629 Decision::Prompt
630 );
631 decider.approve_session("api.example.com", "fetch_url");
632 assert_eq!(
634 decider.evaluate("api.example.com", "fetch_url"),
635 Decision::Allow
636 );
637 }
638
639 #[test]
640 fn approve_persistent_writes_back_to_policy() {
641 let policy = mk(Decision::Prompt, &[], &[]);
642 let mut decider = NetworkPolicyDecider::new(policy, None);
643 decider.approve_persistent("api.example.com", "fetch_url");
644 assert!(
645 decider
646 .policy()
647 .allow
648 .iter()
649 .any(|h| h == "api.example.com")
650 );
651 assert_eq!(
653 decider.evaluate("api.example.com", "fetch_url"),
654 Decision::Allow
655 );
656 }
657
658 #[test]
659 fn deny_session_blocks_subsequent_evaluate() {
660 let policy = mk(Decision::Allow, &[], &[]);
661 let decider = NetworkPolicyDecider::new(policy, None);
662 decider.deny_session("evil.example.com", "fetch_url");
663 assert_eq!(
664 decider.evaluate("evil.example.com", "fetch_url"),
665 Decision::Deny
666 );
667 }
668
669 #[test]
670 fn audit_records_terminal_decisions_through_decider() {
671 let dir = tempdir().expect("tempdir");
672 let auditor = NetworkAuditor::new(dir.path().join("audit.log"), true);
673 let policy = mk(Decision::Deny, &["api.deepseek.com"], &[]);
674 let decider = NetworkPolicyDecider::new(policy, Some(auditor));
675
676 let allow = decider.evaluate("api.deepseek.com", "fetch_url");
677 let deny = decider.evaluate("evil.example.com", "fetch_url");
678 assert_eq!(allow, Decision::Allow);
679 assert_eq!(deny, Decision::Deny);
680
681 let body = std::fs::read_to_string(dir.path().join("audit.log")).expect("read");
682 let lines: Vec<&str> = body.lines().collect();
683 assert_eq!(lines.len(), 2);
684 assert!(lines[0].ends_with("Allow"));
685 assert!(lines[1].ends_with("Deny"));
686 }
687
688 #[test]
689 fn decision_parse_unknown_falls_back_to_prompt() {
690 assert_eq!(Decision::parse("allow"), Decision::Allow);
691 assert_eq!(Decision::parse("Deny"), Decision::Deny);
692 assert_eq!(Decision::parse("BLOCK"), Decision::Deny);
693 assert_eq!(Decision::parse("prompt"), Decision::Prompt);
694 assert_eq!(Decision::parse("garbage"), Decision::Prompt);
695 }
696
697 #[test]
698 fn network_denied_carries_host() {
699 let err = NetworkDenied("api.example.com".to_string());
700 assert_eq!(err.host(), "api.example.com");
701 assert!(format!("{err}").contains("api.example.com"));
702 }
703}