Skip to main content

yantrikdb_protocol/
lib.rs

1pub mod codec;
2pub mod error;
3pub mod frame;
4pub mod messages;
5pub mod opcodes;
6
7pub use codec::YantrikCodec;
8pub use error::ProtocolError;
9pub use frame::Frame;
10pub use opcodes::OpCode;
11
12/// Helper: serialize a message to MessagePack bytes.
13pub fn pack<T: serde::Serialize>(msg: &T) -> Result<bytes::Bytes, ProtocolError> {
14    let data = rmp_serde::to_vec_named(msg)?;
15    Ok(bytes::Bytes::from(data))
16}
17
18/// Helper: deserialize a message from MessagePack bytes.
19pub fn unpack<'de, T: serde::Deserialize<'de>>(data: &'de [u8]) -> Result<T, ProtocolError> {
20    Ok(rmp_serde::from_slice(data)?)
21}
22
23/// Decompress a frame's payload if it's marked compressed, then unpack.
24/// Use this in handlers instead of `unpack` directly when the payload may be compressed.
25pub fn unpack_frame<T: serde::de::DeserializeOwned>(frame: &Frame) -> Result<T, ProtocolError> {
26    if frame.is_compressed() {
27        let decompressed = zstd::decode_all(&frame.payload[..])
28            .map_err(|e| ProtocolError::Io(e))?;
29        Ok(rmp_serde::from_slice(&decompressed)?)
30    } else {
31        Ok(rmp_serde::from_slice(&frame.payload)?)
32    }
33}
34
35/// Pack and zstd-compress a message. Use for large payloads (oplog batches, recall results).
36pub fn pack_compressed<T: serde::Serialize>(msg: &T) -> Result<bytes::Bytes, ProtocolError> {
37    let data = rmp_serde::to_vec_named(msg)?;
38    let compressed = zstd::encode_all(data.as_slice(), 3)
39        .map_err(|e| ProtocolError::Io(e))?;
40    Ok(bytes::Bytes::from(compressed))
41}
42
43/// Build a frame whose payload is auto-compressed if it exceeds `min_size_bytes`.
44pub fn make_frame_auto_compress<T: serde::Serialize>(
45    opcode: OpCode,
46    stream_id: u32,
47    msg: &T,
48    min_size_bytes: usize,
49) -> Result<Frame, ProtocolError> {
50    let raw = pack(msg)?;
51    if raw.len() < min_size_bytes {
52        return Ok(Frame::new(opcode, stream_id, raw));
53    }
54    let compressed = zstd::encode_all(&raw[..], 3).map_err(|e| ProtocolError::Io(e))?;
55    // Only use compression if it actually saved space
56    if compressed.len() < raw.len() {
57        Ok(Frame::new(opcode, stream_id, bytes::Bytes::from(compressed)).with_compression())
58    } else {
59        Ok(Frame::new(opcode, stream_id, raw))
60    }
61}
62
63/// Build a complete frame from an opcode, stream ID, and serializable payload.
64pub fn make_frame<T: serde::Serialize>(
65    opcode: OpCode,
66    stream_id: u32,
67    msg: &T,
68) -> Result<Frame, ProtocolError> {
69    let payload = pack(msg)?;
70    Ok(Frame::new(opcode, stream_id, payload))
71}
72
73/// Build a frame with an error response.
74pub fn make_error(
75    stream_id: u32,
76    code: u16,
77    message: impl Into<String>,
78) -> Result<Frame, ProtocolError> {
79    make_frame(
80        OpCode::Error,
81        stream_id,
82        &messages::ErrorResponse {
83            code,
84            message: message.into(),
85            details: None,
86        },
87    )
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use messages::RememberRequest;
94
95    #[test]
96    fn pack_unpack_roundtrip() {
97        let req = RememberRequest {
98            text: "Alice leads engineering".into(),
99            memory_type: "semantic".into(),
100            importance: 0.9,
101            valence: 0.0,
102            half_life: 168.0,
103            metadata: serde_json::json!({}),
104            namespace: "default".into(),
105            certainty: 1.0,
106            domain: "work".into(),
107            source: "user".into(),
108            emotional_state: None,
109            embedding: None,
110        };
111
112        let packed = pack(&req).unwrap();
113        let unpacked: RememberRequest = unpack(&packed).unwrap();
114
115        assert_eq!(unpacked.text, "Alice leads engineering");
116        assert_eq!(unpacked.importance, 0.9);
117        assert_eq!(unpacked.domain, "work");
118    }
119
120    #[test]
121    fn make_frame_roundtrip() {
122        let req = messages::RecallRequest {
123            query: "who leads engineering?".into(),
124            top_k: 5,
125            memory_type: None,
126            include_consolidated: false,
127            expand_entities: true,
128            namespace: None,
129            domain: None,
130            source: None,
131            query_embedding: None,
132        };
133
134        let frame = make_frame(OpCode::Recall, 7, &req).unwrap();
135        assert_eq!(frame.opcode, OpCode::Recall);
136        assert_eq!(frame.stream_id, 7);
137
138        let decoded: messages::RecallRequest = unpack(&frame.payload).unwrap();
139        assert_eq!(decoded.query, "who leads engineering?");
140        assert_eq!(decoded.top_k, 5);
141    }
142
143    #[test]
144    fn make_error_frame() {
145        let frame =
146            make_error(0, messages::error_codes::AUTH_REQUIRED, "not authenticated").unwrap();
147        assert_eq!(frame.opcode, OpCode::Error);
148
149        let err: messages::ErrorResponse = unpack(&frame.payload).unwrap();
150        assert_eq!(err.code, 1000);
151        assert_eq!(err.message, "not authenticated");
152    }
153}