1use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use sha2::{Digest, Sha256};
8use thiserror::Error;
9use uuid::Uuid;
10use vigil_audit::{ApprovalTargetContext, EngineDegradedPayload, Ledger, Result as AuditResult};
11use vigil_policy::{
12 DescriptorState, PolicyAction, PolicyContext, PolicyDecision, PolicyEngine, PolicyError,
13};
14use vigil_types::{ApprovalRequest, DecisionKind, DecisionRecord, EffectVector, ToolInvocation};
15
16use crate::extract::{
17 BrowserActionExtractor, EffectExtractor, EmailExtractor, PathExtractor, SecretRefExtractor,
18 ShellExtractor, SqlExtractor, UrlExtractor,
19};
20use crate::preflight::{run_preflight, EngineStatusReport, PreflightError};
21use crate::scorer::{DescriptorOracle, DescriptorStatus, RiskScorer};
22
23#[derive(Debug, Error)]
25#[non_exhaustive]
26pub enum FirewallError {
27 #[error("policy: {0}")]
29 Policy(#[from] PolicyError),
30
31 #[error("audit: {0}")]
33 Audit(#[from] vigil_audit::AuditError),
34
35 #[error(
39 "config: `allowed_scopes` must not reuse reserved key `allowed_hosts` \
40 (host allowlist is managed via `FirewallConfig::allowed_hosts`)"
41 )]
42 ReservedScopeKey,
43
44 #[error("preflight scan failed: {reason}")]
52 PreflightScanFailed {
53 reason: String,
55 },
56}
57
58#[derive(Debug, Clone)]
60pub struct FirewallConfig {
61 pub project_roots: Vec<String>,
63 pub allowed_hosts: Vec<String>,
66 pub allowed_scopes: HashMap<String, Vec<String>>,
77 pub approval_ttl_secs: u64,
79
80 pub long_text_threshold: usize,
90}
91
92impl Default for FirewallConfig {
93 fn default() -> Self {
94 Self {
95 project_roots: Vec::new(),
96 allowed_hosts: Vec::new(),
97 allowed_scopes: HashMap::new(),
98 approval_ttl_secs: 300,
99 long_text_threshold: 100,
102 }
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
116#[non_exhaustive]
117pub enum OAuthScopeContext {
118 NonOauth,
120 Scopes(Vec<String>),
122}
123
124impl OAuthScopeContext {
125 fn into_policy_requested_scopes(self) -> Option<Vec<String>> {
126 match self {
127 OAuthScopeContext::NonOauth => None,
128 OAuthScopeContext::Scopes(s) => Some(s),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
140#[non_exhaustive]
141pub enum FirewallOutcome {
142 Allowed {
144 decision: DecisionRecord,
146 effects: EffectVector,
148 },
149 Denied {
151 decision: DecisionRecord,
153 effects: EffectVector,
155 },
156 Approve {
158 decision: DecisionRecord,
160 effects: EffectVector,
162 approval: ApprovalRequest,
164 },
165}
166
167impl FirewallOutcome {
168 pub fn decision_kind(&self) -> DecisionKind {
170 match self {
171 FirewallOutcome::Allowed { .. } => DecisionKind::Allow,
172 FirewallOutcome::Denied { .. } => DecisionKind::Deny,
173 FirewallOutcome::Approve { .. } => DecisionKind::Approve,
174 }
175 }
176}
177
178pub struct Firewall {
180 ledger: Arc<Ledger>,
181 policy: PolicyEngine,
182 scorer: RiskScorer,
183 extractors: Vec<Box<dyn EffectExtractor>>,
184 config: FirewallConfig,
185 scanner: Arc<dyn crate::preflight::PiiScanner>,
188 audit_persist_failures: Arc<crate::preflight::AuditPersistCounter>,
190}
191
192impl std::fmt::Debug for Firewall {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("Firewall")
195 .field("policy_rule_count", &self.policy.len())
196 .field("extractor_count", &self.extractors.len())
197 .field("config", &self.config)
198 .field(
199 "audit_persist_failures",
200 &self
201 .audit_persist_failures
202 .load(std::sync::atomic::Ordering::Relaxed),
203 )
204 .finish()
205 }
206}
207
208impl Firewall {
209 pub fn new(ledger: Arc<Ledger>, policy: PolicyEngine, config: FirewallConfig) -> Self {
212 Self::with_scanner(
213 ledger,
214 policy,
215 config,
216 crate::preflight::default_scanner_arc(),
217 )
218 }
219
220 pub fn with_scanner(
223 ledger: Arc<Ledger>,
224 policy: PolicyEngine,
225 config: FirewallConfig,
226 scanner: Arc<dyn crate::preflight::PiiScanner>,
227 ) -> Self {
228 let roots: Vec<PathBuf> = config.project_roots.iter().map(PathBuf::from).collect();
229 let scorer = RiskScorer::new(config.allowed_hosts.clone(), config.project_roots.clone());
230 let extractors: Vec<Box<dyn EffectExtractor>> = vec![
231 Box::new(PathExtractor::new(roots)),
232 Box::new(UrlExtractor),
233 Box::new(SqlExtractor),
234 Box::new(ShellExtractor),
235 Box::new(EmailExtractor),
236 Box::new(SecretRefExtractor),
237 Box::new(BrowserActionExtractor),
238 ];
239 Self {
240 ledger,
241 policy,
242 scorer,
243 extractors,
244 config,
245 scanner,
246 audit_persist_failures: Arc::new(crate::preflight::AuditPersistCounter::new(0)),
247 }
248 }
249
250 pub fn audit_persist_failures(&self) -> u64 {
255 self.audit_persist_failures
256 .load(std::sync::atomic::Ordering::Relaxed)
257 }
258
259 pub fn evaluate(
276 &self,
277 call: &ToolInvocation,
278 oracle: &dyn DescriptorOracle,
279 scope_ctx: OAuthScopeContext,
280 ) -> Result<FirewallOutcome, FirewallError> {
281 const RESERVED_ALLOWLIST_KEYS: &[&str] = &["allowed_hosts"];
288 if self
289 .config
290 .allowed_scopes
291 .keys()
292 .any(|k| RESERVED_ALLOWLIST_KEYS.contains(&k.as_str()))
293 {
294 return Err(FirewallError::ReservedScopeKey);
295 }
296
297 let mut effects = EffectVector::default();
299 for ex in &self.extractors {
300 ex.extract(call, &mut effects);
301 }
302 dedup_effects(&mut effects);
303
304 let descriptor = oracle.status(&call.server_id, &call.tool_name, &call.descriptor_hash);
306
307 let (risk_score, score_reasons) = self.scorer.score(&effects, descriptor);
309
310 let preflight = run_preflight(
320 self.scanner.as_ref(),
321 &self.ledger,
322 &self.audit_persist_failures,
323 &call.session_id,
324 &call.args,
325 self.config.long_text_threshold,
326 )
327 .map_err(|e| match e {
328 PreflightError::ScanFailed { reason } => FirewallError::PreflightScanFailed { reason },
329 })?;
330
331 let base_risk = risk_score;
333 let pii_delta = preflight.risk_delta;
334 let risk_with_pii = (base_risk as u32).saturating_add(pii_delta).min(100) as u8;
335
336 #[allow(unreachable_patterns)]
340 let descriptor_state = match descriptor {
341 DescriptorStatus::ApprovedStable => DescriptorState::ApprovedStable,
342 DescriptorStatus::FirstSeen => DescriptorState::FirstSeen,
343 DescriptorStatus::Drifted => DescriptorState::Drifted,
344 _ => DescriptorState::FirstSeen,
346 };
347 let mut ctx = PolicyContext {
351 risk_score: risk_with_pii,
354 descriptor: descriptor_state,
355 requested_scopes: scope_ctx.into_policy_requested_scopes(),
358 pii_findings: preflight.pii_summary.clone(),
360 ..Default::default()
361 };
362 ctx.roots
363 .insert("project_roots".into(), self.config.project_roots.clone());
364 ctx.allowlists
366 .insert("allowed_hosts".into(), self.config.allowed_hosts.clone());
367 for (k, v) in &self.config.allowed_scopes {
372 ctx.allowlists.insert(k.clone(), v.clone());
373 }
374 let pdec: PolicyDecision = self.policy.evaluate(&effects, &ctx)?;
375
376 let preflight_reason = format!(
382 "preflight: base_risk={} pii_delta={} final={} labels={}",
383 base_risk,
384 pii_delta,
385 risk_with_pii,
386 if preflight.pii_summary.is_empty() {
387 "(none)".to_string()
388 } else {
389 preflight.counts_csv()
390 }
391 );
392 let mut decision_reasons = merge_reasons(&score_reasons, &pdec.reasons);
393 decision_reasons.push(preflight_reason);
394
395 let degraded_status = match preflight.engine_status {
404 EngineStatusReport::DegradedTimeout | EngineStatusReport::DegradedError => {
405 let stable = preflight.engine_status.stable_code();
406 decision_reasons.push(format!("engine.status={stable}"));
407 Some(preflight.engine_status)
408 }
409 EngineStatusReport::Ok | EngineStatusReport::Unsupported => None,
410 };
411
412 let decision_id = Uuid::new_v4().to_string();
414 let decision = DecisionRecord {
415 decision_id: decision_id.clone(),
416 invocation_id: call.invocation_id.clone(),
417 decision: map_action(pdec.action),
418 risk_score: risk_with_pii,
419 reasons: decision_reasons,
420 policy_ids: pdec.policy_ids.clone(),
421 created_at: now_secs(),
422 };
423 let _ = self
424 .ledger
425 .record_decision(&call.session_id, &decision, &effects)?;
426
427 if let Some(status) = degraded_status {
436 let payload = EngineDegradedPayload {
437 engine_id: "firewall_preflight_scanner".to_string(),
438 status: status.stable_code().to_string(),
439 reason_code: status.stable_code().to_string(),
440 budget_ms: None,
441 elapsed_ms: None,
442 fail_closed_decision: "fall_back_hard_only".to_string(),
443 decision_id: decision_id.clone(),
444 };
445 if self
446 .ledger
447 .record_engine_degraded(&call.session_id, &payload)
448 .is_err()
449 {
450 self.audit_persist_failures
451 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
452 }
453 }
454
455 match pdec.action {
457 PolicyAction::Allow => Ok(FirewallOutcome::Allowed { decision, effects }),
458 PolicyAction::Deny => Ok(FirewallOutcome::Denied { decision, effects }),
459 PolicyAction::Approve => {
460 let (title, summary) = summarize(call, &effects, &decision);
461 let args_hash = compute_args_hash(&call.args)?;
463 let ctx = ApprovalTargetContext {
464 server_id: Some(&call.server_id),
465 tool_name: Some(&call.tool_name),
466 args_hash: Some(&args_hash),
467 };
468 let approval: AuditResult<ApprovalRequest> = self.ledger.create_approval(
469 &call.session_id,
470 &decision,
471 &effects,
472 &title,
473 &summary,
474 self.config.approval_ttl_secs,
475 ctx,
476 );
477 let approval = approval?;
478 Ok(FirewallOutcome::Approve {
479 decision,
480 effects,
481 approval,
482 })
483 }
484 _ => Ok(FirewallOutcome::Denied { decision, effects }),
486 }
487 }
488}
489
490fn dedup_effects(e: &mut EffectVector) {
491 let mut seen = std::collections::HashSet::new();
493 e.effects.retain(|k| seen.insert(*k));
494 e.paths_read.sort();
495 e.paths_read.dedup();
496 e.paths_write.sort();
497 e.paths_write.dedup();
498 e.network_hosts.sort();
499 e.network_hosts.dedup();
500 e.secret_refs.sort();
501 e.secret_refs.dedup();
502 e.recipients.sort();
503 e.recipients.dedup();
504}
505
506fn map_action(a: PolicyAction) -> DecisionKind {
507 match a {
508 PolicyAction::Allow => DecisionKind::Allow,
509 PolicyAction::Deny => DecisionKind::Deny,
510 PolicyAction::Approve => DecisionKind::Approve,
511 _ => DecisionKind::Deny,
513 }
514}
515
516fn merge_reasons(score: &[String], policy: &[String]) -> Vec<String> {
517 let mut out = Vec::with_capacity(score.len() + policy.len());
518 out.extend(score.iter().cloned());
519 out.extend(policy.iter().cloned());
520 out
521}
522
523fn summarize(
524 call: &ToolInvocation,
525 effects: &EffectVector,
526 dec: &DecisionRecord,
527) -> (String, String) {
528 let title = format!("{} on {}", call.tool_name, call.server_id);
529 let mut parts = Vec::new();
530 parts.push(format!("risk {}/100", dec.risk_score));
531 if !effects.paths_write.is_empty() {
532 parts.push(format!("writes: {}", effects.paths_write.join(", ")));
533 }
534 if !effects.paths_read.is_empty() {
535 parts.push(format!("reads: {}", effects.paths_read.len()));
536 }
537 if !effects.network_hosts.is_empty() {
538 parts.push(format!("hosts: {}", effects.network_hosts.join(", ")));
539 }
540 if !effects.secret_refs.is_empty() {
541 parts.push(format!("secrets: {}", effects.secret_refs.join(", ")));
542 }
543 if !effects.recipients.is_empty() {
544 parts.push(format!("recipients: {}", effects.recipients.len()));
545 }
546 (title, parts.join(" | "))
547}
548
549fn now_secs() -> i64 {
550 use std::time::{SystemTime, UNIX_EPOCH};
551 SystemTime::now()
552 .duration_since(UNIX_EPOCH)
553 .map(|d| d.as_secs() as i64)
554 .unwrap_or(0)
555}
556
557pub(crate) fn compute_args_hash(args: &serde_json::Value) -> Result<String, FirewallError> {
560 let bytes = serde_jcs::to_vec(args)
561 .map_err(|e| FirewallError::Audit(vigil_audit::AuditError::Json(e)))?;
562 let mut h = Sha256::new();
563 h.update(&bytes);
564 Ok(hex::encode(h.finalize()))
565}