Skip to main content

telex/
stream_state.rs

1//! Stream state management for Telex.
2//!
3//! Provides the `use_stream` hook for handling async streams (e.g., LLM token streaming).
4
5use std::cell::RefCell;
6use std::rc::Rc;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::mpsc::{self, Receiver, Sender};
9use std::sync::Arc;
10use std::thread;
11
12/// Represents the state of a streaming operation.
13#[derive(Clone, Debug)]
14pub enum StreamState {
15    /// Stream not started yet.
16    Idle,
17    /// Stream is active and receiving data.
18    Streaming,
19    /// Stream completed successfully.
20    Done,
21    /// Stream encountered an error.
22    Error(String),
23}
24
25/// Handle for stream state that can be stored and cloned.
26pub struct StreamHandle<T> {
27    inner: Rc<RefCell<StreamInner<T>>>,
28}
29
30struct StreamInner<T> {
31    /// Accumulated values from the stream.
32    accumulated: T,
33    /// Current state of the stream.
34    state: StreamState,
35    /// Whether the stream has been started.
36    started: bool,
37    /// Receiver for stream items.
38    receiver: Option<Receiver<StreamItem<T>>>,
39    /// Wake flag to notify the event loop when tokens arrive.
40    wake_flag: Option<Arc<AtomicBool>>,
41}
42
43/// An item received from the stream.
44enum StreamItem<T> {
45    /// A value from the stream.
46    Value(T),
47    /// Stream completed.
48    Done,
49    /// Stream errored.
50    Error(String),
51}
52
53impl<T> Clone for StreamHandle<T>
54where
55    T: Clone,
56{
57    fn clone(&self) -> Self {
58        Self {
59            inner: Rc::clone(&self.inner),
60        }
61    }
62}
63
64impl<T: Clone + Default + 'static> StreamHandle<T> {
65    /// Create a new stream handle with default accumulated value.
66    pub fn new() -> Self {
67        Self {
68            inner: Rc::new(RefCell::new(StreamInner {
69                accumulated: T::default(),
70                state: StreamState::Idle,
71                started: false,
72                receiver: None,
73                wake_flag: None,
74            })),
75        }
76    }
77
78    /// Create a new stream handle with an event-loop wake flag.
79    pub fn with_wake_flag(wake_flag: Arc<AtomicBool>) -> Self {
80        Self {
81            inner: Rc::new(RefCell::new(StreamInner {
82                accumulated: T::default(),
83                state: StreamState::Idle,
84                started: false,
85                receiver: None,
86                wake_flag: Some(wake_flag),
87            })),
88        }
89    }
90
91    /// Create a new stream handle with a specific initial value.
92    pub fn with_initial(initial: T) -> Self {
93        Self {
94            inner: Rc::new(RefCell::new(StreamInner {
95                accumulated: initial,
96                state: StreamState::Idle,
97                started: false,
98                receiver: None,
99                wake_flag: None,
100            })),
101        }
102    }
103
104    /// Get the current accumulated value.
105    pub fn get(&self) -> T {
106        self.inner.borrow().accumulated.clone()
107    }
108
109    /// Check if the stream is currently loading/streaming.
110    pub fn is_loading(&self) -> bool {
111        matches!(
112            self.inner.borrow().state,
113            StreamState::Idle | StreamState::Streaming
114        )
115    }
116
117    /// Check if the stream is actively receiving data.
118    pub fn is_streaming(&self) -> bool {
119        matches!(self.inner.borrow().state, StreamState::Streaming)
120    }
121
122    /// Check if the stream has completed.
123    pub fn is_done(&self) -> bool {
124        matches!(self.inner.borrow().state, StreamState::Done)
125    }
126
127    /// Check if the stream encountered an error.
128    pub fn is_error(&self) -> bool {
129        matches!(self.inner.borrow().state, StreamState::Error(_))
130    }
131
132    /// Get the error message if there was an error.
133    pub fn error(&self) -> Option<String> {
134        match &self.inner.borrow().state {
135            StreamState::Error(e) => Some(e.clone()),
136            _ => None,
137        }
138    }
139
140    /// Get the current stream state.
141    pub fn state(&self) -> StreamState {
142        self.inner.borrow().state.clone()
143    }
144}
145
146impl<T: Clone + Send + 'static> StreamHandle<T> {
147    /// Start the stream if not already started.
148    ///
149    /// The `stream_fn` should be a function that returns an iterator.
150    /// Each item from the iterator will be sent through the channel.
151    pub fn start<F, I>(&self, stream_fn: F)
152    where
153        F: FnOnce() -> I + Send + 'static,
154        I: Iterator<Item = T> + Send + 'static,
155    {
156        let mut inner = self.inner.borrow_mut();
157        if inner.started {
158            return;
159        }
160
161        inner.started = true;
162        inner.state = StreamState::Streaming;
163
164        // Create channel for stream items
165        let (tx, rx): (Sender<StreamItem<T>>, Receiver<StreamItem<T>>) = mpsc::channel();
166        inner.receiver = Some(rx);
167        let wake_flag = inner.wake_flag.clone();
168
169        // Spawn thread to run the stream
170        thread::spawn(move || {
171            let iter = stream_fn();
172            for item in iter {
173                if tx.send(StreamItem::Value(item)).is_err() {
174                    // Receiver dropped, stop streaming
175                    return;
176                }
177                if let Some(ref flag) = wake_flag {
178                    flag.store(true, Ordering::Release);
179                }
180            }
181            let _ = tx.send(StreamItem::Done);
182            if let Some(ref flag) = wake_flag {
183                flag.store(true, Ordering::Release);
184            }
185        });
186    }
187
188    /// Start the stream with error handling.
189    pub fn start_with_result<F, I>(&self, stream_fn: F)
190    where
191        F: FnOnce() -> Result<I, String> + Send + 'static,
192        I: Iterator<Item = T> + Send + 'static,
193    {
194        let mut inner = self.inner.borrow_mut();
195        if inner.started {
196            return;
197        }
198
199        inner.started = true;
200        inner.state = StreamState::Streaming;
201
202        let (tx, rx): (Sender<StreamItem<T>>, Receiver<StreamItem<T>>) = mpsc::channel();
203        inner.receiver = Some(rx);
204        let wake_flag = inner.wake_flag.clone();
205
206        thread::spawn(move || match stream_fn() {
207            Ok(iter) => {
208                for item in iter {
209                    if tx.send(StreamItem::Value(item)).is_err() {
210                        return;
211                    }
212                    if let Some(ref flag) = wake_flag {
213                        flag.store(true, Ordering::Release);
214                    }
215                }
216                let _ = tx.send(StreamItem::Done);
217                if let Some(ref flag) = wake_flag {
218                    flag.store(true, Ordering::Release);
219                }
220            }
221            Err(e) => {
222                let _ = tx.send(StreamItem::Error(e));
223                if let Some(ref flag) = wake_flag {
224                    flag.store(true, Ordering::Release);
225                }
226            }
227        });
228    }
229
230    /// Poll for new items and update accumulated value.
231    /// Returns true if there were updates.
232    pub fn poll(&self, accumulate: impl Fn(&mut T, T)) -> bool {
233        let mut inner = self.inner.borrow_mut();
234        let mut updated = false;
235
236        // Take receiver temporarily to avoid borrow conflicts
237        if let Some(receiver) = inner.receiver.take() {
238            // Drain all available items
239            let mut new_state = None;
240            loop {
241                match receiver.try_recv() {
242                    Ok(StreamItem::Value(item)) => {
243                        accumulate(&mut inner.accumulated, item);
244                        updated = true;
245                    }
246                    Ok(StreamItem::Done) => {
247                        new_state = Some(StreamState::Done);
248                        break;
249                    }
250                    Ok(StreamItem::Error(e)) => {
251                        new_state = Some(StreamState::Error(e));
252                        break;
253                    }
254                    Err(mpsc::TryRecvError::Empty) => {
255                        break;
256                    }
257                    Err(mpsc::TryRecvError::Disconnected) => {
258                        if !matches!(inner.state, StreamState::Done | StreamState::Error(_)) {
259                            new_state = Some(StreamState::Error(
260                                "Stream disconnected unexpectedly".to_string(),
261                            ));
262                        }
263                        break;
264                    }
265                }
266            }
267
268            // Put receiver back (unless stream is done)
269            if new_state.is_none() || matches!(new_state, Some(StreamState::Streaming)) {
270                inner.receiver = Some(receiver);
271            }
272
273            if let Some(state) = new_state {
274                inner.state = state;
275            }
276        }
277
278        updated
279    }
280
281    /// Reset the stream to allow restarting.
282    pub fn reset(&self)
283    where
284        T: Default,
285    {
286        let mut inner = self.inner.borrow_mut();
287        inner.accumulated = T::default();
288        inner.state = StreamState::Idle;
289        inner.started = false;
290        inner.receiver = None;
291        // wake_flag is preserved across resets
292    }
293
294    /// Reset the stream with a specific initial value.
295    pub fn reset_with(&self, initial: T) {
296        let mut inner = self.inner.borrow_mut();
297        inner.accumulated = initial;
298        inner.state = StreamState::Idle;
299        inner.started = false;
300        inner.receiver = None;
301        // wake_flag is preserved across resets
302    }
303}
304
305impl<T: Clone + Default + 'static> Default for StreamHandle<T> {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311/// Convenience handle specifically for text streaming.
312/// Automatically accumulates string tokens.
313pub type TextStreamHandle = StreamHandle<String>;
314
315impl TextStreamHandle {
316    /// Poll and accumulate text by concatenation.
317    pub fn poll_text(&self) -> bool {
318        self.poll(|acc, item| acc.push_str(&item))
319    }
320}