systemprompt_api/services/gateway/
audit.rs1#[path = "audit_internal/payload.rs"]
2mod payload;
3
4use payload::{slice_payload, truncate_for_tool_input};
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use anyhow::Result;
10use bytes::Bytes;
11use systemprompt_ai::models::ai_request_record::AiRequestRecord;
12use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
13use systemprompt_ai::repository::{
14 AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
15};
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::{
18 AiRequestId, ContextId, GatewayConversationId, SessionId, TenantId, TraceId, UserId,
19};
20
21use super::captures::{CapturedToolUse, CapturedUsage};
22use super::pricing;
23use super::protocol::canonical::{CanonicalContent, CanonicalRequest, Role};
24use super::protocol::canonical_response::CanonicalResponse;
25use std::sync::Mutex;
26
27#[derive(Debug, Clone)]
28pub struct GatewayRequestContext {
29 pub ai_request_id: AiRequestId,
30 pub user_id: UserId,
31 pub tenant_id: Option<TenantId>,
32 pub session_id: Option<SessionId>,
33 pub context_id: ContextId,
34 pub gateway_conversation_id: Option<GatewayConversationId>,
35 pub trace_id: Option<TraceId>,
36 pub provider: String,
37 pub model: String,
38 pub max_tokens: Option<u32>,
39 pub is_streaming: bool,
40 pub wire_protocol: String,
41}
42
43#[allow(missing_debug_implementations)]
44pub struct GatewayAudit {
45 requests: Arc<AiRequestRepository>,
46 payloads: Arc<AiRequestPayloadRepository>,
47 pub ctx: GatewayRequestContext,
48 served_model: Mutex<Option<String>>,
49 started_at: Instant,
50}
51
52impl GatewayAudit {
53 pub fn new(
54 db: &DbPool,
55 ctx: GatewayRequestContext,
56 ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
57 let requests = Arc::new(AiRequestRepository::new(db)?);
58 let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
59 Ok(Self {
60 requests,
61 payloads,
62 ctx,
63 served_model: Mutex::new(None),
64 started_at: Instant::now(),
65 })
66 }
67
68 pub async fn set_served_model(&self, model: &str) {
69 if model.is_empty() || model == self.ctx.model {
70 return;
71 }
72 if let Ok(mut slot) = self.served_model.lock() {
73 *slot = Some(model.to_string());
74 }
75 if let Err(e) = self
76 .requests
77 .update_model(&self.ctx.ai_request_id, model)
78 .await
79 {
80 tracing::warn!(error = %e, "update_model failed");
81 }
82 }
83
84 fn effective_model(&self) -> String {
85 self.served_model
86 .lock()
87 .map_err(|e| {
88 tracing::warn!(error = %e, "served_model mutex poisoned");
89 e
90 })
91 .ok()
92 .and_then(|s| s.clone())
93 .unwrap_or_else(|| self.ctx.model.clone())
94 }
95
96 fn build_record(&self) -> Result<AiRequestRecord> {
97 let mut record =
98 AiRequestRecord::builder(self.ctx.ai_request_id.clone(), self.ctx.user_id.clone())
99 .provider(self.ctx.provider.clone())
100 .model(self.ctx.model.clone())
101 .streaming(self.ctx.is_streaming);
102 if let Some(t) = &self.ctx.tenant_id {
103 record = record.tenant_id(t.clone());
104 }
105 if let Some(s) = &self.ctx.session_id {
106 record = record.session_id(s.clone());
107 }
108 record = record.context_id(self.ctx.context_id.clone());
109 if let Some(g) = &self.ctx.gateway_conversation_id {
110 record = record.gateway_conversation_id(g.clone());
111 }
112 if let Some(t) = &self.ctx.trace_id {
113 record = record.trace_id(t.clone());
114 }
115 if let Some(mt) = self.ctx.max_tokens {
116 record = record.max_tokens(mt);
117 }
118 record.build().map_err(anyhow::Error::from)
119 }
120
121 pub async fn open(&self, request: &CanonicalRequest, request_body: &Bytes) -> Result<()> {
122 let record = self.build_record()?;
123
124 self.requests
125 .insert_with_id(&self.ctx.ai_request_id, &record)
126 .await?;
127
128 let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
129 if let Err(e) = self
130 .payloads
131 .upsert_request(
132 &self.ctx.ai_request_id,
133 UpsertPayloadParams {
134 body: body_json.as_ref(),
135 excerpt: excerpt.as_deref(),
136 truncated,
137 bytes: Some(bytes),
138 },
139 )
140 .await
141 {
142 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
143 }
144
145 self.persist_request_messages(request).await;
146 Ok(())
147 }
148
149 async fn persist_request_messages(&self, request: &CanonicalRequest) {
150 let mut seq = 0i32;
151 if let Some(system) = &request.system {
152 if !system.is_empty() {
153 if let Err(e) = self
154 .requests
155 .insert_message(&self.ctx.ai_request_id, "system", system, seq)
156 .await
157 {
158 tracing::warn!(error = %e, "insert system message failed");
159 }
160 seq += 1;
161 }
162 }
163 for msg in &request.messages {
164 let role = match msg.role {
165 Role::System => "system",
166 Role::User => "user",
167 Role::Assistant => "assistant",
168 Role::Tool => "tool",
169 };
170 let text = flatten_message_content(&msg.content);
171 if let Err(e) = self
172 .requests
173 .insert_message(&self.ctx.ai_request_id, role, &text, seq)
174 .await
175 {
176 tracing::warn!(error = %e, seq, "insert message failed");
177 }
178 seq += 1;
179 }
180 }
181
182 pub async fn complete(
183 &self,
184 usage: CapturedUsage,
185 tool_calls: Vec<CapturedToolUse>,
186 response: &CanonicalResponse,
187 response_body: &Bytes,
188 ) -> Result<()> {
189 let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
190 let effective_model = self.effective_model();
191 let profile = systemprompt_config::ProfileBootstrap::get().ok();
192 let gateway = profile.as_ref().and_then(|p| p.gateway.as_ref());
193 let pricing_rates = pricing::resolve(&self.ctx.provider, &effective_model, gateway);
194 let cost =
195 pricing::cost_microdollars(pricing_rates, usage.input_tokens, usage.output_tokens);
196
197 self.requests
198 .update_completion(UpdateCompletionParams {
199 id: self.ctx.ai_request_id.clone(),
200 tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
201 input_tokens: usage.input_tokens as i32,
202 output_tokens: usage.output_tokens as i32,
203 cost_microdollars: cost,
204 latency_ms,
205 })
206 .await?;
207
208 self.persist_tool_calls(&tool_calls).await;
209
210 let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
211 if let Err(e) = self
212 .payloads
213 .upsert_response(
214 &self.ctx.ai_request_id,
215 UpsertPayloadParams {
216 body: body_json.as_ref(),
217 excerpt: excerpt.as_deref(),
218 truncated,
219 bytes: Some(bytes),
220 },
221 )
222 .await
223 {
224 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
225 }
226
227 if let Some(assistant_text) = super::parse::extract_assistant_text(response) {
228 if let Err(e) = self
229 .requests
230 .add_response_message(&self.ctx.ai_request_id, &assistant_text)
231 .await
232 {
233 tracing::warn!(error = %e, "assistant response message insert failed");
234 }
235 }
236
237 tracing::info!(
238 ai_request_id = %self.ctx.ai_request_id,
239 user_id = %self.ctx.user_id,
240 provider = %self.ctx.provider,
241 model = %effective_model,
242 wire_protocol = %self.ctx.wire_protocol,
243 input_tokens = usage.input_tokens,
244 output_tokens = usage.output_tokens,
245 cost_microdollars = cost,
246 latency_ms,
247 tool_calls = tool_calls.len(),
248 "Gateway audit: request completed"
249 );
250 Ok(())
251 }
252
253 async fn persist_tool_calls(&self, tool_calls: &[CapturedToolUse]) {
254 for (idx, tool) in tool_calls.iter().enumerate() {
255 let seq = idx as i32 + 1;
256 let trimmed = truncate_for_tool_input(&tool.tool_input);
257 if let Err(e) = self
258 .requests
259 .insert_tool_call(InsertToolCallParams {
260 request_id: &self.ctx.ai_request_id,
261 ai_tool_call_id: &tool.ai_tool_call_id,
262 tool_name: &tool.tool_name,
263 tool_input: &trimmed,
264 sequence_number: seq,
265 })
266 .await
267 {
268 tracing::warn!(error = %e, seq, "tool_call insert failed");
269 }
270 }
271 }
272
273 pub async fn fail(&self, error: &str) -> Result<()> {
274 if let Err(e) = self
275 .requests
276 .update_error(&self.ctx.ai_request_id, error)
277 .await
278 {
279 tracing::warn!(error = %e, "audit fail update failed");
280 }
281 tracing::info!(
282 ai_request_id = %self.ctx.ai_request_id,
283 user_id = %self.ctx.user_id,
284 provider = %self.ctx.provider,
285 model = %self.ctx.model,
286 error,
287 "Gateway audit: request failed"
288 );
289 Ok(())
290 }
291}
292
293fn flatten_message_content(parts: &[CanonicalContent]) -> String {
294 let mut out = String::new();
295 for part in parts {
296 match part {
297 CanonicalContent::Text(t) => push_with_sep(&mut out, t),
298 CanonicalContent::Thinking { text, .. } => push_with_sep(&mut out, text),
299 CanonicalContent::ToolUse { name, input, .. } => {
300 push_with_sep(&mut out, &format!("[tool_use:{name} {input}]"));
301 },
302 CanonicalContent::ToolResult { content, .. } => {
303 for inner in content {
304 if let CanonicalContent::Text(t) = inner {
305 push_with_sep(&mut out, t);
306 }
307 }
308 },
309 CanonicalContent::Image(_) => {},
310 }
311 }
312 out
313}
314
315fn push_with_sep(out: &mut String, fragment: &str) {
316 if fragment.is_empty() {
317 return;
318 }
319 if !out.is_empty() {
320 out.push('\n');
321 }
322 out.push_str(fragment);
323}