systemprompt_api/services/gateway/
policy.rs1use std::sync::{Arc, RwLock};
2use std::time::{Duration, Instant};
3
4use anyhow::Result;
5use systemprompt_ai::repository::AiGatewayPolicyRepository;
6use systemprompt_database::DbPool;
7
8pub use systemprompt_ai::{GatewayPolicySpec, QuotaWindow, SafetyConfig};
13
14const CACHE_TTL: Duration = Duration::from_secs(60);
15
16#[derive(Clone)]
17pub struct PolicyResolver {
18 repo: Arc<AiGatewayPolicyRepository>,
19 cache: Arc<RwLock<Option<CachedEntry>>>,
20}
21
22impl std::fmt::Debug for PolicyResolver {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("PolicyResolver").finish()
25 }
26}
27
28#[derive(Clone)]
29struct CachedEntry {
30 spec: GatewayPolicySpec,
31 fetched_at: Instant,
32}
33
34impl PolicyResolver {
35 pub fn new(db: &DbPool) -> Result<Self> {
36 Ok(Self {
37 repo: Arc::new(
38 AiGatewayPolicyRepository::new(db)
39 .map_err(|e| anyhow::anyhow!("policy repo init: {e}"))?,
40 ),
41 cache: Arc::new(RwLock::new(None)),
42 })
43 }
44
45 pub async fn resolve(&self) -> GatewayPolicySpec {
46 if let Ok(cache) = self.cache.read() {
47 if let Some(entry) = cache.as_ref() {
48 if entry.fetched_at.elapsed() < CACHE_TTL {
49 return entry.spec.clone();
50 }
51 }
52 }
53
54 let rows = match self.repo.find_for_global().await {
55 Ok(r) => r,
56 Err(e) => {
57 tracing::warn!(error = %e, "policy resolve DB error — falling back to permissive");
58 return GatewayPolicySpec::permissive();
59 },
60 };
61
62 let spec = merge(rows);
63 if let Ok(mut cache) = self.cache.write() {
64 *cache = Some(CachedEntry {
65 spec: spec.clone(),
66 fetched_at: Instant::now(),
67 });
68 }
69 spec
70 }
71}
72
73fn merge(rows: Vec<systemprompt_ai::GatewayPolicyRow>) -> GatewayPolicySpec {
74 let mut merged = GatewayPolicySpec::permissive();
75 for row in rows {
76 let Ok(spec) = serde_json::from_value::<GatewayPolicySpec>(row.spec) else {
77 tracing::warn!(policy_id = %row.id, name = %row.name, "policy spec JSON malformed — skipped");
78 continue;
79 };
80 if spec.max_input_tokens_per_call.is_some() {
81 merged.max_input_tokens_per_call = spec.max_input_tokens_per_call;
82 }
83 if spec.max_tool_depth.is_some() {
84 merged.max_tool_depth = spec.max_tool_depth;
85 }
86 if !spec.quota_windows.is_empty() {
87 merged.quota_windows = spec.quota_windows;
88 }
89 if !spec.safety.scanners.is_empty() || !spec.safety.block_categories.is_empty() {
90 merged.safety = spec.safety;
91 }
92 }
93 merged
94}