1use crate::agent_key::VerificationKey;
8use crate::error::{Error, Result};
9use crate::message::{Jwe, Jws, SecurityMode};
10use async_trait::async_trait;
11use base64::Engine;
12use serde::de::DeserializeOwned;
13use serde::Serialize;
14use serde_json::Value;
15use std::any::Any;
16use std::fmt::Debug;
17use std::sync::Arc;
18use tap_msg::didcomm::PlainMessage;
19use uuid::Uuid;
20
21#[derive(Debug, thiserror::Error)]
23pub enum MessageError {
24 #[error("Serialization error: {0}")]
25 Serialization(#[from] serde_json::Error),
26
27 #[error("Key manager error: {0}")]
28 KeyManager(String),
29
30 #[error("Crypto operation failed: {0}")]
31 Crypto(String),
32
33 #[error("Invalid message format: {0}")]
34 InvalidFormat(String),
35
36 #[error("Unsupported security mode: {0:?}")]
37 UnsupportedSecurityMode(SecurityMode),
38
39 #[error("Missing required parameter: {0}")]
40 MissingParameter(String),
41
42 #[error("Key not found: {0}")]
43 KeyNotFound(String),
44
45 #[error("Verification failed")]
46 VerificationFailed,
47
48 #[error("Decryption failed")]
49 DecryptionFailed,
50}
51
52impl From<MessageError> for Error {
53 fn from(err: MessageError) -> Self {
54 match err {
55 MessageError::Serialization(e) => Error::Serialization(e.to_string()),
56 MessageError::KeyManager(e) => Error::Cryptography(e),
57 MessageError::Crypto(e) => Error::Cryptography(e),
58 MessageError::InvalidFormat(e) => Error::Validation(e),
59 MessageError::UnsupportedSecurityMode(mode) => {
60 Error::Validation(format!("Unsupported security mode: {:?}", mode))
61 }
62 MessageError::MissingParameter(e) => {
63 Error::Validation(format!("Missing parameter: {}", e))
64 }
65 MessageError::KeyNotFound(e) => Error::Cryptography(format!("Key not found: {}", e)),
66 MessageError::VerificationFailed => {
67 Error::Cryptography("Verification failed".to_string())
68 }
69 MessageError::DecryptionFailed => Error::Cryptography("Decryption failed".to_string()),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct PackOptions {
77 pub security_mode: SecurityMode,
79 pub recipient_kid: Option<String>,
81 pub sender_kid: Option<String>,
83}
84
85impl Default for PackOptions {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl PackOptions {
92 pub fn new() -> Self {
94 Self {
95 security_mode: SecurityMode::Plain,
96 recipient_kid: None,
97 sender_kid: None,
98 }
99 }
100
101 pub fn with_plain(mut self) -> Self {
103 self.security_mode = SecurityMode::Plain;
104 self
105 }
106
107 pub fn with_sign(mut self, sender_kid: &str) -> Self {
109 self.security_mode = SecurityMode::Signed;
110 self.sender_kid = Some(sender_kid.to_string());
111 self
112 }
113
114 pub fn with_auth_crypt(mut self, sender_kid: &str, recipient_jwk: &serde_json::Value) -> Self {
116 self.security_mode = SecurityMode::AuthCrypt;
117 self.sender_kid = Some(sender_kid.to_string());
118
119 if let Some(kid) = recipient_jwk.get("kid").and_then(|k| k.as_str()) {
121 self.recipient_kid = Some(kid.to_string());
122 }
123
124 self
125 }
126
127 pub fn security_mode(&self) -> SecurityMode {
129 self.security_mode
130 }
131}
132
133#[derive(Debug, Clone)]
135pub struct UnpackOptions {
136 pub expected_security_mode: SecurityMode,
138 pub expected_recipient_kid: Option<String>,
140 pub require_signature: bool,
142}
143
144impl Default for UnpackOptions {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150impl UnpackOptions {
151 pub fn new() -> Self {
153 Self {
154 expected_security_mode: SecurityMode::Any,
155 expected_recipient_kid: None,
156 require_signature: false,
157 }
158 }
159
160 pub fn with_require_signature(mut self, require: bool) -> Self {
162 self.require_signature = require;
163 self
164 }
165}
166
167#[async_trait]
169pub trait Packable<Output = String>: Sized {
170 async fn pack(
172 &self,
173 key_manager: &(impl KeyManagerPacking + ?Sized),
174 options: PackOptions,
175 ) -> Result<Output>;
176}
177
178#[async_trait]
180pub trait Unpackable<Input, Output = PlainMessage>: Sized {
181 async fn unpack(
183 packed_message: &Input,
184 key_manager: &(impl KeyManagerPacking + ?Sized),
185 options: UnpackOptions,
186 ) -> Result<Output>;
187}
188
189#[async_trait]
191pub trait KeyManagerPacking: Send + Sync + Debug {
192 async fn get_signing_key(
194 &self,
195 kid: &str,
196 ) -> Result<Arc<dyn crate::agent_key::SigningKey + Send + Sync>>;
197
198 async fn get_encryption_key(
200 &self,
201 kid: &str,
202 ) -> Result<Arc<dyn crate::agent_key::EncryptionKey + Send + Sync>>;
203
204 async fn get_decryption_key(
206 &self,
207 kid: &str,
208 ) -> Result<Arc<dyn crate::agent_key::DecryptionKey + Send + Sync>>;
209
210 async fn resolve_verification_key(
212 &self,
213 kid: &str,
214 ) -> Result<Arc<dyn VerificationKey + Send + Sync>>;
215}
216
217#[async_trait]
219impl Packable for PlainMessage {
220 async fn pack(
221 &self,
222 key_manager: &(impl KeyManagerPacking + ?Sized),
223 options: PackOptions,
224 ) -> Result<String> {
225 match options.security_mode {
226 SecurityMode::Plain => {
227 serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))
229 }
230 SecurityMode::Signed => {
231 let sender_kid = options.sender_kid.clone().ok_or_else(|| {
233 Error::Validation("Signed mode requires sender_kid".to_string())
234 })?;
235
236 let signing_key = key_manager.get_signing_key(&sender_kid).await?;
238
239 let payload =
241 serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))?;
242
243 let jws = signing_key
245 .create_jws(payload.as_bytes(), None)
246 .await
247 .map_err(|e| Error::Cryptography(format!("Failed to create JWS: {}", e)))?;
248
249 serde_json::to_string(&jws).map_err(|e| Error::Serialization(e.to_string()))
251 }
252 SecurityMode::AuthCrypt => {
253 let sender_kid = options.sender_kid.clone().ok_or_else(|| {
255 Error::Validation("AuthCrypt mode requires sender_kid".to_string())
256 })?;
257
258 let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
259 Error::Validation("AuthCrypt mode requires recipient_kid".to_string())
260 })?;
261
262 let encryption_key = key_manager.get_encryption_key(&sender_kid).await?;
264
265 let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
267
268 let plaintext =
270 serde_json::to_string(self).map_err(|e| Error::Serialization(e.to_string()))?;
271
272 let jwe = encryption_key
274 .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
275 .await
276 .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
277
278 serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
280 }
281 SecurityMode::Any => {
282 Err(Error::Validation(
284 "SecurityMode::Any is not valid for packing".to_string(),
285 ))
286 }
287 }
288 }
289}
290
291pub async fn pack_any<T>(
294 obj: &T,
295 key_manager: &(impl KeyManagerPacking + ?Sized),
296 options: PackOptions,
297) -> Result<String>
298where
299 T: Serialize + Send + Sync + std::fmt::Debug + 'static + Sized,
300{
301 if obj.type_id() == std::any::TypeId::of::<PlainMessage>() {
305 let value = serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
307 let plain_msg: PlainMessage =
308 serde_json::from_value(value).map_err(|e| Error::Serialization(e.to_string()))?;
309 return plain_msg.pack(key_manager, options).await;
310 }
311
312 match options.security_mode {
314 SecurityMode::Plain => {
315 serde_json::to_string(obj).map_err(|e| Error::Serialization(e.to_string()))
317 }
318 SecurityMode::Signed => {
319 let sender_kid = options
321 .sender_kid
322 .clone()
323 .ok_or_else(|| Error::Validation("Signed mode requires sender_kid".to_string()))?;
324
325 let signing_key = key_manager.get_signing_key(&sender_kid).await?;
327
328 let value =
330 serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
331
332 let obj = value
334 .as_object()
335 .ok_or_else(|| Error::Validation("Message is not a JSON object".to_string()))?;
336
337 let id_string = obj
339 .get("id")
340 .map(|v| v.as_str().unwrap_or_default().to_string())
341 .unwrap_or_else(|| Uuid::new_v4().to_string());
342 let id = id_string.as_str();
343
344 let msg_type = obj
346 .get("type")
347 .and_then(|v| v.as_str())
348 .unwrap_or("https://tap.rsvp/schema/1.0/message");
349
350 let from = options.sender_kid.as_ref().map(|kid| {
352 kid.split('#').next().unwrap_or(kid).to_string()
354 });
355
356 let to = if let Some(kid) = &options.recipient_kid {
357 let did = kid.split('#').next().unwrap_or(kid).to_string();
359 vec![did]
360 } else {
361 vec![]
362 };
363
364 let plain_message = PlainMessage {
366 id: id.to_string(),
367 typ: "application/didcomm-plain+json".to_string(),
368 type_: msg_type.to_string(),
369 body: value,
370 from: from.unwrap_or_default(),
371 to,
372 thid: None,
373 pthid: None,
374 created_time: Some(chrono::Utc::now().timestamp() as u64),
375 expires_time: None,
376 from_prior: None,
377 attachments: None,
378 extra_headers: std::collections::HashMap::new(),
379 };
380
381 let payload = serde_json::to_string(&plain_message)
383 .map_err(|e| Error::Serialization(e.to_string()))?;
384
385 let jws = signing_key
387 .create_jws(payload.as_bytes(), None)
388 .await
389 .map_err(|e| Error::Cryptography(format!("Failed to create JWS: {}", e)))?;
390
391 serde_json::to_string(&jws).map_err(|e| Error::Serialization(e.to_string()))
393 }
394 SecurityMode::AuthCrypt => {
395 let sender_kid = options.sender_kid.clone().ok_or_else(|| {
397 Error::Validation("AuthCrypt mode requires sender_kid".to_string())
398 })?;
399
400 let recipient_kid = options.recipient_kid.clone().ok_or_else(|| {
401 Error::Validation("AuthCrypt mode requires recipient_kid".to_string())
402 })?;
403
404 let encryption_key = key_manager.get_encryption_key(&sender_kid).await?;
406
407 let recipient_key = key_manager.resolve_verification_key(&recipient_kid).await?;
409
410 let value =
412 serde_json::to_value(obj).map_err(|e| Error::Serialization(e.to_string()))?;
413
414 let obj = value
416 .as_object()
417 .ok_or_else(|| Error::Validation("Message is not a JSON object".to_string()))?;
418
419 let id_string = obj
421 .get("id")
422 .map(|v| v.as_str().unwrap_or_default().to_string())
423 .unwrap_or_else(|| Uuid::new_v4().to_string());
424 let id = id_string.as_str();
425
426 let msg_type = obj
428 .get("type")
429 .and_then(|v| v.as_str())
430 .unwrap_or("https://tap.rsvp/schema/1.0/message");
431
432 let from = options.sender_kid.as_ref().map(|kid| {
434 kid.split('#').next().unwrap_or(kid).to_string()
436 });
437
438 let to = if let Some(kid) = &options.recipient_kid {
439 let did = kid.split('#').next().unwrap_or(kid).to_string();
441 vec![did]
442 } else {
443 vec![]
444 };
445
446 let plain_message = PlainMessage {
448 id: id.to_string(),
449 typ: "application/didcomm-plain+json".to_string(),
450 type_: msg_type.to_string(),
451 body: value,
452 from: from.unwrap_or_default(),
453 to,
454 thid: None,
455 pthid: None,
456 created_time: Some(chrono::Utc::now().timestamp() as u64),
457 expires_time: None,
458 from_prior: None,
459 attachments: None,
460 extra_headers: std::collections::HashMap::new(),
461 };
462
463 let plaintext = serde_json::to_string(&plain_message)
465 .map_err(|e| Error::Serialization(e.to_string()))?;
466
467 let jwe = encryption_key
469 .create_jwe(plaintext.as_bytes(), &[recipient_key], None)
470 .await
471 .map_err(|e| Error::Cryptography(format!("Failed to create JWE: {}", e)))?;
472
473 serde_json::to_string(&jwe).map_err(|e| Error::Serialization(e.to_string()))
475 }
476 SecurityMode::Any => {
477 Err(Error::Validation(
479 "SecurityMode::Any is not valid for packing".to_string(),
480 ))
481 }
482 }
483}
484
485#[async_trait]
487impl<T: DeserializeOwned + Send + 'static> Unpackable<Jws, T> for Jws {
488 async fn unpack(
489 packed_message: &Jws,
490 key_manager: &(impl KeyManagerPacking + ?Sized),
491 _options: UnpackOptions,
492 ) -> Result<T> {
493 let payload_bytes = base64::engine::general_purpose::STANDARD
495 .decode(&packed_message.payload)
496 .map_err(|e| Error::Cryptography(format!("Failed to decode JWS payload: {}", e)))?;
497
498 let payload_str = String::from_utf8(payload_bytes)
500 .map_err(|e| Error::Validation(format!("Invalid UTF-8 in payload: {}", e)))?;
501
502 let plain_message: PlainMessage =
504 serde_json::from_str(&payload_str).map_err(|e| Error::Serialization(e.to_string()))?;
505
506 let mut verified = false;
508
509 for signature in &packed_message.signatures {
510 let protected_bytes = base64::engine::general_purpose::STANDARD
512 .decode(&signature.protected)
513 .map_err(|e| {
514 Error::Cryptography(format!("Failed to decode protected header: {}", e))
515 })?;
516
517 let protected: crate::message::JwsProtected = serde_json::from_slice(&protected_bytes)
519 .map_err(|e| {
520 Error::Serialization(format!("Failed to parse protected header: {}", e))
521 })?;
522
523 let kid = &signature.header.kid;
525
526 let verification_key = match key_manager.resolve_verification_key(kid).await {
528 Ok(key) => key,
529 Err(_) => continue, };
531
532 let signature_bytes = base64::engine::general_purpose::STANDARD
534 .decode(&signature.signature)
535 .map_err(|e| Error::Cryptography(format!("Failed to decode signature: {}", e)))?;
536
537 let signing_input = format!("{}.{}", signature.protected, packed_message.payload);
539
540 match verification_key
542 .verify_signature(signing_input.as_bytes(), &signature_bytes, &protected)
543 .await
544 {
545 Ok(true) => {
546 verified = true;
547 break;
548 }
549 _ => continue,
550 }
551 }
552
553 if !verified {
554 return Err(Error::Cryptography(
555 "Signature verification failed".to_string(),
556 ));
557 }
558
559 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
561 let result = serde_json::to_value(plain_message).unwrap();
563 return serde_json::from_value(result).map_err(|e| Error::Serialization(e.to_string()));
564 }
565
566 serde_json::from_value(plain_message.body).map_err(|e| Error::Serialization(e.to_string()))
568 }
569}
570
571#[async_trait]
573impl<T: DeserializeOwned + Send + 'static> Unpackable<Jwe, T> for Jwe {
574 async fn unpack(
575 packed_message: &Jwe,
576 key_manager: &(impl KeyManagerPacking + ?Sized),
577 options: UnpackOptions,
578 ) -> Result<T> {
579 let recipients = if let Some(kid) = &options.expected_recipient_kid {
581 packed_message
583 .recipients
584 .iter()
585 .filter(|r| r.header.kid == *kid)
586 .collect::<Vec<_>>()
587 } else {
588 packed_message.recipients.iter().collect::<Vec<_>>()
590 };
591
592 for recipient in recipients {
594 let kid = &recipient.header.kid;
596
597 let decryption_key = match key_manager.get_decryption_key(kid).await {
599 Ok(key) => key,
600 Err(_) => continue, };
602
603 match decryption_key.unwrap_jwe(packed_message).await {
605 Ok(plaintext) => {
606 let plaintext_str = String::from_utf8(plaintext).map_err(|e| {
608 Error::Validation(format!("Invalid UTF-8 in plaintext: {}", e))
609 })?;
610
611 let plain_message: PlainMessage = match serde_json::from_str(&plaintext_str) {
613 Ok(msg) => msg,
614 Err(e) => {
615 return Err(Error::Serialization(e.to_string()));
616 }
617 };
618
619 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
621 let result = serde_json::to_value(plain_message).unwrap();
623 return serde_json::from_value(result)
624 .map_err(|e| Error::Serialization(e.to_string()));
625 }
626
627 return serde_json::from_value(plain_message.body)
629 .map_err(|e| Error::Serialization(e.to_string()));
630 }
631 Err(_) => continue, }
633 }
634
635 Err(Error::Cryptography("Failed to decrypt message".to_string()))
637 }
638}
639
640#[async_trait]
642impl<T: DeserializeOwned + Send + 'static> Unpackable<String, T> for String {
643 async fn unpack(
644 packed_message: &String,
645 key_manager: &(impl KeyManagerPacking + ?Sized),
646 options: UnpackOptions,
647 ) -> Result<T> {
648 if let Ok(value) = serde_json::from_str::<Value>(packed_message) {
650 if value.get("payload").is_some() && value.get("signatures").is_some() {
652 let jws: Jws = serde_json::from_str(packed_message)
654 .map_err(|e| Error::Serialization(e.to_string()))?;
655
656 return Jws::unpack(&jws, key_manager, options).await;
657 }
658
659 if value.get("ciphertext").is_some()
661 && value.get("protected").is_some()
662 && value.get("recipients").is_some()
663 {
664 let jwe: Jwe = serde_json::from_str(packed_message)
666 .map_err(|e| Error::Serialization(e.to_string()))?;
667
668 return Jwe::unpack(&jwe, key_manager, options).await;
669 }
670
671 if value.get("body").is_some() && value.get("type").is_some() {
673 let plain: PlainMessage = serde_json::from_str(packed_message)
675 .map_err(|e| Error::Serialization(e.to_string()))?;
676
677 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<PlainMessage>() {
679 let result = serde_json::to_value(plain).unwrap();
681 return serde_json::from_value(result)
682 .map_err(|e| Error::Serialization(e.to_string()));
683 }
684
685 return serde_json::from_value(plain.body)
687 .map_err(|e| Error::Serialization(e.to_string()));
688 }
689
690 return serde_json::from_value(value).map_err(|e| Error::Serialization(e.to_string()));
692 }
693
694 Err(Error::Validation("Message is not valid JSON".to_string()))
696 }
697}