Skip to main content

systemprompt_api/services/gateway/
policy.rs

1use std::sync::{Arc, RwLock};
2use std::time::{Duration, Instant};
3
4use anyhow::Result;
5use systemprompt_ai::repository::AiGatewayPolicyRepository;
6use systemprompt_database::DbPool;
7
8// The gateway-policy spec types are owned by `systemprompt-ai` so the
9// version-controlled `services/ai/gateway-policies.yaml` and the persisted
10// `ai_gateway_policies.spec` column share one schema. Re-exported here so
11// existing `super::policy::{...}` call sites are unaffected.
12pub use systemprompt_ai::{GatewayPolicySpec, QuotaWindow, SafetyConfig};
13
14const CACHE_TTL: Duration = Duration::from_secs(60);
15
16#[derive(Clone)]
17pub struct PolicyResolver {
18    repo: Arc<AiGatewayPolicyRepository>,
19    cache: Arc<RwLock<Option<CachedEntry>>>,
20}
21
22impl std::fmt::Debug for PolicyResolver {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("PolicyResolver").finish()
25    }
26}
27
28#[derive(Clone)]
29struct CachedEntry {
30    spec: GatewayPolicySpec,
31    fetched_at: Instant,
32}
33
34impl PolicyResolver {
35    pub fn new(db: &DbPool) -> Result<Self> {
36        Ok(Self {
37            repo: Arc::new(
38                AiGatewayPolicyRepository::new(db)
39                    .map_err(|e| anyhow::anyhow!("policy repo init: {e}"))?,
40            ),
41            cache: Arc::new(RwLock::new(None)),
42        })
43    }
44
45    pub async fn resolve(&self) -> GatewayPolicySpec {
46        if let Ok(cache) = self.cache.read() {
47            if let Some(entry) = cache.as_ref() {
48                if entry.fetched_at.elapsed() < CACHE_TTL {
49                    return entry.spec.clone();
50                }
51            }
52        }
53
54        let rows = match self.repo.find_for_global().await {
55            Ok(r) => r,
56            Err(e) => {
57                tracing::warn!(error = %e, "policy resolve DB error — falling back to permissive");
58                return GatewayPolicySpec::permissive();
59            },
60        };
61
62        let spec = merge(rows);
63        if let Ok(mut cache) = self.cache.write() {
64            *cache = Some(CachedEntry {
65                spec: spec.clone(),
66                fetched_at: Instant::now(),
67            });
68        }
69        spec
70    }
71}
72
73fn merge(rows: Vec<systemprompt_ai::GatewayPolicyRow>) -> GatewayPolicySpec {
74    let mut merged = GatewayPolicySpec::permissive();
75    for row in rows {
76        let Ok(spec) = serde_json::from_value::<GatewayPolicySpec>(row.spec) else {
77            tracing::warn!(policy_id = %row.id, name = %row.name, "policy spec JSON malformed — skipped");
78            continue;
79        };
80        if spec.allowed_models.is_some() {
81            merged.allowed_models = spec.allowed_models;
82        }
83        if spec.max_input_tokens_per_call.is_some() {
84            merged.max_input_tokens_per_call = spec.max_input_tokens_per_call;
85        }
86        if spec.max_tool_depth.is_some() {
87            merged.max_tool_depth = spec.max_tool_depth;
88        }
89        if !spec.quota_windows.is_empty() {
90            merged.quota_windows = spec.quota_windows;
91        }
92        if !spec.safety.scanners.is_empty() || !spec.safety.block_categories.is_empty() {
93            merged.safety = spec.safety;
94        }
95    }
96    merged
97}