Skip to main content

systemprompt_api/services/gateway/
quota.rs

1use anyhow::Result;
2use chrono::{DateTime, TimeZone, Utc};
3use systemprompt_ai::repository::{
4    AiQuotaBucketRepository, IncrementParams, QuotaBucketDelta, QuotaBucketState,
5};
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::{TenantId, UserId};
8
9use super::policy::QuotaWindow;
10
11#[derive(Debug, Clone, Copy)]
12pub struct QuotaDecision {
13    pub allow: bool,
14    pub window_seconds: i32,
15    pub limit_requests: Option<i64>,
16    pub limit_input_tokens: Option<i64>,
17    pub limit_output_tokens: Option<i64>,
18    pub state: QuotaBucketState,
19}
20
21pub async fn precheck_and_reserve(
22    db: &DbPool,
23    tenant_id: Option<&TenantId>,
24    user_id: &UserId,
25    windows: &[QuotaWindow],
26) -> Result<Option<QuotaDecision>> {
27    if windows.is_empty() {
28        return Ok(None);
29    }
30    let repo =
31        AiQuotaBucketRepository::new(db).map_err(|e| anyhow::anyhow!("quota repo init: {e}"))?;
32
33    let now = Utc::now();
34    for window in windows {
35        let window_start = align_window(now, window.window_seconds);
36        let state = repo
37            .increment(IncrementParams {
38                tenant_id,
39                user_id,
40                window_seconds: window.window_seconds,
41                window_start,
42                delta: QuotaBucketDelta {
43                    requests: 1,
44                    input_tokens: 0,
45                    output_tokens: 0,
46                },
47            })
48            .await?;
49
50        if let Some(max) = window.max_requests {
51            if state.requests > max {
52                return Ok(Some(QuotaDecision {
53                    allow: false,
54                    window_seconds: window.window_seconds,
55                    limit_requests: Some(max),
56                    limit_input_tokens: window.max_input_tokens,
57                    limit_output_tokens: window.max_output_tokens,
58                    state,
59                }));
60            }
61        }
62    }
63    Ok(None)
64}
65
66#[derive(Debug)]
67pub struct PostUpdateParams<'a> {
68    pub tenant_id: Option<&'a TenantId>,
69    pub user_id: &'a UserId,
70    pub windows: &'a [QuotaWindow],
71    pub input_tokens: u32,
72    pub output_tokens: u32,
73}
74
75pub async fn post_update_tokens(db: &DbPool, params: PostUpdateParams<'_>) {
76    if params.windows.is_empty() {
77        return;
78    }
79    let repo = match AiQuotaBucketRepository::new(db) {
80        Ok(r) => r,
81        Err(e) => {
82            tracing::warn!(error = %e, "quota repo init failed in post_update");
83            return;
84        },
85    };
86    let now = Utc::now();
87    for window in params.windows {
88        let window_start = align_window(now, window.window_seconds);
89        if let Err(e) = repo
90            .increment(IncrementParams {
91                tenant_id: params.tenant_id,
92                user_id: params.user_id,
93                window_seconds: window.window_seconds,
94                window_start,
95                delta: QuotaBucketDelta {
96                    requests: 0,
97                    input_tokens: i64::from(params.input_tokens),
98                    output_tokens: i64::from(params.output_tokens),
99                },
100            })
101            .await
102        {
103            tracing::warn!(error = %e, window_seconds = window.window_seconds, "quota post_update failed");
104        }
105    }
106}
107
108fn align_window(now: DateTime<Utc>, window_seconds: i32) -> DateTime<Utc> {
109    let secs = now.timestamp();
110    let w = i64::from(window_seconds.max(1));
111    let aligned = (secs / w) * w;
112    Utc.timestamp_opt(aligned, 0).single().unwrap_or(now)
113}