systemprompt_api/services/gateway/stream_tap/
mod.rs1mod accumulator;
5
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9
10use axum::body::Body;
11use bytes::Bytes;
12use futures_util::stream::{BoxStream, Stream};
13
14use self::accumulator::{Summary, TapState, accumulate_event, extract_summary};
15use super::audit::GatewayAudit;
16use super::protocol::canonical_response::CanonicalEvent;
17use super::protocol::inbound::InboundAdapter;
18
19pub fn tap(
20 upstream: BoxStream<'static, Result<CanonicalEvent, String>>,
21 inbound: Arc<dyn InboundAdapter>,
22 request_model: String,
23 audit: Arc<GatewayAudit>,
24) -> Body {
25 let state = Arc::new(Mutex::new(TapState::default()));
26 let tapped = TappedStream {
27 inner: upstream,
28 state: Arc::clone(&state),
29 inbound,
30 request_model,
31 audit,
32 };
33 Body::from_stream(tapped)
34}
35
36struct TappedStream {
37 inner: BoxStream<'static, Result<CanonicalEvent, String>>,
38 state: Arc<Mutex<TapState>>,
39 inbound: Arc<dyn InboundAdapter>,
40 request_model: String,
41 audit: Arc<GatewayAudit>,
42}
43
44impl Stream for TappedStream {
45 type Item = Result<Bytes, std::io::Error>;
46
47 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
48 loop {
49 match self.inner.as_mut().poll_next(cx) {
50 Poll::Pending => return Poll::Pending,
51 Poll::Ready(None) => {
52 return self.finalize_on_eof();
53 },
54 Poll::Ready(Some(Err(e))) => {
55 if let Ok(mut s) = self.state.lock() {
56 s.error = Some(e.clone());
57 }
58 let err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, e);
59 return Poll::Ready(Some(Err(err)));
60 },
61 Poll::Ready(Some(Ok(event))) => {
62 if let Ok(mut s) = self.state.lock() {
63 accumulate_event(&mut s, &event);
64 }
65 let rendered = self.inbound.render_event(&event, &self.request_model);
66 if let Some(bytes) = rendered {
67 if let Ok(mut s) = self.state.lock() {
68 s.final_bytes.extend_from_slice(&bytes);
69 }
70 return Poll::Ready(Some(Ok(bytes)));
71 }
72 },
73 }
74 }
75 }
76}
77
78impl TappedStream {
79 fn take_summary(&self) -> Option<Summary> {
80 self.state.lock().ok().and_then(|mut s| {
81 if s.finalized {
82 return None;
83 }
84 s.finalized = true;
85 Some(extract_summary(&mut s))
86 })
87 }
88
89 fn finalize_on_eof(&self) -> Poll<Option<Result<Bytes, std::io::Error>>> {
90 let Some(summary) = self.take_summary() else {
91 return Poll::Ready(None);
92 };
93 finalize(Arc::clone(&self.audit), summary, "eof");
94 Poll::Ready(None)
95 }
96}
97
98impl Drop for TappedStream {
99 fn drop(&mut self) {
100 let Some(summary) = self.take_summary() else {
101 return;
102 };
103 finalize(Arc::clone(&self.audit), summary, "drop");
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum FinalizeDecision {
109 Fail(&'static str),
110 Complete { cost_capture_miss: bool },
111}
112
113pub const fn classify(
114 error: Option<&str>,
115 saw_stop: bool,
116 has_content: bool,
117 has_usage: bool,
118) -> FinalizeDecision {
119 if error.is_some() {
120 return FinalizeDecision::Fail("upstream stream error");
121 }
122 if !saw_stop {
123 return FinalizeDecision::Fail(if has_content {
124 "stream ended without stop event"
125 } else {
126 "empty upstream stream"
127 });
128 }
129 FinalizeDecision::Complete {
130 cost_capture_miss: has_content && !has_usage,
131 }
132}
133
134fn finalize(audit: Arc<GatewayAudit>, summary: Summary, origin: &'static str) {
135 tokio::spawn(async move {
136 if let Some(model) = summary.served_model.as_deref() {
137 audit.set_served_model(model).await;
138 }
139
140 let has_content = !summary.final_bytes.is_empty();
141 let has_usage = summary.usage.input_tokens > 0 || summary.usage.output_tokens > 0;
142 match classify(
143 summary.error.as_deref(),
144 summary.saw_stop,
145 has_content,
146 has_usage,
147 ) {
148 FinalizeDecision::Fail(reason) => {
149 let msg = summary.error.as_deref().unwrap_or(reason);
150 if let Err(e) = audit.fail(msg).await {
151 tracing::warn!(origin, error = %e, "stream audit fail failed");
152 }
153 },
154 FinalizeDecision::Complete { cost_capture_miss } => {
155 if cost_capture_miss {
156 tracing::warn!(
157 origin,
158 "stream completed with content but zero usage: cost capture miss"
159 );
160 }
161 if let Err(e) = audit
162 .complete(
163 summary.usage,
164 summary.tool_calls,
165 &summary.response,
166 &summary.final_bytes,
167 )
168 .await
169 {
170 tracing::warn!(origin, error = %e, "stream audit complete failed");
171 }
172 },
173 }
174 });
175}