1use base64::{engine::general_purpose, Engine};
7use serde::de::{self, MapAccess, Visitor};
8use serde::ser::SerializeMap;
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10
11pub fn base64_decode_flexible(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
13 general_purpose::URL_SAFE_NO_PAD
14 .decode(input)
15 .or_else(|_| general_purpose::URL_SAFE.decode(input))
16 .or_else(|_| general_purpose::STANDARD.decode(input))
17 .or_else(|_| general_purpose::STANDARD_NO_PAD.decode(input))
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum SecurityMode {
30 Plain,
32 Signed,
34 AuthCrypt,
36 AnonCrypt,
38 Any,
40}
41
42pub const PRESENTATION_MESSAGE_TYPE: &str = "https://tap.rsvp/schema/1.0#Presentation";
47
48pub const DIDCOMM_SIGNED: &str = "application/didcomm-signed+json";
49pub const DIDCOMM_ENCRYPTED: &str = "application/didcomm-encrypted+json";
50
51#[derive(Debug)]
61pub struct Jws {
62 pub payload: String,
63 pub signatures: Vec<JwsSignature>,
64}
65
66impl Serialize for Jws {
67 fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
68 if self.signatures.len() == 1 {
69 let sig = &self.signatures[0];
71 let mut map = serializer.serialize_map(Some(3))?;
72 map.serialize_entry("payload", &self.payload)?;
73 map.serialize_entry("protected", &sig.protected)?;
74 map.serialize_entry("signature", &sig.signature)?;
75 map.end()
76 } else {
77 let mut map = serializer.serialize_map(Some(2))?;
79 map.serialize_entry("payload", &self.payload)?;
80 map.serialize_entry("signatures", &self.signatures)?;
81 map.end()
82 }
83 }
84}
85
86impl<'de> Deserialize<'de> for Jws {
87 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
88 struct JwsVisitor;
89
90 impl<'de> Visitor<'de> for JwsVisitor {
91 type Value = Jws;
92
93 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
94 formatter.write_str("a JWS in General or Flattened serialization")
95 }
96
97 fn visit_map<M: MapAccess<'de>>(
98 self,
99 mut map: M,
100 ) -> std::result::Result<Jws, M::Error> {
101 let mut payload: Option<String> = None;
102 let mut signatures: Option<Vec<JwsSignature>> = None;
103 let mut protected: Option<String> = None;
105 let mut signature: Option<String> = None;
106
107 while let Some(key) = map.next_key::<String>()? {
108 match key.as_str() {
109 "payload" => payload = Some(map.next_value()?),
110 "signatures" => signatures = Some(map.next_value()?),
111 "protected" => protected = Some(map.next_value()?),
112 "signature" => signature = Some(map.next_value()?),
113 _ => {
114 let _: serde_json::Value = map.next_value()?;
115 }
116 }
117 }
118
119 let payload = payload.ok_or_else(|| de::Error::missing_field("payload"))?;
120
121 if let Some(sigs) = signatures {
123 Ok(Jws {
124 payload,
125 signatures: sigs,
126 })
127 } else if let (Some(prot), Some(sig)) = (protected, signature) {
128 Ok(Jws {
130 payload,
131 signatures: vec![JwsSignature {
132 protected: prot,
133 signature: sig,
134 }],
135 })
136 } else {
137 Err(de::Error::custom(
138 "JWS must have either 'signatures' array or 'protected'+'signature' fields",
139 ))
140 }
141 }
142 }
143
144 deserializer.deserialize_map(JwsVisitor)
145 }
146}
147
148#[derive(Serialize, Deserialize, Debug)]
149pub struct JwsSignature {
150 pub protected: String,
151 pub signature: String,
152}
153
154#[derive(Serialize, Deserialize, Debug, Clone)]
156pub struct JwsProtected {
157 #[serde(default = "default_didcomm_signed")]
158 pub typ: String,
159 pub alg: String,
160 pub kid: String,
161}
162
163fn default_didcomm_signed() -> String {
165 DIDCOMM_SIGNED.to_string()
166}
167
168impl JwsSignature {
169 pub fn get_kid(&self) -> Option<String> {
171 if let Ok(protected_bytes) = base64_decode_flexible(&self.protected) {
172 if let Ok(protected) = serde_json::from_slice::<JwsProtected>(&protected_bytes) {
173 return Some(protected.kid);
174 }
175 }
176 None
177 }
178
179 pub fn get_protected_header(&self) -> Result<JwsProtected, Box<dyn std::error::Error>> {
181 let protected_bytes = base64_decode_flexible(&self.protected)?;
182 let protected = serde_json::from_slice::<JwsProtected>(&protected_bytes)?;
183 Ok(protected)
184 }
185}
186#[derive(Serialize, Deserialize, Debug)]
189pub struct Jwe {
190 pub ciphertext: String,
191 pub protected: String,
192 pub recipients: Vec<JweRecipient>,
193 pub tag: String,
194 pub iv: String,
195}
196
197#[derive(Serialize, Deserialize, Debug)]
198pub struct JweRecipient {
199 pub encrypted_key: String,
200 pub header: JweHeader,
201}
202
203#[derive(Serialize, Deserialize, Debug)]
204pub struct JweHeader {
205 pub kid: String,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 pub sender_kid: Option<String>,
208}
209
210#[derive(Serialize, Deserialize, Debug)]
212pub struct JweProtected {
213 pub epk: EphemeralPublicKey,
214 #[serde(default, skip_serializing_if = "String::is_empty")]
215 pub apv: String,
216 #[serde(default, skip_serializing_if = "String::is_empty")]
217 pub apu: String,
218 #[serde(default = "default_didcomm_encrypted")]
219 pub typ: String,
220 pub enc: String,
221 pub alg: String,
222}
223
224fn default_didcomm_encrypted() -> String {
226 DIDCOMM_ENCRYPTED.to_string()
227}
228
229#[derive(Serialize, Deserialize, Debug)]
231#[serde(tag = "kty")]
232pub enum EphemeralPublicKey {
233 #[serde(rename = "EC")]
234 Ec { crv: String, x: String, y: String },
235 #[serde(rename = "OKP")]
236 Okp { crv: String, x: String },
237}