1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use std::{
borrow::Borrow,
collections::HashSet,
hash::{BuildHasher, Hash},
};
pub trait SetInsertExt<T> {
fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
where
T: Borrow<Q>,
Q: Hash + Eq,
F: FnOnce(&Q) -> Result<T, E>;
}
impl<T, S> SetInsertExt<T> for HashSet<T, S>
where
T: Eq + Hash,
S: BuildHasher,
{
fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
where
T: Borrow<Q>,
Q: Hash + Eq,
F: FnOnce(&Q) -> Result<T, E>,
{
if !self.contains(value) {
self.insert(f(value)?);
}
match self.get(value) {
Some(value) => Ok(value),
None => unsafe { core::hint::unreachable_unchecked() },
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::SetInsertExt;
#[test]
fn it_works_when_present() {
let mut set = HashSet::new();
set.insert(0);
assert_eq!(
set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(1)),
Ok(&0)
);
assert_eq!(
set.get_or_try_insert_with::<_, _, ()>(&0, |_| Err(())),
Ok(&0)
);
}
#[test]
fn it_works_when_not_present() {
let mut set = HashSet::new();
assert_eq!(
set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(0)),
Ok(&0),
);
}
#[test]
fn it_errors() {
let mut set = HashSet::<i32>::new();
assert_eq!(set.get_or_try_insert_with(&0, |_| Err(())), Err(()));
}
}