ssi_vc/syntax/
types.rs

1use educe::Educe;
2use serde::{ser::SerializeSeq, Deserialize, Serialize};
3use ssi_json_ld::JsonLdTypes;
4use std::{borrow::Cow, marker::PhantomData};
5
6pub trait RequiredType {
7    const REQUIRED_TYPE: &'static str;
8}
9
10pub trait RequiredTypeSet {
11    const REQUIRED_TYPES: &'static [&'static str];
12}
13
14impl RequiredTypeSet for () {
15    const REQUIRED_TYPES: &'static [&'static str] = &[];
16}
17
18impl<T: RequiredType> RequiredTypeSet for T {
19    const REQUIRED_TYPES: &'static [&'static str] = &[T::REQUIRED_TYPE];
20}
21
22pub trait TypeSerializationPolicy {
23    const PREFER_ARRAY: bool;
24}
25
26/// List of types.
27///
28/// An unordered list of types that must include `B` (a base type) implementing
29/// [`RequiredType`], and more required types given by `T` implementing
30/// [`RequiredTypeSet`].
31#[derive(Educe)]
32#[educe(Debug, Clone)]
33pub struct Types<B, T = ()>(Vec<String>, PhantomData<(B, T)>);
34
35impl<B, T: RequiredTypeSet> Default for Types<B, T> {
36    fn default() -> Self {
37        Self(
38            T::REQUIRED_TYPES
39                .iter()
40                .copied()
41                .map(ToOwned::to_owned)
42                .collect(),
43            PhantomData,
44        )
45    }
46}
47
48impl<B, T> Types<B, T> {
49    pub fn additional_types(&self) -> &[String] {
50        &self.0
51    }
52}
53
54impl<B: RequiredType, T> Types<B, T> {
55    pub fn to_json_ld_types(&self) -> JsonLdTypes {
56        JsonLdTypes::new(&[B::REQUIRED_TYPE], Cow::Borrowed(&self.0))
57    }
58}
59
60impl<B: RequiredType + TypeSerializationPolicy, T> Serialize for Types<B, T> {
61    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62    where
63        S: serde::Serializer,
64    {
65        if !B::PREFER_ARRAY && self.0.is_empty() {
66            B::REQUIRED_TYPE.serialize(serializer)
67        } else {
68            let mut seq = serializer.serialize_seq(Some(1 + self.0.len()))?;
69            seq.serialize_element(B::REQUIRED_TYPE)?;
70            for t in &self.0 {
71                seq.serialize_element(t)?;
72            }
73            seq.end()
74        }
75    }
76}
77
78impl<'de, B: RequiredType, T: RequiredTypeSet> Deserialize<'de> for Types<B, T> {
79    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
80    where
81        D: serde::Deserializer<'de>,
82    {
83        struct Visitor<B, T>(PhantomData<(B, T)>);
84
85        impl<'de, B: RequiredType, T: RequiredTypeSet> serde::de::Visitor<'de> for Visitor<B, T> {
86            type Value = Types<B, T>;
87
88            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
89                write!(formatter, "credential types")
90            }
91
92            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
93            where
94                E: serde::de::Error,
95            {
96                if v == B::REQUIRED_TYPE {
97                    for &required in T::REQUIRED_TYPES {
98                        if required != B::REQUIRED_TYPE {
99                            return Err(E::custom(format!(
100                                "expected required `{}` type",
101                                required
102                            )));
103                        }
104                    }
105
106                    Ok(Types(Vec::new(), PhantomData))
107                } else {
108                    Err(E::custom(format!(
109                        "expected required `{}` type",
110                        B::REQUIRED_TYPE
111                    )))
112                }
113            }
114
115            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
116            where
117                A: serde::de::SeqAccess<'de>,
118            {
119                let mut base_type = false;
120                let mut types = Vec::new();
121
122                while let Some(t) = seq.next_element()? {
123                    if t == B::REQUIRED_TYPE {
124                        base_type = true
125                    } else {
126                        types.push(t)
127                    }
128                }
129
130                if !base_type {
131                    return Err(<A::Error as serde::de::Error>::custom(format!(
132                        "expected required `{}` type",
133                        B::REQUIRED_TYPE
134                    )));
135                }
136
137                for &required in T::REQUIRED_TYPES {
138                    if !types.iter().any(|s| s == required) {
139                        return Err(<A::Error as serde::de::Error>::custom(format!(
140                            "expected required `{required}` type"
141                        )));
142                    }
143                }
144
145                Ok(Types(types, PhantomData))
146            }
147        }
148
149        deserializer.deserialize_any(Visitor(PhantomData))
150    }
151}