1use std::{
2 borrow::Borrow,
3 collections::HashSet,
4 hash::{BuildHasher, Hash},
5};
6
7pub trait SetInsertExt<T> {
9 fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
39 where
40 T: Borrow<Q>,
41 Q: Hash + Eq,
42 F: FnOnce(&Q) -> Result<T, E>;
43}
44
45impl<T, S> SetInsertExt<T> for HashSet<T, S>
46where
47 T: Eq + Hash,
48 S: BuildHasher,
49{
50 fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
51 where
52 T: Borrow<Q>,
53 Q: Hash + Eq,
54 F: FnOnce(&Q) -> Result<T, E>,
55 {
56 if !self.contains(value) {
57 self.insert(f(value)?);
58 }
59 match self.get(value) {
60 Some(value) => Ok(value),
61 None => unsafe { core::hint::unreachable_unchecked() },
62 }
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use std::collections::HashSet;
69
70 use super::SetInsertExt;
71
72 #[test]
73 fn it_works_when_present() {
74 let mut set = HashSet::new();
75 set.insert(0);
76 assert_eq!(
77 set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(1)),
78 Ok(&0)
79 );
80 assert_eq!(
81 set.get_or_try_insert_with::<_, _, ()>(&0, |_| Err(())),
82 Ok(&0)
83 );
84 }
85
86 #[test]
87 fn it_works_when_not_present() {
88 let mut set = HashSet::new();
89 assert_eq!(
90 set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(0)),
91 Ok(&0),
92 );
93 }
94
95 #[test]
96 fn it_errors() {
97 let mut set = HashSet::<i32>::new();
98 assert_eq!(set.get_or_try_insert_with(&0, |_| Err(())), Err(()));
99 }
100}