systemprompt_api/services/gateway/
audit.rs1use std::sync::Arc;
2use std::time::Instant;
3
4use anyhow::Result;
5use bytes::Bytes;
6use serde_json::Value;
7use systemprompt_ai::models::ai_request_record::AiRequestRecord;
8use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
9use systemprompt_ai::repository::{
10 AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
11};
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{AiRequestId, SessionId, TenantId, TraceId, UserId};
14
15use super::models::AnthropicGatewayRequest;
16use super::pricing;
17use std::sync::Mutex;
18
19const PAYLOAD_CAP_BYTES: usize = 256 * 1024;
20const EXCERPT_BYTES: usize = 8 * 1024;
21
22#[derive(Debug, Clone)]
23pub struct GatewayRequestContext {
24 pub ai_request_id: AiRequestId,
25 pub user_id: UserId,
26 pub tenant_id: Option<TenantId>,
27 pub session_id: Option<SessionId>,
28 pub trace_id: Option<TraceId>,
29 pub provider: String,
30 pub model: String,
31 pub max_tokens: Option<u32>,
32 pub is_streaming: bool,
33}
34
35#[derive(Debug, Clone, Copy, Default)]
36pub struct CapturedUsage {
37 pub input_tokens: u32,
38 pub output_tokens: u32,
39}
40
41#[derive(Debug, Clone)]
42pub struct CapturedToolUse {
43 pub ai_tool_call_id: String,
44 pub tool_name: String,
45 pub tool_input: String,
46}
47
48#[allow(missing_debug_implementations)]
49pub struct GatewayAudit {
50 requests: Arc<AiRequestRepository>,
51 payloads: Arc<AiRequestPayloadRepository>,
52 pub ctx: GatewayRequestContext,
53 served_model: Mutex<Option<String>>,
54 started_at: Instant,
55}
56
57impl GatewayAudit {
58 pub fn new(
59 db: &DbPool,
60 ctx: GatewayRequestContext,
61 ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
62 let requests = Arc::new(AiRequestRepository::new(db)?);
63 let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
64 Ok(Self {
65 requests,
66 payloads,
67 ctx,
68 served_model: Mutex::new(None),
69 started_at: Instant::now(),
70 })
71 }
72
73 pub async fn set_served_model(&self, model: &str) {
74 if model.is_empty() || model == self.ctx.model {
75 return;
76 }
77 if let Ok(mut slot) = self.served_model.lock() {
78 *slot = Some(model.to_string());
79 }
80 if let Err(e) = self
81 .requests
82 .update_model(&self.ctx.ai_request_id, model)
83 .await
84 {
85 tracing::warn!(error = %e, "update_model failed");
86 }
87 }
88
89 fn effective_model(&self) -> String {
90 self.served_model
91 .lock()
92 .ok()
93 .and_then(|s| s.clone())
94 .unwrap_or_else(|| self.ctx.model.clone())
95 }
96
97 fn build_record(&self) -> Result<AiRequestRecord> {
98 let mut record =
99 AiRequestRecord::builder(self.ctx.ai_request_id.as_str(), self.ctx.user_id.clone())
100 .provider(self.ctx.provider.clone())
101 .model(self.ctx.model.clone())
102 .streaming(self.ctx.is_streaming);
103 if let Some(t) = &self.ctx.tenant_id {
104 record = record.tenant_id(t.clone());
105 }
106 if let Some(s) = &self.ctx.session_id {
107 record = record.session_id(s.clone());
108 }
109 if let Some(t) = &self.ctx.trace_id {
110 record = record.trace_id(t.clone());
111 }
112 if let Some(mt) = self.ctx.max_tokens {
113 record = record.max_tokens(mt);
114 }
115 record.build().map_err(anyhow::Error::from)
116 }
117
118 pub async fn open(
119 &self,
120 request: &AnthropicGatewayRequest,
121 request_body: &Bytes,
122 ) -> Result<()> {
123 let record = self.build_record()?;
124
125 self.requests
126 .insert_with_id(&self.ctx.ai_request_id, &record)
127 .await?;
128
129 let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
130 if let Err(e) = self
131 .payloads
132 .upsert_request(
133 &self.ctx.ai_request_id,
134 UpsertPayloadParams {
135 body: body_json.as_ref(),
136 excerpt: excerpt.as_deref(),
137 truncated,
138 bytes: Some(bytes),
139 },
140 )
141 .await
142 {
143 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
144 }
145
146 self.persist_request_messages(request).await;
147 Ok(())
148 }
149
150 async fn persist_request_messages(&self, request: &AnthropicGatewayRequest) {
151 let mut seq = 0i32;
152 if let Some(system) = request.system.as_ref() {
153 if let Some(text) = super::flatten::flatten_system_prompt(system) {
154 if let Err(e) = self
155 .requests
156 .insert_message(&self.ctx.ai_request_id, "system", &text, seq)
157 .await
158 {
159 tracing::warn!(error = %e, "insert system message failed");
160 }
161 seq += 1;
162 }
163 }
164 for msg in &request.messages {
165 let text = super::flatten::flatten_message_content(&msg.content);
166 if let Err(e) = self
167 .requests
168 .insert_message(&self.ctx.ai_request_id, &msg.role, &text, seq)
169 .await
170 {
171 tracing::warn!(error = %e, seq, "insert message failed");
172 }
173 seq += 1;
174 }
175 }
176
177 pub async fn complete(
178 &self,
179 usage: CapturedUsage,
180 tool_calls: Vec<CapturedToolUse>,
181 response_body: &Bytes,
182 ) -> Result<()> {
183 let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
184 let effective_model = self.effective_model();
185 let pricing_rates = pricing::lookup(&self.ctx.provider, &effective_model);
186 let cost =
187 pricing::cost_microdollars(pricing_rates, usage.input_tokens, usage.output_tokens);
188
189 self.requests
190 .update_completion(UpdateCompletionParams {
191 id: self.ctx.ai_request_id.clone(),
192 tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
193 input_tokens: usage.input_tokens as i32,
194 output_tokens: usage.output_tokens as i32,
195 cost_microdollars: cost,
196 latency_ms,
197 })
198 .await?;
199
200 self.persist_tool_calls(&tool_calls).await;
201
202 let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
203 if let Err(e) = self
204 .payloads
205 .upsert_response(
206 &self.ctx.ai_request_id,
207 UpsertPayloadParams {
208 body: body_json.as_ref(),
209 excerpt: excerpt.as_deref(),
210 truncated,
211 bytes: Some(bytes),
212 },
213 )
214 .await
215 {
216 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
217 }
218
219 if let Some(assistant_text) = super::parse::extract_assistant_text(response_body) {
220 if let Err(e) = self
221 .requests
222 .add_response_message(&self.ctx.ai_request_id, &assistant_text)
223 .await
224 {
225 tracing::warn!(error = %e, "assistant response message insert failed");
226 }
227 }
228
229 tracing::info!(
230 ai_request_id = %self.ctx.ai_request_id,
231 user_id = %self.ctx.user_id,
232 provider = %self.ctx.provider,
233 model = %effective_model,
234 input_tokens = usage.input_tokens,
235 output_tokens = usage.output_tokens,
236 cost_microdollars = cost,
237 latency_ms,
238 tool_calls = tool_calls.len(),
239 "Gateway audit: request completed"
240 );
241 Ok(())
242 }
243
244 async fn persist_tool_calls(&self, tool_calls: &[CapturedToolUse]) {
245 for (idx, tool) in tool_calls.iter().enumerate() {
246 let seq = idx as i32 + 1;
247 let trimmed = truncate_for_tool_input(&tool.tool_input);
248 if let Err(e) = self
249 .requests
250 .insert_tool_call(InsertToolCallParams {
251 request_id: &self.ctx.ai_request_id,
252 ai_tool_call_id: &tool.ai_tool_call_id,
253 tool_name: &tool.tool_name,
254 tool_input: &trimmed,
255 sequence_number: seq,
256 })
257 .await
258 {
259 tracing::warn!(error = %e, seq, "tool_call insert failed");
260 }
261 }
262 }
263
264 pub async fn fail(&self, error: &str) -> Result<()> {
265 if let Err(e) = self
266 .requests
267 .update_error(&self.ctx.ai_request_id, error)
268 .await
269 {
270 tracing::warn!(error = %e, "audit fail update failed");
271 }
272 tracing::info!(
273 ai_request_id = %self.ctx.ai_request_id,
274 user_id = %self.ctx.user_id,
275 provider = %self.ctx.provider,
276 model = %self.ctx.model,
277 error,
278 "Gateway audit: request failed"
279 );
280 Ok(())
281 }
282}
283
284fn slice_payload(bytes: &Bytes) -> (Option<Value>, Option<String>, bool, i32) {
285 let len = bytes.len();
286 let len_i32 = len.min(i32::MAX as usize) as i32;
287 if len <= PAYLOAD_CAP_BYTES {
288 serde_json::from_slice::<Value>(bytes).map_or_else(
289 |_| {
290 let excerpt = String::from_utf8_lossy(bytes).to_string();
291 (None, Some(excerpt), false, len_i32)
292 },
293 |v| (Some(v), None, false, len_i32),
294 )
295 } else {
296 let head_len = EXCERPT_BYTES.min(len);
297 let head = String::from_utf8_lossy(&bytes[..head_len]).to_string();
298 let tail_start = len.saturating_sub(EXCERPT_BYTES);
299 let tail = String::from_utf8_lossy(&bytes[tail_start..]).to_string();
300 let excerpt = format!("{head}\n...<truncated {} bytes>...\n{tail}", len - head_len);
301 (None, Some(excerpt), true, len_i32)
302 }
303}
304
305fn truncate_for_tool_input(input: &str) -> String {
306 const TOOL_INPUT_CAP: usize = 64 * 1024;
307 if input.len() <= TOOL_INPUT_CAP {
308 input.to_string()
309 } else {
310 let head = &input[..TOOL_INPUT_CAP];
311 format!(
312 "{head}...<truncated {} bytes>",
313 input.len() - TOOL_INPUT_CAP
314 )
315 }
316}