vortex_layout/
sequence.rs1use std::cmp::Ordering;
5use std::collections::BTreeSet;
6use std::fmt;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::Context;
12use std::task::Poll;
13use std::task::Waker;
14
15use futures::Stream;
16use futures::StreamExt;
17use parking_lot::Mutex;
18use pin_project_lite::pin_project;
19use vortex_array::ArrayRef;
20use vortex_array::dtype::DType;
21use vortex_array::stream::ArrayStream;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_utils::aliases::hash_map::HashMap;
25
26pub struct SequenceId {
45 id: Vec<usize>,
46 universe: Arc<Mutex<SequenceUniverse>>,
47}
48
49impl PartialEq for SequenceId {
50 fn eq(&self, other: &Self) -> bool {
51 self.id == other.id
52 }
53}
54
55impl Eq for SequenceId {}
56
57impl PartialOrd for SequenceId {
58 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
59 Some(self.cmp(other))
60 }
61}
62
63impl Ord for SequenceId {
64 fn cmp(&self, other: &Self) -> Ordering {
65 self.id.cmp(&other.id)
66 }
67}
68
69impl Hash for SequenceId {
70 fn hash<H: Hasher>(&self, state: &mut H) {
71 self.id.hash(state);
72 }
73}
74
75impl fmt::Debug for SequenceId {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_struct("SequenceId").field("id", &self.id).finish()
78 }
79}
80
81impl SequenceId {
82 pub fn root() -> SequencePointer {
88 SequencePointer(SequenceId::new(vec![0], Default::default()))
89 }
90
91 pub fn descend(self) -> SequencePointer {
102 let mut id = self.id.clone();
103 id.push(0);
104 SequencePointer(SequenceId::new(id, Arc::clone(&self.universe)))
105 }
106
107 pub async fn collapse(&mut self) {
133 WaitSequenceFuture(self).await;
134 }
135
136 fn new(id: Vec<usize>, universe: Arc<Mutex<SequenceUniverse>>) -> Self {
139 let res = Self { id, universe };
142 res.universe.lock().add(&res);
143 res
144 }
145}
146
147impl Drop for SequenceId {
148 fn drop(&mut self) {
149 let waker = self.universe.lock().remove(self);
150 if let Some(w) = waker {
151 w.wake();
152 }
153 }
154}
155
156#[derive(Debug)]
161pub struct SequencePointer(SequenceId);
162
163impl SequencePointer {
164 pub fn split(mut self) -> (SequencePointer, SequencePointer) {
166 (self.split_off(), self)
167 }
168
169 pub fn split_off(&mut self) -> SequencePointer {
173 self.advance().descend()
175 }
176
177 pub fn advance(&mut self) -> SequenceId {
184 let mut next_id = self.0.id.clone();
185
186 let last = next_id.last_mut();
188 let last = last.vortex_expect("must have at least one element");
189 *last += 1;
190 let next_sibling = SequenceId::new(next_id, Arc::clone(&self.0.universe));
191 std::mem::replace(&mut self.0, next_sibling)
192 }
193
194 pub fn downgrade(self) -> SequenceId {
200 self.0
201 }
202}
203
204#[derive(Default)]
205struct SequenceUniverse {
206 active: BTreeSet<Vec<usize>>,
207 wakers: HashMap<Vec<usize>, Waker>,
208}
209
210impl SequenceUniverse {
211 fn add(&mut self, sequence_id: &SequenceId) {
212 self.active.insert(sequence_id.id.clone());
213 }
214
215 fn remove(&mut self, sequence_id: &SequenceId) -> Option<Waker> {
216 self.active.remove(&sequence_id.id);
217 let Some(first) = self.active.first() else {
218 assert!(self.wakers.is_empty(), "all wakers must have been removed");
220 return None;
221 };
222 self.wakers.remove(first)
223 }
224}
225
226struct WaitSequenceFuture<'a>(&'a mut SequenceId);
227
228impl Future for WaitSequenceFuture<'_> {
229 type Output = ();
230
231 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
232 let mut guard = self.0.universe.lock();
233 let current_first = guard
234 .active
235 .first()
236 .cloned()
237 .vortex_expect("if we have a future, we must have at least one active sequence");
238 if self.0.id == current_first {
239 guard.wakers.remove(&self.0.id);
240 return Poll::Ready(());
241 }
242
243 guard.wakers.insert(self.0.id.clone(), cx.waker().clone());
244 Poll::Pending
245 }
246}
247
248impl Drop for WaitSequenceFuture<'_> {
250 fn drop(&mut self) {
251 self.0.universe.lock().wakers.remove(&self.0.id);
252 }
253}
254
255pub trait SequentialStream: Stream<Item = VortexResult<(SequenceId, ArrayRef)>> {
256 fn dtype(&self) -> &DType;
257}
258
259pub type SendableSequentialStream = Pin<Box<dyn SequentialStream + Send>>;
260
261impl SequentialStream for SendableSequentialStream {
262 fn dtype(&self) -> &DType {
263 (**self).dtype()
264 }
265}
266
267pub trait SequentialStreamExt: SequentialStream {
268 fn sendable(self) -> SendableSequentialStream
270 where
271 Self: Sized + Send + 'static,
272 {
273 Box::pin(self)
274 }
275}
276
277impl<S: SequentialStream> SequentialStreamExt for S {}
278
279pin_project! {
280 pub struct SequentialStreamAdapter<S> {
281 dtype: DType,
282 #[pin]
283 inner: S,
284 }
285}
286
287impl<S> SequentialStreamAdapter<S> {
288 pub fn new(dtype: DType, inner: S) -> Self {
289 Self { dtype, inner }
290 }
291}
292
293impl<S> SequentialStream for SequentialStreamAdapter<S>
294where
295 S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
296{
297 fn dtype(&self) -> &DType {
298 &self.dtype
299 }
300}
301
302impl<S> Stream for SequentialStreamAdapter<S>
303where
304 S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
305{
306 type Item = VortexResult<(SequenceId, ArrayRef)>;
307
308 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309 let this = self.project();
310 let array = futures::ready!(this.inner.poll_next(cx));
311 if let Some(Ok((_, array))) = array.as_ref() {
312 assert_eq!(
313 array.dtype(),
314 this.dtype,
315 "Sequential stream of {} got chunk of {}.",
316 array.dtype(),
317 this.dtype
318 );
319 }
320
321 Poll::Ready(array)
322 }
323
324 fn size_hint(&self) -> (usize, Option<usize>) {
325 self.inner.size_hint()
326 }
327}
328
329pub trait SequentialArrayStreamExt: ArrayStream {
330 fn sequenced(self, mut pointer: SequencePointer) -> SendableSequentialStream
332 where
333 Self: Sized + Send + 'static,
334 {
335 Box::pin(SequentialStreamAdapter::new(
336 self.dtype().clone(),
337 StreamExt::map(self, move |item| {
338 item.map(|array| (pointer.advance(), array))
339 }),
340 ))
341 }
342}
343
344impl<S: ArrayStream> SequentialArrayStreamExt for S {}