Skip to main content

systemprompt_api/services/gateway/stream_tap/
mod.rs

1//! Streaming response tap: re-renders upstream canonical events to the inbound
2//! wire format while accumulating a full response snapshot for the audit sink.
3
4mod accumulator;
5
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9
10use axum::body::Body;
11use bytes::Bytes;
12use futures_util::stream::{BoxStream, Stream};
13
14use self::accumulator::{Summary, TapState, accumulate_event, extract_summary};
15use super::audit::GatewayAudit;
16use super::protocol::canonical_response::CanonicalEvent;
17use super::protocol::inbound::InboundAdapter;
18
19pub fn tap(
20    upstream: BoxStream<'static, Result<CanonicalEvent, String>>,
21    inbound: Arc<dyn InboundAdapter>,
22    request_model: String,
23    audit: Arc<GatewayAudit>,
24) -> Body {
25    let state = Arc::new(Mutex::new(TapState::default()));
26    let tapped = TappedStream {
27        inner: upstream,
28        state: Arc::clone(&state),
29        inbound,
30        request_model,
31        audit,
32    };
33    Body::from_stream(tapped)
34}
35
36struct TappedStream {
37    inner: BoxStream<'static, Result<CanonicalEvent, String>>,
38    state: Arc<Mutex<TapState>>,
39    inbound: Arc<dyn InboundAdapter>,
40    request_model: String,
41    audit: Arc<GatewayAudit>,
42}
43
44impl Stream for TappedStream {
45    type Item = Result<Bytes, std::io::Error>;
46
47    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
48        loop {
49            match self.inner.as_mut().poll_next(cx) {
50                Poll::Pending => return Poll::Pending,
51                Poll::Ready(None) => {
52                    return self.finalize_on_eof();
53                },
54                Poll::Ready(Some(Err(e))) => {
55                    if let Ok(mut s) = self.state.lock() {
56                        s.error = Some(e.clone());
57                    }
58                    let err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, e);
59                    return Poll::Ready(Some(Err(err)));
60                },
61                Poll::Ready(Some(Ok(event))) => {
62                    if let Ok(mut s) = self.state.lock() {
63                        accumulate_event(&mut s, &event);
64                    }
65                    let rendered = self.inbound.render_event(&event, &self.request_model);
66                    if let Some(bytes) = rendered {
67                        if let Ok(mut s) = self.state.lock() {
68                            s.final_bytes.extend_from_slice(&bytes);
69                        }
70                        return Poll::Ready(Some(Ok(bytes)));
71                    }
72                },
73            }
74        }
75    }
76}
77
78impl TappedStream {
79    fn take_summary(&self) -> Option<Summary> {
80        self.state.lock().ok().and_then(|mut s| {
81            if s.finalized {
82                return None;
83            }
84            s.finalized = true;
85            Some(extract_summary(&mut s))
86        })
87    }
88
89    fn finalize_on_eof(&self) -> Poll<Option<Result<Bytes, std::io::Error>>> {
90        let Some(summary) = self.take_summary() else {
91            return Poll::Ready(None);
92        };
93        finalize(Arc::clone(&self.audit), summary, "eof");
94        Poll::Ready(None)
95    }
96}
97
98impl Drop for TappedStream {
99    fn drop(&mut self) {
100        let Some(summary) = self.take_summary() else {
101            return;
102        };
103        finalize(Arc::clone(&self.audit), summary, "drop");
104    }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum FinalizeDecision {
109    Fail(&'static str),
110    Complete { cost_capture_miss: bool },
111}
112
113pub const fn classify(
114    error: Option<&str>,
115    saw_stop: bool,
116    has_content: bool,
117    has_usage: bool,
118) -> FinalizeDecision {
119    if error.is_some() {
120        return FinalizeDecision::Fail("upstream stream error");
121    }
122    if !saw_stop {
123        return FinalizeDecision::Fail(if has_content {
124            "stream ended without stop event"
125        } else {
126            "empty upstream stream"
127        });
128    }
129    FinalizeDecision::Complete {
130        cost_capture_miss: has_content && !has_usage,
131    }
132}
133
134fn finalize(audit: Arc<GatewayAudit>, summary: Summary, origin: &'static str) {
135    tokio::spawn(async move {
136        if let Some(model) = summary.served_model.as_deref() {
137            audit.set_served_model(model).await;
138        }
139
140        let has_content = !summary.final_bytes.is_empty();
141        let has_usage = summary.usage.input_tokens > 0 || summary.usage.output_tokens > 0;
142        match classify(
143            summary.error.as_deref(),
144            summary.saw_stop,
145            has_content,
146            has_usage,
147        ) {
148            FinalizeDecision::Fail(reason) => {
149                let msg = summary.error.as_deref().unwrap_or(reason);
150                if let Err(e) = audit.fail(msg).await {
151                    tracing::warn!(origin, error = %e, "stream audit fail failed");
152                }
153            },
154            FinalizeDecision::Complete { cost_capture_miss } => {
155                if cost_capture_miss {
156                    tracing::warn!(
157                        origin,
158                        "stream completed with content but zero usage: cost capture miss"
159                    );
160                }
161                if let Err(e) = audit
162                    .complete(
163                        summary.usage,
164                        summary.tool_calls,
165                        &summary.response,
166                        &summary.final_bytes,
167                    )
168                    .await
169                {
170                    tracing::warn!(origin, error = %e, "stream audit complete failed");
171                }
172            },
173        }
174    });
175}