tiy_core/stream/
event_stream.rs1use futures::Stream;
4use parking_lot::Mutex;
5use std::collections::VecDeque;
6use std::pin::Pin;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::task::{Context, Poll, Waker};
10
11struct EventStreamInner<T, R> {
13 events: Mutex<VecDeque<T>>,
15 done: AtomicBool,
17 result: Mutex<Option<R>>,
19 waker: Mutex<Option<Waker>>,
21}
22
23pub struct EventStream<T, R = T> {
25 inner: Arc<EventStreamInner<T, R>>,
26 is_complete: fn(&T) -> bool,
27 extract_result: fn(T) -> R,
28}
29
30impl<T, R> EventStream<T, R>
31where
32 T: Clone + Send + 'static,
33 R: Send + 'static,
34{
35 pub fn new(is_complete: fn(&T) -> bool, extract_result: fn(T) -> R) -> Self {
37 Self {
38 inner: Arc::new(EventStreamInner {
39 events: Mutex::new(VecDeque::new()),
40 done: AtomicBool::new(false),
41 result: Mutex::new(None),
42 waker: Mutex::new(None),
43 }),
44 is_complete,
45 extract_result,
46 }
47 }
48
49 fn wake(&self) {
51 if let Some(waker) = self.inner.waker.lock().take() {
52 waker.wake();
53 }
54 }
55
56 pub fn push(&self, event: T) {
58 if self.inner.done.load(Ordering::SeqCst) {
59 return;
61 }
62
63 let is_complete = (self.is_complete)(&event);
64 if is_complete {
65 self.inner.events.lock().push_back(event.clone());
68 let result = (self.extract_result)(event);
70 *self.inner.result.lock() = Some(result);
71 self.inner.done.store(true, Ordering::SeqCst);
72 } else {
73 self.inner.events.lock().push_back(event);
74 }
75 self.wake();
76 }
77
78 pub fn end(&self, result: Option<R>) {
80 if result.is_some() {
81 *self.inner.result.lock() = result;
82 }
83 self.inner.done.store(true, Ordering::SeqCst);
84 self.wake();
85 }
86
87 pub fn is_done(&self) -> bool {
89 self.inner.done.load(Ordering::SeqCst)
90 }
91
92 pub async fn result(&self) -> R {
94 loop {
95 {
96 let mut result = self.inner.result.lock();
97 if let Some(r) = result.take() {
98 return r;
99 }
100 }
101 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
102 }
103 }
104
105 pub async fn try_result(&self, timeout: std::time::Duration) -> Option<R> {
108 match tokio::time::timeout(timeout, self.result()).await {
109 Ok(r) => Some(r),
110 Err(_) => None,
111 }
112 }
113}
114
115impl<T, R> Stream for EventStream<T, R>
116where
117 T: Send + Unpin,
118{
119 type Item = T;
120
121 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122 let this = self.get_mut();
123
124 {
126 let mut queue = this.inner.events.lock();
127 if let Some(event) = queue.pop_front() {
128 return Poll::Ready(Some(event));
129 }
130 }
131
132 if this.inner.done.load(Ordering::SeqCst) {
134 return Poll::Ready(None);
135 }
136
137 *this.inner.waker.lock() = Some(cx.waker().clone());
139
140 {
142 let mut queue = this.inner.events.lock();
143 if let Some(event) = queue.pop_front() {
144 return Poll::Ready(Some(event));
145 }
146 }
147 if this.inner.done.load(Ordering::SeqCst) {
148 return Poll::Ready(None);
149 }
150
151 Poll::Pending
152 }
153}
154
155impl<T, R> Clone for EventStream<T, R> {
156 fn clone(&self) -> Self {
157 Self {
158 inner: Arc::clone(&self.inner),
159 is_complete: self.is_complete,
160 extract_result: self.extract_result,
161 }
162 }
163}
164
165pub type AssistantMessageEventStream =
167 EventStream<crate::types::AssistantMessageEvent, crate::types::AssistantMessage>;
168
169impl AssistantMessageEventStream {
170 pub fn new_assistant_stream() -> Self {
172 Self::new(
173 |event| event.is_complete(),
174 |event| match event {
175 crate::types::AssistantMessageEvent::Done { message, .. } => message.clone(),
176 crate::types::AssistantMessageEvent::Error { error, .. } => error.clone(),
177 _ => unreachable!("is_complete should only return true for Done/Error"),
178 },
179 )
180 }
181}