Skip to main content

dag/set/
id_lazy.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::fmt;
10use std::sync::Arc;
11use std::sync::Mutex;
12use std::sync::MutexGuard;
13
14use indexmap::IndexSet;
15use nonblocking::non_blocking_result;
16
17use super::hints::Flags;
18use super::id_static::IdStaticSet;
19use super::AsyncSetQuery;
20use super::BoxVertexStream;
21use super::Hints;
22use crate::ops::DagAlgorithm;
23use crate::ops::IdConvert;
24use crate::protocol::disable_remote_protocol;
25use crate::Group;
26use crate::Id;
27use crate::IdSet;
28use crate::Result;
29use crate::Vertex;
30
31/// A set backed by a lazy iterator of Ids.
32pub struct IdLazySet {
33    // Mutex: iter() does not take &mut self.
34    // Arc: iter() result does not have a lifetime on this struct.
35    inner: Arc<Mutex<Inner>>,
36    pub map: Arc<dyn IdConvert + Send + Sync>,
37    pub(crate) dag: Arc<dyn DagAlgorithm + Send + Sync>,
38    hints: Hints,
39}
40
41struct Inner {
42    iter: Box<dyn Iterator<Item = Result<Id>> + Send + Sync>,
43    visited: IndexSet<Id>,
44    state: State,
45}
46
47impl Inner {
48    fn load_more(&mut self, n: usize, mut out: Option<&mut Vec<Id>>) -> Result<()> {
49        if matches!(self.state, State::Complete | State::Error) {
50            return Ok(());
51        }
52        for _ in 0..n {
53            match self.iter.next() {
54                Some(Ok(id)) => {
55                    if let Some(ref mut out) = out {
56                        out.push(id);
57                    }
58                    self.visited.insert(id);
59                }
60                None => {
61                    self.state = State::Complete;
62                    break;
63                }
64                Some(Err(err)) => {
65                    self.state = State::Error;
66                    return Err(err);
67                }
68            }
69        }
70        Ok(())
71    }
72}
73
74#[derive(Copy, Clone, Debug, PartialEq)]
75enum State {
76    Incomplete,
77    Complete,
78    Error,
79}
80
81pub struct Iter {
82    inner: Arc<Mutex<Inner>>,
83    index: usize,
84    map: Arc<dyn IdConvert + Send + Sync>,
85}
86
87impl Iter {
88    fn into_box_stream(self) -> BoxVertexStream {
89        Box::pin(futures::stream::unfold(self, |this| this.next()))
90    }
91
92    async fn next(mut self) -> Option<(Result<Vertex>, Self)> {
93        loop {
94            let state = {
95                let inner = self.inner.lock().unwrap();
96                inner.state
97            };
98            match state {
99                State::Error => break None,
100                State::Complete if self.inner.lock().unwrap().visited.len() <= self.index => {
101                    break None;
102                }
103                State::Complete | State::Incomplete => {
104                    let opt_id = {
105                        let inner = self.inner.lock().unwrap();
106                        inner.visited.get_index(self.index).cloned()
107                    };
108                    match opt_id {
109                        Some(id) => {
110                            self.index += 1;
111                            match self.map.vertex_name(id).await {
112                                Err(err) => {
113                                    self.inner.lock().unwrap().state = State::Error;
114                                    return Some((Err(err), self));
115                                }
116                                Ok(vertex) => {
117                                    break Some((Ok(vertex), self));
118                                }
119                            }
120                        }
121                        None => {
122                            // Data not available. Load more.
123                            let more = {
124                                let mut inner = self.inner.lock().unwrap();
125                                inner.load_more(1, None)
126                            };
127                            if let Err(err) = more {
128                                return Some((Err(err), self));
129                            }
130                        }
131                    }
132                }
133            }
134        }
135    }
136}
137
138struct DebugId {
139    id: Id,
140    name: Option<Vertex>,
141}
142
143impl fmt::Debug for DebugId {
144    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
145        if let Some(name) = &self.name {
146            fmt::Debug::fmt(&name, f)?;
147            write!(f, "+{:?}", self.id)?;
148        } else {
149            write!(f, "{:?}", self.id)?;
150        }
151        Ok(())
152    }
153}
154
155impl fmt::Debug for IdLazySet {
156    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157        f.write_str("<lazy ")?;
158        let inner = self.inner.lock().unwrap();
159        let limit = f.width().unwrap_or(3);
160        f.debug_list()
161            .entries(inner.visited.iter().take(limit).map(|&id| DebugId {
162                id,
163                name: disable_remote_protocol(|| {
164                    non_blocking_result(self.map.vertex_name(id)).ok()
165                }),
166            }))
167            .finish()?;
168        let remaining = inner.visited.len().max(limit) - limit;
169        match (remaining, inner.state) {
170            (0, State::Incomplete) => f.write_str(" + ? more")?,
171            (n, State::Incomplete) => write!(f, "+ {} + ? more", n)?,
172            (0, _) => {}
173            (n, _) => write!(f, " + {} more", n)?,
174        }
175        f.write_str(">")?;
176        Ok(())
177    }
178}
179
180impl IdLazySet {
181    pub fn from_iter_idmap_dag<I>(
182        names: I,
183        map: Arc<dyn IdConvert + Send + Sync>,
184        dag: Arc<dyn DagAlgorithm + Send + Sync>,
185    ) -> Self
186    where
187        I: IntoIterator<Item = Result<Id>> + 'static,
188        <I as IntoIterator>::IntoIter: Send + Sync,
189    {
190        let iter = names.into_iter();
191        let inner = Inner {
192            iter: Box::new(iter),
193            visited: IndexSet::new(),
194            state: State::Incomplete,
195        };
196        let hints = Hints::new_with_idmap_dag(map.clone(), dag.clone());
197        Self {
198            inner: Arc::new(Mutex::new(inner)),
199            map,
200            dag,
201            hints,
202        }
203    }
204
205    /// Convert to an IdStaticSet.
206    pub fn to_static(&self) -> Result<IdStaticSet> {
207        let inner = self.load_all()?;
208        let mut spans = IdSet::empty();
209        for &id in inner.visited.iter() {
210            spans.push(id);
211        }
212        Ok(IdStaticSet::from_id_set_idmap_dag(
213            spans,
214            self.map.clone(),
215            self.dag.clone(),
216        ))
217    }
218
219    fn load_all(&self) -> Result<MutexGuard<Inner>> {
220        let mut inner = self.inner.lock().unwrap();
221        inner.load_more(usize::max_value(), None)?;
222        Ok(inner)
223    }
224}
225
226#[async_trait::async_trait]
227impl AsyncSetQuery for IdLazySet {
228    async fn iter(&self) -> Result<BoxVertexStream> {
229        let inner = self.inner.clone();
230        let map = self.map.clone();
231        let iter = Iter {
232            inner,
233            index: 0,
234            map,
235        };
236        Ok(iter.into_box_stream())
237    }
238
239    async fn iter_rev(&self) -> Result<BoxVertexStream> {
240        let inner = self.load_all()?;
241        struct State {
242            map: Arc<dyn IdConvert + Send + Sync>,
243            iter: Box<dyn Iterator<Item = Id> + Send>,
244        }
245        let state = State {
246            map: self.map.clone(),
247            iter: Box::new(inner.visited.clone().into_iter().rev()),
248        };
249        async fn next(mut state: State) -> Option<(Result<Vertex>, State)> {
250            match state.iter.next() {
251                None => None,
252                Some(id) => {
253                    let result = state.map.vertex_name(id).await;
254                    Some((result, state))
255                }
256            }
257        }
258
259        let stream = futures::stream::unfold(state, next);
260        Ok(Box::pin(stream))
261    }
262
263    async fn count_slow(&self) -> Result<u64> {
264        let inner = self.load_all()?;
265        Ok(inner.visited.len().try_into()?)
266    }
267
268    async fn last(&self) -> Result<Option<Vertex>> {
269        let opt_id = {
270            let inner = self.load_all()?;
271            inner.visited.iter().rev().nth(0).cloned()
272        };
273        match opt_id {
274            Some(id) => Ok(Some(self.map.vertex_name(id).await?)),
275            None => Ok(None),
276        }
277    }
278
279    async fn contains(&self, name: &Vertex) -> Result<bool> {
280        let id = match self.map.vertex_id_with_max_group(name, Group::MAX).await? {
281            None => {
282                return Ok(false);
283            }
284            Some(id) => id,
285        };
286        let mut inner = self.inner.lock().unwrap();
287        if inner.visited.contains(&id) {
288            return Ok(true);
289        } else {
290            let mut loaded = Vec::new();
291            loop {
292                // Fast paths.
293                if let Some(&last_id) = inner.visited.iter().next_back() {
294                    let hints = self.hints();
295                    if hints.contains(Flags::ID_DESC) {
296                        if last_id < id {
297                            return Ok(false);
298                        }
299                    } else if hints.contains(Flags::ID_ASC) {
300                        if last_id > id {
301                            return Ok(false);
302                        }
303                    }
304                }
305                loaded.clear();
306                inner.load_more(1, Some(&mut loaded))?;
307                debug_assert!(loaded.len() <= 1);
308                if loaded.is_empty() {
309                    break;
310                }
311                if loaded.first() == Some(&id) {
312                    return Ok(true);
313                }
314            }
315        }
316        Ok(false)
317    }
318
319    async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
320        let id = match self.map.vertex_id_with_max_group(name, Group::MAX).await? {
321            None => {
322                return Ok(Some(false));
323            }
324            Some(id) => id,
325        };
326        let inner = self.inner.lock().unwrap();
327        if inner.visited.contains(&id) {
328            return Ok(Some(true));
329        } else if inner.state != State::Incomplete {
330            return Ok(Some(false));
331        }
332        Ok(None)
333    }
334
335    fn as_any(&self) -> &dyn Any {
336        self
337    }
338
339    fn hints(&self) -> &Hints {
340        &self.hints
341    }
342
343    fn id_convert(&self) -> Option<&dyn IdConvert> {
344        Some(self.map.as_ref() as &dyn IdConvert)
345    }
346}
347
348#[cfg(test)]
349pub(crate) mod test_utils {
350    use std::sync::atomic::AtomicU64;
351    use std::sync::atomic::Ordering::AcqRel;
352
353    use super::*;
354    use crate::ops::PrefixLookup;
355    use crate::tests::dummy_dag::DummyDag;
356    use crate::VerLink;
357
358    static STR_ID_MAP_ID: AtomicU64 = AtomicU64::new(0);
359
360    pub(crate) struct StrIdMap {
361        id: String,
362        version: VerLink,
363    }
364
365    impl StrIdMap {
366        pub(crate) fn new() -> Self {
367            Self {
368                id: format!("str:{}", STR_ID_MAP_ID.fetch_add(1, AcqRel)),
369                version: VerLink::new(),
370            }
371        }
372    }
373
374    #[async_trait::async_trait]
375    impl PrefixLookup for StrIdMap {
376        async fn vertexes_by_hex_prefix(&self, _: &[u8], _: usize) -> Result<Vec<Vertex>> {
377            // Dummy implementation.
378            Ok(Vec::new())
379        }
380    }
381    #[async_trait::async_trait]
382    impl IdConvert for StrIdMap {
383        async fn vertex_id(&self, name: Vertex) -> Result<Id> {
384            let slice: [u8; 8] = name.as_ref().try_into().unwrap();
385            let id = u64::from_le(unsafe { std::mem::transmute(slice) });
386            Ok(Id(id))
387        }
388        async fn vertex_id_with_max_group(
389            &self,
390            name: &Vertex,
391            _max_group: Group,
392        ) -> Result<Option<Id>> {
393            if name.as_ref().len() == 8 {
394                let id = self.vertex_id(name.clone()).await?;
395                Ok(Some(id))
396            } else {
397                Ok(None)
398            }
399        }
400        async fn vertex_name(&self, id: Id) -> Result<Vertex> {
401            let buf: [u8; 8] = unsafe { std::mem::transmute(id.0.to_le()) };
402            Ok(Vertex::copy_from(&buf))
403        }
404        async fn contains_vertex_name(&self, name: &Vertex) -> Result<bool> {
405            Ok(name.as_ref().len() == 8)
406        }
407        fn map_id(&self) -> &str {
408            &self.id
409        }
410        fn map_version(&self) -> &VerLink {
411            &self.version
412        }
413        async fn contains_vertex_id_locally(&self, ids: &[Id]) -> Result<Vec<bool>> {
414            Ok(ids.iter().map(|_| true).collect())
415        }
416        async fn contains_vertex_name_locally(&self, names: &[Vertex]) -> Result<Vec<bool>> {
417            Ok(names.iter().map(|name| name.as_ref().len() == 8).collect())
418        }
419    }
420
421    pub fn lazy_set(a: &[u64]) -> IdLazySet {
422        let ids: Vec<Id> = a.iter().map(|i| Id(*i as _)).collect();
423        IdLazySet::from_iter_idmap_dag(
424            ids.into_iter().map(Ok),
425            Arc::new(StrIdMap::new()),
426            Arc::new(DummyDag::new()),
427        )
428    }
429
430    pub fn lazy_set_inherit(a: &[u64], set: &IdLazySet) -> IdLazySet {
431        let ids: Vec<Id> = a.iter().map(|i| Id(*i as _)).collect();
432        IdLazySet::from_iter_idmap_dag(ids.into_iter().map(Ok), set.map.clone(), set.dag.clone())
433    }
434}
435
436#[cfg(all(test, feature = "indexedlog-backend"))]
437#[allow(clippy::redundant_clone)]
438pub(crate) mod tests {
439    use std::collections::HashSet;
440
441    use nonblocking::non_blocking_result as r;
442
443    use super::super::tests::*;
444    use super::super::Set;
445    use super::test_utils::*;
446    use super::*;
447
448    #[test]
449    fn test_id_lazy_basic() -> Result<()> {
450        let set = lazy_set(&[0x11, 0x33, 0x22, 0x77, 0x55]);
451        check_invariants(&set)?;
452        assert_eq!(shorten_iter(ni(set.iter())), ["11", "33", "22", "77", "55"]);
453        assert_eq!(
454            shorten_iter(ni(set.iter_rev())),
455            ["55", "77", "22", "33", "11"]
456        );
457        assert!(!nb(set.is_empty())?);
458        assert_eq!(nb(set.count_slow())?, 5);
459        assert_eq!(shorten_name(nb(set.first())?.unwrap()), "11");
460        assert_eq!(shorten_name(nb(set.last())?.unwrap()), "55");
461        Ok(())
462    }
463
464    #[test]
465    fn test_hints_fast_paths() -> Result<()> {
466        let set = lazy_set(&[0x20, 0x50, 0x30, 0x70]);
467
468        // Incorrect hints, but useful for testing.
469        set.hints().add_flags(Flags::ID_ASC);
470
471        let v = |i: u64| -> Vertex { r(StrIdMap::new().vertex_name(Id(i))).unwrap() };
472        assert!(nb(set.contains(&v(0x20)))?);
473        assert!(nb(set.contains(&v(0x50)))?);
474        assert!(!nb(set.contains(&v(0x30)))?);
475
476        set.hints().add_flags(Flags::ID_DESC);
477        assert!(nb(set.contains(&v(0x30)))?);
478        assert!(!nb(set.contains(&v(0x70)))?);
479
480        Ok(())
481    }
482
483    #[test]
484    fn test_debug() {
485        let set = lazy_set(&[0]);
486        assert_eq!(dbg(&set), "<lazy [] + ? more>");
487        nb(set.count_slow()).unwrap();
488        assert_eq!(dbg(&set), "<lazy [0000000000000000+0]>");
489
490        let set = lazy_set(&[1, 3, 2]);
491        assert_eq!(dbg(&set), "<lazy [] + ? more>");
492        let mut iter = ni(set.iter()).unwrap();
493        iter.next();
494        assert_eq!(dbg(&set), "<lazy [0100000000000000+1] + ? more>");
495        iter.next();
496        assert_eq!(
497            dbg(&set),
498            "<lazy [0100000000000000+1, 0300000000000000+3] + ? more>"
499        );
500        iter.next();
501        assert_eq!(format!("{:2.2?}", &set), "<lazy [01+1, 03+3]+ 1 + ? more>");
502        iter.next();
503        assert_eq!(format!("{:1.3?}", &set), "<lazy [010+1] + 2 more>");
504    }
505
506    #[test]
507    fn test_flatten() {
508        let set1 = lazy_set(&[3, 2, 4]);
509        let set2 = lazy_set_inherit(&[3, 7, 6], &set1);
510        let set1 = Set::from_query(set1);
511        let set2 = Set::from_query(set2);
512
513        // Show flatten by names, and flatten by ids.
514        // The first should be <static ...>, the second should be <spans ...>.
515        let show = |set: Set| {
516            [
517                format!("{:5.2?}", r(set.flatten_names()).unwrap()),
518                format!("{:5.2?}", r(set.flatten()).unwrap()),
519            ]
520        };
521
522        assert_eq!(
523            show(set1.clone() | set2.clone()),
524            [
525                "<static [03, 02, 04, 07, 06]>",
526                "<spans [06:07+6:7, 02:04+2:4]>"
527            ]
528        );
529        assert_eq!(
530            show(set1.clone() & set2.clone()),
531            ["<static [03]>", "<spans [03+3]>"]
532        );
533        assert_eq!(
534            show(set1.clone() - set2.clone()),
535            ["<static [02, 04]>", "<spans [04+4, 02+2]>"]
536        );
537    }
538
539    quickcheck::quickcheck! {
540        fn test_id_lazy_quickcheck(a: Vec<u64>) -> bool {
541            let set = lazy_set(&a);
542            check_invariants(&set).unwrap();
543
544            let count = nb(set.count_slow()).unwrap() as usize;
545            assert!(count <= a.len());
546
547            let set2: HashSet<_> = a.iter().cloned().collect();
548            assert_eq!(count, set2.len());
549
550            true
551        }
552    }
553}