Skip to main content

systemprompt_api/services/gateway/
policy.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use systemprompt_ai::repository::AiGatewayPolicyRepository;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::TenantId;
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
12pub struct QuotaWindow {
13    pub window_seconds: i32,
14    pub max_requests: Option<i64>,
15    pub max_input_tokens: Option<i64>,
16    pub max_output_tokens: Option<i64>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
20pub struct SafetyConfig {
21    #[serde(default)]
22    pub scanners: Vec<String>,
23    #[serde(default)]
24    pub block_categories: Vec<String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, Default)]
28pub struct GatewayPolicySpec {
29    #[serde(default)]
30    pub allowed_models: Option<Vec<String>>,
31    #[serde(default)]
32    pub max_input_tokens_per_call: Option<u32>,
33    #[serde(default)]
34    pub max_tool_depth: Option<u32>,
35    #[serde(default)]
36    pub quota_windows: Vec<QuotaWindow>,
37    #[serde(default)]
38    pub safety: SafetyConfig,
39}
40
41impl GatewayPolicySpec {
42    pub fn permissive() -> Self {
43        Self::default()
44    }
45
46    pub fn model_allowed(&self, model: &str) -> bool {
47        self.allowed_models
48            .as_deref()
49            .is_none_or(|list| list.iter().any(|m| m == model))
50    }
51}
52
53const CACHE_TTL: Duration = Duration::from_secs(60);
54
55#[derive(Clone)]
56pub struct PolicyResolver {
57    repo: Arc<AiGatewayPolicyRepository>,
58    cache: Arc<RwLock<HashMap<String, CachedEntry>>>,
59}
60
61impl std::fmt::Debug for PolicyResolver {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("PolicyResolver").finish()
64    }
65}
66
67#[derive(Clone)]
68struct CachedEntry {
69    spec: GatewayPolicySpec,
70    fetched_at: Instant,
71}
72
73impl PolicyResolver {
74    pub fn new(db: &DbPool) -> Result<Self> {
75        Ok(Self {
76            repo: Arc::new(
77                AiGatewayPolicyRepository::new(db)
78                    .map_err(|e| anyhow::anyhow!("policy repo init: {e}"))?,
79            ),
80            cache: Arc::new(RwLock::new(HashMap::new())),
81        })
82    }
83
84    pub async fn resolve(&self, tenant_id: Option<&TenantId>) -> GatewayPolicySpec {
85        let key = tenant_id
86            .map(|t| t.as_str().to_string())
87            .unwrap_or_default();
88
89        if let Ok(cache) = self.cache.read() {
90            if let Some(entry) = cache.get(&key) {
91                if entry.fetched_at.elapsed() < CACHE_TTL {
92                    return entry.spec.clone();
93                }
94            }
95        }
96
97        let rows = match self.repo.find_for_tenant(tenant_id).await {
98            Ok(r) => r,
99            Err(e) => {
100                tracing::warn!(error = %e, "policy resolve DB error — falling back to permissive");
101                return GatewayPolicySpec::permissive();
102            },
103        };
104
105        let spec = merge(rows);
106        if let Ok(mut cache) = self.cache.write() {
107            cache.insert(
108                key,
109                CachedEntry {
110                    spec: spec.clone(),
111                    fetched_at: Instant::now(),
112                },
113            );
114        }
115        spec
116    }
117}
118
119fn merge(rows: Vec<systemprompt_ai::GatewayPolicyRow>) -> GatewayPolicySpec {
120    let mut merged = GatewayPolicySpec::permissive();
121    for row in rows {
122        let Ok(spec) = serde_json::from_value::<GatewayPolicySpec>(row.spec) else {
123            tracing::warn!(policy_id = %row.id, name = %row.name, "policy spec JSON malformed — skipped");
124            continue;
125        };
126        if spec.allowed_models.is_some() {
127            merged.allowed_models = spec.allowed_models;
128        }
129        if spec.max_input_tokens_per_call.is_some() {
130            merged.max_input_tokens_per_call = spec.max_input_tokens_per_call;
131        }
132        if spec.max_tool_depth.is_some() {
133            merged.max_tool_depth = spec.max_tool_depth;
134        }
135        if !spec.quota_windows.is_empty() {
136            merged.quota_windows = spec.quota_windows;
137        }
138        if !spec.safety.scanners.is_empty() || !spec.safety.block_categories.is_empty() {
139            merged.safety = spec.safety;
140        }
141    }
142    merged
143}