systemprompt_api/services/gateway/
audit.rs1use std::sync::Arc;
2use std::time::Instant;
3
4use anyhow::Result;
5use bytes::Bytes;
6use serde_json::Value;
7use systemprompt_ai::models::ai_request_record::AiRequestRecord;
8use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
9use systemprompt_ai::repository::{
10 AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
11};
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{AiRequestId, SessionId, TenantId, TraceId, UserId};
14
15use super::pricing::{self, ModelPricing};
16
17const PAYLOAD_CAP_BYTES: usize = 256 * 1024;
18const EXCERPT_BYTES: usize = 8 * 1024;
19
20#[derive(Debug, Clone)]
21pub struct GatewayRequestContext {
22 pub ai_request_id: AiRequestId,
23 pub user_id: UserId,
24 pub tenant_id: Option<TenantId>,
25 pub session_id: Option<SessionId>,
26 pub trace_id: Option<TraceId>,
27 pub provider: String,
28 pub model: String,
29 pub max_tokens: Option<u32>,
30 pub is_streaming: bool,
31}
32
33#[derive(Debug, Clone, Copy, Default)]
34pub struct CapturedUsage {
35 pub input_tokens: u32,
36 pub output_tokens: u32,
37}
38
39#[derive(Debug, Clone)]
40pub struct CapturedToolUse {
41 pub ai_tool_call_id: String,
42 pub tool_name: String,
43 pub tool_input: String,
44}
45
46#[derive(Clone, Debug)]
47pub struct GatewayAudit {
48 requests: Arc<AiRequestRepository>,
49 payloads: Arc<AiRequestPayloadRepository>,
50 pub ctx: GatewayRequestContext,
51 pricing: ModelPricing,
52 started_at: Instant,
53}
54
55impl GatewayAudit {
56 pub fn new(
57 db: &DbPool,
58 ctx: GatewayRequestContext,
59 ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
60 let requests = Arc::new(AiRequestRepository::new(db)?);
61 let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
62 let pricing = pricing::lookup(&ctx.provider, &ctx.model);
63 Ok(Self {
64 requests,
65 payloads,
66 ctx,
67 pricing,
68 started_at: Instant::now(),
69 })
70 }
71
72 pub async fn open(&self, request_body: &Bytes) -> Result<()> {
73 let record =
74 AiRequestRecord::builder(self.ctx.ai_request_id.as_str(), self.ctx.user_id.clone())
75 .provider(self.ctx.provider.clone())
76 .model(self.ctx.model.clone())
77 .streaming(self.ctx.is_streaming);
78 let record = if let Some(t) = &self.ctx.tenant_id {
79 record.tenant_id(t.clone())
80 } else {
81 record
82 };
83 let record = if let Some(s) = &self.ctx.session_id {
84 record.session_id(s.clone())
85 } else {
86 record
87 };
88 let record = if let Some(t) = &self.ctx.trace_id {
89 record.trace_id(t.clone())
90 } else {
91 record
92 };
93 let record = if let Some(mt) = self.ctx.max_tokens {
94 record.max_tokens(mt)
95 } else {
96 record
97 };
98 let record = record.build()?;
99
100 self.requests
101 .insert_with_id(&self.ctx.ai_request_id, &record)
102 .await?;
103
104 let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
105 if let Err(e) = self
106 .payloads
107 .upsert_request(
108 &self.ctx.ai_request_id,
109 UpsertPayloadParams {
110 body: body_json.as_ref(),
111 excerpt: excerpt.as_deref(),
112 truncated,
113 bytes: Some(bytes),
114 },
115 )
116 .await
117 {
118 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
119 }
120 Ok(())
121 }
122
123 pub async fn complete(
124 &self,
125 usage: CapturedUsage,
126 tool_calls: Vec<CapturedToolUse>,
127 response_body: &Bytes,
128 ) -> Result<()> {
129 let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
130 let cost =
131 pricing::cost_microdollars(self.pricing, usage.input_tokens, usage.output_tokens);
132
133 self.requests
134 .update_completion(UpdateCompletionParams {
135 id: self.ctx.ai_request_id.clone(),
136 tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
137 input_tokens: usage.input_tokens as i32,
138 output_tokens: usage.output_tokens as i32,
139 cost_microdollars: cost,
140 latency_ms,
141 })
142 .await?;
143
144 for (idx, tool) in tool_calls.iter().enumerate() {
145 let seq = idx as i32 + 1;
146 let trimmed = truncate_for_tool_input(&tool.tool_input);
147 if let Err(e) = self
148 .requests
149 .insert_tool_call(InsertToolCallParams {
150 request_id: &self.ctx.ai_request_id,
151 ai_tool_call_id: &tool.ai_tool_call_id,
152 tool_name: &tool.tool_name,
153 tool_input: &trimmed,
154 sequence_number: seq,
155 })
156 .await
157 {
158 tracing::warn!(error = %e, seq, "tool_call insert failed");
159 }
160 }
161
162 let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
163 if let Err(e) = self
164 .payloads
165 .upsert_response(
166 &self.ctx.ai_request_id,
167 UpsertPayloadParams {
168 body: body_json.as_ref(),
169 excerpt: excerpt.as_deref(),
170 truncated,
171 bytes: Some(bytes),
172 },
173 )
174 .await
175 {
176 tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
177 }
178
179 tracing::info!(
180 ai_request_id = %self.ctx.ai_request_id,
181 user_id = %self.ctx.user_id,
182 provider = %self.ctx.provider,
183 model = %self.ctx.model,
184 input_tokens = usage.input_tokens,
185 output_tokens = usage.output_tokens,
186 cost_microdollars = cost,
187 latency_ms,
188 tool_calls = tool_calls.len(),
189 "Gateway audit: request completed"
190 );
191 Ok(())
192 }
193
194 pub async fn fail(&self, error: &str) -> Result<()> {
195 if let Err(e) = self
196 .requests
197 .update_error(&self.ctx.ai_request_id, error)
198 .await
199 {
200 tracing::warn!(error = %e, "audit fail update failed");
201 }
202 tracing::info!(
203 ai_request_id = %self.ctx.ai_request_id,
204 user_id = %self.ctx.user_id,
205 provider = %self.ctx.provider,
206 model = %self.ctx.model,
207 error,
208 "Gateway audit: request failed"
209 );
210 Ok(())
211 }
212}
213
214fn slice_payload(bytes: &Bytes) -> (Option<Value>, Option<String>, bool, i32) {
215 let len = bytes.len();
216 let len_i32 = len.min(i32::MAX as usize) as i32;
217 if len <= PAYLOAD_CAP_BYTES {
218 serde_json::from_slice::<Value>(bytes).map_or_else(
219 |_| {
220 let excerpt = String::from_utf8_lossy(bytes).to_string();
221 (None, Some(excerpt), false, len_i32)
222 },
223 |v| (Some(v), None, false, len_i32),
224 )
225 } else {
226 let head_len = EXCERPT_BYTES.min(len);
227 let head = String::from_utf8_lossy(&bytes[..head_len]).to_string();
228 let tail_start = len.saturating_sub(EXCERPT_BYTES);
229 let tail = String::from_utf8_lossy(&bytes[tail_start..]).to_string();
230 let excerpt = format!("{head}\n...<truncated {} bytes>...\n{tail}", len - head_len);
231 (None, Some(excerpt), true, len_i32)
232 }
233}
234
235fn truncate_for_tool_input(input: &str) -> String {
236 const TOOL_INPUT_CAP: usize = 64 * 1024;
237 if input.len() <= TOOL_INPUT_CAP {
238 input.to_string()
239 } else {
240 let head = &input[..TOOL_INPUT_CAP];
241 format!(
242 "{head}...<truncated {} bytes>",
243 input.len() - TOOL_INPUT_CAP
244 )
245 }
246}