Skip to main content

systemprompt_api/services/gateway/
audit.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use anyhow::Result;
5use bytes::Bytes;
6use serde_json::Value;
7use systemprompt_ai::models::ai_request_record::AiRequestRecord;
8use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
9use systemprompt_ai::repository::{
10    AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
11};
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{AiRequestId, SessionId, TenantId, TraceId, UserId};
14
15use super::pricing::{self, ModelPricing};
16
17const PAYLOAD_CAP_BYTES: usize = 256 * 1024;
18const EXCERPT_BYTES: usize = 8 * 1024;
19
20#[derive(Debug, Clone)]
21pub struct GatewayRequestContext {
22    pub ai_request_id: AiRequestId,
23    pub user_id: UserId,
24    pub tenant_id: Option<TenantId>,
25    pub session_id: Option<SessionId>,
26    pub trace_id: Option<TraceId>,
27    pub provider: String,
28    pub model: String,
29    pub max_tokens: Option<u32>,
30    pub is_streaming: bool,
31}
32
33#[derive(Debug, Clone, Copy, Default)]
34pub struct CapturedUsage {
35    pub input_tokens: u32,
36    pub output_tokens: u32,
37}
38
39#[derive(Debug, Clone)]
40pub struct CapturedToolUse {
41    pub ai_tool_call_id: String,
42    pub tool_name: String,
43    pub tool_input: String,
44}
45
46#[derive(Clone, Debug)]
47pub struct GatewayAudit {
48    requests: Arc<AiRequestRepository>,
49    payloads: Arc<AiRequestPayloadRepository>,
50    pub ctx: GatewayRequestContext,
51    pricing: ModelPricing,
52    started_at: Instant,
53}
54
55impl GatewayAudit {
56    pub fn new(
57        db: &DbPool,
58        ctx: GatewayRequestContext,
59    ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
60        let requests = Arc::new(AiRequestRepository::new(db)?);
61        let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
62        let pricing = pricing::lookup(&ctx.provider, &ctx.model);
63        Ok(Self {
64            requests,
65            payloads,
66            ctx,
67            pricing,
68            started_at: Instant::now(),
69        })
70    }
71
72    pub async fn open(&self, request_body: &Bytes) -> Result<()> {
73        let record =
74            AiRequestRecord::builder(self.ctx.ai_request_id.as_str(), self.ctx.user_id.clone())
75                .provider(self.ctx.provider.clone())
76                .model(self.ctx.model.clone())
77                .streaming(self.ctx.is_streaming);
78        let record = if let Some(t) = &self.ctx.tenant_id {
79            record.tenant_id(t.clone())
80        } else {
81            record
82        };
83        let record = if let Some(s) = &self.ctx.session_id {
84            record.session_id(s.clone())
85        } else {
86            record
87        };
88        let record = if let Some(t) = &self.ctx.trace_id {
89            record.trace_id(t.clone())
90        } else {
91            record
92        };
93        let record = if let Some(mt) = self.ctx.max_tokens {
94            record.max_tokens(mt)
95        } else {
96            record
97        };
98        let record = record.build()?;
99
100        self.requests
101            .insert_with_id(&self.ctx.ai_request_id, &record)
102            .await?;
103
104        let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
105        if let Err(e) = self
106            .payloads
107            .upsert_request(
108                &self.ctx.ai_request_id,
109                UpsertPayloadParams {
110                    body: body_json.as_ref(),
111                    excerpt: excerpt.as_deref(),
112                    truncated,
113                    bytes: Some(bytes),
114                },
115            )
116            .await
117        {
118            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
119        }
120        Ok(())
121    }
122
123    pub async fn complete(
124        &self,
125        usage: CapturedUsage,
126        tool_calls: Vec<CapturedToolUse>,
127        response_body: &Bytes,
128    ) -> Result<()> {
129        let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
130        let cost =
131            pricing::cost_microdollars(self.pricing, usage.input_tokens, usage.output_tokens);
132
133        self.requests
134            .update_completion(UpdateCompletionParams {
135                id: self.ctx.ai_request_id.clone(),
136                tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
137                input_tokens: usage.input_tokens as i32,
138                output_tokens: usage.output_tokens as i32,
139                cost_microdollars: cost,
140                latency_ms,
141            })
142            .await?;
143
144        for (idx, tool) in tool_calls.iter().enumerate() {
145            let seq = idx as i32 + 1;
146            let trimmed = truncate_for_tool_input(&tool.tool_input);
147            if let Err(e) = self
148                .requests
149                .insert_tool_call(InsertToolCallParams {
150                    request_id: &self.ctx.ai_request_id,
151                    ai_tool_call_id: &tool.ai_tool_call_id,
152                    tool_name: &tool.tool_name,
153                    tool_input: &trimmed,
154                    sequence_number: seq,
155                })
156                .await
157            {
158                tracing::warn!(error = %e, seq, "tool_call insert failed");
159            }
160        }
161
162        let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
163        if let Err(e) = self
164            .payloads
165            .upsert_response(
166                &self.ctx.ai_request_id,
167                UpsertPayloadParams {
168                    body: body_json.as_ref(),
169                    excerpt: excerpt.as_deref(),
170                    truncated,
171                    bytes: Some(bytes),
172                },
173            )
174            .await
175        {
176            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
177        }
178
179        tracing::info!(
180            ai_request_id = %self.ctx.ai_request_id,
181            user_id = %self.ctx.user_id,
182            provider = %self.ctx.provider,
183            model = %self.ctx.model,
184            input_tokens = usage.input_tokens,
185            output_tokens = usage.output_tokens,
186            cost_microdollars = cost,
187            latency_ms,
188            tool_calls = tool_calls.len(),
189            "Gateway audit: request completed"
190        );
191        Ok(())
192    }
193
194    pub async fn fail(&self, error: &str) -> Result<()> {
195        if let Err(e) = self
196            .requests
197            .update_error(&self.ctx.ai_request_id, error)
198            .await
199        {
200            tracing::warn!(error = %e, "audit fail update failed");
201        }
202        tracing::info!(
203            ai_request_id = %self.ctx.ai_request_id,
204            user_id = %self.ctx.user_id,
205            provider = %self.ctx.provider,
206            model = %self.ctx.model,
207            error,
208            "Gateway audit: request failed"
209        );
210        Ok(())
211    }
212}
213
214fn slice_payload(bytes: &Bytes) -> (Option<Value>, Option<String>, bool, i32) {
215    let len = bytes.len();
216    let len_i32 = len.min(i32::MAX as usize) as i32;
217    if len <= PAYLOAD_CAP_BYTES {
218        serde_json::from_slice::<Value>(bytes).map_or_else(
219            |_| {
220                let excerpt = String::from_utf8_lossy(bytes).to_string();
221                (None, Some(excerpt), false, len_i32)
222            },
223            |v| (Some(v), None, false, len_i32),
224        )
225    } else {
226        let head_len = EXCERPT_BYTES.min(len);
227        let head = String::from_utf8_lossy(&bytes[..head_len]).to_string();
228        let tail_start = len.saturating_sub(EXCERPT_BYTES);
229        let tail = String::from_utf8_lossy(&bytes[tail_start..]).to_string();
230        let excerpt = format!("{head}\n...<truncated {} bytes>...\n{tail}", len - head_len);
231        (None, Some(excerpt), true, len_i32)
232    }
233}
234
235fn truncate_for_tool_input(input: &str) -> String {
236    const TOOL_INPUT_CAP: usize = 64 * 1024;
237    if input.len() <= TOOL_INPUT_CAP {
238        input.to_string()
239    } else {
240        let head = &input[..TOOL_INPUT_CAP];
241        format!(
242            "{head}...<truncated {} bytes>",
243            input.len() - TOOL_INPUT_CAP
244        )
245    }
246}