Skip to main content

tdb_succinct/
util.rs

1use futures::io::Result;
2use futures::stream::{Peekable, Stream, StreamExt};
3use futures::task::{Context, Poll};
4use futures::TryStreamExt;
5use std::cmp::{Ordering, Reverse};
6use std::collections::BinaryHeap;
7use std::fmt;
8use std::marker::{PhantomData, Unpin};
9use std::pin::Pin;
10use tokio::io::{AsyncWrite, AsyncWriteExt};
11
12pub fn find_common_prefix(b1: &[u8], b2: &[u8]) -> usize {
13    let mut common = 0;
14    while common < b1.len() && common < b2.len() {
15        if b1[common] == b2[common] {
16            common += 1;
17        } else {
18            break;
19        }
20    }
21
22    common
23}
24
25pub fn find_common_prefix_ord(b1: &[u8], b2: &[u8]) -> (usize, Ordering) {
26    let common_prefix = find_common_prefix(b1, b2);
27
28    if common_prefix == b1.len() && b1.len() == b2.len() {
29        (common_prefix, Ordering::Equal)
30    } else if b1.len() == common_prefix {
31        (common_prefix, Ordering::Less)
32    } else if b2.len() == common_prefix {
33        (common_prefix, Ordering::Greater)
34    } else {
35        (common_prefix, b1[common_prefix].cmp(&b2[common_prefix]))
36    }
37}
38
39pub async fn write_nul_terminated_bytes<W: AsyncWrite + Unpin>(
40    w: &mut W,
41    bytes: &[u8],
42) -> Result<usize> {
43    w.write_all(&bytes).await?;
44    w.write_all(&[0]).await?;
45
46    let count = bytes.len() + 1;
47
48    Ok(count)
49}
50
51/// Write a buffer to `w`.
52pub async fn write_padding<W: AsyncWrite + Unpin>(
53    w: &mut W,
54    current_pos: usize,
55    width: u8,
56) -> Result<()> {
57    let required_padding = (width as usize - current_pos % width as usize) % width as usize;
58    w.write_all(&vec![0; required_padding]).await?;
59
60    Ok(())
61}
62
63/// Write a `u64` in big-endian order to `w`.
64pub async fn write_u64<W: AsyncWrite + Unpin>(w: &mut W, num: u64) -> Result<()> {
65    w.write_all(&num.to_be_bytes()).await?;
66
67    Ok(())
68}
69
70pub struct HeapSortedIterator<'a, T: Ord, I: 'a + Iterator<Item = T> + Unpin + Send> {
71    iters: Vec<I>,
72    heap: BinaryHeap<(Reverse<T>, usize)>,
73    _x: PhantomData<&'a ()>,
74}
75
76pub fn heap_sorted_iter<'a, T: Ord, I: 'a + Iterator<Item = T> + Unpin + Send>(
77    mut iters: Vec<I>,
78) -> HeapSortedIterator<'a, T, I> {
79    let mut heap = BinaryHeap::with_capacity(iters.len());
80
81    for (ix, i) in iters.iter_mut().enumerate() {
82        if let Some(item) = i.next() {
83            heap.push((Reverse(item), ix));
84        }
85    }
86
87    HeapSortedIterator {
88        iters,
89        heap,
90        _x: Default::default(),
91    }
92}
93
94impl<'a, T: Ord + Unpin, I: 'a + Iterator<Item = T> + Unpin + Send> Iterator
95    for HeapSortedIterator<'a, T, I>
96{
97    type Item = T;
98
99    fn next(&mut self) -> Option<Self::Item> {
100        if let Some(ix) = self.heap.peek().map(|(_, ix)| *ix) {
101            // we're about to pop an element from the heap. we'll need to read the next item in its corresponding stream to add to the heap afterwards.
102            let iter = &mut self.iters[ix];
103            match iter.next() {
104                Some(next_item) => {
105                    let item = self.heap.pop().unwrap();
106                    self.heap.push((Reverse(next_item), ix));
107
108                    Some(item.0 .0)
109                }
110                None => {
111                    let item = self.heap.pop().unwrap();
112                    Some(item.0 .0)
113                }
114            }
115        } else {
116            None
117        }
118    }
119}
120
121pub struct HeapSortedStream<
122    'a,
123    T: Ord,
124    E,
125    S: 'a + Stream<Item = std::result::Result<T, E>> + Unpin + Send,
126> {
127    streams: Vec<S>,
128    heap: BinaryHeap<(Reverse<T>, usize)>,
129    _x: PhantomData<&'a ()>,
130}
131
132pub async fn heap_sorted_stream<
133    'a,
134    T: Ord,
135    E,
136    S: 'a + Stream<Item = std::result::Result<T, E>> + Unpin + Send,
137>(
138    mut streams: Vec<S>,
139) -> std::result::Result<HeapSortedStream<'a, T, E, S>, E> {
140    let mut heap = BinaryHeap::with_capacity(streams.len());
141
142    for (ix, s) in streams.iter_mut().enumerate() {
143        if let Some(item) = s.try_next().await? {
144            heap.push((Reverse(item), ix));
145        }
146    }
147
148    Ok(HeapSortedStream {
149        streams,
150        heap,
151        _x: Default::default(),
152    })
153}
154
155impl<'a, T: Ord + Unpin, E, S: 'a + Stream<Item = std::result::Result<T, E>> + Unpin + Send> Stream
156    for HeapSortedStream<'a, T, E, S>
157{
158    type Item = std::result::Result<T, E>;
159
160    fn poll_next(
161        self: Pin<&mut Self>,
162        cx: &mut Context,
163    ) -> Poll<Option<std::result::Result<T, E>>> {
164        let self_ = self.get_mut();
165        if let Some(ix) = self_.heap.peek().map(|(_, ix)| *ix) {
166            // we're about to pop an element from the heap. we'll need to read the next item in its corresponding stream to add to the heap afterwards.
167            let stream = &mut self_.streams[ix];
168            match Pin::new(stream).poll_next(cx) {
169                Poll::Ready(Some(Ok(next_item))) => {
170                    let item = self_.heap.pop().unwrap();
171                    self_.heap.push((Reverse(next_item), ix));
172
173                    Poll::Ready(Some(Ok(item.0 .0)))
174                }
175                Poll::Ready(None) => {
176                    let item = self_.heap.pop().unwrap();
177                    Poll::Ready(Some(Ok(item.0 .0)))
178                }
179                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
180                Poll::Pending => Poll::Pending,
181            }
182        } else {
183            Poll::Ready(None)
184        }
185    }
186}
187
188pub struct SortedStream<
189    'a,
190    T,
191    S: 'a + Stream<Item = T> + Unpin + Send,
192    F: 'a + Fn(&[Option<&T>]) -> Option<usize>,
193> {
194    streams: Vec<Peekable<S>>,
195    pick_fn: F,
196    _x: PhantomData<&'a ()>,
197}
198
199impl<
200        'a,
201        T,
202        S: 'a + Stream<Item = T> + Unpin + Send,
203        F: 'a + Fn(&[Option<&T>]) -> Option<usize> + Unpin,
204    > Stream for SortedStream<'a, T, S, F>
205{
206    type Item = T;
207
208    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<T>> {
209        let mut v = Vec::with_capacity(self.streams.len());
210        let self_ = self.get_mut();
211        for s in self_.streams.iter_mut() {
212            match Pin::new(s).poll_peek(cx) {
213                Poll::Ready(val) => v.push(val),
214                Poll::Pending => return Poll::Pending,
215            }
216        }
217
218        let ix = (self_.pick_fn)(&v[..]);
219
220        match ix {
221            None => Poll::Ready(None),
222            Some(ix) => {
223                let next = Pin::new(&mut self_.streams[ix]).poll_next(cx);
224                match next {
225                    Poll::Ready(next) => Poll::Ready(next),
226                    _ => panic!("unexpected result in stream polling - reported ready earlier but not on later poll")
227                }
228            }
229        }
230    }
231}
232
233pub fn sorted_stream<
234    'a,
235    T: 'a,
236    S: 'a + Stream<Item = T> + Unpin + Send,
237    F: 'a + Fn(&[Option<&T>]) -> Option<usize> + Unpin,
238>(
239    streams: Vec<S>,
240    pick_fn: F,
241) -> SortedStream<'a, T, S, F> {
242    let peekable_streams = streams.into_iter().map(|s| s.peekable()).collect();
243    SortedStream {
244        streams: peekable_streams,
245        pick_fn,
246        _x: Default::default(),
247    }
248}
249
250pub fn compare_or_result<T: Ord, E: fmt::Debug>(
251    r1: &std::result::Result<T, E>,
252    r2: &std::result::Result<T, E>,
253) -> Ordering {
254    if r1.is_err() {
255        if r2.is_err() {
256            Ordering::Equal
257        } else {
258            Ordering::Less
259        }
260    } else if r2.is_err() {
261        Ordering::Greater
262    } else {
263        r1.as_ref().unwrap().cmp(r2.as_ref().unwrap())
264    }
265}
266
267struct SortedIterator<
268    T,
269    I: Iterator<Item = T> + Send,
270    F: 'static + Fn(&[Option<&T>]) -> Option<usize>,
271> {
272    iters: Vec<std::iter::Peekable<I>>,
273    pick_fn: F,
274}
275
276impl<'a, T, I: 'a + Iterator<Item = T> + Send, F: 'static + Fn(&[Option<&T>]) -> Option<usize>>
277    Iterator for SortedIterator<T, I, F>
278{
279    type Item = T;
280
281    fn next(&mut self) -> Option<T> {
282        let mut v = Vec::with_capacity(self.iters.len());
283        for s in self.iters.iter_mut() {
284            v.push(s.peek());
285        }
286
287        let ix = (self.pick_fn)(&v[..]);
288
289        match ix {
290            None => None,
291            Some(ix) => self.iters[ix].next(),
292        }
293    }
294}
295
296pub fn sorted_iterator<
297    'a,
298    T: 'a,
299    I: 'a + Iterator<Item = T> + Send,
300    F: 'static + Fn(&[Option<&T>]) -> Option<usize>,
301>(
302    iters: Vec<I>,
303    pick_fn: F,
304) -> impl Iterator<Item = T> + 'a {
305    let peekable_iters = iters
306        .into_iter()
307        .map(std::iter::Iterator::peekable)
308        .collect();
309    SortedIterator {
310        iters: peekable_iters,
311        pick_fn,
312    }
313}
314
315pub fn stream_iter_ok<T, E, I: IntoIterator<Item = T>>(
316    iter: I,
317) -> impl Stream<Item = std::result::Result<T, E>> {
318    futures::stream::iter(iter).map(Ok::<T, E>)
319}
320
321pub fn assert_poll_next<T, S: Stream<Item = T>>(stream: Pin<&mut S>, cx: &mut Context) -> T {
322    match stream.poll_next(cx) {
323        Poll::Ready(Some(item)) => item,
324        _ => panic!("stream was expected to have a result but did not."),
325    }
326}
327
328pub fn calculate_width(size: u64) -> u8 {
329    let mut msb = u64::BITS - size.leading_zeros();
330    // zero is a degenerate case, but needs to be represented with one bit.
331    if msb == 0 {
332        msb = 1
333    };
334    msb as u8
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use futures::executor::block_on;
341
342    #[test]
343    fn sort_some_streams() {
344        let v1 = vec![1, 3, 5, 8, 12];
345        let v2 = vec![7, 9, 15];
346        let v3 = vec![0, 1, 2, 3, 4];
347
348        let streams = vec![
349            futures::stream::iter(v1),
350            futures::stream::iter(v2),
351            futures::stream::iter(v3),
352        ];
353
354        let sorted = sorted_stream(streams, |results| {
355            results
356                .iter()
357                .enumerate()
358                .filter(|&(_, item)| item.is_some())
359                .min_by_key(|&(_, item)| item)
360                .map(|x| x.0)
361        });
362
363        let result: Vec<_> = block_on(sorted.collect());
364
365        assert_eq!(vec![0, 1, 1, 2, 3, 3, 4, 5, 7, 8, 9, 12, 15], result);
366    }
367}