Skip to main content

systemprompt_api/services/gateway/
audit.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use anyhow::Result;
5use bytes::Bytes;
6use serde_json::Value;
7use systemprompt_ai::models::ai_request_record::AiRequestRecord;
8use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
9use systemprompt_ai::repository::{
10    AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
11};
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{AiRequestId, SessionId, TenantId, TraceId, UserId};
14
15use super::models::AnthropicGatewayRequest;
16use super::pricing;
17use std::sync::Mutex;
18
19const PAYLOAD_CAP_BYTES: usize = 256 * 1024;
20const EXCERPT_BYTES: usize = 8 * 1024;
21
22#[derive(Debug, Clone)]
23pub struct GatewayRequestContext {
24    pub ai_request_id: AiRequestId,
25    pub user_id: UserId,
26    pub tenant_id: Option<TenantId>,
27    pub session_id: Option<SessionId>,
28    pub trace_id: Option<TraceId>,
29    pub provider: String,
30    pub model: String,
31    pub max_tokens: Option<u32>,
32    pub is_streaming: bool,
33}
34
35#[derive(Debug, Clone, Copy, Default)]
36pub struct CapturedUsage {
37    pub input_tokens: u32,
38    pub output_tokens: u32,
39}
40
41#[derive(Debug, Clone)]
42pub struct CapturedToolUse {
43    pub ai_tool_call_id: String,
44    pub tool_name: String,
45    pub tool_input: String,
46}
47
48#[allow(missing_debug_implementations)]
49pub struct GatewayAudit {
50    requests: Arc<AiRequestRepository>,
51    payloads: Arc<AiRequestPayloadRepository>,
52    pub ctx: GatewayRequestContext,
53    served_model: Mutex<Option<String>>,
54    started_at: Instant,
55}
56
57impl GatewayAudit {
58    pub fn new(
59        db: &DbPool,
60        ctx: GatewayRequestContext,
61    ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
62        let requests = Arc::new(AiRequestRepository::new(db)?);
63        let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
64        Ok(Self {
65            requests,
66            payloads,
67            ctx,
68            served_model: Mutex::new(None),
69            started_at: Instant::now(),
70        })
71    }
72
73    pub async fn set_served_model(&self, model: &str) {
74        if model.is_empty() || model == self.ctx.model {
75            return;
76        }
77        if let Ok(mut slot) = self.served_model.lock() {
78            *slot = Some(model.to_string());
79        }
80        if let Err(e) = self
81            .requests
82            .update_model(&self.ctx.ai_request_id, model)
83            .await
84        {
85            tracing::warn!(error = %e, "update_model failed");
86        }
87    }
88
89    fn effective_model(&self) -> String {
90        self.served_model
91            .lock()
92            .ok()
93            .and_then(|s| s.clone())
94            .unwrap_or_else(|| self.ctx.model.clone())
95    }
96
97    fn build_record(&self) -> Result<AiRequestRecord> {
98        let mut record =
99            AiRequestRecord::builder(self.ctx.ai_request_id.clone(), self.ctx.user_id.clone())
100                .provider(self.ctx.provider.clone())
101                .model(self.ctx.model.clone())
102                .streaming(self.ctx.is_streaming);
103        if let Some(t) = &self.ctx.tenant_id {
104            record = record.tenant_id(t.clone());
105        }
106        if let Some(s) = &self.ctx.session_id {
107            record = record.session_id(s.clone());
108        }
109        if let Some(t) = &self.ctx.trace_id {
110            record = record.trace_id(t.clone());
111        }
112        if let Some(mt) = self.ctx.max_tokens {
113            record = record.max_tokens(mt);
114        }
115        record.build().map_err(anyhow::Error::from)
116    }
117
118    pub async fn open(
119        &self,
120        request: &AnthropicGatewayRequest,
121        request_body: &Bytes,
122    ) -> Result<()> {
123        let record = self.build_record()?;
124
125        self.requests
126            .insert_with_id(&self.ctx.ai_request_id, &record)
127            .await?;
128
129        let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
130        if let Err(e) = self
131            .payloads
132            .upsert_request(
133                &self.ctx.ai_request_id,
134                UpsertPayloadParams {
135                    body: body_json.as_ref(),
136                    excerpt: excerpt.as_deref(),
137                    truncated,
138                    bytes: Some(bytes),
139                },
140            )
141            .await
142        {
143            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
144        }
145
146        self.persist_request_messages(request).await;
147        Ok(())
148    }
149
150    async fn persist_request_messages(&self, request: &AnthropicGatewayRequest) {
151        let mut seq = 0i32;
152        if let Some(system) = request.system.as_ref() {
153            if let Some(text) = super::flatten::flatten_system_prompt(system) {
154                if let Err(e) = self
155                    .requests
156                    .insert_message(&self.ctx.ai_request_id, "system", &text, seq)
157                    .await
158                {
159                    tracing::warn!(error = %e, "insert system message failed");
160                }
161                seq += 1;
162            }
163        }
164        for msg in &request.messages {
165            let text = super::flatten::flatten_message_content(&msg.content);
166            if let Err(e) = self
167                .requests
168                .insert_message(&self.ctx.ai_request_id, &msg.role, &text, seq)
169                .await
170            {
171                tracing::warn!(error = %e, seq, "insert message failed");
172            }
173            seq += 1;
174        }
175    }
176
177    pub async fn complete(
178        &self,
179        usage: CapturedUsage,
180        tool_calls: Vec<CapturedToolUse>,
181        response_body: &Bytes,
182    ) -> Result<()> {
183        let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
184        let effective_model = self.effective_model();
185        let pricing_rates = pricing::lookup(&self.ctx.provider, &effective_model);
186        let cost =
187            pricing::cost_microdollars(pricing_rates, usage.input_tokens, usage.output_tokens);
188
189        self.requests
190            .update_completion(UpdateCompletionParams {
191                id: self.ctx.ai_request_id.clone(),
192                tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
193                input_tokens: usage.input_tokens as i32,
194                output_tokens: usage.output_tokens as i32,
195                cost_microdollars: cost,
196                latency_ms,
197            })
198            .await?;
199
200        self.persist_tool_calls(&tool_calls).await;
201
202        let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
203        if let Err(e) = self
204            .payloads
205            .upsert_response(
206                &self.ctx.ai_request_id,
207                UpsertPayloadParams {
208                    body: body_json.as_ref(),
209                    excerpt: excerpt.as_deref(),
210                    truncated,
211                    bytes: Some(bytes),
212                },
213            )
214            .await
215        {
216            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
217        }
218
219        if let Some(assistant_text) = super::parse::extract_assistant_text(response_body) {
220            if let Err(e) = self
221                .requests
222                .add_response_message(&self.ctx.ai_request_id, &assistant_text)
223                .await
224            {
225                tracing::warn!(error = %e, "assistant response message insert failed");
226            }
227        }
228
229        tracing::info!(
230            ai_request_id = %self.ctx.ai_request_id,
231            user_id = %self.ctx.user_id,
232            provider = %self.ctx.provider,
233            model = %effective_model,
234            input_tokens = usage.input_tokens,
235            output_tokens = usage.output_tokens,
236            cost_microdollars = cost,
237            latency_ms,
238            tool_calls = tool_calls.len(),
239            "Gateway audit: request completed"
240        );
241        Ok(())
242    }
243
244    async fn persist_tool_calls(&self, tool_calls: &[CapturedToolUse]) {
245        for (idx, tool) in tool_calls.iter().enumerate() {
246            let seq = idx as i32 + 1;
247            let trimmed = truncate_for_tool_input(&tool.tool_input);
248            if let Err(e) = self
249                .requests
250                .insert_tool_call(InsertToolCallParams {
251                    request_id: &self.ctx.ai_request_id,
252                    ai_tool_call_id: &tool.ai_tool_call_id,
253                    tool_name: &tool.tool_name,
254                    tool_input: &trimmed,
255                    sequence_number: seq,
256                })
257                .await
258            {
259                tracing::warn!(error = %e, seq, "tool_call insert failed");
260            }
261        }
262    }
263
264    pub async fn fail(&self, error: &str) -> Result<()> {
265        if let Err(e) = self
266            .requests
267            .update_error(&self.ctx.ai_request_id, error)
268            .await
269        {
270            tracing::warn!(error = %e, "audit fail update failed");
271        }
272        tracing::info!(
273            ai_request_id = %self.ctx.ai_request_id,
274            user_id = %self.ctx.user_id,
275            provider = %self.ctx.provider,
276            model = %self.ctx.model,
277            error,
278            "Gateway audit: request failed"
279        );
280        Ok(())
281    }
282}
283
284fn slice_payload(bytes: &Bytes) -> (Option<Value>, Option<String>, bool, i32) {
285    let len = bytes.len();
286    let len_i32 = len.min(i32::MAX as usize) as i32;
287    if len <= PAYLOAD_CAP_BYTES {
288        serde_json::from_slice::<Value>(bytes).map_or_else(
289            |_| {
290                let excerpt = String::from_utf8_lossy(bytes).to_string();
291                (None, Some(excerpt), false, len_i32)
292            },
293            |v| (Some(v), None, false, len_i32),
294        )
295    } else {
296        let head_len = EXCERPT_BYTES.min(len);
297        let head = String::from_utf8_lossy(&bytes[..head_len]).to_string();
298        let tail_start = len.saturating_sub(EXCERPT_BYTES);
299        let tail = String::from_utf8_lossy(&bytes[tail_start..]).to_string();
300        let excerpt = format!("{head}\n...<truncated {} bytes>...\n{tail}", len - head_len);
301        (None, Some(excerpt), true, len_i32)
302    }
303}
304
305fn truncate_for_tool_input(input: &str) -> String {
306    const TOOL_INPUT_CAP: usize = 64 * 1024;
307    if input.len() <= TOOL_INPUT_CAP {
308        input.to_string()
309    } else {
310        let head = &input[..TOOL_INPUT_CAP];
311        format!(
312            "{head}...<truncated {} bytes>",
313            input.len() - TOOL_INPUT_CAP
314        )
315    }
316}