1use flate2::{bufread::ZlibDecoder, write::ZlibEncoder, Compression};
12use iref::Uri;
13use serde::Serialize;
14use ssi_claims_core::{DateTimeProvider, ResolverProvider};
15use std::{
16 io::{self, Read, Write},
17 time::Duration,
18};
19
20pub mod cbor;
21pub mod json;
22
23pub use json::StatusListJwt;
24use ssi_jwk::JWKResolver;
25use ssi_jws::{InvalidJws, JwsSlice};
26use ssi_jwt::{ClaimSet, InvalidClaimValue, JWTClaims, ToDecodedJwt};
27
28use crate::{
29 EncodedStatusMap, FromBytes, FromBytesOptions, Overflow, StatusMap, StatusMapEntry,
30 StatusMapEntrySet, StatusSizeError,
31};
32
33pub const VALID: u8 = 0;
37
38pub const INVALID: u8 = 1;
45
46pub const SUSPENDED: u8 = 2;
53
54pub enum StatusListToken {
56 Jwt(StatusListJwt),
57}
58
59impl StatusListToken {
60 pub fn decode_status_list(self) -> Result<StatusList, DecodeError> {
61 match self {
62 Self::Jwt(claims) => json::decode_status_list_jwt(claims),
63 }
64 }
65}
66
67#[derive(Debug, thiserror::Error)]
68pub enum FromBytesError {
69 #[error("missing media type")]
70 MissingMediaType,
71
72 #[error("unexpected media type `{0}`")]
73 UnexpectedMediaType(String),
74
75 #[error(transparent)]
76 JWS(#[from] InvalidJws<Vec<u8>>),
77
78 #[error("invalid JWT: {0}")]
79 JWT(#[from] ssi_jwt::DecodeError),
80
81 #[error("unexpected JWS type `{0}`")]
82 UnexpectedJWSType(String),
83
84 #[error("missing JWS type")]
85 MissingJWSType,
86
87 #[error("proof preparation failed: {0}")]
88 Preparation(#[from] ssi_claims_core::ProofPreparationError),
89
90 #[error("proof verification failed: {0}")]
91 Verification(#[from] ssi_claims_core::ProofValidationError),
92
93 #[error(transparent)]
94 Rejected(#[from] ssi_claims_core::Invalid),
95}
96
97impl<V> FromBytes<V> for StatusListToken
98where
99 V: ResolverProvider + DateTimeProvider,
100 V::Resolver: JWKResolver,
101{
102 type Error = FromBytesError;
103
104 async fn from_bytes_with(
105 bytes: &[u8],
106 media_type: &str,
107 verifier: &V,
108 _options: FromBytesOptions,
109 ) -> Result<Self, Self::Error> {
110 match media_type {
111 "statuslist+jwt" => {
112 let jwt = JwsSlice::new(bytes)
113 .map_err(InvalidJws::into_owned)?
114 .to_decoded_custom_jwt::<json::StatusListJwtPrivateClaims>()?;
115
116 match jwt.signing_bytes.header.type_.as_deref() {
117 Some("statuslist+jwt") => {
118 jwt.verify(verifier).await??;
119 Ok(Self::Jwt(jwt.signing_bytes.payload))
120 }
121 Some(other) => Err(FromBytesError::UnexpectedJWSType(other.to_owned())),
122 None => Err(FromBytesError::MissingJWSType),
123 }
124 }
125 "statuslist+cwt" => {
126 todo!()
127 }
128 other => Err(FromBytesError::UnexpectedMediaType(other.to_owned())),
129 }
130 }
131}
132
133#[derive(Debug, thiserror::Error)]
134pub enum DecodeError {
135 #[error("invalid claim: {0}")]
136 Claim(String),
137
138 #[error("missing issuer")]
139 MissingIssuer,
140
141 #[error("missing subject")]
142 MissingSubject,
143
144 #[error("missing `status_list` claim")]
145 MissingStatusList,
146
147 #[error("invalid base64: {0}")]
148 Base64(#[from] base64::DecodeError),
149
150 #[error("ZLIB decompression: {0}")]
151 Zlib(#[from] io::Error),
152}
153
154impl DecodeError {
155 pub fn claim(e: impl ToString) -> Self {
156 Self::Claim(e.to_string())
157 }
158}
159
160impl EncodedStatusMap for StatusListToken {
161 type Decoded = StatusList;
162 type DecodeError = DecodeError;
163
164 fn decode(self) -> Result<Self::Decoded, Self::DecodeError> {
165 self.decode_status_list()
166 }
167}
168
169pub const JWT_TYPE: &str = "statuslist+jwt";
173
174#[derive(Debug, thiserror::Error)]
175#[error("invalid status size {0}")]
176pub struct InvalidStatusSize(u8);
177
178impl From<InvalidStatusSize> for StatusSizeError {
179 fn from(_value: InvalidStatusSize) -> Self {
180 Self::Invalid
181 }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
186#[serde(transparent)]
187pub struct StatusSize(u8);
188
189impl StatusSize {
190 pub const fn status_per_byte(&self) -> usize {
192 8 / self.0 as usize
193 }
194
195 pub const fn status_mask(&self) -> u8 {
197 match self.0 {
198 1 => 0b1,
199 2 => 0b11,
200 4 => 0b1111,
201 _ => 0b11111111,
202 }
203 }
204
205 pub const fn offset_of(&self, index: usize) -> (usize, usize) {
208 let spb = self.status_per_byte();
209 (index / spb, (index % spb) * self.0 as usize)
210 }
211}
212
213impl TryFrom<u8> for StatusSize {
214 type Error = InvalidStatusSize;
215
216 fn try_from(value: u8) -> Result<Self, Self::Error> {
217 if matches!(value, 1 | 2 | 4 | 8) {
218 Ok(Self(value))
219 } else {
220 Err(InvalidStatusSize(value))
221 }
222 }
223}
224
225impl From<StatusSize> for u8 {
226 fn from(value: StatusSize) -> Self {
227 value.0
228 }
229}
230
231impl<'de> serde::Deserialize<'de> for StatusSize {
232 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
233 where
234 D: serde::Deserializer<'de>,
235 {
236 u8::deserialize(deserializer)?
237 .try_into()
238 .map_err(serde::de::Error::custom)
239 }
240}
241
242#[derive(Clone)]
243pub struct StatusList {
244 bit_string: BitString,
245 ttl: Option<u64>,
246}
247
248impl StatusList {
249 pub fn new(bit_string: BitString, ttl: Option<u64>) -> Self {
250 Self { bit_string, ttl }
251 }
252
253 pub fn iter(&self) -> BitStringIter {
254 self.bit_string.iter()
255 }
256}
257
258impl StatusMap for StatusList {
259 type Key = usize;
260 type StatusSize = StatusSize;
261 type Status = u8;
262
263 fn time_to_live(&self) -> Option<Duration> {
264 self.ttl.map(Duration::from_secs)
265 }
266
267 fn get_by_key(
268 &self,
269 _status_size: Option<StatusSize>,
270 key: Self::Key,
271 ) -> Result<Option<Self::Status>, StatusSizeError> {
272 Ok(self.bit_string.get(key))
273 }
274}
275
276#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
283pub struct BitString {
284 status_size: StatusSize,
285 bytes: Vec<u8>,
286 len: usize,
287}
288
289impl BitString {
290 pub const DEFAULT_LIMIT: u64 = 16 * 1024 * 1024;
294
295 pub fn new(status_size: StatusSize) -> Self {
297 Self {
298 status_size,
299 bytes: Vec::new(),
300 len: 0,
301 }
302 }
303
304 pub fn new_with(
309 status_size: StatusSize,
310 len: usize,
311 mut f: impl FnMut(usize) -> u8,
312 ) -> Result<Self, Overflow> {
313 let mut result = Self::with_capacity(status_size, len);
314
315 for i in 0..len {
316 result.push(f(i))?;
317 }
318
319 Ok(result)
320 }
321
322 pub fn new_with_value(
325 status_size: StatusSize,
326 len: usize,
327 value: u8,
328 ) -> Result<Self, Overflow> {
329 Self::new_with(status_size, len, |_| value)
330 }
331
332 pub fn new_zeroed(status_size: StatusSize, len: usize) -> Self {
337 Self::new_valid(status_size, len)
338 }
339
340 pub fn new_valid(status_size: StatusSize, len: usize) -> Self {
343 Self::new_with_value(status_size, len, VALID).unwrap() }
345
346 pub fn new_invalid(status_size: StatusSize, len: usize) -> Self {
349 Self::new_with_value(status_size, len, INVALID).unwrap() }
351
352 pub fn with_capacity(status_size: StatusSize, capacity: usize) -> Self {
355 Self {
356 status_size,
357 bytes: Vec::with_capacity(capacity.div_ceil(status_size.status_per_byte())),
358 len: 0,
359 }
360 }
361
362 pub fn from_parts(status_size: StatusSize, data: Vec<u8>) -> Self {
364 let len = data.len() * 8usize / status_size.0 as usize;
365 Self {
366 status_size,
367 bytes: data,
368 len,
369 }
370 }
371
372 pub fn from_compressed_bytes(
378 status_size: StatusSize,
379 bytes: &[u8],
380 limit: Option<u64>,
381 ) -> Result<Self, io::Error> {
382 let limit = limit.unwrap_or(Self::DEFAULT_LIMIT);
383 let mut decoder = ZlibDecoder::new(bytes).take(limit);
384 let mut buffer = Vec::new();
385 decoder.read_to_end(&mut buffer)?;
386 Ok(Self::from_parts(status_size, buffer))
387 }
388
389 pub fn status_size(&self) -> StatusSize {
391 self.status_size
392 }
393
394 pub fn is_empty(&self) -> bool {
396 self.len == 0
397 }
398
399 pub fn len(&self) -> usize {
401 self.len
402 }
403
404 pub fn get(&self, index: usize) -> Option<u8> {
406 if index < self.len {
407 let (a, b) = self.status_size.offset_of(index);
408 Some((self.bytes[a] >> b) & self.status_size.status_mask())
409 } else {
410 None
411 }
412 }
413
414 pub fn set(&mut self, index: usize, value: u8) -> Result<u8, Overflow> {
419 if index >= self.len {
420 return Err(Overflow::Index(index));
421 }
422
423 let status_mask = self.status_size.status_mask();
424 let masked_value = value & status_mask;
425
426 if masked_value != value {
427 return Err(Overflow::Value(value));
428 }
429
430 let (a, b) = self.status_size.offset_of(index);
431
432 let old_value = (self.bytes[a] >> b) & status_mask;
433 self.bytes[a] &= !(status_mask << b); self.bytes[a] |= masked_value << b; Ok(old_value)
437 }
438
439 pub fn push(&mut self, value: u8) -> Result<usize, Overflow> {
444 let status_mask = self.status_size.status_mask();
445 let masked_value = value & status_mask;
446
447 if masked_value != value {
448 return Err(Overflow::Value(value));
449 }
450
451 let index = self.len;
452 self.len += 1;
453 let (a, b) = self.status_size.offset_of(index);
454
455 if a == self.bytes.len() {
456 self.bytes.push(masked_value << b)
457 } else {
458 self.bytes[a] |= masked_value << b
459 }
460
461 Ok(a)
462 }
463
464 pub fn as_bytes(&self) -> &[u8] {
466 self.bytes.as_slice()
467 }
468
469 pub fn to_compressed_bytes(&self, compression: Compression) -> Vec<u8> {
475 let mut buffer = Vec::new();
476
477 {
478 let mut encoder = ZlibEncoder::new(&mut buffer, compression);
479
480 encoder.write_all(&self.bytes).unwrap();
481 }
482
483 buffer
484 }
485
486 pub fn iter(&self) -> BitStringIter {
488 BitStringIter {
489 bit_string: self,
490 index: 0,
491 }
492 }
493
494 pub fn into_parts(self) -> (StatusSize, Vec<u8>) {
497 (self.status_size, self.bytes)
498 }
499}
500
501pub struct BitStringIter<'a> {
502 bit_string: &'a BitString,
503 index: usize,
504}
505
506impl Iterator for BitStringIter<'_> {
507 type Item = u8;
508
509 fn next(&mut self) -> Option<Self::Item> {
510 self.bit_string.get(self.index).inspect(|_| {
511 self.index += 1;
512 })
513 }
514}
515
516#[derive(Debug, thiserror::Error)]
517pub enum EntrySetFromBytesError {
518 #[error(transparent)]
519 Json(#[from] serde_json::Error),
520
521 #[error(transparent)]
522 JWS(#[from] InvalidJws<Vec<u8>>),
523
524 #[error(transparent)]
525 JWT(#[from] ssi_jwt::DecodeError),
526
527 #[error(transparent)]
528 ClaimValue(#[from] InvalidClaimValue),
529
530 #[error("proof preparation failed: {0}")]
531 ProofPreparation(#[from] ssi_claims_core::ProofPreparationError),
532
533 #[error("proof validation failed: {0}")]
534 ProofValidation(#[from] ssi_claims_core::ProofValidationError),
535
536 #[error("rejected claims: {0}")]
537 Rejected(#[from] ssi_claims_core::Invalid),
538
539 #[error("missing status")]
540 MissingStatus,
541}
542
543pub enum AnyStatusListEntrySet {
544 Json(json::Status),
545}
546
547impl<V> FromBytes<V> for AnyStatusListEntrySet
548where
549 V: ResolverProvider + DateTimeProvider,
550 V::Resolver: JWKResolver,
551{
552 type Error = EntrySetFromBytesError;
553
554 async fn from_bytes_with(
555 bytes: &[u8],
556 media_type: &str,
557 verifier: &V,
558 _options: FromBytesOptions,
559 ) -> Result<Self, EntrySetFromBytesError> {
560 match media_type {
561 "application/json" => {
562 let claims: JWTClaims = serde_json::from_slice(bytes)?;
563 Ok(Self::Json(
564 claims
565 .try_get::<json::Status>()?
566 .ok_or(EntrySetFromBytesError::MissingStatus)?
567 .into_owned(),
568 ))
569 }
570 "application/jwt" => {
571 let jwt = JwsSlice::new(bytes)
572 .map_err(InvalidJws::into_owned)?
573 .to_decoded_jwt()?;
574 jwt.verify(verifier).await??;
575
576 Ok(Self::Json(
577 jwt.signing_bytes
578 .payload
579 .try_get::<json::Status>()?
580 .ok_or(EntrySetFromBytesError::MissingStatus)?
581 .into_owned(),
582 ))
583 }
584 _ => todo!(),
591 }
592 }
593}
594
595impl StatusMapEntrySet for AnyStatusListEntrySet {
596 type Entry<'a>
597 = AnyStatusListReference<'a>
598 where
599 Self: 'a;
600
601 fn get_entry(&self, purpose: crate::StatusPurpose<&str>) -> Option<Self::Entry<'_>> {
602 match self {
603 Self::Json(s) => s.get_entry(purpose).map(AnyStatusListReference::Json),
604 }
605 }
606}
607
608pub enum AnyStatusListReference<'a> {
609 Json(&'a json::StatusListReference),
610}
611
612impl StatusMapEntry for AnyStatusListReference<'_> {
613 type Key = usize;
614 type StatusSize = StatusSize;
615
616 fn key(&self) -> Self::Key {
617 match self {
618 Self::Json(e) => e.key(),
619 }
620 }
621
622 fn status_list_url(&self) -> &Uri {
623 match self {
624 Self::Json(e) => e.status_list_url(),
625 }
626 }
627
628 fn status_size(&self) -> Option<Self::StatusSize> {
629 match self {
630 Self::Json(e) => e.status_size(),
631 }
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use flate2::Compression;
638 use rand::{rngs::StdRng, RngCore, SeedableRng};
639
640 use crate::Overflow;
641
642 use super::{json::JsonStatusList, BitString, StatusSize};
643
644 fn random_bit_string(
645 rng: &mut StdRng,
646 status_size: StatusSize,
647 len: usize,
648 ) -> (Vec<u8>, BitString) {
649 let mut values = Vec::with_capacity(len);
650
651 for _ in 0..len {
652 values.push((rng.next_u32() & 0xff) as u8 & status_size.status_mask())
653 }
654
655 let mut bit_string = BitString::new(status_size);
656 for &s in &values {
657 bit_string.push(s).unwrap();
658 }
659
660 (values, bit_string)
661 }
662
663 fn randomized_roundtrip(seed: u64, status_size: StatusSize, len: usize) {
664 let mut rng = StdRng::seed_from_u64(seed);
665 let (values, bit_string) = random_bit_string(&mut rng, status_size, len);
666
667 let encoded = JsonStatusList::encode(&bit_string, Compression::fast());
668 let decoded = encoded.decode(None).unwrap();
669
670 assert!(decoded.len() >= len);
671
672 for (i, item) in values.into_iter().enumerate().take(len) {
673 assert_eq!(decoded.get(i), Some(item))
674 }
675 }
676
677 fn randomized_write(seed: u64, status_size: StatusSize, len: usize) {
678 let mut rng = StdRng::seed_from_u64(seed);
679 let (mut values, mut bit_string) = random_bit_string(&mut rng, status_size, len);
680
681 for _ in 0..len {
682 let i = (rng.next_u32() as usize) % len;
683 let value = (rng.next_u32() & 0xff) as u8 & status_size.status_mask();
684 bit_string.set(i, value).unwrap();
685 values[i] = value;
686 }
687
688 for (i, item) in values.into_iter().enumerate().take(len) {
689 assert_eq!(bit_string.get(i), Some(item))
690 }
691 }
692
693 #[test]
694 fn randomized_roundtrip_1bit() {
695 for i in 0..10 {
696 randomized_roundtrip(i, 1u8.try_into().unwrap(), 10);
697 }
698
699 for i in 0..10 {
700 randomized_roundtrip(i, 1u8.try_into().unwrap(), 100);
701 }
702
703 for i in 0..10 {
704 randomized_roundtrip(i, 1u8.try_into().unwrap(), 1000);
705 }
706 }
707
708 #[test]
709 fn randomized_write_1bits() {
710 for i in 0..10 {
711 randomized_write(i, 1u8.try_into().unwrap(), 10);
712 }
713
714 for i in 0..10 {
715 randomized_write(i, 1u8.try_into().unwrap(), 100);
716 }
717
718 for i in 0..10 {
719 randomized_write(i, 1u8.try_into().unwrap(), 1000);
720 }
721 }
722
723 #[test]
724 fn randomized_roundtrip_2bits() {
725 for i in 0..10 {
726 randomized_roundtrip(i, 2u8.try_into().unwrap(), 10);
727 }
728
729 for i in 0..10 {
730 randomized_roundtrip(i, 2u8.try_into().unwrap(), 100);
731 }
732
733 for i in 0..10 {
734 randomized_roundtrip(i, 2u8.try_into().unwrap(), 1000);
735 }
736 }
737
738 #[test]
739 fn randomized_write_2bits() {
740 for i in 0..10 {
741 randomized_write(i, 2u8.try_into().unwrap(), 10);
742 }
743
744 for i in 0..10 {
745 randomized_write(i, 2u8.try_into().unwrap(), 100);
746 }
747
748 for i in 0..10 {
749 randomized_write(i, 2u8.try_into().unwrap(), 1000);
750 }
751 }
752
753 #[test]
754 fn randomized_roundtrip_4bits() {
755 for i in 0..10 {
756 randomized_roundtrip(i, 4u8.try_into().unwrap(), 10);
757 }
758
759 for i in 0..10 {
760 randomized_roundtrip(i, 4u8.try_into().unwrap(), 100);
761 }
762
763 for i in 0..10 {
764 randomized_roundtrip(i, 4u8.try_into().unwrap(), 1000);
765 }
766 }
767
768 #[test]
769 fn randomized_write_4bits() {
770 for i in 0..10 {
771 randomized_write(i, 4u8.try_into().unwrap(), 10);
772 }
773
774 for i in 0..10 {
775 randomized_write(i, 4u8.try_into().unwrap(), 100);
776 }
777
778 for i in 0..10 {
779 randomized_write(i, 4u8.try_into().unwrap(), 1000);
780 }
781 }
782
783 #[test]
784 fn overflows() {
785 let mut rng = StdRng::seed_from_u64(0);
786 let (_, mut bitstring) = random_bit_string(&mut rng, 1u8.try_into().unwrap(), 15);
787
788 assert!(bitstring.get(15).is_none());
790
791 assert_eq!(bitstring.set(15, 0), Err(Overflow::Index(15)));
793
794 assert_eq!(bitstring.set(14, 2), Err(Overflow::Value(2)));
796 }
797
798 #[test]
799 fn deserialize_status_size_1() {
800 assert!(serde_json::from_str::<StatusSize>("1").is_ok())
801 }
802
803 #[test]
804 fn deserialize_status_size_2() {
805 assert!(serde_json::from_str::<StatusSize>("2").is_ok())
806 }
807
808 #[test]
809 fn deserialize_status_size_4() {
810 assert!(serde_json::from_str::<StatusSize>("4").is_ok())
811 }
812
813 #[test]
814 fn deserialize_status_size_8() {
815 assert!(serde_json::from_str::<StatusSize>("8").is_ok())
816 }
817
818 #[test]
819 fn deserialize_status_size_non_power_of_two() {
820 assert!(serde_json::from_str::<StatusSize>("3").is_err())
821 }
822
823 #[test]
824 fn deserialize_status_size_negative() {
825 assert!(serde_json::from_str::<StatusSize>("-1").is_err())
826 }
827
828 #[test]
829 fn deserialize_status_size_overflow() {
830 assert!(serde_json::from_str::<StatusSize>("9").is_err())
831 }
832}