systemprompt_api/services/gateway/
audit.rs1#[path = "audit_internal/payload.rs"]
2mod payload;
3
4use payload::{slice_payload, truncate_for_tool_input};
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use anyhow::Result;
10use bytes::Bytes;
11use systemprompt_ai::models::ai_request_record::AiRequestRecord;
12use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
13use systemprompt_ai::repository::{
14 AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
15};
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::{AiRequestId, SessionId, TenantId, TraceId, UserId};
18
19use super::captures::{CapturedToolUse, CapturedUsage};
20use super::models::AnthropicGatewayRequest;
21use super::pricing;
22use std::sync::Mutex;
23
24#[derive(Debug, Clone)]
25pub struct GatewayRequestContext {
26 pub ai_request_id: AiRequestId,
27 pub user_id: UserId,
28 pub tenant_id: Option<TenantId>,
29 pub session_id: Option<SessionId>,
30 pub trace_id: Option<TraceId>,
31 pub provider: String,
32 pub model: String,
33 pub max_tokens: Option<u32>,
34 pub is_streaming: bool,
35}
36
37#[allow(missing_debug_implementations)]
38pub struct GatewayAudit {
39 requests: Arc<AiRequestRepository>,
40 payloads: Arc<AiRequestPayloadRepository>,
41 pub ctx: GatewayRequestContext,
42 served_model: Mutex<Option<String>>,
43 started_at: Instant,
44}
45
46impl GatewayAudit {
47 pub fn new(
48 db: &DbPool,
49 ctx: GatewayRequestContext,
50 ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
51 let requests = Arc::new(AiRequestRepository::new(db)?);
52 let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
53 Ok(Self {
54 requests,
55 payloads,
56 ctx,
57 served_model: Mutex::new(None),
58 started_at: Instant::now(),
59 })
60 }
61
62 pub async fn set_served_model(&self, model: &str) {
63 if model.is_empty() || model == self.ctx.model {
64 return;
65 }
66 if let Ok(mut slot) = self.served_model.lock() {
67 *slot = Some(model.to_string());
68 }
69 if let Err(e) = self
70 .requests
71 .update_model(&self.ctx.ai_request_id, model)
72 .await
73 {
74 tracing::warn!(error = %e, "update_model failed");
75 }
76 }
77
78 fn effective_model(&self) -> String {
79 self.served_model
80 .lock()
81 .map_err(|e| {
82 tracing::warn!(error = %e, "served_model mutex poisoned");
83 e
84 })
85 .ok()
86 .and_then(|s| s.clone())
87 .unwrap_or_else(|| self.ctx.model.clone())
88 }
89
90 fn build_record(&self) -> Result<AiRequestRecord> {
91 let mut record =
92 AiRequestRecord::builder(self.ctx.ai_request_id.clone(), self.ctx.user_id.clone())
93 .provider(self.ctx.provider.clone())
94 .model(self.ctx.model.clone())
95 .streaming(self.ctx.is_streaming);
96 if let Some(t) = &self.ctx.tenant_id {
97 record = record.tenant_id(t.clone());
98 }
99 if let Some(s) = &self.ctx.session_id {
100 record = record.session_id(s.clone());
101 }
102 if let Some(t) = &self.ctx.trace_id {
103 record = record.trace_id(t.clone());
104 }
105 if let Some(mt) = self.ctx.max_tokens {
106 record = record.max_tokens(mt);
107 }
108 record.build().map_err(anyhow::Error::from)
109 }
110
111 pub async fn open(
112 &self,
113 request: &AnthropicGatewayRequest,
114 request_body: &Bytes,
115 ) -> Result<()> {
116 let record = self.build_record()?;
117
118 self.requests
119 .insert_with_id(&self.ctx.ai_request_id, &record)
120 .await?;
121
122 let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
123 if let Err(e) = self
124 .payloads
125 .upsert_request(
126 &self.ctx.ai_request_id,
127 UpsertPayloadParams {
128 body: body_json.as_ref(),
129 excerpt: excerpt.as_deref(),
130 truncated,
131 bytes: Some(bytes),
132 },
133 )
134 .await
135 {
136 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
137 }
138
139 self.persist_request_messages(request).await;
140 Ok(())
141 }
142
143 async fn persist_request_messages(&self, request: &AnthropicGatewayRequest) {
144 let mut seq = 0i32;
145 if let Some(system) = request.system.as_ref() {
146 if let Some(text) = super::flatten::flatten_system_prompt(system) {
147 if let Err(e) = self
148 .requests
149 .insert_message(&self.ctx.ai_request_id, "system", &text, seq)
150 .await
151 {
152 tracing::warn!(error = %e, "insert system message failed");
153 }
154 seq += 1;
155 }
156 }
157 for msg in &request.messages {
158 let text = super::flatten::flatten_message_content(&msg.content);
159 if let Err(e) = self
160 .requests
161 .insert_message(&self.ctx.ai_request_id, &msg.role, &text, seq)
162 .await
163 {
164 tracing::warn!(error = %e, seq, "insert message failed");
165 }
166 seq += 1;
167 }
168 }
169
170 pub async fn complete(
171 &self,
172 usage: CapturedUsage,
173 tool_calls: Vec<CapturedToolUse>,
174 response_body: &Bytes,
175 ) -> Result<()> {
176 let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
177 let effective_model = self.effective_model();
178 let pricing_rates = pricing::lookup(&self.ctx.provider, &effective_model);
179 let cost =
180 pricing::cost_microdollars(pricing_rates, usage.input_tokens, usage.output_tokens);
181
182 self.requests
183 .update_completion(UpdateCompletionParams {
184 id: self.ctx.ai_request_id.clone(),
185 tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
186 input_tokens: usage.input_tokens as i32,
187 output_tokens: usage.output_tokens as i32,
188 cost_microdollars: cost,
189 latency_ms,
190 })
191 .await?;
192
193 self.persist_tool_calls(&tool_calls).await;
194
195 let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
196 if let Err(e) = self
197 .payloads
198 .upsert_response(
199 &self.ctx.ai_request_id,
200 UpsertPayloadParams {
201 body: body_json.as_ref(),
202 excerpt: excerpt.as_deref(),
203 truncated,
204 bytes: Some(bytes),
205 },
206 )
207 .await
208 {
209 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
210 }
211
212 if let Some(assistant_text) = super::parse::extract_assistant_text(response_body) {
213 if let Err(e) = self
214 .requests
215 .add_response_message(&self.ctx.ai_request_id, &assistant_text)
216 .await
217 {
218 tracing::warn!(error = %e, "assistant response message insert failed");
219 }
220 }
221
222 tracing::info!(
223 ai_request_id = %self.ctx.ai_request_id,
224 user_id = %self.ctx.user_id,
225 provider = %self.ctx.provider,
226 model = %effective_model,
227 input_tokens = usage.input_tokens,
228 output_tokens = usage.output_tokens,
229 cost_microdollars = cost,
230 latency_ms,
231 tool_calls = tool_calls.len(),
232 "Gateway audit: request completed"
233 );
234 Ok(())
235 }
236
237 async fn persist_tool_calls(&self, tool_calls: &[CapturedToolUse]) {
238 for (idx, tool) in tool_calls.iter().enumerate() {
239 let seq = idx as i32 + 1;
240 let trimmed = truncate_for_tool_input(&tool.tool_input);
241 if let Err(e) = self
242 .requests
243 .insert_tool_call(InsertToolCallParams {
244 request_id: &self.ctx.ai_request_id,
245 ai_tool_call_id: &tool.ai_tool_call_id,
246 tool_name: &tool.tool_name,
247 tool_input: &trimmed,
248 sequence_number: seq,
249 })
250 .await
251 {
252 tracing::warn!(error = %e, seq, "tool_call insert failed");
253 }
254 }
255 }
256
257 pub async fn fail(&self, error: &str) -> Result<()> {
258 if let Err(e) = self
259 .requests
260 .update_error(&self.ctx.ai_request_id, error)
261 .await
262 {
263 tracing::warn!(error = %e, "audit fail update failed");
264 }
265 tracing::info!(
266 ai_request_id = %self.ctx.ai_request_id,
267 user_id = %self.ctx.user_id,
268 provider = %self.ctx.provider,
269 model = %self.ctx.model,
270 error,
271 "Gateway audit: request failed"
272 );
273 Ok(())
274 }
275}