Skip to main content

valinor_wire/
codec.rs

1use crate::envelope::{Envelope, EnvelopeError};
2use base64::Engine;
3use base64::engine::general_purpose::STANDARD;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum WireFormat {
7    #[default]
8    Json,
9    Protobuf,
10}
11
12pub struct WireCodec {
13    format: WireFormat,
14}
15
16impl WireCodec {
17    pub fn new(format: WireFormat) -> Self {
18        Self { format }
19    }
20
21    pub fn json() -> Self {
22        Self::new(WireFormat::Json)
23    }
24
25    pub fn protobuf() -> Self {
26        Self::new(WireFormat::Protobuf)
27    }
28
29    pub fn encode(&self, envelope: &Envelope) -> Result<Vec<u8>, EnvelopeError> {
30        match self.format {
31            WireFormat::Json => {
32                serde_json::to_vec(envelope).map_err(|e| EnvelopeError::Encode(e.to_string()))
33            }
34            WireFormat::Protobuf => encode_binary(envelope),
35        }
36    }
37
38    pub fn decode(&self, data: &[u8]) -> Result<Envelope, EnvelopeError> {
39        match self.format {
40            WireFormat::Json => {
41                serde_json::from_slice(data).map_err(|e| EnvelopeError::Decode(e.to_string()))
42            }
43            WireFormat::Protobuf => decode_binary(data),
44        }
45    }
46}
47
48fn encode_binary(envelope: &Envelope) -> Result<Vec<u8>, EnvelopeError> {
49    if envelope.request_id.is_some() {
50        return Err(EnvelopeError::Encode(
51            "binary envelope does not support request_id".to_string(),
52        ));
53    }
54
55    let payload_b64 = envelope
56        .payload
57        .as_str()
58        .ok_or_else(|| EnvelopeError::Encode("binary payload must be base64 string".to_string()))?;
59    let payload = STANDARD
60        .decode(payload_b64)
61        .map_err(|e| EnvelopeError::Encode(e.to_string()))?;
62
63    let msg_type_bytes = envelope.msg_type.as_bytes();
64    if msg_type_bytes.len() > u16::MAX as usize {
65        return Err(EnvelopeError::Encode("message type too long".to_string()));
66    }
67
68    let mut out = Vec::with_capacity(2 + msg_type_bytes.len() + payload.len());
69    let len = msg_type_bytes.len() as u16;
70    out.extend_from_slice(&len.to_be_bytes());
71    out.extend_from_slice(msg_type_bytes);
72    out.extend_from_slice(&payload);
73    Ok(out)
74}
75
76fn decode_binary(data: &[u8]) -> Result<Envelope, EnvelopeError> {
77    if data.len() < 2 {
78        return Err(EnvelopeError::Decode(
79            "binary envelope too short".to_string(),
80        ));
81    }
82    let len = u16::from_be_bytes([data[0], data[1]]) as usize;
83    if data.len() < 2 + len {
84        return Err(EnvelopeError::Decode(
85            "binary envelope type length invalid".to_string(),
86        ));
87    }
88    let msg_type = std::str::from_utf8(&data[2..2 + len])
89        .map_err(|e| EnvelopeError::Decode(e.to_string()))?
90        .to_string();
91    let payload = &data[2 + len..];
92    let payload_b64 = STANDARD.encode(payload);
93
94    Ok(Envelope {
95        msg_type,
96        request_id: None,
97        payload: serde_json::Value::String(payload_b64),
98    })
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn json_round_trip() {
107        let envelope = Envelope::event("chat.say", serde_json::json!({"text": "hi"}));
108        let codec = WireCodec::json();
109        let data = codec.encode(&envelope).unwrap();
110        let decoded = codec.decode(&data).unwrap();
111        assert_eq!(decoded.msg_type, "chat.say");
112    }
113
114    #[test]
115    fn binary_round_trip() {
116        let payload = STANDARD.encode(b"payload");
117        let envelope = Envelope::event("event.type", serde_json::Value::String(payload));
118        let codec = WireCodec::protobuf();
119        let data = codec.encode(&envelope).unwrap();
120        let decoded = codec.decode(&data).unwrap();
121        assert_eq!(decoded.msg_type, "event.type");
122    }
123}