Skip to main content

tiy_core/stream/
event_stream.rs

1//! Simplified event stream implementation using async-safe primitives.
2
3use futures::Stream;
4use parking_lot::Mutex;
5use std::collections::VecDeque;
6use std::pin::Pin;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::task::{Context, Poll, Waker};
10
11/// Shared inner state for the event stream.
12struct EventStreamInner<T, R> {
13    /// Event queue.
14    events: Mutex<VecDeque<T>>,
15    /// Whether the stream is done.
16    done: AtomicBool,
17    /// The final result.
18    result: Mutex<Option<R>>,
19    /// Waker to notify when new events are available.
20    waker: Mutex<Option<Waker>>,
21}
22
23/// A generic event stream that supports async iteration and final result retrieval.
24pub struct EventStream<T, R = T> {
25    inner: Arc<EventStreamInner<T, R>>,
26    is_complete: fn(&T) -> bool,
27    extract_result: fn(T) -> R,
28}
29
30impl<T, R> EventStream<T, R>
31where
32    T: Clone + Send + 'static,
33    R: Send + 'static,
34{
35    /// Create a new event stream.
36    pub fn new(is_complete: fn(&T) -> bool, extract_result: fn(T) -> R) -> Self {
37        Self {
38            inner: Arc::new(EventStreamInner {
39                events: Mutex::new(VecDeque::new()),
40                done: AtomicBool::new(false),
41                result: Mutex::new(None),
42                waker: Mutex::new(None),
43            }),
44            is_complete,
45            extract_result,
46        }
47    }
48
49    /// Wake the stream consumer.
50    fn wake(&self) {
51        if let Some(waker) = self.inner.waker.lock().take() {
52            waker.wake();
53        }
54    }
55
56    /// Push an event to the stream.
57    pub fn push(&self, event: T) {
58        if self.inner.done.load(Ordering::SeqCst) {
59            // Stream is already done, ignore further events
60            return;
61        }
62
63        let is_complete = (self.is_complete)(&event);
64        if is_complete {
65            // Push the completion event to the queue so Stream
66            // consumers can observe Done/Error before the stream ends.
67            self.inner.events.lock().push_back(event.clone());
68            // Extract the result and store it for result() callers.
69            let result = (self.extract_result)(event);
70            *self.inner.result.lock() = Some(result);
71            self.inner.done.store(true, Ordering::SeqCst);
72        } else {
73            self.inner.events.lock().push_back(event);
74        }
75        self.wake();
76    }
77
78    /// End the stream with an optional result.
79    pub fn end(&self, result: Option<R>) {
80        if result.is_some() {
81            *self.inner.result.lock() = result;
82        }
83        self.inner.done.store(true, Ordering::SeqCst);
84        self.wake();
85    }
86
87    /// Check if the stream has ended.
88    pub fn is_done(&self) -> bool {
89        self.inner.done.load(Ordering::SeqCst)
90    }
91
92    /// Get the final result (async, non-blocking wait).
93    pub async fn result(&self) -> R {
94        loop {
95            {
96                let mut result = self.inner.result.lock();
97                if let Some(r) = result.take() {
98                    return r;
99                }
100            }
101            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
102        }
103    }
104
105    /// Get the final result with a timeout.
106    /// Returns `Some(result)` on success, `None` if the timeout expires.
107    pub async fn try_result(&self, timeout: std::time::Duration) -> Option<R> {
108        match tokio::time::timeout(timeout, self.result()).await {
109            Ok(r) => Some(r),
110            Err(_) => None,
111        }
112    }
113}
114
115impl<T, R> Stream for EventStream<T, R>
116where
117    T: Send + Unpin,
118{
119    type Item = T;
120
121    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122        let this = self.get_mut();
123
124        // Try to get an event from the queue
125        {
126            let mut queue = this.inner.events.lock();
127            if let Some(event) = queue.pop_front() {
128                return Poll::Ready(Some(event));
129            }
130        }
131
132        // Queue is empty — check if we're done
133        if this.inner.done.load(Ordering::SeqCst) {
134            return Poll::Ready(None);
135        }
136
137        // Not done, no events: register waker and return Pending
138        *this.inner.waker.lock() = Some(cx.waker().clone());
139
140        // Double-check after registering waker to avoid race condition
141        {
142            let mut queue = this.inner.events.lock();
143            if let Some(event) = queue.pop_front() {
144                return Poll::Ready(Some(event));
145            }
146        }
147        if this.inner.done.load(Ordering::SeqCst) {
148            return Poll::Ready(None);
149        }
150
151        Poll::Pending
152    }
153}
154
155impl<T, R> Clone for EventStream<T, R> {
156    fn clone(&self) -> Self {
157        Self {
158            inner: Arc::clone(&self.inner),
159            is_complete: self.is_complete,
160            extract_result: self.extract_result,
161        }
162    }
163}
164
165/// Assistant message event stream type alias.
166pub type AssistantMessageEventStream =
167    EventStream<crate::types::AssistantMessageEvent, crate::types::AssistantMessage>;
168
169impl AssistantMessageEventStream {
170    /// Create a new assistant message event stream.
171    pub fn new_assistant_stream() -> Self {
172        Self::new(
173            |event| event.is_complete(),
174            |event| match event {
175                crate::types::AssistantMessageEvent::Done { message, .. } => message.clone(),
176                crate::types::AssistantMessageEvent::Error { error, .. } => error.clone(),
177                _ => unreachable!("is_complete should only return true for Done/Error"),
178            },
179        )
180    }
181}