1use educe::Educe;
2use serde::{ser::SerializeSeq, Deserialize, Serialize};
3use ssi_json_ld::JsonLdTypes;
4use std::{borrow::Cow, marker::PhantomData};
5
6pub trait RequiredType {
7 const REQUIRED_TYPE: &'static str;
8}
9
10pub trait RequiredTypeSet {
11 const REQUIRED_TYPES: &'static [&'static str];
12}
13
14impl RequiredTypeSet for () {
15 const REQUIRED_TYPES: &'static [&'static str] = &[];
16}
17
18impl<T: RequiredType> RequiredTypeSet for T {
19 const REQUIRED_TYPES: &'static [&'static str] = &[T::REQUIRED_TYPE];
20}
21
22pub trait TypeSerializationPolicy {
23 const PREFER_ARRAY: bool;
24}
25
26#[derive(Educe)]
32#[educe(Debug, Clone)]
33pub struct Types<B, T = ()>(Vec<String>, PhantomData<(B, T)>);
34
35impl<B, T: RequiredTypeSet> Default for Types<B, T> {
36 fn default() -> Self {
37 Self(
38 T::REQUIRED_TYPES
39 .iter()
40 .copied()
41 .map(ToOwned::to_owned)
42 .collect(),
43 PhantomData,
44 )
45 }
46}
47
48impl<B, T> Types<B, T> {
49 pub fn additional_types(&self) -> &[String] {
50 &self.0
51 }
52}
53
54impl<B: RequiredType, T> Types<B, T> {
55 pub fn to_json_ld_types(&self) -> JsonLdTypes {
56 JsonLdTypes::new(&[B::REQUIRED_TYPE], Cow::Borrowed(&self.0))
57 }
58}
59
60impl<B: RequiredType + TypeSerializationPolicy, T> Serialize for Types<B, T> {
61 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62 where
63 S: serde::Serializer,
64 {
65 if !B::PREFER_ARRAY && self.0.is_empty() {
66 B::REQUIRED_TYPE.serialize(serializer)
67 } else {
68 let mut seq = serializer.serialize_seq(Some(1 + self.0.len()))?;
69 seq.serialize_element(B::REQUIRED_TYPE)?;
70 for t in &self.0 {
71 seq.serialize_element(t)?;
72 }
73 seq.end()
74 }
75 }
76}
77
78impl<'de, B: RequiredType, T: RequiredTypeSet> Deserialize<'de> for Types<B, T> {
79 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
80 where
81 D: serde::Deserializer<'de>,
82 {
83 struct Visitor<B, T>(PhantomData<(B, T)>);
84
85 impl<'de, B: RequiredType, T: RequiredTypeSet> serde::de::Visitor<'de> for Visitor<B, T> {
86 type Value = Types<B, T>;
87
88 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
89 write!(formatter, "credential types")
90 }
91
92 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
93 where
94 E: serde::de::Error,
95 {
96 if v == B::REQUIRED_TYPE {
97 for &required in T::REQUIRED_TYPES {
98 if required != B::REQUIRED_TYPE {
99 return Err(E::custom(format!(
100 "expected required `{}` type",
101 required
102 )));
103 }
104 }
105
106 Ok(Types(Vec::new(), PhantomData))
107 } else {
108 Err(E::custom(format!(
109 "expected required `{}` type",
110 B::REQUIRED_TYPE
111 )))
112 }
113 }
114
115 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
116 where
117 A: serde::de::SeqAccess<'de>,
118 {
119 let mut base_type = false;
120 let mut types = Vec::new();
121
122 while let Some(t) = seq.next_element()? {
123 if t == B::REQUIRED_TYPE {
124 base_type = true
125 } else {
126 types.push(t)
127 }
128 }
129
130 if !base_type {
131 return Err(<A::Error as serde::de::Error>::custom(format!(
132 "expected required `{}` type",
133 B::REQUIRED_TYPE
134 )));
135 }
136
137 for &required in T::REQUIRED_TYPES {
138 if !types.iter().any(|s| s == required) {
139 return Err(<A::Error as serde::de::Error>::custom(format!(
140 "expected required `{required}` type"
141 )));
142 }
143 }
144
145 Ok(Types(types, PhantomData))
146 }
147 }
148
149 deserializer.deserialize_any(Visitor(PhantomData))
150 }
151}