Skip to main content

zamsync_network/protocol/
codec.rs

1use super::frame;
2use std::io::{Read, Write};
3use zamsync_core::{SyncMessage, ZamError, ZamResult};
4
5/// Encodes `msg` onto `writer` and returns the number of bytes written to the wire.
6pub fn encode(msg: &SyncMessage, writer: &mut impl Write) -> ZamResult<usize> {
7    let bytes =
8        rkyv::to_bytes::<_, 1024>(msg).map_err(|e| ZamError::Serialization(e.to_string()))?;
9    frame::write_frame(writer, &bytes)
10}
11
12pub fn decode(reader: &mut impl Read) -> ZamResult<SyncMessage> {
13    let bytes = frame::read_frame(reader)?;
14    rkyv::from_bytes::<SyncMessage>(&bytes).map_err(|e| ZamError::Serialization(format!("{}", e)))
15}
16
17#[cfg(test)]
18mod tests {
19    use super::*;
20    use std::io::Cursor;
21    use zamsync_core::{NodeId, VersionVector};
22
23    #[test]
24    fn test_handshake_roundtrip() {
25        let msg = SyncMessage::Handshake {
26            node_id: NodeId(42),
27            vv: VersionVector::new(),
28        };
29
30        let mut buf = Vec::new();
31        encode(&msg, &mut buf).unwrap();
32
33        let mut cursor = Cursor::new(&buf);
34        let decoded = decode(&mut cursor).unwrap();
35
36        match decoded {
37            SyncMessage::Handshake { node_id, .. } => assert_eq!(node_id.0, 42),
38            _ => panic!("unexpected message type"),
39        }
40    }
41
42    #[test]
43    fn test_pull_request_roundtrip() {
44        use zamsync_core::SequenceNumber;
45        let msg = SyncMessage::PullRequest {
46            origin_node: NodeId(1),
47            start_seq: SequenceNumber(100),
48            limit: 50,
49        };
50
51        let mut buf = Vec::new();
52        encode(&msg, &mut buf).unwrap();
53
54        let mut cursor = Cursor::new(&buf);
55        let decoded = decode(&mut cursor).unwrap();
56
57        match decoded {
58            SyncMessage::PullRequest {
59                origin_node,
60                start_seq,
61                limit,
62            } => {
63                assert_eq!(origin_node.0, 1);
64                assert_eq!(start_seq.0, 100);
65                assert_eq!(limit, 50);
66            }
67            _ => panic!("unexpected message type"),
68        }
69    }
70
71    #[test]
72    fn test_event_batch_roundtrip() {
73        use zamsync_core::{Event, Hlc, SequenceNumber};
74        let event = Event {
75            origin_node: NodeId(3),
76            seq: SequenceNumber(7),
77            hlc: Hlc::new(9999, 0),
78            event_type: 2,
79            payload: b"payload".to_vec(),
80        };
81        let msg = SyncMessage::EventBatch {
82            origin_node: NodeId(3),
83            events: vec![event],
84        };
85
86        let mut buf = Vec::new();
87        encode(&msg, &mut buf).unwrap();
88
89        let mut cursor = Cursor::new(&buf);
90        let decoded = decode(&mut cursor).unwrap();
91
92        match decoded {
93            SyncMessage::EventBatch {
94                origin_node,
95                events,
96            } => {
97                assert_eq!(origin_node.0, 3);
98                assert_eq!(events.len(), 1);
99                assert_eq!(events[0].payload, b"payload");
100            }
101            _ => panic!("unexpected message type"),
102        }
103    }
104}