vortex_layout/
sequence.rs1use std::cmp::Ordering;
5use std::collections::BTreeSet;
6use std::fmt;
7use std::hash::{Hash, Hasher};
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll, Waker};
11
12use futures::{Stream, StreamExt};
13use parking_lot::Mutex;
14use pin_project_lite::pin_project;
15use vortex_array::stream::ArrayStream;
16use vortex_array::{Array, ArrayRef};
17use vortex_dtype::DType;
18use vortex_error::{VortexExpect, VortexResult};
19use vortex_utils::aliases::hash_map::HashMap;
20
21pub struct SequenceId {
40 id: Vec<usize>,
41 universe: Arc<Mutex<SequenceUniverse>>,
42}
43
44impl PartialEq for SequenceId {
45 fn eq(&self, other: &Self) -> bool {
46 self.id == other.id
47 }
48}
49
50impl Eq for SequenceId {}
51
52impl PartialOrd for SequenceId {
53 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58impl Ord for SequenceId {
59 fn cmp(&self, other: &Self) -> Ordering {
60 self.id.cmp(&other.id)
61 }
62}
63
64impl Hash for SequenceId {
65 fn hash<H: Hasher>(&self, state: &mut H) {
66 self.id.hash(state);
67 }
68}
69
70impl fmt::Debug for SequenceId {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 f.debug_struct("SequenceId").field("id", &self.id).finish()
73 }
74}
75
76impl SequenceId {
77 pub fn root() -> SequencePointer {
83 SequencePointer(SequenceId::new(vec![0], Default::default()))
84 }
85
86 pub fn descend(self) -> SequencePointer {
97 let mut id = self.id.clone();
98 id.push(0);
99 SequencePointer(SequenceId::new(id, self.universe.clone()))
100 }
101
102 pub async fn collapse(&mut self) {
128 WaitSequenceFuture(self).await;
129 }
130
131 fn new(id: Vec<usize>, universe: Arc<Mutex<SequenceUniverse>>) -> Self {
134 let res = Self { id, universe };
137 res.universe.lock().add(&res);
138 res
139 }
140}
141
142impl Drop for SequenceId {
143 fn drop(&mut self) {
144 let waker = self.universe.lock().remove(self);
145 if let Some(w) = waker {
146 w.wake();
147 }
148 }
149}
150
151#[derive(Debug)]
156pub struct SequencePointer(SequenceId);
157
158impl SequencePointer {
159 pub fn split(mut self) -> (SequencePointer, SequencePointer) {
161 (self.split_off(), self)
162 }
163
164 pub fn split_off(&mut self) -> SequencePointer {
168 self.advance().descend()
170 }
171
172 pub fn advance(&mut self) -> SequenceId {
179 let mut next_id = self.0.id.clone();
180
181 let last = next_id.last_mut();
183 let last = last.vortex_expect("must have at least one element");
184 *last += 1;
185 let next_sibling = SequenceId::new(next_id, self.0.universe.clone());
186 std::mem::replace(&mut self.0, next_sibling)
187 }
188
189 pub fn downgrade(self) -> SequenceId {
195 self.0
196 }
197}
198
199#[derive(Default)]
200struct SequenceUniverse {
201 active: BTreeSet<Vec<usize>>,
202 wakers: HashMap<Vec<usize>, Waker>,
203}
204
205impl SequenceUniverse {
206 fn add(&mut self, sequence_id: &SequenceId) {
207 self.active.insert(sequence_id.id.clone());
208 }
209
210 fn remove(&mut self, sequence_id: &SequenceId) -> Option<Waker> {
211 self.active.remove(&sequence_id.id);
212 let Some(first) = self.active.first() else {
213 assert!(self.wakers.is_empty(), "all wakers must have been removed");
215 return None;
216 };
217 self.wakers.remove(first)
218 }
219}
220
221struct WaitSequenceFuture<'a>(&'a mut SequenceId);
222
223impl Future for WaitSequenceFuture<'_> {
224 type Output = ();
225
226 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227 let mut guard = self.0.universe.lock();
228 let current_first = guard
229 .active
230 .first()
231 .cloned()
232 .vortex_expect("if we have a future, we must have at least one active sequence");
233 if self.0.id == current_first {
234 guard.wakers.remove(&self.0.id);
235 return Poll::Ready(());
236 }
237
238 guard.wakers.insert(self.0.id.clone(), cx.waker().clone());
239 Poll::Pending
240 }
241}
242
243impl Drop for WaitSequenceFuture<'_> {
245 fn drop(&mut self) {
246 self.0.universe.lock().wakers.remove(&self.0.id);
247 }
248}
249
250pub trait SequentialStream: Stream<Item = VortexResult<(SequenceId, ArrayRef)>> {
251 fn dtype(&self) -> &DType;
252}
253
254pub type SendableSequentialStream = Pin<Box<dyn SequentialStream + Send>>;
255
256impl SequentialStream for SendableSequentialStream {
257 fn dtype(&self) -> &DType {
258 (**self).dtype()
259 }
260}
261
262pub trait SequentialStreamExt: SequentialStream {
263 fn sendable(self) -> SendableSequentialStream
265 where
266 Self: Sized + Send + 'static,
267 {
268 Box::pin(self)
269 }
270}
271
272impl<S: SequentialStream> SequentialStreamExt for S {}
273
274pin_project! {
275 pub struct SequentialStreamAdapter<S> {
276 dtype: DType,
277 #[pin]
278 inner: S,
279 }
280}
281
282impl<S> SequentialStreamAdapter<S> {
283 pub fn new(dtype: DType, inner: S) -> Self {
284 Self { dtype, inner }
285 }
286}
287
288impl<S> SequentialStream for SequentialStreamAdapter<S>
289where
290 S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
291{
292 fn dtype(&self) -> &DType {
293 &self.dtype
294 }
295}
296
297impl<S> Stream for SequentialStreamAdapter<S>
298where
299 S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
300{
301 type Item = VortexResult<(SequenceId, ArrayRef)>;
302
303 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
304 let this = self.project();
305 let array = futures::ready!(this.inner.poll_next(cx));
306 if let Some(Ok((_, array))) = array.as_ref() {
307 assert_eq!(
308 array.dtype(),
309 this.dtype,
310 "Sequential stream of {} got chunk of {}.",
311 array.dtype(),
312 this.dtype
313 );
314 }
315
316 Poll::Ready(array)
317 }
318
319 fn size_hint(&self) -> (usize, Option<usize>) {
320 self.inner.size_hint()
321 }
322}
323
324pub trait SequentialArrayStreamExt: ArrayStream {
325 fn sequenced(self, mut pointer: SequencePointer) -> SendableSequentialStream
327 where
328 Self: Sized + Send + 'static,
329 {
330 Box::pin(SequentialStreamAdapter::new(
331 self.dtype().clone(),
332 StreamExt::map(self, move |item| {
333 item.map(|array| (pointer.advance(), array))
334 }),
335 ))
336 }
337}
338
339impl<S: ArrayStream> SequentialArrayStreamExt for S {}