try_insert_ext/
set.rs

1use std::{
2    borrow::Borrow,
3    collections::HashSet,
4    hash::{BuildHasher, Hash},
5};
6
7/// Extends sets with `get_or_try_insert_with`.
8pub trait SetInsertExt<T> {
9    /// If the set does not contain the value, computes the value from `f`. If
10    /// `f` returns `Ok`, inserts the value. If `f` returns `Err`, returns the
11    /// error. If there is no error, returns a reference to the contained value.
12    ///
13    /// # Examples
14    ///
15    /// ```
16    /// use std::collections::HashSet;
17    ///
18    /// use try_insert_ext::SetInsertExt;
19    ///
20    /// let mut set: HashSet<String> = ["cat", "dog", "horse"]
21    ///     .iter()
22    ///     .map(|&pet| pet.to_owned())
23    ///     .collect();
24    ///
25    /// assert_eq!(set.len(), 3);
26    /// let value = set.get_or_try_insert_with("error", |_| Err(()));
27    /// assert!(value.is_err());
28    /// for &pet in &["cat", "dog", "fish"] {
29    ///     let value = set.get_or_try_insert_with::<_, _, ()>(
30    ///         pet,
31    ///         |pet| Ok(pet.to_owned()),
32    ///     );
33    ///     assert_eq!(value, Ok(&pet.to_owned()));
34    /// }
35    ///
36    /// assert_eq!(set.len(), 4);
37    /// ```
38    fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
39    where
40        T: Borrow<Q>,
41        Q: Hash + Eq,
42        F: FnOnce(&Q) -> Result<T, E>;
43}
44
45impl<T, S> SetInsertExt<T> for HashSet<T, S>
46where
47    T: Eq + Hash,
48    S: BuildHasher,
49{
50    fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
51    where
52        T: Borrow<Q>,
53        Q: Hash + Eq,
54        F: FnOnce(&Q) -> Result<T, E>,
55    {
56        if !self.contains(value) {
57            self.insert(f(value)?);
58        }
59        match self.get(value) {
60            Some(value) => Ok(value),
61            None => unsafe { core::hint::unreachable_unchecked() },
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use std::collections::HashSet;
69
70    use super::SetInsertExt;
71
72    #[test]
73    fn it_works_when_present() {
74        let mut set = HashSet::new();
75        set.insert(0);
76        assert_eq!(
77            set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(1)),
78            Ok(&0)
79        );
80        assert_eq!(
81            set.get_or_try_insert_with::<_, _, ()>(&0, |_| Err(())),
82            Ok(&0)
83        );
84    }
85
86    #[test]
87    fn it_works_when_not_present() {
88        let mut set = HashSet::new();
89        assert_eq!(
90            set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(0)),
91            Ok(&0),
92        );
93    }
94
95    #[test]
96    fn it_errors() {
97        let mut set = HashSet::<i32>::new();
98        assert_eq!(set.get_or_try_insert_with(&0, |_| Err(())), Err(()));
99    }
100}