1use crate::{Offset, Operation, TopicName};
2use bytes::{BufMut, Bytes, BytesMut};
3use selium_std::errors::{ProtocolError, Result, SeliumError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7type Headers = Option<HashMap<String, String>>;
8
9const REGISTER_PUBLISHER: u8 = 0x0;
10const REGISTER_SUBSCRIBER: u8 = 0x1;
11const REGISTER_REPLIER: u8 = 0x2;
12const REGISTER_REQUESTOR: u8 = 0x3;
13const MESSAGE: u8 = 0x4;
14const BATCH_MESSAGE: u8 = 0x5;
15const ERROR: u8 = 0x6;
16const OK: u8 = 0x7;
17
18#[derive(Clone, Debug, PartialEq)]
19pub enum Frame {
20 RegisterPublisher(PublisherPayload),
21 RegisterSubscriber(SubscriberPayload),
22 RegisterReplier(ReplierPayload),
23 RegisterRequestor(RequestorPayload),
24 Message(MessagePayload),
25 BatchMessage(BatchPayload),
26 Error(ErrorPayload),
27 Ok,
28}
29
30impl Frame {
31 pub fn get_length(&self) -> Result<u64> {
32 Ok(match self {
33 Self::RegisterPublisher(payload) => {
34 bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
35 }
36 Self::RegisterSubscriber(payload) => {
37 bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
38 }
39 Self::RegisterReplier(payload) => {
40 bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
41 }
42 Self::RegisterRequestor(payload) => {
43 bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
44 }
45 Self::Message(payload) => {
46 bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
47 }
48 Self::BatchMessage(payload) => {
49 bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
50 }
51 Self::Error(payload) => {
52 bincode::serialized_size(payload).map_err(ProtocolError::SerdeError)?
53 }
54 Self::Ok => 0,
55 })
56 }
57
58 pub fn get_type(&self) -> u8 {
59 match self {
60 Self::RegisterPublisher(_) => REGISTER_PUBLISHER,
61 Self::RegisterSubscriber(_) => REGISTER_SUBSCRIBER,
62 Self::RegisterReplier(_) => REGISTER_REPLIER,
63 Self::RegisterRequestor(_) => REGISTER_REQUESTOR,
64 Self::Message(_) => MESSAGE,
65 Self::BatchMessage(_) => BATCH_MESSAGE,
66 Self::Error(_) => ERROR,
67 Self::Ok => OK,
68 }
69 }
70
71 pub fn get_topic(&self) -> Option<&TopicName> {
72 match self {
73 Self::RegisterPublisher(p) => Some(&p.topic),
74 Self::RegisterSubscriber(s) => Some(&s.topic),
75 Self::RegisterReplier(s) => Some(&s.topic),
76 Self::RegisterRequestor(c) => Some(&c.topic),
77 Self::Message(_) => None,
78 Self::BatchMessage(_) => None,
79 Self::Error(_) => None,
80 Self::Ok => None,
81 }
82 }
83
84 pub fn write_to_bytes(self, dst: &mut BytesMut) -> Result<()> {
85 match self {
86 Frame::RegisterPublisher(payload) => bincode::serialize_into(dst.writer(), &payload)
87 .map_err(ProtocolError::SerdeError)?,
88 Frame::RegisterSubscriber(payload) => bincode::serialize_into(dst.writer(), &payload)
89 .map_err(ProtocolError::SerdeError)?,
90 Frame::RegisterReplier(payload) => bincode::serialize_into(dst.writer(), &payload)
91 .map_err(ProtocolError::SerdeError)?,
92 Frame::RegisterRequestor(payload) => bincode::serialize_into(dst.writer(), &payload)
93 .map_err(ProtocolError::SerdeError)?,
94 Frame::Message(payload) => bincode::serialize_into(dst.writer(), &payload)
95 .map_err(ProtocolError::SerdeError)?,
96 Frame::Error(payload) => bincode::serialize_into(dst.writer(), &payload)
97 .map_err(ProtocolError::SerdeError)?,
98 Frame::BatchMessage(payload) => bincode::serialize_into(dst.writer(), &payload)
99 .map_err(ProtocolError::SerdeError)?,
100 Frame::Ok => (),
101 }
102
103 Ok(())
104 }
105
106 pub fn retention_policy(&self) -> Option<u64> {
107 match self {
108 Self::RegisterPublisher(payload) => Some(payload.retention_policy),
109 Self::RegisterSubscriber(payload) => Some(payload.retention_policy),
110 _ => None,
111 }
112 }
113
114 pub fn batch_size(&self) -> Option<u32> {
115 match self {
116 Self::Message(_) => Some(1),
117 Self::BatchMessage(payload) => Some(payload.size),
118 _ => None,
119 }
120 }
121
122 pub fn message(&self) -> Option<&[u8]> {
123 match self {
124 Self::Message(payload) => Some(&payload.message),
125 Self::BatchMessage(payload) => Some(&payload.message),
126 _ => None,
127 }
128 }
129
130 pub fn unwrap_message(self) -> MessagePayload {
131 match self {
132 Self::Message(p) => p,
133 _ => panic!("Attempted to unwrap non-Message Frame variant"),
134 }
135 }
136}
137
138impl TryFrom<(u8, BytesMut)> for Frame {
139 type Error = SeliumError;
140
141 fn try_from(
142 (message_type, bytes): (u8, BytesMut),
143 ) -> Result<Self, <Frame as TryFrom<(u8, BytesMut)>>::Error> {
144 let frame = match message_type {
145 REGISTER_PUBLISHER => Frame::RegisterPublisher(
146 bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
147 ),
148 REGISTER_SUBSCRIBER => Frame::RegisterSubscriber(
149 bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
150 ),
151 REGISTER_REPLIER => Frame::RegisterReplier(
152 bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
153 ),
154 REGISTER_REQUESTOR => Frame::RegisterRequestor(
155 bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
156 ),
157 MESSAGE => {
158 Frame::Message(bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?)
159 }
160 BATCH_MESSAGE => Frame::BatchMessage(
161 bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?,
162 ),
163 ERROR => Frame::Error(bincode::deserialize(&bytes).map_err(ProtocolError::SerdeError)?),
164 OK => Frame::Ok,
165 _type => return Err(ProtocolError::UnknownMessageType(_type))?,
166 };
167
168 Ok(frame)
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct PublisherPayload {
174 pub topic: TopicName,
175 pub retention_policy: u64,
176 pub operations: Vec<Operation>,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
180pub struct SubscriberPayload {
181 pub topic: TopicName,
182 pub retention_policy: u64,
183 pub operations: Vec<Operation>,
184 pub offset: Offset,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188pub struct ReplierPayload {
189 pub topic: TopicName,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
193pub struct RequestorPayload {
194 pub topic: TopicName,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198pub struct MessagePayload {
199 pub headers: Headers,
200 pub message: Bytes,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
204pub struct BatchPayload {
205 pub message: Bytes,
206 pub size: u32,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
210pub struct ErrorPayload {
211 pub code: u32,
212 pub message: Bytes,
213}