selium_protocol/
frame.rs

1use crate::{Offset, Operation, TopicName};
2use bytes::{BufMut, Bytes, BytesMut};
3use selium_std::errors::{ProtocolError, Result, SeliumError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7type Headers = Option<HashMap<String, String>>;
8
9const REGISTER_PUBLISHER: u8 = 0x0;
10const REGISTER_SUBSCRIBER: u8 = 0x1;
11const REGISTER_REPLIER: u8 = 0x2;
12const REGISTER_REQUESTOR: u8 = 0x3;
13const MESSAGE: u8 = 0x4;
14const BATCH_MESSAGE: u8 = 0x5;
15const ERROR: u8 = 0x6;
16const OK: u8 = 0x7;
17
18#[derive(Clone, Debug, PartialEq)]
19pub enum Frame {
20    RegisterPublisher(PublisherPayload),
21    RegisterSubscriber(SubscriberPayload),
22    RegisterReplier(ReplierPayload),
23    RegisterRequestor(RequestorPayload),
24    Message(MessagePayload),
25    BatchMessage(BatchPayload),
26    Error(ErrorPayload),
27    Ok,
28}
29
30impl Frame {
31    pub fn get_length(&self) -> Result<u64> {
32        Ok(match self {
33            Self::RegisterPublisher(payload) => {
34                bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
35            }
36            Self::RegisterSubscriber(payload) => {
37                bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
38            }
39            Self::RegisterReplier(payload) => {
40                bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
41            }
42            Self::RegisterRequestor(payload) => {
43                bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
44            }
45            Self::Message(payload) => {
46                bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
47            }
48            Self::BatchMessage(payload) => {
49                bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
50            }
51            Self::Error(payload) => {
52                bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
53            }
54            Self::Ok => 0,
55        })
56    }
57
58    pub fn get_type(&self) -> u8 {
59        match self {
60            Self::RegisterPublisher(_) => REGISTER_PUBLISHER,
61            Self::RegisterSubscriber(_) => REGISTER_SUBSCRIBER,
62            Self::RegisterReplier(_) => REGISTER_REPLIER,
63            Self::RegisterRequestor(_) => REGISTER_REQUESTOR,
64            Self::Message(_) => MESSAGE,
65            Self::BatchMessage(_) => BATCH_MESSAGE,
66            Self::Error(_) => ERROR,
67            Self::Ok => OK,
68        }
69    }
70
71    pub fn get_topic(&self) -> Option<&TopicName> {
72        match self {
73            Self::RegisterPublisher(p) => Some(&p.topic),
74            Self::RegisterSubscriber(s) => Some(&s.topic),
75            Self::RegisterReplier(s) => Some(&s.topic),
76            Self::RegisterRequestor(c) => Some(&c.topic),
77            Self::Message(_) => None,
78            Self::BatchMessage(_) => None,
79            Self::Error(_) => None,
80            Self::Ok => None,
81        }
82    }
83
84    pub fn write_to_bytes(self, dst: &mut BytesMut) -> Result<()> {
85        match self {
86            Frame::RegisterPublisher(payload) => bincode::serialize_into(dst.writer(), &payload)
87                .map_err(ProtocolError::SerdeError)?,
88            Frame::RegisterSubscriber(payload) => bincode::serialize_into(dst.writer(), &payload)
89                .map_err(ProtocolError::SerdeError)?,
90            Frame::RegisterReplier(payload) => bincode::serialize_into(dst.writer(), &payload)
91                .map_err(ProtocolError::SerdeError)?,
92            Frame::RegisterRequestor(payload) => bincode::serialize_into(dst.writer(), &payload)
93                .map_err(ProtocolError::SerdeError)?,
94            Frame::Message(payload) => bincode::serialize_into(dst.writer(), &payload)
95                .map_err(ProtocolError::SerdeError)?,
96            Frame::Error(payload) => bincode::serialize_into(dst.writer(), &payload)
97                .map_err(ProtocolError::SerdeError)?,
98            Frame::BatchMessage(payload) => bincode::serialize_into(dst.writer(), &payload)
99                .map_err(ProtocolError::SerdeError)?,
100            Frame::Ok => (),
101        }
102
103        Ok(())
104    }
105
106    pub fn retention_policy(&self) -> Option<u64> {
107        match self {
108            Self::RegisterPublisher(payload) => Some(payload.retention_policy),
109            Self::RegisterSubscriber(payload) => Some(payload.retention_policy),
110            _ => None,
111        }
112    }
113
114    pub fn batch_size(&self) -> Option<u32> {
115        match self {
116            Self::Message(_) => Some(1),
117            Self::BatchMessage(payload) => Some(payload.size),
118            _ => None,
119        }
120    }
121
122    pub fn message(&self) -> Option<&[u8]> {
123        match self {
124            Self::Message(payload) => Some(&payload.message),
125            Self::BatchMessage(payload) => Some(&payload.message),
126            _ => None,
127        }
128    }
129
130    pub fn unwrap_message(self) -> MessagePayload {
131        match self {
132            Self::Message(p) => p,
133            _ => panic!("Attempted to unwrap non-Message Frame variant"),
134        }
135    }
136}
137
138impl TryFrom<(u8, BytesMut)> for Frame {
139    type Error = SeliumError;
140
141    fn try_from(
142        (message_type, bytes): (u8, BytesMut),
143    ) -> Result<Self, <Frame as TryFrom<(u8, BytesMut)>>::Error> {
144        let frame = match message_type {
145            REGISTER_PUBLISHER => Frame::RegisterPublisher(
146                bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
147            ),
148            REGISTER_SUBSCRIBER => Frame::RegisterSubscriber(
149                bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
150            ),
151            REGISTER_REPLIER => Frame::RegisterReplier(
152                bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
153            ),
154            REGISTER_REQUESTOR => Frame::RegisterRequestor(
155                bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
156            ),
157            MESSAGE => {
158                Frame::Message(bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?)
159            }
160            BATCH_MESSAGE => Frame::BatchMessage(
161                bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
162            ),
163            ERROR => Frame::Error(bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?),
164            OK => Frame::Ok,
165            _type => return Err(ProtocolError::UnknownMessageType(_type))?,
166        };
167
168        Ok(frame)
169    }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct PublisherPayload {
174    pub topic: TopicName,
175    pub retention_policy: u64,
176    pub operations: Vec<Operation>,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
180pub struct SubscriberPayload {
181    pub topic: TopicName,
182    pub retention_policy: u64,
183    pub operations: Vec<Operation>,
184    pub offset: Offset,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188pub struct ReplierPayload {
189    pub topic: TopicName,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
193pub struct RequestorPayload {
194    pub topic: TopicName,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198pub struct MessagePayload {
199    pub headers: Headers,
200    pub message: Bytes,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
204pub struct BatchPayload {
205    pub message: Bytes,
206    pub size: u32,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
210pub struct ErrorPayload {
211    pub code: u32,
212    pub message: Bytes,
213}