Skip to main content

systemprompt_api/services/gateway/audit/
mod.rs

1mod message_text;
2pub mod payload;
3
4use message_text::flatten_message_content;
5use payload::{slice_payload, truncate_for_tool_input};
6
7use std::sync::Arc;
8use std::time::Instant;
9
10use anyhow::Result;
11use bytes::Bytes;
12use systemprompt_ai::models::ai_request_record::AiRequestRecord;
13use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
14use systemprompt_ai::repository::{
15    AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
16};
17use systemprompt_database::DbPool;
18use systemprompt_identifiers::{
19    AiRequestId, ContextId, GatewayConversationId, SessionId, TraceId, UserId,
20};
21
22use super::captures::{CapturedToolUse, CapturedUsage};
23use super::pricing;
24use super::protocol::canonical::{CanonicalRequest, Role};
25use super::protocol::canonical_response::CanonicalResponse;
26use std::sync::Mutex;
27
28#[derive(Debug, Clone)]
29pub struct GatewayRequestContext {
30    pub ai_request_id: AiRequestId,
31    pub user_id: UserId,
32    pub session_id: Option<SessionId>,
33    pub context_id: ContextId,
34    pub gateway_conversation_id: Option<GatewayConversationId>,
35    pub trace_id: Option<TraceId>,
36    pub provider: String,
37    pub model: String,
38    pub max_tokens: Option<u32>,
39    pub is_streaming: bool,
40    pub wire_protocol: String,
41}
42
43#[expect(
44    missing_debug_implementations,
45    reason = "service type holds repository clients that intentionally do not implement Debug"
46)]
47pub struct GatewayAudit {
48    requests: Arc<AiRequestRepository>,
49    payloads: Arc<AiRequestPayloadRepository>,
50    pub ctx: GatewayRequestContext,
51    served_model: Mutex<Option<String>>,
52    started_at: Instant,
53}
54
55impl GatewayAudit {
56    pub fn new(
57        db: &DbPool,
58        ctx: GatewayRequestContext,
59    ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
60        let requests = Arc::new(AiRequestRepository::new(db)?);
61        let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
62        Ok(Self {
63            requests,
64            payloads,
65            ctx,
66            served_model: Mutex::new(None),
67            started_at: Instant::now(),
68        })
69    }
70
71    pub async fn set_served_model(&self, model: &str) {
72        if model.is_empty() || model == self.ctx.model {
73            return;
74        }
75        if let Ok(mut slot) = self.served_model.lock() {
76            *slot = Some(model.to_owned());
77        }
78        if let Err(e) = self
79            .requests
80            .update_model(&self.ctx.ai_request_id, model)
81            .await
82        {
83            tracing::warn!(error = %e, "update_model failed");
84        }
85    }
86
87    fn effective_model(&self) -> String {
88        self.served_model
89            .lock()
90            .map_err(|e| {
91                tracing::warn!(error = %e, "served_model mutex poisoned");
92                e
93            })
94            .ok()
95            .and_then(|s| s.clone())
96            .unwrap_or_else(|| self.ctx.model.clone())
97    }
98
99    fn build_record(&self) -> Result<AiRequestRecord> {
100        let mut record =
101            AiRequestRecord::builder(self.ctx.ai_request_id.clone(), self.ctx.user_id.clone())
102                .provider(self.ctx.provider.clone())
103                .model(self.ctx.model.clone())
104                .streaming(self.ctx.is_streaming);
105        if let Some(s) = &self.ctx.session_id {
106            record = record.session_id(s.clone());
107        }
108        record = record.context_id(self.ctx.context_id.clone());
109        if let Some(g) = &self.ctx.gateway_conversation_id {
110            record = record.gateway_conversation_id(g.clone());
111        }
112        if let Some(t) = &self.ctx.trace_id {
113            record = record.trace_id(t.clone());
114        }
115        if let Some(mt) = self.ctx.max_tokens {
116            record = record.max_tokens(mt);
117        }
118        record.build().map_err(anyhow::Error::from)
119    }
120
121    pub async fn open(&self, request: &CanonicalRequest, request_body: &Bytes) -> Result<()> {
122        let record = self.build_record()?;
123
124        self.requests
125            .insert_with_id(&self.ctx.ai_request_id, &record)
126            .await?;
127
128        let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
129        if let Err(e) = self
130            .payloads
131            .upsert_request(
132                &self.ctx.ai_request_id,
133                UpsertPayloadParams {
134                    body: body_json.as_ref(),
135                    excerpt: excerpt.as_deref(),
136                    truncated,
137                    bytes: Some(bytes),
138                },
139            )
140            .await
141        {
142            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
143        }
144
145        self.persist_request_messages(request).await;
146        Ok(())
147    }
148
149    async fn persist_request_messages(&self, request: &CanonicalRequest) {
150        let mut seq = 0i32;
151        if let Some(system) = &request.system {
152            if !system.is_empty() {
153                if let Err(e) = self
154                    .requests
155                    .insert_message(&self.ctx.ai_request_id, "system", system, seq)
156                    .await
157                {
158                    tracing::warn!(error = %e, "insert system message failed");
159                }
160                seq += 1;
161            }
162        }
163        for msg in &request.messages {
164            let role = match msg.role {
165                Role::System => "system",
166                Role::User => "user",
167                Role::Assistant => "assistant",
168                Role::Tool => "tool",
169            };
170            let text = flatten_message_content(&msg.content);
171            if let Err(e) = self
172                .requests
173                .insert_message(&self.ctx.ai_request_id, role, &text, seq)
174                .await
175            {
176                tracing::warn!(error = %e, seq, "insert message failed");
177            }
178            seq += 1;
179        }
180    }
181
182    pub async fn complete(
183        &self,
184        usage: CapturedUsage,
185        tool_calls: Vec<CapturedToolUse>,
186        response: &CanonicalResponse,
187        response_body: &Bytes,
188    ) -> Result<()> {
189        let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
190        let effective_model = self.effective_model();
191        let profile = systemprompt_config::ProfileBootstrap::get().ok();
192        let gateway = profile.as_ref().and_then(|p| p.gateway.as_ref());
193        let pricing_rates = pricing::resolve(&self.ctx.provider, &effective_model, gateway);
194        let cost =
195            pricing::cost_microdollars(pricing_rates, usage.input_tokens, usage.output_tokens);
196
197        self.requests
198            .update_completion(UpdateCompletionParams {
199                id: self.ctx.ai_request_id.clone(),
200                tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
201                input_tokens: usage.input_tokens as i32,
202                output_tokens: usage.output_tokens as i32,
203                cost_microdollars: cost,
204                latency_ms,
205            })
206            .await?;
207
208        self.persist_tool_calls(&tool_calls).await;
209
210        let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
211        if let Err(e) = self
212            .payloads
213            .upsert_response(
214                &self.ctx.ai_request_id,
215                UpsertPayloadParams {
216                    body: body_json.as_ref(),
217                    excerpt: excerpt.as_deref(),
218                    truncated,
219                    bytes: Some(bytes),
220                },
221            )
222            .await
223        {
224            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
225        }
226
227        if let Some(assistant_text) = super::parse::extract_assistant_text(response) {
228            if let Err(e) = self
229                .requests
230                .add_response_message(&self.ctx.ai_request_id, &assistant_text)
231                .await
232            {
233                tracing::warn!(error = %e, "assistant response message insert failed");
234            }
235        }
236
237        tracing::info!(
238            ai_request_id = %self.ctx.ai_request_id,
239            user_id = %self.ctx.user_id,
240            provider = %self.ctx.provider,
241            model = %effective_model,
242            wire_protocol = %self.ctx.wire_protocol,
243            input_tokens = usage.input_tokens,
244            output_tokens = usage.output_tokens,
245            cost_microdollars = cost,
246            latency_ms,
247            tool_calls = tool_calls.len(),
248            "Gateway audit: request completed"
249        );
250        Ok(())
251    }
252
253    async fn persist_tool_calls(&self, tool_calls: &[CapturedToolUse]) {
254        for (idx, tool) in tool_calls.iter().enumerate() {
255            let seq = idx as i32 + 1;
256            let trimmed = truncate_for_tool_input(&tool.tool_input);
257            if let Err(e) = self
258                .requests
259                .insert_tool_call(InsertToolCallParams {
260                    request_id: &self.ctx.ai_request_id,
261                    ai_tool_call_id: &tool.ai_tool_call_id,
262                    tool_name: &tool.tool_name,
263                    tool_input: &trimmed,
264                    sequence_number: seq,
265                })
266                .await
267            {
268                tracing::warn!(error = %e, seq, "tool_call insert failed");
269            }
270        }
271    }
272
273    pub async fn fail(&self, error: &str) -> Result<()> {
274        if let Err(e) = self
275            .requests
276            .update_error(&self.ctx.ai_request_id, error)
277            .await
278        {
279            tracing::warn!(error = %e, "audit fail update failed");
280        }
281        tracing::info!(
282            ai_request_id = %self.ctx.ai_request_id,
283            user_id = %self.ctx.user_id,
284            provider = %self.ctx.provider,
285            model = %self.ctx.model,
286            error,
287            "Gateway audit: request failed"
288        );
289        Ok(())
290    }
291}