Skip to main content

systemprompt_api/services/gateway/
quota.rs

1use anyhow::Result;
2use chrono::{DateTime, TimeZone, Utc};
3use systemprompt_ai::repository::{
4    AiQuotaBucketRepository, IncrementParams, QuotaBucketDelta, QuotaBucketState,
5};
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::UserId;
8
9use super::policy::QuotaWindow;
10
11#[derive(Debug, Clone, Copy)]
12pub struct QuotaDecision {
13    pub allow: bool,
14    pub window_seconds: i32,
15    pub limit_requests: Option<i64>,
16    pub limit_input_tokens: Option<i64>,
17    pub limit_output_tokens: Option<i64>,
18    pub state: QuotaBucketState,
19}
20
21pub async fn precheck_and_reserve(
22    db: &DbPool,
23    user_id: &UserId,
24    windows: &[QuotaWindow],
25) -> Result<Option<QuotaDecision>> {
26    if windows.is_empty() {
27        return Ok(None);
28    }
29    let repo =
30        AiQuotaBucketRepository::new(db).map_err(|e| anyhow::anyhow!("quota repo init: {e}"))?;
31
32    let now = Utc::now();
33    for window in windows {
34        let window_start = align_window(now, window.window_seconds);
35        let state = repo
36            .increment(IncrementParams {
37                user_id,
38                window_seconds: window.window_seconds,
39                window_start,
40                delta: QuotaBucketDelta {
41                    requests: 1,
42                    input_tokens: 0,
43                    output_tokens: 0,
44                },
45            })
46            .await?;
47
48        if let Some(max) = window.max_requests {
49            if state.requests > max {
50                return Ok(Some(QuotaDecision {
51                    allow: false,
52                    window_seconds: window.window_seconds,
53                    limit_requests: Some(max),
54                    limit_input_tokens: window.max_input_tokens,
55                    limit_output_tokens: window.max_output_tokens,
56                    state,
57                }));
58            }
59        }
60    }
61    Ok(None)
62}
63
64#[derive(Debug)]
65pub struct PostUpdateParams<'a> {
66    pub user_id: &'a UserId,
67    pub windows: &'a [QuotaWindow],
68    pub input_tokens: u32,
69    pub output_tokens: u32,
70}
71
72pub async fn post_update_tokens(db: &DbPool, params: PostUpdateParams<'_>) {
73    if params.windows.is_empty() {
74        return;
75    }
76    let repo = match AiQuotaBucketRepository::new(db) {
77        Ok(r) => r,
78        Err(e) => {
79            tracing::warn!(error = %e, "quota repo init failed in post_update");
80            return;
81        },
82    };
83    let now = Utc::now();
84    for window in params.windows {
85        let window_start = align_window(now, window.window_seconds);
86        if let Err(e) = repo
87            .increment(IncrementParams {
88                user_id: params.user_id,
89                window_seconds: window.window_seconds,
90                window_start,
91                delta: QuotaBucketDelta {
92                    requests: 0,
93                    input_tokens: i64::from(params.input_tokens),
94                    output_tokens: i64::from(params.output_tokens),
95                },
96            })
97            .await
98        {
99            tracing::warn!(error = %e, window_seconds = window.window_seconds, "quota post_update failed");
100        }
101    }
102}
103
104fn align_window(now: DateTime<Utc>, window_seconds: i32) -> DateTime<Utc> {
105    let secs = now.timestamp();
106    let w = i64::from(window_seconds.max(1));
107    let aligned = (secs / w) * w;
108    Utc.timestamp_opt(aligned, 0).single().unwrap_or(now)
109}