Skip to main content

systemprompt_api/services/gateway/
policy.rs

1//! Resolution and caching of the effective gateway policy.
2//!
3//! [`PolicyResolver`] loads the global policy rows, merges them into a single
4//! [`GatewayPolicySpec`], and caches the result for a short TTL; a DB error or
5//! a malformed spec degrades to a permissive policy rather than failing the
6//! request.
7
8use std::sync::{Arc, RwLock};
9use std::time::{Duration, Instant};
10
11use anyhow::Result;
12use systemprompt_ai::repository::AiGatewayPolicyRepository;
13use systemprompt_database::DbPool;
14
15// The gateway-policy spec types are owned by `systemprompt-ai` so the
16// version-controlled `services/gateway/policies.yaml` and the persisted
17// `ai_gateway_policies.spec` column share one schema. Re-exported here so
18// existing `super::policy::{...}` call sites are unaffected.
19pub use systemprompt_ai::{GatewayPolicySpec, QuotaWindow, SafetyConfig};
20
21const CACHE_TTL: Duration = Duration::from_secs(60);
22
23#[derive(Clone)]
24pub struct PolicyResolver {
25    repo: Arc<AiGatewayPolicyRepository>,
26    cache: Arc<RwLock<Option<CachedEntry>>>,
27}
28
29impl std::fmt::Debug for PolicyResolver {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("PolicyResolver").finish()
32    }
33}
34
35#[derive(Clone)]
36struct CachedEntry {
37    spec: GatewayPolicySpec,
38    fetched_at: Instant,
39}
40
41impl PolicyResolver {
42    pub fn new(db: &DbPool) -> Result<Self> {
43        Ok(Self {
44            repo: Arc::new(
45                AiGatewayPolicyRepository::new(db)
46                    .map_err(|e| anyhow::anyhow!("policy repo init: {e}"))?,
47            ),
48            cache: Arc::new(RwLock::new(None)),
49        })
50    }
51
52    pub async fn resolve(&self) -> GatewayPolicySpec {
53        if let Ok(cache) = self.cache.read() {
54            if let Some(entry) = cache.as_ref() {
55                if entry.fetched_at.elapsed() < CACHE_TTL {
56                    return entry.spec.clone();
57                }
58            }
59        }
60
61        let rows = match self.repo.find_for_global().await {
62            Ok(r) => r,
63            Err(e) => {
64                tracing::warn!(error = %e, "policy resolve DB error — falling back to permissive");
65                return GatewayPolicySpec::permissive();
66            },
67        };
68
69        let spec = merge(rows);
70        if let Ok(mut cache) = self.cache.write() {
71            *cache = Some(CachedEntry {
72                spec: spec.clone(),
73                fetched_at: Instant::now(),
74            });
75        }
76        spec
77    }
78}
79
80fn merge(rows: Vec<systemprompt_ai::GatewayPolicyRow>) -> GatewayPolicySpec {
81    let mut merged = GatewayPolicySpec::permissive();
82    for row in rows {
83        let Ok(spec) = serde_json::from_value::<GatewayPolicySpec>(row.spec) else {
84            tracing::warn!(policy_id = %row.id, name = %row.name, "policy spec JSON malformed — skipped");
85            continue;
86        };
87        if spec.max_input_tokens_per_call.is_some() {
88            merged.max_input_tokens_per_call = spec.max_input_tokens_per_call;
89        }
90        if spec.max_tool_depth.is_some() {
91            merged.max_tool_depth = spec.max_tool_depth;
92        }
93        if !spec.quota_windows.is_empty() {
94            merged.quota_windows = spec.quota_windows;
95        }
96        if !spec.safety.scanners.is_empty() || !spec.safety.block_categories.is_empty() {
97            merged.safety = spec.safety;
98        }
99    }
100    merged
101}