Skip to main content

tap_agent/
message.rs

1//! Message types and utilities for the TAP Agent.
2//!
3//! This module provides constants and types for working with TAP messages,
4//! including security modes and message type identifiers.
5
6use base64::{engine::general_purpose, Engine};
7use serde::de::{self, MapAccess, Visitor};
8use serde::ser::SerializeMap;
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10
11/// Decode a base64-encoded string, accepting both standard base64 and base64url (with or without padding).
12pub fn base64_decode_flexible(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
13    general_purpose::URL_SAFE_NO_PAD
14        .decode(input)
15        .or_else(|_| general_purpose::URL_SAFE.decode(input))
16        .or_else(|_| general_purpose::STANDARD.decode(input))
17        .or_else(|_| general_purpose::STANDARD_NO_PAD.decode(input))
18}
19
20/// Security mode for message packing and unpacking.
21///
22/// Defines the level of protection applied to messages:
23/// - `Plain`: No encryption or signing (insecure, only for testing)
24/// - `Signed`: Message is signed but not encrypted (integrity protected)
25/// - `AuthCrypt`: Message is authenticated and encrypted (confidentiality + integrity, sender revealed)
26/// - `AnonCrypt`: Message is anonymously encrypted (confidentiality only, sender hidden)
27/// - `Any`: Accept any security mode when unpacking (only used for receiving)
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum SecurityMode {
30    /// Plaintext - no encryption or signatures
31    Plain,
32    /// Signed - message is signed but not encrypted
33    Signed,
34    /// Authenticated and Encrypted - message is both signed and encrypted (sender revealed)
35    AuthCrypt,
36    /// Anonymous Encrypted - message is encrypted but not signed (sender hidden)
37    AnonCrypt,
38    /// Any security mode - used for unpacking when any mode is acceptable
39    Any,
40}
41
42/// Message type identifiers used by the TAP Protocol
43/// These constant strings are used to identify different message types
44/// in the TAP protocol communications.
45/// Type identifier for Presentation messages
46pub const PRESENTATION_MESSAGE_TYPE: &str = "https://tap.rsvp/schema/1.0#Presentation";
47
48pub const DIDCOMM_SIGNED: &str = "application/didcomm-signed+json";
49pub const DIDCOMM_ENCRYPTED: &str = "application/didcomm-encrypted+json";
50
51// JWS-related types
52
53/// JWS (JSON Web Signature) supporting both General and Flattened serializations per RFC 7515.
54///
55/// When serializing:
56/// - Single signature: uses Flattened JWS format (`protected`, `payload`, `signature` at top level)
57/// - Multiple signatures: uses General JWS format (`payload`, `signatures` array)
58///
59/// When deserializing: accepts both formats.
60#[derive(Debug)]
61pub struct Jws {
62    pub payload: String,
63    pub signatures: Vec<JwsSignature>,
64}
65
66impl Serialize for Jws {
67    fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
68        if self.signatures.len() == 1 {
69            // Flattened JWS: { "protected", "payload", "signature" }
70            let sig = &self.signatures[0];
71            let mut map = serializer.serialize_map(Some(3))?;
72            map.serialize_entry("payload", &self.payload)?;
73            map.serialize_entry("protected", &sig.protected)?;
74            map.serialize_entry("signature", &sig.signature)?;
75            map.end()
76        } else {
77            // General JWS: { "payload", "signatures": [...] }
78            let mut map = serializer.serialize_map(Some(2))?;
79            map.serialize_entry("payload", &self.payload)?;
80            map.serialize_entry("signatures", &self.signatures)?;
81            map.end()
82        }
83    }
84}
85
86impl<'de> Deserialize<'de> for Jws {
87    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
88        struct JwsVisitor;
89
90        impl<'de> Visitor<'de> for JwsVisitor {
91            type Value = Jws;
92
93            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
94                formatter.write_str("a JWS in General or Flattened serialization")
95            }
96
97            fn visit_map<M: MapAccess<'de>>(
98                self,
99                mut map: M,
100            ) -> std::result::Result<Jws, M::Error> {
101                let mut payload: Option<String> = None;
102                let mut signatures: Option<Vec<JwsSignature>> = None;
103                // Flattened fields
104                let mut protected: Option<String> = None;
105                let mut signature: Option<String> = None;
106
107                while let Some(key) = map.next_key::<String>()? {
108                    match key.as_str() {
109                        "payload" => payload = Some(map.next_value()?),
110                        "signatures" => signatures = Some(map.next_value()?),
111                        "protected" => protected = Some(map.next_value()?),
112                        "signature" => signature = Some(map.next_value()?),
113                        _ => {
114                            let _: serde_json::Value = map.next_value()?;
115                        }
116                    }
117                }
118
119                let payload = payload.ok_or_else(|| de::Error::missing_field("payload"))?;
120
121                // Prefer General format if "signatures" is present
122                if let Some(sigs) = signatures {
123                    Ok(Jws {
124                        payload,
125                        signatures: sigs,
126                    })
127                } else if let (Some(prot), Some(sig)) = (protected, signature) {
128                    // Flattened format
129                    Ok(Jws {
130                        payload,
131                        signatures: vec![JwsSignature {
132                            protected: prot,
133                            signature: sig,
134                        }],
135                    })
136                } else {
137                    Err(de::Error::custom(
138                        "JWS must have either 'signatures' array or 'protected'+'signature' fields",
139                    ))
140                }
141            }
142        }
143
144        deserializer.deserialize_map(JwsVisitor)
145    }
146}
147
148#[derive(Serialize, Deserialize, Debug)]
149pub struct JwsSignature {
150    pub protected: String,
151    pub signature: String,
152}
153
154// Structure for decoded JWS protected field
155#[derive(Serialize, Deserialize, Debug, Clone)]
156pub struct JwsProtected {
157    #[serde(default = "default_didcomm_signed")]
158    pub typ: String,
159    pub alg: String,
160    pub kid: String,
161}
162
163// Helper function for JwsProtected typ default
164fn default_didcomm_signed() -> String {
165    DIDCOMM_SIGNED.to_string()
166}
167
168impl JwsSignature {
169    /// Extracts the kid (key identifier) from the protected header
170    pub fn get_kid(&self) -> Option<String> {
171        if let Ok(protected_bytes) = base64_decode_flexible(&self.protected) {
172            if let Ok(protected) = serde_json::from_slice::<JwsProtected>(&protected_bytes) {
173                return Some(protected.kid);
174            }
175        }
176        None
177    }
178
179    /// Decodes and returns the protected header
180    pub fn get_protected_header(&self) -> Result<JwsProtected, Box<dyn std::error::Error>> {
181        let protected_bytes = base64_decode_flexible(&self.protected)?;
182        let protected = serde_json::from_slice::<JwsProtected>(&protected_bytes)?;
183        Ok(protected)
184    }
185}
186// JWE-related types
187
188#[derive(Serialize, Deserialize, Debug)]
189pub struct Jwe {
190    pub ciphertext: String,
191    pub protected: String,
192    pub recipients: Vec<JweRecipient>,
193    pub tag: String,
194    pub iv: String,
195}
196
197#[derive(Serialize, Deserialize, Debug)]
198pub struct JweRecipient {
199    pub encrypted_key: String,
200    pub header: JweHeader,
201}
202
203#[derive(Serialize, Deserialize, Debug)]
204pub struct JweHeader {
205    pub kid: String,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    pub sender_kid: Option<String>,
208}
209
210// Structure for decoded JWE protected field
211#[derive(Serialize, Deserialize, Debug)]
212pub struct JweProtected {
213    pub epk: EphemeralPublicKey,
214    #[serde(default, skip_serializing_if = "String::is_empty")]
215    pub apv: String,
216    #[serde(default, skip_serializing_if = "String::is_empty")]
217    pub apu: String,
218    #[serde(default = "default_didcomm_encrypted")]
219    pub typ: String,
220    pub enc: String,
221    pub alg: String,
222}
223
224// Helper function for JweProtected typ default
225fn default_didcomm_encrypted() -> String {
226    DIDCOMM_ENCRYPTED.to_string()
227}
228
229// Enum to handle different ephemeral public key types
230#[derive(Serialize, Deserialize, Debug)]
231#[serde(tag = "kty")]
232pub enum EphemeralPublicKey {
233    #[serde(rename = "EC")]
234    Ec { crv: String, x: String, y: String },
235    #[serde(rename = "OKP")]
236    Okp { crv: String, x: String },
237}