Skip to main content

systemprompt_api/services/gateway/
audit.rs

1#[path = "audit_internal/payload.rs"]
2mod payload;
3
4use payload::{slice_payload, truncate_for_tool_input};
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use anyhow::Result;
10use bytes::Bytes;
11use systemprompt_ai::models::ai_request_record::AiRequestRecord;
12use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
13use systemprompt_ai::repository::{
14    AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
15};
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::{AiRequestId, SessionId, TenantId, TraceId, UserId};
18
19use super::captures::{CapturedToolUse, CapturedUsage};
20use super::models::AnthropicGatewayRequest;
21use super::pricing;
22use std::sync::Mutex;
23
24#[derive(Debug, Clone)]
25pub struct GatewayRequestContext {
26    pub ai_request_id: AiRequestId,
27    pub user_id: UserId,
28    pub tenant_id: Option<TenantId>,
29    pub session_id: Option<SessionId>,
30    pub trace_id: Option<TraceId>,
31    pub provider: String,
32    pub model: String,
33    pub max_tokens: Option<u32>,
34    pub is_streaming: bool,
35}
36
37#[allow(missing_debug_implementations)]
38pub struct GatewayAudit {
39    requests: Arc<AiRequestRepository>,
40    payloads: Arc<AiRequestPayloadRepository>,
41    pub ctx: GatewayRequestContext,
42    served_model: Mutex<Option<String>>,
43    started_at: Instant,
44}
45
46impl GatewayAudit {
47    pub fn new(
48        db: &DbPool,
49        ctx: GatewayRequestContext,
50    ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
51        let requests = Arc::new(AiRequestRepository::new(db)?);
52        let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
53        Ok(Self {
54            requests,
55            payloads,
56            ctx,
57            served_model: Mutex::new(None),
58            started_at: Instant::now(),
59        })
60    }
61
62    pub async fn set_served_model(&self, model: &str) {
63        if model.is_empty() || model == self.ctx.model {
64            return;
65        }
66        if let Ok(mut slot) = self.served_model.lock() {
67            *slot = Some(model.to_string());
68        }
69        if let Err(e) = self
70            .requests
71            .update_model(&self.ctx.ai_request_id, model)
72            .await
73        {
74            tracing::warn!(error = %e, "update_model failed");
75        }
76    }
77
78    fn effective_model(&self) -> String {
79        self.served_model
80            .lock()
81            .map_err(|e| {
82                tracing::warn!(error = %e, "served_model mutex poisoned");
83                e
84            })
85            .ok()
86            .and_then(|s| s.clone())
87            .unwrap_or_else(|| self.ctx.model.clone())
88    }
89
90    fn build_record(&self) -> Result<AiRequestRecord> {
91        let mut record =
92            AiRequestRecord::builder(self.ctx.ai_request_id.clone(), self.ctx.user_id.clone())
93                .provider(self.ctx.provider.clone())
94                .model(self.ctx.model.clone())
95                .streaming(self.ctx.is_streaming);
96        if let Some(t) = &self.ctx.tenant_id {
97            record = record.tenant_id(t.clone());
98        }
99        if let Some(s) = &self.ctx.session_id {
100            record = record.session_id(s.clone());
101        }
102        if let Some(t) = &self.ctx.trace_id {
103            record = record.trace_id(t.clone());
104        }
105        if let Some(mt) = self.ctx.max_tokens {
106            record = record.max_tokens(mt);
107        }
108        record.build().map_err(anyhow::Error::from)
109    }
110
111    pub async fn open(
112        &self,
113        request: &AnthropicGatewayRequest,
114        request_body: &Bytes,
115    ) -> Result<()> {
116        let record = self.build_record()?;
117
118        self.requests
119            .insert_with_id(&self.ctx.ai_request_id, &record)
120            .await?;
121
122        let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
123        if let Err(e) = self
124            .payloads
125            .upsert_request(
126                &self.ctx.ai_request_id,
127                UpsertPayloadParams {
128                    body: body_json.as_ref(),
129                    excerpt: excerpt.as_deref(),
130                    truncated,
131                    bytes: Some(bytes),
132                },
133            )
134            .await
135        {
136            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
137        }
138
139        self.persist_request_messages(request).await;
140        Ok(())
141    }
142
143    async fn persist_request_messages(&self, request: &AnthropicGatewayRequest) {
144        let mut seq = 0i32;
145        if let Some(system) = request.system.as_ref() {
146            if let Some(text) = super::flatten::flatten_system_prompt(system) {
147                if let Err(e) = self
148                    .requests
149                    .insert_message(&self.ctx.ai_request_id, "system", &text, seq)
150                    .await
151                {
152                    tracing::warn!(error = %e, "insert system message failed");
153                }
154                seq += 1;
155            }
156        }
157        for msg in &request.messages {
158            let text = super::flatten::flatten_message_content(&msg.content);
159            if let Err(e) = self
160                .requests
161                .insert_message(&self.ctx.ai_request_id, &msg.role, &text, seq)
162                .await
163            {
164                tracing::warn!(error = %e, seq, "insert message failed");
165            }
166            seq += 1;
167        }
168    }
169
170    pub async fn complete(
171        &self,
172        usage: CapturedUsage,
173        tool_calls: Vec<CapturedToolUse>,
174        response_body: &Bytes,
175    ) -> Result<()> {
176        let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
177        let effective_model = self.effective_model();
178        let pricing_rates = pricing::lookup(&self.ctx.provider, &effective_model);
179        let cost =
180            pricing::cost_microdollars(pricing_rates, usage.input_tokens, usage.output_tokens);
181
182        self.requests
183            .update_completion(UpdateCompletionParams {
184                id: self.ctx.ai_request_id.clone(),
185                tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
186                input_tokens: usage.input_tokens as i32,
187                output_tokens: usage.output_tokens as i32,
188                cost_microdollars: cost,
189                latency_ms,
190            })
191            .await?;
192
193        self.persist_tool_calls(&tool_calls).await;
194
195        let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
196        if let Err(e) = self
197            .payloads
198            .upsert_response(
199                &self.ctx.ai_request_id,
200                UpsertPayloadParams {
201                    body: body_json.as_ref(),
202                    excerpt: excerpt.as_deref(),
203                    truncated,
204                    bytes: Some(bytes),
205                },
206            )
207            .await
208        {
209            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
210        }
211
212        if let Some(assistant_text) = super::parse::extract_assistant_text(response_body) {
213            if let Err(e) = self
214                .requests
215                .add_response_message(&self.ctx.ai_request_id, &assistant_text)
216                .await
217            {
218                tracing::warn!(error = %e, "assistant response message insert failed");
219            }
220        }
221
222        tracing::info!(
223            ai_request_id = %self.ctx.ai_request_id,
224            user_id = %self.ctx.user_id,
225            provider = %self.ctx.provider,
226            model = %effective_model,
227            input_tokens = usage.input_tokens,
228            output_tokens = usage.output_tokens,
229            cost_microdollars = cost,
230            latency_ms,
231            tool_calls = tool_calls.len(),
232            "Gateway audit: request completed"
233        );
234        Ok(())
235    }
236
237    async fn persist_tool_calls(&self, tool_calls: &[CapturedToolUse]) {
238        for (idx, tool) in tool_calls.iter().enumerate() {
239            let seq = idx as i32 + 1;
240            let trimmed = truncate_for_tool_input(&tool.tool_input);
241            if let Err(e) = self
242                .requests
243                .insert_tool_call(InsertToolCallParams {
244                    request_id: &self.ctx.ai_request_id,
245                    ai_tool_call_id: &tool.ai_tool_call_id,
246                    tool_name: &tool.tool_name,
247                    tool_input: &trimmed,
248                    sequence_number: seq,
249                })
250                .await
251            {
252                tracing::warn!(error = %e, seq, "tool_call insert failed");
253            }
254        }
255    }
256
257    pub async fn fail(&self, error: &str) -> Result<()> {
258        if let Err(e) = self
259            .requests
260            .update_error(&self.ctx.ai_request_id, error)
261            .await
262        {
263            tracing::warn!(error = %e, "audit fail update failed");
264        }
265        tracing::info!(
266            ai_request_id = %self.ctx.ai_request_id,
267            user_id = %self.ctx.user_id,
268            provider = %self.ctx.provider,
269            model = %self.ctx.model,
270            error,
271            "Gateway audit: request failed"
272        );
273        Ok(())
274    }
275}