Skip to main content

systemprompt_api/services/gateway/
audit.rs

1#[path = "audit_internal/payload.rs"]
2mod payload;
3
4use payload::{slice_payload, truncate_for_tool_input};
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use anyhow::Result;
10use bytes::Bytes;
11use systemprompt_ai::models::ai_request_record::AiRequestRecord;
12use systemprompt_ai::repository::ai_requests::UpdateCompletionParams;
13use systemprompt_ai::repository::{
14    AiRequestPayloadRepository, AiRequestRepository, InsertToolCallParams, UpsertPayloadParams,
15};
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::{
18    AiRequestId, ContextId, GatewayConversationId, SessionId, TenantId, TraceId, UserId,
19};
20
21use super::captures::{CapturedToolUse, CapturedUsage};
22use super::pricing;
23use super::protocol::canonical::{CanonicalContent, CanonicalRequest, Role};
24use super::protocol::canonical_response::CanonicalResponse;
25use std::sync::Mutex;
26
27#[derive(Debug, Clone)]
28pub struct GatewayRequestContext {
29    pub ai_request_id: AiRequestId,
30    pub user_id: UserId,
31    pub tenant_id: Option<TenantId>,
32    pub session_id: Option<SessionId>,
33    pub context_id: ContextId,
34    pub gateway_conversation_id: Option<GatewayConversationId>,
35    pub trace_id: Option<TraceId>,
36    pub provider: String,
37    pub model: String,
38    pub max_tokens: Option<u32>,
39    pub is_streaming: bool,
40    pub wire_protocol: String,
41}
42
43#[allow(missing_debug_implementations)]
44pub struct GatewayAudit {
45    requests: Arc<AiRequestRepository>,
46    payloads: Arc<AiRequestPayloadRepository>,
47    pub ctx: GatewayRequestContext,
48    served_model: Mutex<Option<String>>,
49    started_at: Instant,
50}
51
52impl GatewayAudit {
53    pub fn new(
54        db: &DbPool,
55        ctx: GatewayRequestContext,
56    ) -> Result<Self, systemprompt_ai::error::RepositoryError> {
57        let requests = Arc::new(AiRequestRepository::new(db)?);
58        let payloads = Arc::new(AiRequestPayloadRepository::new(db)?);
59        Ok(Self {
60            requests,
61            payloads,
62            ctx,
63            served_model: Mutex::new(None),
64            started_at: Instant::now(),
65        })
66    }
67
68    pub async fn set_served_model(&self, model: &str) {
69        if model.is_empty() || model == self.ctx.model {
70            return;
71        }
72        if let Ok(mut slot) = self.served_model.lock() {
73            *slot = Some(model.to_string());
74        }
75        if let Err(e) = self
76            .requests
77            .update_model(&self.ctx.ai_request_id, model)
78            .await
79        {
80            tracing::warn!(error = %e, "update_model failed");
81        }
82    }
83
84    fn effective_model(&self) -> String {
85        self.served_model
86            .lock()
87            .map_err(|e| {
88                tracing::warn!(error = %e, "served_model mutex poisoned");
89                e
90            })
91            .ok()
92            .and_then(|s| s.clone())
93            .unwrap_or_else(|| self.ctx.model.clone())
94    }
95
96    fn build_record(&self) -> Result<AiRequestRecord> {
97        let mut record =
98            AiRequestRecord::builder(self.ctx.ai_request_id.clone(), self.ctx.user_id.clone())
99                .provider(self.ctx.provider.clone())
100                .model(self.ctx.model.clone())
101                .streaming(self.ctx.is_streaming);
102        if let Some(t) = &self.ctx.tenant_id {
103            record = record.tenant_id(t.clone());
104        }
105        if let Some(s) = &self.ctx.session_id {
106            record = record.session_id(s.clone());
107        }
108        record = record.context_id(self.ctx.context_id.clone());
109        if let Some(g) = &self.ctx.gateway_conversation_id {
110            record = record.gateway_conversation_id(g.clone());
111        }
112        if let Some(t) = &self.ctx.trace_id {
113            record = record.trace_id(t.clone());
114        }
115        if let Some(mt) = self.ctx.max_tokens {
116            record = record.max_tokens(mt);
117        }
118        record.build().map_err(anyhow::Error::from)
119    }
120
121    pub async fn open(&self, request: &CanonicalRequest, request_body: &Bytes) -> Result<()> {
122        let record = self.build_record()?;
123
124        self.requests
125            .insert_with_id(&self.ctx.ai_request_id, &record)
126            .await?;
127
128        let (body_json, excerpt, truncated, bytes) = slice_payload(request_body);
129        if let Err(e) = self
130            .payloads
131            .upsert_request(
132                &self.ctx.ai_request_id,
133                UpsertPayloadParams {
134                    body: body_json.as_ref(),
135                    excerpt: excerpt.as_deref(),
136                    truncated,
137                    bytes: Some(bytes),
138                },
139            )
140            .await
141        {
142            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (request) failed");
143        }
144
145        self.persist_request_messages(request).await;
146        Ok(())
147    }
148
149    async fn persist_request_messages(&self, request: &CanonicalRequest) {
150        let mut seq = 0i32;
151        if let Some(system) = &request.system {
152            if !system.is_empty() {
153                if let Err(e) = self
154                    .requests
155                    .insert_message(&self.ctx.ai_request_id, "system", system, seq)
156                    .await
157                {
158                    tracing::warn!(error = %e, "insert system message failed");
159                }
160                seq += 1;
161            }
162        }
163        for msg in &request.messages {
164            let role = match msg.role {
165                Role::System => "system",
166                Role::User => "user",
167                Role::Assistant => "assistant",
168                Role::Tool => "tool",
169            };
170            let text = flatten_message_content(&msg.content);
171            if let Err(e) = self
172                .requests
173                .insert_message(&self.ctx.ai_request_id, role, &text, seq)
174                .await
175            {
176                tracing::warn!(error = %e, seq, "insert message failed");
177            }
178            seq += 1;
179        }
180    }
181
182    pub async fn complete(
183        &self,
184        usage: CapturedUsage,
185        tool_calls: Vec<CapturedToolUse>,
186        response: &CanonicalResponse,
187        response_body: &Bytes,
188    ) -> Result<()> {
189        let latency_ms = self.started_at.elapsed().as_millis().min(i32::MAX as u128) as i32;
190        let effective_model = self.effective_model();
191        let profile = systemprompt_config::ProfileBootstrap::get().ok();
192        let gateway = profile.as_ref().and_then(|p| p.gateway.as_ref());
193        let pricing_rates = pricing::resolve(&self.ctx.provider, &effective_model, gateway);
194        let cost =
195            pricing::cost_microdollars(pricing_rates, usage.input_tokens, usage.output_tokens);
196
197        self.requests
198            .update_completion(UpdateCompletionParams {
199                id: self.ctx.ai_request_id.clone(),
200                tokens_used: (usage.input_tokens + usage.output_tokens) as i32,
201                input_tokens: usage.input_tokens as i32,
202                output_tokens: usage.output_tokens as i32,
203                cost_microdollars: cost,
204                latency_ms,
205            })
206            .await?;
207
208        self.persist_tool_calls(&tool_calls).await;
209
210        let (body_json, excerpt, truncated, bytes) = slice_payload(response_body);
211        if let Err(e) = self
212            .payloads
213            .upsert_response(
214                &self.ctx.ai_request_id,
215                UpsertPayloadParams {
216                    body: body_json.as_ref(),
217                    excerpt: excerpt.as_deref(),
218                    truncated,
219                    bytes: Some(bytes),
220                },
221            )
222            .await
223        {
224            tracing::warn!(error = %e, ai_request_id = %self.ctx.ai_request_id, "payload insert (response) failed");
225        }
226
227        if let Some(assistant_text) = super::parse::extract_assistant_text(response) {
228            if let Err(e) = self
229                .requests
230                .add_response_message(&self.ctx.ai_request_id, &assistant_text)
231                .await
232            {
233                tracing::warn!(error = %e, "assistant response message insert failed");
234            }
235        }
236
237        tracing::info!(
238            ai_request_id = %self.ctx.ai_request_id,
239            user_id = %self.ctx.user_id,
240            provider = %self.ctx.provider,
241            model = %effective_model,
242            wire_protocol = %self.ctx.wire_protocol,
243            input_tokens = usage.input_tokens,
244            output_tokens = usage.output_tokens,
245            cost_microdollars = cost,
246            latency_ms,
247            tool_calls = tool_calls.len(),
248            "Gateway audit: request completed"
249        );
250        Ok(())
251    }
252
253    async fn persist_tool_calls(&self, tool_calls: &[CapturedToolUse]) {
254        for (idx, tool) in tool_calls.iter().enumerate() {
255            let seq = idx as i32 + 1;
256            let trimmed = truncate_for_tool_input(&tool.tool_input);
257            if let Err(e) = self
258                .requests
259                .insert_tool_call(InsertToolCallParams {
260                    request_id: &self.ctx.ai_request_id,
261                    ai_tool_call_id: &tool.ai_tool_call_id,
262                    tool_name: &tool.tool_name,
263                    tool_input: &trimmed,
264                    sequence_number: seq,
265                })
266                .await
267            {
268                tracing::warn!(error = %e, seq, "tool_call insert failed");
269            }
270        }
271    }
272
273    pub async fn fail(&self, error: &str) -> Result<()> {
274        if let Err(e) = self
275            .requests
276            .update_error(&self.ctx.ai_request_id, error)
277            .await
278        {
279            tracing::warn!(error = %e, "audit fail update failed");
280        }
281        tracing::info!(
282            ai_request_id = %self.ctx.ai_request_id,
283            user_id = %self.ctx.user_id,
284            provider = %self.ctx.provider,
285            model = %self.ctx.model,
286            error,
287            "Gateway audit: request failed"
288        );
289        Ok(())
290    }
291}
292
293fn flatten_message_content(parts: &[CanonicalContent]) -> String {
294    let mut out = String::new();
295    for part in parts {
296        match part {
297            CanonicalContent::Text(t) => push_with_sep(&mut out, t),
298            CanonicalContent::Thinking { text, .. } => push_with_sep(&mut out, text),
299            CanonicalContent::ToolUse { name, input, .. } => {
300                push_with_sep(&mut out, &format!("[tool_use:{name} {input}]"));
301            },
302            CanonicalContent::ToolResult { content, .. } => {
303                for inner in content {
304                    if let CanonicalContent::Text(t) = inner {
305                        push_with_sep(&mut out, t);
306                    }
307                }
308            },
309            CanonicalContent::Image(_) => {},
310        }
311    }
312    out
313}
314
315fn push_with_sep(out: &mut String, fragment: &str) {
316    if fragment.is_empty() {
317        return;
318    }
319    if !out.is_empty() {
320        out.push('\n');
321    }
322    out.push_str(fragment);
323}