Skip to main content

sgr_agent/
streaming.rs

1//! Streaming abstraction — channel-based streaming for agent output.
2//!
3//! Provides `StreamingSender` and `StreamingReceiver` for streaming
4//! text chunks and tool results from agent execution.
5
6use tokio::sync::mpsc;
7
8/// A chunk of streaming output.
9#[derive(Debug, Clone)]
10pub enum StreamChunk {
11    /// Text content chunk.
12    Text(String),
13    /// Tool call started.
14    ToolStart { name: String, id: String },
15    /// Tool result received.
16    ToolResult { name: String, output: String },
17    /// Agent step completed.
18    StepDone { step: usize },
19    /// Stream finished.
20    Done,
21    /// Error occurred.
22    Error(String),
23}
24
25/// Sender side — used by agent loop to emit chunks.
26#[derive(Clone)]
27pub struct StreamingSender {
28    tx: mpsc::UnboundedSender<StreamChunk>,
29}
30
31impl StreamingSender {
32    /// Send a text chunk.
33    pub fn add_text(&self, text: impl Into<String>) {
34        let _ = self.tx.send(StreamChunk::Text(text.into()));
35    }
36
37    /// Signal tool execution started.
38    pub fn add_tool_start(&self, name: impl Into<String>, id: impl Into<String>) {
39        let _ = self.tx.send(StreamChunk::ToolStart {
40            name: name.into(),
41            id: id.into(),
42        });
43    }
44
45    /// Send tool result.
46    pub fn add_tool_result(&self, name: impl Into<String>, output: impl Into<String>) {
47        let _ = self.tx.send(StreamChunk::ToolResult {
48            name: name.into(),
49            output: output.into(),
50        });
51    }
52
53    /// Signal step completion.
54    pub fn add_step_done(&self, step: usize) {
55        let _ = self.tx.send(StreamChunk::StepDone { step });
56    }
57
58    /// Signal stream is complete.
59    pub fn finish(&self) {
60        let _ = self.tx.send(StreamChunk::Done);
61    }
62
63    /// Signal error.
64    pub fn add_error(&self, err: impl Into<String>) {
65        let _ = self.tx.send(StreamChunk::Error(err.into()));
66    }
67}
68
69/// Receiver side — used by UI/consumer to read chunks.
70pub struct StreamingReceiver {
71    rx: mpsc::UnboundedReceiver<StreamChunk>,
72}
73
74impl StreamingReceiver {
75    /// Receive next chunk. Returns None when sender is dropped.
76    pub async fn next(&mut self) -> Option<StreamChunk> {
77        self.rx.recv().await
78    }
79
80    /// Collect all chunks until Done or sender drops.
81    pub async fn collect_all(&mut self) -> Vec<StreamChunk> {
82        let mut chunks = Vec::new();
83        while let Some(chunk) = self.rx.recv().await {
84            let is_done = matches!(chunk, StreamChunk::Done);
85            chunks.push(chunk);
86            if is_done {
87                break;
88            }
89        }
90        chunks
91    }
92}
93
94/// Create a streaming channel pair.
95pub fn streaming_channel() -> (StreamingSender, StreamingReceiver) {
96    let (tx, rx) = mpsc::unbounded_channel();
97    (StreamingSender { tx }, StreamingReceiver { rx })
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[tokio::test]
105    async fn channel_sends_and_receives() {
106        let (tx, mut rx) = streaming_channel();
107        tx.add_text("hello");
108        tx.add_text("world");
109        tx.finish();
110
111        let chunks = rx.collect_all().await;
112        assert_eq!(chunks.len(), 3);
113        assert!(matches!(&chunks[0], StreamChunk::Text(s) if s == "hello"));
114        assert!(matches!(&chunks[1], StreamChunk::Text(s) if s == "world"));
115        assert!(matches!(&chunks[2], StreamChunk::Done));
116    }
117
118    #[tokio::test]
119    async fn tool_events() {
120        let (tx, mut rx) = streaming_channel();
121        tx.add_tool_start("bash", "call_0");
122        tx.add_tool_result("bash", "output here");
123        tx.add_step_done(1);
124        tx.finish();
125
126        let chunks = rx.collect_all().await;
127        assert_eq!(chunks.len(), 4);
128        assert!(matches!(&chunks[0], StreamChunk::ToolStart { name, .. } if name == "bash"));
129        assert!(
130            matches!(&chunks[1], StreamChunk::ToolResult { output, .. } if output == "output here")
131        );
132        assert!(matches!(&chunks[2], StreamChunk::StepDone { step: 1 }));
133    }
134
135    #[tokio::test]
136    async fn next_returns_none_on_drop() {
137        let (tx, mut rx) = streaming_channel();
138        tx.add_text("one");
139        drop(tx);
140
141        assert!(matches!(rx.next().await, Some(StreamChunk::Text(_))));
142        assert!(rx.next().await.is_none());
143    }
144
145    #[tokio::test]
146    async fn error_chunk() {
147        let (tx, mut rx) = streaming_channel();
148        tx.add_error("something failed");
149        tx.finish();
150
151        let chunks = rx.collect_all().await;
152        assert!(matches!(&chunks[0], StreamChunk::Error(s) if s == "something failed"));
153    }
154
155    #[tokio::test]
156    async fn sender_is_clone() {
157        let (tx, mut rx) = streaming_channel();
158        let tx2 = tx.clone();
159        tx.add_text("from tx1");
160        tx2.add_text("from tx2");
161        tx.finish();
162
163        let chunks = rx.collect_all().await;
164        assert_eq!(chunks.len(), 3); // 2 texts + Done
165    }
166}