1use tokio::sync::mpsc;
7
8#[derive(Debug, Clone)]
10pub enum StreamChunk {
11 Text(String),
13 ToolStart { name: String, id: String },
15 ToolResult { name: String, output: String },
17 StepDone { step: usize },
19 Done,
21 Error(String),
23}
24
25#[derive(Clone)]
27pub struct StreamingSender {
28 tx: mpsc::UnboundedSender<StreamChunk>,
29}
30
31impl StreamingSender {
32 pub fn add_text(&self, text: impl Into<String>) {
34 let _ = self.tx.send(StreamChunk::Text(text.into()));
35 }
36
37 pub fn add_tool_start(&self, name: impl Into<String>, id: impl Into<String>) {
39 let _ = self.tx.send(StreamChunk::ToolStart {
40 name: name.into(),
41 id: id.into(),
42 });
43 }
44
45 pub fn add_tool_result(&self, name: impl Into<String>, output: impl Into<String>) {
47 let _ = self.tx.send(StreamChunk::ToolResult {
48 name: name.into(),
49 output: output.into(),
50 });
51 }
52
53 pub fn add_step_done(&self, step: usize) {
55 let _ = self.tx.send(StreamChunk::StepDone { step });
56 }
57
58 pub fn finish(&self) {
60 let _ = self.tx.send(StreamChunk::Done);
61 }
62
63 pub fn add_error(&self, err: impl Into<String>) {
65 let _ = self.tx.send(StreamChunk::Error(err.into()));
66 }
67}
68
69pub struct StreamingReceiver {
71 rx: mpsc::UnboundedReceiver<StreamChunk>,
72}
73
74impl StreamingReceiver {
75 pub async fn next(&mut self) -> Option<StreamChunk> {
77 self.rx.recv().await
78 }
79
80 pub async fn collect_all(&mut self) -> Vec<StreamChunk> {
82 let mut chunks = Vec::new();
83 while let Some(chunk) = self.rx.recv().await {
84 let is_done = matches!(chunk, StreamChunk::Done);
85 chunks.push(chunk);
86 if is_done {
87 break;
88 }
89 }
90 chunks
91 }
92}
93
94pub fn streaming_channel() -> (StreamingSender, StreamingReceiver) {
96 let (tx, rx) = mpsc::unbounded_channel();
97 (StreamingSender { tx }, StreamingReceiver { rx })
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[tokio::test]
105 async fn channel_sends_and_receives() {
106 let (tx, mut rx) = streaming_channel();
107 tx.add_text("hello");
108 tx.add_text("world");
109 tx.finish();
110
111 let chunks = rx.collect_all().await;
112 assert_eq!(chunks.len(), 3);
113 assert!(matches!(&chunks[0], StreamChunk::Text(s) if s == "hello"));
114 assert!(matches!(&chunks[1], StreamChunk::Text(s) if s == "world"));
115 assert!(matches!(&chunks[2], StreamChunk::Done));
116 }
117
118 #[tokio::test]
119 async fn tool_events() {
120 let (tx, mut rx) = streaming_channel();
121 tx.add_tool_start("bash", "call_0");
122 tx.add_tool_result("bash", "output here");
123 tx.add_step_done(1);
124 tx.finish();
125
126 let chunks = rx.collect_all().await;
127 assert_eq!(chunks.len(), 4);
128 assert!(matches!(&chunks[0], StreamChunk::ToolStart { name, .. } if name == "bash"));
129 assert!(
130 matches!(&chunks[1], StreamChunk::ToolResult { output, .. } if output == "output here")
131 );
132 assert!(matches!(&chunks[2], StreamChunk::StepDone { step: 1 }));
133 }
134
135 #[tokio::test]
136 async fn next_returns_none_on_drop() {
137 let (tx, mut rx) = streaming_channel();
138 tx.add_text("one");
139 drop(tx);
140
141 assert!(matches!(rx.next().await, Some(StreamChunk::Text(_))));
142 assert!(rx.next().await.is_none());
143 }
144
145 #[tokio::test]
146 async fn error_chunk() {
147 let (tx, mut rx) = streaming_channel();
148 tx.add_error("something failed");
149 tx.finish();
150
151 let chunks = rx.collect_all().await;
152 assert!(matches!(&chunks[0], StreamChunk::Error(s) if s == "something failed"));
153 }
154
155 #[tokio::test]
156 async fn sender_is_clone() {
157 let (tx, mut rx) = streaming_channel();
158 let tx2 = tx.clone();
159 tx.add_text("from tx1");
160 tx2.add_text("from tx2");
161 tx.finish();
162
163 let chunks = rx.collect_all().await;
164 assert_eq!(chunks.len(), 3); }
166}