1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
use std::{
    borrow::Borrow,
    collections::HashSet,
    hash::{BuildHasher, Hash},
};

/// Extends sets with `get_or_try_insert_with`.
pub trait SetInsertExt<T> {
    /// If the set does not contain the value, computes the value from `f`. If
    /// `f` returns `Ok`, inserts the value. If `f` returns `Err`, returns the
    /// error. If there is no error, returns a reference to the contained value.
    ///
    /// # Examples
    ///
    /// ```
    /// use std::collections::HashSet;
    ///
    /// use try_insert_ext::SetInsertExt;
    ///
    /// let mut set: HashSet<String> = ["cat", "dog", "horse"]
    ///     .iter()
    ///     .map(|&pet| pet.to_owned())
    ///     .collect();
    ///
    /// assert_eq!(set.len(), 3);
    /// let value = set.get_or_try_insert_with("error", |_| Err(()));
    /// assert!(value.is_err());
    /// for &pet in &["cat", "dog", "fish"] {
    ///     let value = set.get_or_try_insert_with::<_, _, ()>(
    ///         pet,
    ///         |pet| Ok(pet.to_owned()),
    ///     );
    ///     assert_eq!(value, Ok(&pet.to_owned()));
    /// }
    ///
    /// assert_eq!(set.len(), 4);
    /// ```
    fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
    where
        T: Borrow<Q>,
        Q: Hash + Eq,
        F: FnOnce(&Q) -> Result<T, E>;
}

impl<T, S> SetInsertExt<T> for HashSet<T, S>
where
    T: Eq + Hash,
    S: BuildHasher,
{
    fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
    where
        T: Borrow<Q>,
        Q: Hash + Eq,
        F: FnOnce(&Q) -> Result<T, E>,
    {
        if !self.contains(value) {
            self.insert(f(value)?);
        }
        match self.get(value) {
            Some(value) => Ok(value),
            None => unsafe { core::hint::unreachable_unchecked() },
        }
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashSet;

    use super::SetInsertExt;

    #[test]
    fn it_works_when_present() {
        let mut set = HashSet::new();
        set.insert(0);
        assert_eq!(
            set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(1)),
            Ok(&0)
        );
        assert_eq!(
            set.get_or_try_insert_with::<_, _, ()>(&0, |_| Err(())),
            Ok(&0)
        );
    }

    #[test]
    fn it_works_when_not_present() {
        let mut set = HashSet::new();
        assert_eq!(
            set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(0)),
            Ok(&0),
        );
    }

    #[test]
    fn it_errors() {
        let mut set = HashSet::<i32>::new();
        assert_eq!(set.get_or_try_insert_with(&0, |_| Err(())), Err(()));
    }
}