1use crate::{ZoeyError, Result};
4use tokio::sync::mpsc;
5
6#[derive(Debug, Clone)]
8pub struct TextChunk {
9 pub text: String,
11 pub is_final: bool,
13 pub metadata: Option<serde_json::Value>,
15}
16
17pub type TextStream = mpsc::Receiver<Result<TextChunk>>;
19
20pub type TextStreamSender = mpsc::Sender<Result<TextChunk>>;
22
23pub fn create_text_stream(buffer_size: usize) -> (TextStreamSender, TextStream) {
25 mpsc::channel(buffer_size)
26}
27
28pub struct StreamHandler {
30 sender: TextStreamSender,
31}
32
33impl StreamHandler {
34 pub fn new(sender: TextStreamSender) -> Self {
36 Self { sender }
37 }
38
39 pub async fn send_chunk(&self, text: String, is_final: bool) -> Result<()> {
41 self.sender
42 .send(Ok(TextChunk {
43 text,
44 is_final,
45 metadata: None,
46 }))
47 .await
48 .map_err(|e| ZoeyError::other(format!("Failed to send chunk: {}", e)))
49 }
50
51 pub async fn send_chunk_with_meta(&self, text: String, is_final: bool, metadata: Option<serde_json::Value>) -> Result<()> {
53 self.sender
54 .send(Ok(TextChunk {
55 text,
56 is_final,
57 metadata,
58 }))
59 .await
60 .map_err(|e| ZoeyError::other(format!("Failed to send chunk: {}", e)))
61 }
62
63 pub async fn send_error(&self, error: ZoeyError) -> Result<()> {
65 self.sender
66 .send(Err(error))
67 .await
68 .map_err(|e| ZoeyError::other(format!("Failed to send error: {}", e)))
69 }
70
71 pub async fn finish(&self, text: String) -> Result<()> {
73 self.send_chunk(text, true).await
74 }
75}
76
77pub async fn collect_stream(mut stream: TextStream) -> Result<String> {
79 let mut result = String::new();
80
81 while let Some(chunk_result) = stream.recv().await {
82 let chunk = chunk_result?;
83 result.push_str(&chunk.text);
84
85 if chunk.is_final {
86 break;
87 }
88 }
89
90 Ok(result)
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[tokio::test]
98 async fn test_stream_creation() {
99 let (sender, mut receiver) = create_text_stream(10);
100
101 sender
103 .send(Ok(TextChunk {
104 text: "Hello".to_string(),
105 is_final: false,
106 metadata: None,
107 }))
108 .await
109 .unwrap();
110
111 sender
112 .send(Ok(TextChunk {
113 text: " World".to_string(),
114 is_final: true,
115 metadata: None,
116 }))
117 .await
118 .unwrap();
119
120 let chunk1 = receiver.recv().await.unwrap().unwrap();
122 assert_eq!(chunk1.text, "Hello");
123 assert!(!chunk1.is_final);
124
125 let chunk2 = receiver.recv().await.unwrap().unwrap();
126 assert_eq!(chunk2.text, " World");
127 assert!(chunk2.is_final);
128 }
129
130 #[tokio::test]
131 async fn test_stream_handler() {
132 let (sender, receiver) = create_text_stream(10);
133 let handler = StreamHandler::new(sender);
134
135 tokio::spawn(async move {
137 handler
138 .send_chunk("Chunk 1".to_string(), false)
139 .await
140 .unwrap();
141 handler
142 .send_chunk("Chunk 2".to_string(), false)
143 .await
144 .unwrap();
145 handler.finish("Final chunk".to_string()).await.unwrap();
146 });
147
148 let result = collect_stream(receiver).await.unwrap();
150 assert_eq!(result, "Chunk 1Chunk 2Final chunk");
151 }
152
153 #[tokio::test]
154 async fn test_stream_error() {
155 let (sender, mut receiver) = create_text_stream(10);
156
157 sender
158 .send(Err(ZoeyError::other("Test error")))
159 .await
160 .unwrap();
161
162 let chunk_result = receiver.recv().await.unwrap();
163 assert!(chunk_result.is_err());
164 }
165}