1use std::error::Error;
14use std::fmt;
15
16use rkyv::{
17 Archive, Archived, Deserialize as RkyvDeserialize, Place, Serialize as RkyvSerialize,
18 rancor::{Fallible, Source},
19 ser::{Allocator, Writer},
20 vec::{ArchivedVec, VecResolver},
21};
22use serde::{Deserialize, Deserializer, Serialize, Serializer};
23use smallvec::SmallVec;
24
25use crate::DbString;
26
27#[derive(Clone, Debug, Eq, Hash, PartialEq)]
29pub struct LabelSet(SmallVec<[DbString; 3]>);
30
31#[derive(Debug)]
32struct InvalidArchivedLabelSet;
33
34impl fmt::Display for InvalidArchivedLabelSet {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 f.write_str("archived LabelSet must be sorted by DbString order with no duplicates")
37 }
38}
39
40impl Error for InvalidArchivedLabelSet {}
41
42impl LabelSet {
43 #[must_use]
45 pub fn new() -> Self {
46 Self(SmallVec::new())
47 }
48
49 #[allow(clippy::should_implement_trait)]
51 #[must_use]
52 pub fn from_iter(labels: impl IntoIterator<Item = DbString>) -> Self {
53 labels.into_iter().collect()
54 }
55
56 #[must_use]
58 pub fn single(label: DbString) -> Self {
59 let mut labels = SmallVec::new();
60 labels.push(label);
61 Self(labels)
62 }
63
64 #[must_use]
66 pub fn edge(label: DbString) -> Self {
67 Self::single(label)
68 }
69
70 pub fn insert(&mut self, label: DbString) -> bool {
72 match self.0.binary_search(&label) {
73 Ok(_) => false,
74 Err(idx) => {
75 self.0.insert(idx, label);
76 true
77 }
78 }
79 }
80
81 pub fn remove(&mut self, label: &DbString) -> bool {
83 match self.0.binary_search(label) {
84 Ok(idx) => {
85 self.0.remove(idx);
86 true
87 }
88 Err(_) => false,
89 }
90 }
91
92 #[must_use]
94 pub fn contains(&self, label: &DbString) -> bool {
95 self.0.binary_search(label).is_ok()
96 }
97
98 #[must_use]
100 pub fn len(&self) -> usize {
101 self.0.len()
102 }
103
104 #[must_use]
106 pub fn is_empty(&self) -> bool {
107 self.0.is_empty()
108 }
109
110 pub fn iter(&self) -> impl Iterator<Item = &DbString> {
112 self.0.iter()
113 }
114
115 #[cfg(test)]
116 fn sorted_deduped_invariant_holds(&self) -> bool {
117 self.0.windows(2).all(|pair| pair[0] < pair[1])
118 }
119
120 #[cfg(test)]
121 fn spilled(&self) -> bool {
122 self.0.spilled()
123 }
124}
125
126impl Default for LabelSet {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl Serialize for LabelSet {
133 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134 where
135 S: Serializer,
136 {
137 self.0.serialize(serializer)
142 }
143}
144
145impl<'de> Deserialize<'de> for LabelSet {
146 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
147 where
148 D: Deserializer<'de>,
149 {
150 let raw: SmallVec<[DbString; 3]> = SmallVec::deserialize(deserializer)?;
154 for window in raw.windows(2) {
155 if window[0] >= window[1] {
156 return Err(serde::de::Error::custom(
157 "LabelSet must be sorted by DbString order with no duplicate labels",
158 ));
159 }
160 }
161 Ok(Self(raw))
162 }
163}
164
165impl Archive for LabelSet {
166 type Archived = ArchivedVec<Archived<DbString>>;
167 type Resolver = VecResolver;
168
169 fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
170 ArchivedVec::resolve_from_slice(self.0.as_slice(), resolver, out);
171 }
172}
173
174impl<S> RkyvSerialize<S> for LabelSet
175where
176 S: Fallible + Allocator + Writer + ?Sized,
177 DbString: RkyvSerialize<S>,
178{
179 fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
180 ArchivedVec::serialize_from_slice(self.0.as_slice(), serializer)
183 }
184}
185
186impl<D> RkyvDeserialize<LabelSet, D> for ArchivedVec<Archived<DbString>>
187where
188 D: Fallible + ?Sized,
189 D::Error: Source,
190 Archived<DbString>: RkyvDeserialize<DbString, D>,
191{
192 fn deserialize(&self, deserializer: &mut D) -> Result<LabelSet, D::Error> {
193 let mut raw: SmallVec<[DbString; 3]> = SmallVec::new();
194 for label in self.as_slice() {
195 raw.push(label.deserialize(deserializer)?);
196 }
197 for window in raw.windows(2) {
200 if window[0] >= window[1] {
201 rkyv::rancor::fail!(InvalidArchivedLabelSet);
202 }
203 }
204 Ok(LabelSet(raw))
205 }
206}
207
208impl FromIterator<DbString> for LabelSet {
209 fn from_iter<T: IntoIterator<Item = DbString>>(iter: T) -> Self {
210 let mut set = Self::new();
211 for label in iter {
212 set.insert(label);
213 }
214 set
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use proptest::prelude::*;
221
222 use super::*;
223 use crate::db_string;
224
225 fn label(name: &str) -> DbString {
226 db_string(name).unwrap()
227 }
228
229 #[test]
230 fn insert_remove_contains_round_trip() {
231 let a = label("ls.a");
232 let mut set = LabelSet::new();
233 assert!(set.insert(a.clone()));
234 assert!(set.contains(&a));
235 assert!(set.remove(&a));
236 assert!(!set.contains(&a));
237 }
238
239 #[test]
240 fn insert_returns_false_on_duplicate() {
241 let a = label("ls.dup");
242 let mut set = LabelSet::new();
243 assert!(set.insert(a.clone()));
244 assert!(!set.insert(a));
245 assert_eq!(set.len(), 1);
246 }
247
248 #[test]
249 fn iter_yields_sorted_order() {
250 let a = label("ls.sorted.a");
251 let b = label("ls.sorted.b");
252 let set = LabelSet::from_iter([b.clone(), a.clone()]);
253 assert_eq!(set.iter().cloned().collect::<Vec<_>>(), vec![a, b]);
254 }
255
256 #[test]
257 fn set_with_three_inline_does_not_spill() {
258 let set = LabelSet::from_iter(["ls.i.1", "ls.i.2", "ls.i.3"].map(label));
259 assert_eq!(set.len(), 3);
260 assert!(!set.spilled());
261 }
262
263 #[test]
264 fn set_with_four_or_more_spills_to_heap() {
265 let set = LabelSet::from_iter(["ls.s.1", "ls.s.2", "ls.s.3", "ls.s.4"].map(label));
266 assert_eq!(set.len(), 4);
267 assert!(set.spilled());
268 }
269
270 #[test]
271 fn from_iter_dedups_and_sorts() {
272 let a = label("ls.dedup.a");
273 let b = label("ls.dedup.b");
274 let set = LabelSet::from_iter([b.clone(), a.clone(), b.clone()]);
275 assert_eq!(set.iter().cloned().collect::<Vec<_>>(), vec![a, b]);
276 }
277
278 #[test]
279 fn eq_independent_of_insertion_order() {
280 let a = label("ls.eq.a");
281 let b = label("ls.eq.b");
282 assert_eq!(
283 LabelSet::from_iter([a.clone(), b.clone()]),
284 LabelSet::from_iter([b, a])
285 );
286 }
287
288 #[test]
289 fn deserialize_round_trips_sorted_set() {
290 let a = label("ls.de.a");
291 let b = label("ls.de.b");
292 let set = LabelSet::from_iter([a, b]);
293 let bytes = postcard::to_allocvec(&set).unwrap();
294 let round: LabelSet = postcard::from_bytes(&bytes).unwrap();
295 assert_eq!(round, set);
296 }
297
298 #[test]
299 fn serialize_independent_of_insertion_order() {
300 let labels = ["ls.wire.gamma", "ls.wire.alpha", "ls.wire.beta"];
304 let forward = LabelSet::from_iter(labels.map(label));
305 let mut rev = labels;
306 rev.reverse();
307 let reverse = LabelSet::from_iter(rev.map(label));
308
309 assert_eq!(
310 postcard::to_allocvec(&forward).unwrap(),
311 postcard::to_allocvec(&reverse).unwrap(),
312 "serde wire must be insertion-order-independent",
313 );
314 assert_eq!(
315 rkyv::to_bytes::<rkyv::rancor::Error>(&forward)
316 .unwrap()
317 .to_vec(),
318 rkyv::to_bytes::<rkyv::rancor::Error>(&reverse)
319 .unwrap()
320 .to_vec(),
321 "rkyv archive must be insertion-order-independent",
322 );
323 }
324
325 #[test]
326 fn deserialize_round_trips_canonical_payload() {
327 let b = label("ls.de.canon.zebra");
330 let a = label("ls.de.canon.apple");
331 let bytes = postcard::to_allocvec::<SmallVec<[DbString; 3]>>(&{
332 let mut v = SmallVec::<[DbString; 3]>::new();
333 v.push(a.clone());
334 v.push(b.clone());
335 v
336 })
337 .unwrap();
338 let result: LabelSet = postcard::from_bytes(&bytes).unwrap();
339 assert!(result.contains(&a));
340 assert!(result.contains(&b));
341 assert!(result.sorted_deduped_invariant_holds());
342 }
343
344 #[test]
345 fn deserialize_rejects_non_canonical_payload() {
346 let zebra = label("ls.de.noncanon.zebra");
349 let apple = label("ls.de.noncanon.apple");
350 let bytes = postcard::to_allocvec::<SmallVec<[DbString; 3]>>(&{
351 let mut v = SmallVec::<[DbString; 3]>::new();
352 v.push(zebra);
353 v.push(apple);
354 v
355 })
356 .unwrap();
357 let result: Result<LabelSet, _> = postcard::from_bytes(&bytes);
358 assert!(result.is_err());
359 }
360
361 #[test]
362 fn deserialize_rejects_duplicate_payload() {
363 let a = label("ls.de.dup.a");
364 let bytes = postcard::to_allocvec::<SmallVec<[DbString; 3]>>(&{
365 let mut v = SmallVec::<[DbString; 3]>::new();
366 v.push(a.clone());
367 v.push(a);
368 v
369 })
370 .unwrap();
371 let result: Result<LabelSet, _> = postcard::from_bytes(&bytes);
372 assert!(result.is_err());
373 }
374
375 #[test]
376 fn empty_single_and_large_sets() {
377 assert!(LabelSet::new().is_empty());
378 assert_eq!(LabelSet::single(label("ls.one")).len(), 1);
379 let large = LabelSet::from_iter((0..100).map(|idx| {
380 let name = format!("ls.large.{idx}");
381 db_string(&name).unwrap()
382 }));
383 assert_eq!(large.len(), 100);
384 assert!(large.sorted_deduped_invariant_holds());
385 }
386
387 #[test]
388 fn rkyv_deserialize_round_trips_sorted_set() {
389 let a = label("ls.rkyv.a");
390 let b = label("ls.rkyv.b");
391 let set = LabelSet::from_iter([a, b]);
392 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&set).unwrap();
393 let round: LabelSet = rkyv::from_bytes::<LabelSet, rkyv::rancor::Error>(&bytes).unwrap();
394 assert_eq!(round, set);
395 }
396
397 #[test]
398 fn rkyv_deserialize_round_trips_canonical_payload() {
399 let b = label("ls.rkyv.canon.zebra");
402 let a = label("ls.rkyv.canon.apple");
403 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&vec![a.clone(), b.clone()]).unwrap();
404 let result = rkyv::from_bytes::<LabelSet, rkyv::rancor::Error>(&bytes).unwrap();
405 assert!(result.contains(&a));
406 assert!(result.contains(&b));
407 assert!(result.sorted_deduped_invariant_holds());
408 }
409
410 #[test]
411 fn rkyv_deserialize_rejects_non_canonical_payload() {
412 let zebra = label("ls.rkyv.noncanon.zebra");
415 let apple = label("ls.rkyv.noncanon.apple");
416 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&vec![zebra, apple]).unwrap();
417 let result = rkyv::from_bytes::<LabelSet, rkyv::rancor::Error>(&bytes);
418 assert!(result.is_err());
419 }
420
421 #[test]
422 fn rkyv_deserialize_rejects_duplicate_payload() {
423 let a = label("ls.rkyv.dup.a");
424 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&vec![a.clone(), a]).unwrap();
425 let result = rkyv::from_bytes::<LabelSet, rkyv::rancor::Error>(&bytes);
426 assert!(result.is_err());
427 }
428
429 proptest! {
430 #[test]
431 fn random_inserts_are_sorted_and_deduped(raw in proptest::collection::vec(0_u8..64, 1..128)) {
432 let mut set = LabelSet::new();
433 let mut expected = std::collections::BTreeSet::new();
434 for value in raw {
435 let name = format!("ls.prop.{value}");
436 let label = db_string(&name).unwrap();
437 let inserted = set.insert(label.clone());
438 prop_assert_eq!(inserted, expected.insert(label));
439 prop_assert!(set.sorted_deduped_invariant_holds());
440 prop_assert_eq!(set.len(), expected.len());
441 }
442 }
443 }
444}