vortex_layout/
sequence.rs1use std::cmp::Ordering;
5use std::collections::BTreeSet;
6use std::fmt;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::Context;
12use std::task::Poll;
13use std::task::Waker;
14
15use futures::Stream;
16use futures::StreamExt;
17use parking_lot::Mutex;
18use pin_project_lite::pin_project;
19use vortex_array::Array;
20use vortex_array::ArrayRef;
21use vortex_array::stream::ArrayStream;
22use vortex_dtype::DType;
23use vortex_error::VortexExpect;
24use vortex_error::VortexResult;
25use vortex_utils::aliases::hash_map::HashMap;
26
27pub struct SequenceId {
46 id: Vec<usize>,
47 universe: Arc<Mutex<SequenceUniverse>>,
48}
49
50impl PartialEq for SequenceId {
51 fn eq(&self, other: &Self) -> bool {
52 self.id == other.id
53 }
54}
55
56impl Eq for SequenceId {}
57
58impl PartialOrd for SequenceId {
59 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
60 Some(self.cmp(other))
61 }
62}
63
64impl Ord for SequenceId {
65 fn cmp(&self, other: &Self) -> Ordering {
66 self.id.cmp(&other.id)
67 }
68}
69
70impl Hash for SequenceId {
71 fn hash<H: Hasher>(&self, state: &mut H) {
72 self.id.hash(state);
73 }
74}
75
76impl fmt::Debug for SequenceId {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 f.debug_struct("SequenceId").field("id", &self.id).finish()
79 }
80}
81
82impl SequenceId {
83 pub fn root() -> SequencePointer {
89 SequencePointer(SequenceId::new(vec![0], Default::default()))
90 }
91
92 pub fn descend(self) -> SequencePointer {
103 let mut id = self.id.clone();
104 id.push(0);
105 SequencePointer(SequenceId::new(id, self.universe.clone()))
106 }
107
108 pub async fn collapse(&mut self) {
134 WaitSequenceFuture(self).await;
135 }
136
137 fn new(id: Vec<usize>, universe: Arc<Mutex<SequenceUniverse>>) -> Self {
140 let res = Self { id, universe };
143 res.universe.lock().add(&res);
144 res
145 }
146}
147
148impl Drop for SequenceId {
149 fn drop(&mut self) {
150 let waker = self.universe.lock().remove(self);
151 if let Some(w) = waker {
152 w.wake();
153 }
154 }
155}
156
157#[derive(Debug)]
162pub struct SequencePointer(SequenceId);
163
164impl SequencePointer {
165 pub fn split(mut self) -> (SequencePointer, SequencePointer) {
167 (self.split_off(), self)
168 }
169
170 pub fn split_off(&mut self) -> SequencePointer {
174 self.advance().descend()
176 }
177
178 pub fn advance(&mut self) -> SequenceId {
185 let mut next_id = self.0.id.clone();
186
187 let last = next_id.last_mut();
189 let last = last.vortex_expect("must have at least one element");
190 *last += 1;
191 let next_sibling = SequenceId::new(next_id, self.0.universe.clone());
192 std::mem::replace(&mut self.0, next_sibling)
193 }
194
195 pub fn downgrade(self) -> SequenceId {
201 self.0
202 }
203}
204
205#[derive(Default)]
206struct SequenceUniverse {
207 active: BTreeSet<Vec<usize>>,
208 wakers: HashMap<Vec<usize>, Waker>,
209}
210
211impl SequenceUniverse {
212 fn add(&mut self, sequence_id: &SequenceId) {
213 self.active.insert(sequence_id.id.clone());
214 }
215
216 fn remove(&mut self, sequence_id: &SequenceId) -> Option<Waker> {
217 self.active.remove(&sequence_id.id);
218 let Some(first) = self.active.first() else {
219 assert!(self.wakers.is_empty(), "all wakers must have been removed");
221 return None;
222 };
223 self.wakers.remove(first)
224 }
225}
226
227struct WaitSequenceFuture<'a>(&'a mut SequenceId);
228
229impl Future for WaitSequenceFuture<'_> {
230 type Output = ();
231
232 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233 let mut guard = self.0.universe.lock();
234 let current_first = guard
235 .active
236 .first()
237 .cloned()
238 .vortex_expect("if we have a future, we must have at least one active sequence");
239 if self.0.id == current_first {
240 guard.wakers.remove(&self.0.id);
241 return Poll::Ready(());
242 }
243
244 guard.wakers.insert(self.0.id.clone(), cx.waker().clone());
245 Poll::Pending
246 }
247}
248
249impl Drop for WaitSequenceFuture<'_> {
251 fn drop(&mut self) {
252 self.0.universe.lock().wakers.remove(&self.0.id);
253 }
254}
255
256pub trait SequentialStream: Stream<Item = VortexResult<(SequenceId, ArrayRef)>> {
257 fn dtype(&self) -> &DType;
258}
259
260pub type SendableSequentialStream = Pin<Box<dyn SequentialStream + Send>>;
261
262impl SequentialStream for SendableSequentialStream {
263 fn dtype(&self) -> &DType {
264 (**self).dtype()
265 }
266}
267
268pub trait SequentialStreamExt: SequentialStream {
269 fn sendable(self) -> SendableSequentialStream
271 where
272 Self: Sized + Send + 'static,
273 {
274 Box::pin(self)
275 }
276}
277
278impl<S: SequentialStream> SequentialStreamExt for S {}
279
280pin_project! {
281 pub struct SequentialStreamAdapter<S> {
282 dtype: DType,
283 #[pin]
284 inner: S,
285 }
286}
287
288impl<S> SequentialStreamAdapter<S> {
289 pub fn new(dtype: DType, inner: S) -> Self {
290 Self { dtype, inner }
291 }
292}
293
294impl<S> SequentialStream for SequentialStreamAdapter<S>
295where
296 S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
297{
298 fn dtype(&self) -> &DType {
299 &self.dtype
300 }
301}
302
303impl<S> Stream for SequentialStreamAdapter<S>
304where
305 S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
306{
307 type Item = VortexResult<(SequenceId, ArrayRef)>;
308
309 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
310 let this = self.project();
311 let array = futures::ready!(this.inner.poll_next(cx));
312 if let Some(Ok((_, array))) = array.as_ref() {
313 assert_eq!(
314 array.dtype(),
315 this.dtype,
316 "Sequential stream of {} got chunk of {}.",
317 array.dtype(),
318 this.dtype
319 );
320 }
321
322 Poll::Ready(array)
323 }
324
325 fn size_hint(&self) -> (usize, Option<usize>) {
326 self.inner.size_hint()
327 }
328}
329
330pub trait SequentialArrayStreamExt: ArrayStream {
331 fn sequenced(self, mut pointer: SequencePointer) -> SendableSequentialStream
333 where
334 Self: Sized + Send + 'static,
335 {
336 Box::pin(SequentialStreamAdapter::new(
337 self.dtype().clone(),
338 StreamExt::map(self, move |item| {
339 item.map(|array| (pointer.advance(), array))
340 }),
341 ))
342 }
343}
344
345impl<S: ArrayStream> SequentialArrayStreamExt for S {}