1use crate::envelope::{Envelope, EnvelopeError};
2use base64::Engine;
3use base64::engine::general_purpose::STANDARD;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum WireFormat {
7 #[default]
8 Json,
9 Protobuf,
10}
11
12pub struct WireCodec {
13 format: WireFormat,
14}
15
16impl WireCodec {
17 pub fn new(format: WireFormat) -> Self {
18 Self { format }
19 }
20
21 pub fn json() -> Self {
22 Self::new(WireFormat::Json)
23 }
24
25 pub fn protobuf() -> Self {
26 Self::new(WireFormat::Protobuf)
27 }
28
29 pub fn encode(&self, envelope: &Envelope) -> Result<Vec<u8>, EnvelopeError> {
30 match self.format {
31 WireFormat::Json => {
32 serde_json::to_vec(envelope).map_err(|e| EnvelopeError::Encode(e.to_string()))
33 }
34 WireFormat::Protobuf => encode_binary(envelope),
35 }
36 }
37
38 pub fn decode(&self, data: &[u8]) -> Result<Envelope, EnvelopeError> {
39 match self.format {
40 WireFormat::Json => {
41 serde_json::from_slice(data).map_err(|e| EnvelopeError::Decode(e.to_string()))
42 }
43 WireFormat::Protobuf => decode_binary(data),
44 }
45 }
46}
47
48fn encode_binary(envelope: &Envelope) -> Result<Vec<u8>, EnvelopeError> {
49 if envelope.request_id.is_some() {
50 return Err(EnvelopeError::Encode(
51 "binary envelope does not support request_id".to_string(),
52 ));
53 }
54
55 let payload_b64 = envelope
56 .payload
57 .as_str()
58 .ok_or_else(|| EnvelopeError::Encode("binary payload must be base64 string".to_string()))?;
59 let payload = STANDARD
60 .decode(payload_b64)
61 .map_err(|e| EnvelopeError::Encode(e.to_string()))?;
62
63 let msg_type_bytes = envelope.msg_type.as_bytes();
64 if msg_type_bytes.len() > u16::MAX as usize {
65 return Err(EnvelopeError::Encode("message type too long".to_string()));
66 }
67
68 let mut out = Vec::with_capacity(2 + msg_type_bytes.len() + payload.len());
69 let len = msg_type_bytes.len() as u16;
70 out.extend_from_slice(&len.to_be_bytes());
71 out.extend_from_slice(msg_type_bytes);
72 out.extend_from_slice(&payload);
73 Ok(out)
74}
75
76fn decode_binary(data: &[u8]) -> Result<Envelope, EnvelopeError> {
77 if data.len() < 2 {
78 return Err(EnvelopeError::Decode(
79 "binary envelope too short".to_string(),
80 ));
81 }
82 let len = u16::from_be_bytes([data[0], data[1]]) as usize;
83 if data.len() < 2 + len {
84 return Err(EnvelopeError::Decode(
85 "binary envelope type length invalid".to_string(),
86 ));
87 }
88 let msg_type = std::str::from_utf8(&data[2..2 + len])
89 .map_err(|e| EnvelopeError::Decode(e.to_string()))?
90 .to_string();
91 let payload = &data[2 + len..];
92 let payload_b64 = STANDARD.encode(payload);
93
94 Ok(Envelope {
95 msg_type,
96 request_id: None,
97 payload: serde_json::Value::String(payload_b64),
98 })
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn json_round_trip() {
107 let envelope = Envelope::event("chat.say", serde_json::json!({"text": "hi"}));
108 let codec = WireCodec::json();
109 let data = codec.encode(&envelope).unwrap();
110 let decoded = codec.decode(&data).unwrap();
111 assert_eq!(decoded.msg_type, "chat.say");
112 }
113
114 #[test]
115 fn binary_round_trip() {
116 let payload = STANDARD.encode(b"payload");
117 let envelope = Envelope::event("event.type", serde_json::Value::String(payload));
118 let codec = WireCodec::protobuf();
119 let data = codec.encode(&envelope).unwrap();
120 let decoded = codec.decode(&data).unwrap();
121 assert_eq!(decoded.msg_type, "event.type");
122 }
123}