1use core::fmt;
9use iref::UriBuf;
10use serde::{Deserialize, Serialize};
11use std::{hash::Hash, str::FromStr, time::Duration};
12
13use crate::{Overflow, StatusMap, StatusSizeError};
14
15mod syntax;
16pub use syntax::*;
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct StatusMessage {
20 #[serde(with = "prefixed_hexadecimal")]
21 pub status: u8,
22 pub message: String,
23}
24
25impl StatusMessage {
26 pub fn new(status: u8, message: String) -> Self {
27 Self { status, message }
28 }
29}
30
31#[derive(Debug, thiserror::Error)]
32#[error("invalid status size `{0}`")]
33pub struct InvalidStatusSize(u8);
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
36pub struct StatusSize(u8);
37
38impl TryFrom<u8> for StatusSize {
39 type Error = InvalidStatusSize;
40
41 fn try_from(value: u8) -> Result<Self, Self::Error> {
42 if value <= 8 {
43 Ok(Self(value))
44 } else {
45 Err(InvalidStatusSize(value))
46 }
47 }
48}
49
50impl Default for StatusSize {
51 fn default() -> Self {
52 Self::DEFAULT
53 }
54}
55
56impl StatusSize {
57 pub const DEFAULT: Self = Self(1);
58
59 pub fn is_default(&self) -> bool {
60 *self == Self::DEFAULT
61 }
62
63 fn offset_of(&self, index: usize) -> Offset {
64 let bit_offset = self.0 as usize * index;
65 Offset {
66 byte: bit_offset / 8,
67 bit: bit_offset % 8,
68 }
69 }
70
71 fn mask(&self) -> u8 {
72 if self.0 == 8 {
73 0xff
74 } else {
75 (1 << self.0) - 1
76 }
77 }
78}
79
80impl<'de> Deserialize<'de> for StatusSize {
81 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82 where
83 D: serde::Deserializer<'de>,
84 {
85 u8::deserialize(deserializer)?
86 .try_into()
87 .map_err(serde::de::Error::custom)
88 }
89}
90
91#[derive(Debug)]
92struct Offset {
93 byte: usize,
94 bit: usize,
95}
96
97impl Offset {
98 fn left_shift(&self, status_size: StatusSize) -> (i32, Option<u32>) {
99 let high = (8 - status_size.0 as isize - self.bit as isize) as i32;
100 let low = if high < 0 {
101 Some((8 + high) as u32)
102 } else {
103 None
104 };
105
106 (high, low)
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
115#[serde(transparent)]
116pub struct TimeToLive(pub u64);
117
118impl Default for TimeToLive {
119 fn default() -> Self {
120 Self::DEFAULT
121 }
122}
123
124impl TimeToLive {
125 pub const DEFAULT: Self = Self(300000);
126
127 pub fn is_default(&self) -> bool {
128 *self == Self::DEFAULT
129 }
130}
131
132impl From<TimeToLive> for Duration {
133 fn from(value: TimeToLive) -> Self {
134 Duration::from_millis(value.0)
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
139#[serde(rename_all = "camelCase")]
140pub enum StatusPurpose {
141 Revocation,
145
146 Suspension,
150
151 Message,
157}
158
159impl StatusPurpose {
160 pub fn from_name(name: &str) -> Option<Self> {
162 match name {
163 "revocation" => Some(Self::Revocation),
164 "suspension" => Some(Self::Suspension),
165 "message" => Some(Self::Message),
166 _ => None,
167 }
168 }
169
170 pub fn name(&self) -> &'static str {
172 match self {
173 Self::Revocation => "revocation",
174 Self::Suspension => "suspension",
175 Self::Message => "message",
176 }
177 }
178
179 pub fn as_str(&self) -> &'static str {
183 self.name()
184 }
185
186 pub fn into_name(self) -> &'static str {
190 self.name()
191 }
192
193 pub fn into_str(self) -> &'static str {
197 self.name()
198 }
199}
200
201impl<'a> From<&'a StatusPurpose> for crate::StatusPurpose<&'a str> {
202 fn from(value: &'a StatusPurpose) -> Self {
203 match value {
204 StatusPurpose::Revocation => Self::Revocation,
205 StatusPurpose::Suspension => Self::Suspension,
206 StatusPurpose::Message => Self::Other("message"),
207 }
208 }
209}
210
211impl<'a> PartialEq<crate::StatusPurpose<&'a str>> for StatusPurpose {
212 fn eq(&self, other: &crate::StatusPurpose<&'a str>) -> bool {
213 matches!(
214 (self, other),
215 (Self::Revocation, crate::StatusPurpose::Revocation)
216 | (Self::Suspension, crate::StatusPurpose::Suspension)
217 | (Self::Message, crate::StatusPurpose::Other("message"))
218 )
219 }
220}
221
222impl fmt::Display for StatusPurpose {
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 self.name().fmt(f)
225 }
226}
227
228#[derive(Debug, Clone, thiserror::Error)]
230#[error("invalid status purpose: {0}")]
231pub struct InvalidStatusPurpose(pub String);
232
233impl FromStr for StatusPurpose {
234 type Err = InvalidStatusPurpose;
235
236 fn from_str(s: &str) -> Result<Self, Self::Err> {
237 Self::from_name(s).ok_or_else(|| InvalidStatusPurpose(s.to_owned()))
238 }
239}
240
241#[derive(Debug, Clone)]
251pub struct BitString {
252 status_size: StatusSize,
253 bytes: Vec<u8>,
254 len: usize,
255}
256
257impl BitString {
258 pub fn new(status_size: StatusSize) -> Self {
260 Self {
261 status_size,
262 bytes: Vec::new(),
263 len: 0,
264 }
265 }
266
267 pub fn new_with(
272 status_size: StatusSize,
273 len: usize,
274 mut f: impl FnMut(usize) -> u8,
275 ) -> Result<Self, Overflow> {
276 let mut result = Self::with_capacity(status_size, len);
277
278 for i in 0..len {
279 result.push(f(i))?;
280 }
281
282 Ok(result)
283 }
284
285 pub fn new_with_value(
288 status_size: StatusSize,
289 len: usize,
290 value: u8,
291 ) -> Result<Self, Overflow> {
292 Self::new_with(status_size, len, |_| value)
293 }
294
295 pub fn new_zeroed(status_size: StatusSize, len: usize) -> Self {
298 Self::new_with_value(status_size, len, 0).unwrap() }
300
301 pub fn with_capacity(status_size: StatusSize, capacity: usize) -> Self {
304 Self {
305 status_size,
306 bytes: Vec::with_capacity((capacity * status_size.0 as usize).div_ceil(8)),
307 len: 0,
308 }
309 }
310
311 pub fn from_bytes(status_size: StatusSize, bytes: Vec<u8>) -> Self {
313 let len = bytes.len() * 8usize / status_size.0 as usize;
314 Self {
315 status_size,
316 bytes,
317 len,
318 }
319 }
320
321 pub fn is_empty(&self) -> bool {
323 self.len == 0
324 }
325
326 pub fn len(&self) -> usize {
328 self.len
329 }
330
331 pub fn get(&self, index: usize) -> Option<u8> {
333 if index >= self.len {
334 return None;
335 }
336
337 let offset = self.status_size.offset_of(index);
338 let (high_shift, low_shift) = offset.left_shift(self.status_size);
339
340 Some(self.get_at(offset.byte, high_shift, low_shift))
341 }
342
343 fn get_at(&self, byte_offset: usize, high_shift: i32, low_shift: Option<u32>) -> u8 {
344 let high = self
345 .bytes
346 .get(byte_offset)
347 .unwrap()
348 .overflowing_signed_shr(high_shift)
349 .0;
350
351 let low = match low_shift {
352 Some(low_shift) => {
353 self.bytes
354 .get(byte_offset + 1)
355 .unwrap()
356 .overflowing_shr(low_shift)
357 .0
358 }
359 None => 0,
360 };
361
362 (high | low) & self.status_size.mask()
363 }
364
365 pub fn set(&mut self, index: usize, value: u8) -> Result<u8, Overflow> {
370 if index >= self.len {
371 return Err(Overflow::Index(index));
372 }
373
374 let mask = self.status_size.mask();
375 let masked_value = value & mask;
376 if masked_value != value {
377 return Err(Overflow::Value(value));
378 }
379
380 let offset = self.status_size.offset_of(index);
381 let (high_shift, low_shift) = offset.left_shift(self.status_size);
382
383 let old_value = self.get_at(offset.byte, high_shift, low_shift);
384
385 self.bytes[offset.byte] &= !mask.overflowing_signed_shl(high_shift).0; self.bytes[offset.byte] |= masked_value.overflowing_signed_shl(high_shift).0; if let Some(low_shift) = low_shift {
388 self.bytes[offset.byte + 1] &= !mask.overflowing_shl(low_shift).0; self.bytes[offset.byte + 1] |= masked_value.overflowing_shl(low_shift).0;
390 }
392
393 Ok(old_value)
394 }
395
396 pub fn push(&mut self, value: u8) -> Result<usize, Overflow> {
401 let masked_value = value & self.status_size.mask();
402 if masked_value != value {
403 return Err(Overflow::Value(value));
404 }
405
406 let index = self.len;
407 let offset = self.status_size.offset_of(index);
408
409 let (high_shift, low_shift) = offset.left_shift(self.status_size);
410
411 if offset.byte == self.bytes.len() {
412 self.bytes
413 .push(masked_value.overflowing_signed_shl(high_shift).0);
414 } else {
415 self.bytes[offset.byte] |= masked_value.overflowing_signed_shl(high_shift).0
416 }
417
418 if let Some(low_shift) = low_shift {
419 self.bytes.push(masked_value.overflowing_shl(low_shift).0);
420 }
421
422 self.len += 1;
423 Ok(index)
424 }
425
426 pub fn iter(&self) -> BitStringIter {
428 BitStringIter {
429 bit_string: self,
430 index: 0,
431 }
432 }
433
434 pub fn encode(&self) -> EncodedList {
436 EncodedList::encode(&self.bytes)
437 }
438}
439
440trait OverflowingSignedShift: Sized {
441 fn overflowing_signed_shl(self, shift: i32) -> (Self, bool);
442
443 fn overflowing_signed_shr(self, shift: i32) -> (Self, bool);
444}
445
446impl OverflowingSignedShift for u8 {
447 fn overflowing_signed_shl(self, shift: i32) -> (u8, bool) {
448 if shift < 0 {
449 self.overflowing_shr(shift.unsigned_abs())
450 } else {
451 self.overflowing_shl(shift.unsigned_abs())
452 }
453 }
454
455 fn overflowing_signed_shr(self, shift: i32) -> (u8, bool) {
456 if shift < 0 {
457 self.overflowing_shl(shift.unsigned_abs())
458 } else {
459 self.overflowing_shr(shift.unsigned_abs())
460 }
461 }
462}
463
464#[derive(Debug, Clone)]
465pub struct StatusList {
466 bit_string: BitString,
467 ttl: TimeToLive,
468}
469
470impl StatusList {
471 pub fn new(status_size: StatusSize, ttl: TimeToLive) -> Self {
472 Self {
473 bit_string: BitString::new(status_size),
474 ttl,
475 }
476 }
477
478 pub fn from_bytes(status_size: StatusSize, bytes: Vec<u8>, ttl: TimeToLive) -> Self {
479 Self {
480 bit_string: BitString::from_bytes(status_size, bytes),
481 ttl,
482 }
483 }
484
485 pub fn is_empty(&self) -> bool {
486 self.bit_string.is_empty()
487 }
488
489 pub fn len(&self) -> usize {
490 self.bit_string.len()
491 }
492
493 pub fn get(&self, index: usize) -> Option<u8> {
494 self.bit_string.get(index)
495 }
496
497 pub fn set(&mut self, index: usize, value: u8) -> Result<u8, Overflow> {
498 self.bit_string.set(index, value)
499 }
500
501 pub fn push(&mut self, value: u8) -> Result<usize, Overflow> {
502 self.bit_string.push(value)
503 }
504
505 pub fn iter(&self) -> BitStringIter {
506 self.bit_string.iter()
507 }
508
509 pub fn to_credential_subject(
510 &self,
511 id: Option<UriBuf>,
512 status_purpose: StatusPurpose,
513 status_message: Vec<StatusMessage>,
514 ) -> BitstringStatusList {
515 BitstringStatusList::new(
516 id,
517 status_purpose,
518 self.bit_string.status_size,
519 self.bit_string.encode(),
520 self.ttl,
521 status_message,
522 )
523 }
524}
525
526pub struct BitStringIter<'a> {
527 bit_string: &'a BitString,
528 index: usize,
529}
530
531impl Iterator for BitStringIter<'_> {
532 type Item = u8;
533
534 fn next(&mut self) -> Option<Self::Item> {
535 self.bit_string.get(self.index).inspect(|_| {
536 self.index += 1;
537 })
538 }
539}
540
541impl StatusMap for StatusList {
542 type Key = usize;
543 type Status = u8;
544 type StatusSize = StatusSize;
545
546 fn time_to_live(&self) -> Option<Duration> {
547 Some(self.ttl.into())
548 }
549
550 fn get_by_key(
551 &self,
552 _status_size: Option<StatusSize>,
553 key: Self::Key,
554 ) -> Result<Option<u8>, StatusSizeError> {
555 Ok(self.bit_string.get(key))
556 }
557}
558
559mod prefixed_hexadecimal {
560 use serde::{Deserialize, Deserializer, Serialize, Serializer};
561
562 pub fn serialize<S>(value: &u8, serializer: S) -> Result<S::Ok, S::Error>
563 where
564 S: Serializer,
565 {
566 format!("{value:#x}").serialize(serializer)
567 }
568
569 pub fn deserialize<'de, D>(deserializer: D) -> Result<u8, D::Error>
570 where
571 D: Deserializer<'de>,
572 {
573 let string = String::deserialize(deserializer)?;
574 let number = string
575 .strip_prefix("0x")
576 .ok_or_else(|| serde::de::Error::custom("missing `0x` prefix"))?;
577 u8::from_str_radix(number, 16).map_err(serde::de::Error::custom)
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use rand::{rngs::StdRng, RngCore, SeedableRng};
584
585 use crate::Overflow;
586
587 use super::{BitString, StatusSize};
588
589 fn random_bit_string(
590 rng: &mut StdRng,
591 status_size: StatusSize,
592 len: usize,
593 ) -> (Vec<u8>, BitString) {
594 let mut values = Vec::with_capacity(len);
595
596 for _ in 0..len {
597 values.push((rng.next_u32() & 0xff) as u8 & status_size.mask())
598 }
599
600 let mut bit_string = BitString::new(status_size);
601 for &s in &values {
602 bit_string.push(s).unwrap();
603 }
604
605 (values, bit_string)
606 }
607
608 fn randomized_roundtrip(seed: u64, status_size: StatusSize, len: usize) {
609 let mut rng = StdRng::seed_from_u64(seed);
610 let (values, bit_string) = random_bit_string(&mut rng, status_size, len);
611
612 let encoded = bit_string.encode();
613 let decoded = BitString::from_bytes(status_size, encoded.decode(None).unwrap());
614
615 assert!(decoded.len() >= len);
616
617 for i in 0..len {
618 assert_eq!(decoded.get(i), Some(values[i]))
619 }
620 }
621
622 fn randomized_write(seed: u64, status_size: StatusSize, len: usize) {
623 let mut rng = StdRng::seed_from_u64(seed);
624 let (mut values, mut bit_string) = random_bit_string(&mut rng, status_size, len);
625
626 for _ in 0..len {
627 let i = (rng.next_u32() as usize) % len;
628 let value = (rng.next_u32() & 0xff) as u8 & status_size.mask();
629 bit_string.set(i, value).unwrap();
630 values[i] = value;
631 }
632
633 for (i, item) in values.into_iter().enumerate().take(len) {
634 assert_eq!(bit_string.get(i), Some(item))
635 }
636 }
637
638 #[test]
639 fn randomized_roundtrip_1bit() {
640 for i in 0..10 {
641 randomized_roundtrip(i, 1u8.try_into().unwrap(), 10);
642 }
643
644 for i in 0..10 {
645 randomized_roundtrip(i, 1u8.try_into().unwrap(), 100);
646 }
647
648 for i in 0..10 {
649 randomized_roundtrip(i, 1u8.try_into().unwrap(), 1000);
650 }
651 }
652
653 #[test]
654 fn randomized_write_1bits() {
655 for i in 0..10 {
656 randomized_write(i, 1u8.try_into().unwrap(), 10);
657 }
658
659 for i in 0..10 {
660 randomized_write(i, 1u8.try_into().unwrap(), 100);
661 }
662
663 for i in 0..10 {
664 randomized_write(i, 1u8.try_into().unwrap(), 1000);
665 }
666 }
667
668 #[test]
669 fn randomized_roundtrip_3bits() {
670 for i in 0..10 {
671 randomized_roundtrip(i, 3u8.try_into().unwrap(), 10);
672 }
673
674 for i in 0..10 {
675 randomized_roundtrip(i, 3u8.try_into().unwrap(), 100);
676 }
677
678 for i in 0..10 {
679 randomized_roundtrip(i, 3u8.try_into().unwrap(), 1000);
680 }
681 }
682
683 #[test]
684 fn randomized_write_3bits() {
685 for i in 0..10 {
686 randomized_write(i, 3u8.try_into().unwrap(), 10);
687 }
688
689 for i in 0..10 {
690 randomized_write(i, 3u8.try_into().unwrap(), 100);
691 }
692
693 for i in 0..10 {
694 randomized_write(i, 3u8.try_into().unwrap(), 1000);
695 }
696 }
697
698 #[test]
699 fn randomized_roundtrip_7bits() {
700 for i in 0..10 {
701 randomized_roundtrip(i, 7u8.try_into().unwrap(), 10);
702 }
703
704 for i in 0..10 {
705 randomized_roundtrip(i, 7u8.try_into().unwrap(), 100);
706 }
707
708 for i in 0..10 {
709 randomized_roundtrip(i, 7u8.try_into().unwrap(), 1000);
710 }
711 }
712
713 #[test]
714 fn randomized_write_7bits() {
715 for i in 0..10 {
716 randomized_write(i, 7u8.try_into().unwrap(), 10);
717 }
718
719 for i in 0..10 {
720 randomized_write(i, 7u8.try_into().unwrap(), 100);
721 }
722
723 for i in 0..10 {
724 randomized_write(i, 7u8.try_into().unwrap(), 1000);
725 }
726 }
727
728 #[test]
729 fn overflows() {
730 let mut rng = StdRng::seed_from_u64(0);
731 let (_, mut bitstring) = random_bit_string(&mut rng, 1u8.try_into().unwrap(), 15);
732
733 assert!(bitstring.get(15).is_none());
735
736 assert_eq!(bitstring.set(15, 0), Err(Overflow::Index(15)));
738
739 assert_eq!(bitstring.set(14, 2), Err(Overflow::Value(2)));
741 }
742
743 #[test]
744 fn deserialize_status_size_1() {
745 assert!(serde_json::from_str::<StatusSize>("1").is_ok())
746 }
747
748 #[test]
749 fn deserialize_status_size_2() {
750 assert!(serde_json::from_str::<StatusSize>("2").is_ok())
751 }
752
753 #[test]
754 fn deserialize_status_size_3() {
755 assert!(serde_json::from_str::<StatusSize>("3").is_ok())
756 }
757
758 #[test]
759 fn deserialize_status_size_negative() {
760 assert!(serde_json::from_str::<StatusSize>("-1").is_err())
761 }
762
763 #[test]
764 fn deserialize_status_size_overflow() {
765 assert!(serde_json::from_str::<StatusSize>("9").is_err())
766 }
767}