1use alloc::vec::Vec;
7
8use zerodds_cdr::{BufferReader, BufferWriter, Endianness};
9
10use crate::cancel_request::CancelRequest;
11use crate::close_connection::CloseConnection;
12use crate::error::{GiopError, GiopResult};
13use crate::flags::Flags;
14use crate::fragment::Fragment;
15use crate::header::{HEADER_SIZE, MessageHeader};
16use crate::locate_reply::LocateReply;
17use crate::locate_request::LocateRequest;
18use crate::message_error::MessageError;
19use crate::message_type::MessageType;
20use crate::reply::Reply;
21use crate::request::Request;
22use crate::version::Version;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum Message {
27 Request(Request),
29 Reply(Reply),
31 CancelRequest(CancelRequest),
33 LocateRequest(LocateRequest),
35 LocateReply(LocateReply),
37 CloseConnection(CloseConnection),
39 MessageError(MessageError),
41 Fragment(Fragment),
43}
44
45impl Message {
46 #[must_use]
48 pub const fn message_type(&self) -> MessageType {
49 match self {
50 Self::Request(_) => MessageType::Request,
51 Self::Reply(_) => MessageType::Reply,
52 Self::CancelRequest(_) => MessageType::CancelRequest,
53 Self::LocateRequest(_) => MessageType::LocateRequest,
54 Self::LocateReply(_) => MessageType::LocateReply,
55 Self::CloseConnection(_) => MessageType::CloseConnection,
56 Self::MessageError(_) => MessageType::MessageError,
57 Self::Fragment(_) => MessageType::Fragment,
58 }
59 }
60}
61
62pub fn encode_message(
71 version: Version,
72 endianness: Endianness,
73 more_fragments: bool,
74 msg: &Message,
75) -> GiopResult<Vec<u8>> {
76 if more_fragments && !version.supports_fragments() {
77 return Err(GiopError::FragmentNotSupported {
78 major: version.major,
79 minor: version.minor,
80 });
81 }
82 let mut body_writer = BufferWriter::new(endianness);
84 encode_body(version, msg, &mut body_writer)?;
85 let body = body_writer.into_bytes();
86 let body_size =
87 u32::try_from(body.len()).map_err(|_| GiopError::Malformed("body exceeds u32".into()))?;
88
89 let mut flags = Flags::from_endianness(endianness);
90 flags = flags.with_fragment(more_fragments);
91 let header = MessageHeader::new(version, flags, msg.message_type(), body_size);
92
93 let mut out = BufferWriter::with_capacity(endianness, HEADER_SIZE + body.len());
94 header.encode(&mut out)?;
95 out.write_bytes(&body)?;
96 Ok(out.into_bytes())
97}
98
99fn encode_body(version: Version, msg: &Message, w: &mut BufferWriter) -> GiopResult<()> {
100 match msg {
101 Message::Request(r) => r.encode(version, w),
102 Message::Reply(r) => r.encode(version, w),
103 Message::CancelRequest(c) => c.encode(w),
104 Message::LocateRequest(l) => l.encode(version, w),
105 Message::LocateReply(l) => l.encode(version, w),
106 Message::CloseConnection(_) | Message::MessageError(_) => Ok(()),
107 Message::Fragment(f) => f.encode(version, w),
108 }
109}
110
111pub fn decode_message(bytes: &[u8]) -> GiopResult<(Message, &[u8])> {
117 let (header, body) = MessageHeader::decode(bytes)?;
118 let body_size = header.message_size as usize;
119 if body.len() < body_size {
120 return Err(GiopError::BodyTooLarge {
121 body_size,
122 message_size_field: header.message_size,
123 });
124 }
125 let body_slice = &body[..body_size];
126 let mut r = BufferReader::new(body_slice, header.endianness());
127 let msg = decode_body(header, &mut r)?;
128 Ok((msg, &body[body_size..]))
129}
130
131fn decode_body(header: MessageHeader, r: &mut BufferReader<'_>) -> GiopResult<Message> {
132 let v = header.version;
133 Ok(match header.message_type {
134 MessageType::Request => Message::Request(Request::decode(v, r)?),
135 MessageType::Reply => Message::Reply(Reply::decode(v, r)?),
136 MessageType::CancelRequest => Message::CancelRequest(CancelRequest::decode(r)?),
137 MessageType::LocateRequest => Message::LocateRequest(LocateRequest::decode(v, r)?),
138 MessageType::LocateReply => Message::LocateReply(LocateReply::decode(v, r)?),
139 MessageType::CloseConnection => Message::CloseConnection(CloseConnection),
140 MessageType::MessageError => Message::MessageError(MessageError),
141 MessageType::Fragment => Message::Fragment(Fragment::decode(v, r)?),
142 })
143}
144
145#[cfg(test)]
146#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
147mod tests {
148 use super::*;
149 use crate::request::ResponseFlags;
150 use crate::service_context::ServiceContextList;
151 use crate::target_address::TargetAddress;
152
153 fn sample_request_msg(version: Version) -> Message {
154 Message::Request(Request {
155 request_id: 7,
156 response_flags: ResponseFlags::SYNC_WITH_TARGET,
157 target: TargetAddress::Key(alloc::vec![0xab, 0xcd]),
158 operation: "ping".into(),
159 requesting_principal: if version.uses_v1_2_request_layout() {
160 None
161 } else {
162 Some(alloc::vec::Vec::new())
163 },
164 service_context: ServiceContextList::default(),
165 body: alloc::vec![1, 2, 3, 4, 5, 6, 7, 8],
166 })
167 }
168
169 #[test]
170 fn round_trip_request_giop_1_0_be() {
171 let m = sample_request_msg(Version::V1_0);
172 let bytes = encode_message(Version::V1_0, Endianness::Big, false, &m).unwrap();
173 let (decoded, rest) = decode_message(&bytes).unwrap();
174 assert_eq!(decoded, m);
175 assert!(rest.is_empty());
176 }
177
178 #[test]
179 fn round_trip_request_giop_1_2_le() {
180 let m = sample_request_msg(Version::V1_2);
181 let bytes = encode_message(Version::V1_2, Endianness::Little, false, &m).unwrap();
182 let (decoded, rest) = decode_message(&bytes).unwrap();
183 assert_eq!(decoded, m);
184 assert!(rest.is_empty());
185 }
186
187 #[test]
188 fn round_trip_close_connection() {
189 let m = Message::CloseConnection(CloseConnection);
190 let bytes = encode_message(Version::V1_2, Endianness::Big, false, &m).unwrap();
191 assert_eq!(bytes.len(), 12);
193 let (decoded, _) = decode_message(&bytes).unwrap();
194 assert_eq!(decoded, m);
195 }
196
197 #[test]
198 fn round_trip_message_error() {
199 let m = Message::MessageError(MessageError);
200 let bytes = encode_message(Version::V1_1, Endianness::Little, false, &m).unwrap();
201 assert_eq!(bytes.len(), 12);
202 let (decoded, _) = decode_message(&bytes).unwrap();
203 assert_eq!(decoded, m);
204 }
205
206 #[test]
207 fn round_trip_fragment_with_more_bit() {
208 let m = Message::Fragment(Fragment {
209 header: Some(crate::fragment::FragmentHeader { request_id: 3 }),
210 body: alloc::vec![0; 32],
211 });
212 let bytes = encode_message(Version::V1_2, Endianness::Big, true, &m).unwrap();
213 assert_eq!(bytes[6] & Flags::FRAGMENT_BIT, Flags::FRAGMENT_BIT);
215 let (decoded, _) = decode_message(&bytes).unwrap();
216 assert_eq!(decoded, m);
217 }
218
219 #[test]
220 fn fragment_bit_in_giop_1_0_is_rejected() {
221 let m = Message::Fragment(Fragment {
222 header: None,
223 body: alloc::vec::Vec::new(),
224 });
225 let err = encode_message(Version::V1_0, Endianness::Big, true, &m).unwrap_err();
226 assert!(matches!(err, GiopError::FragmentNotSupported { .. }));
227 }
228
229 #[test]
230 fn header_message_size_matches_actual_body() {
231 let m = sample_request_msg(Version::V1_2);
232 let bytes = encode_message(Version::V1_2, Endianness::Big, false, &m).unwrap();
233 let (h, body) = MessageHeader::decode(&bytes).unwrap();
234 assert_eq!(h.message_size as usize, body.len());
235 }
236
237 #[test]
238 fn round_trip_cancel_request() {
239 let m = Message::CancelRequest(CancelRequest { request_id: 99 });
240 let bytes = encode_message(Version::V1_1, Endianness::Big, false, &m).unwrap();
241 let (decoded, _) = decode_message(&bytes).unwrap();
242 assert_eq!(decoded, m);
243 }
244
245 #[test]
246 fn round_trip_locate_request_giop_1_2() {
247 let m = Message::LocateRequest(LocateRequest {
248 request_id: 1,
249 target: TargetAddress::Key(alloc::vec![0xa, 0xb]),
250 });
251 let bytes = encode_message(Version::V1_2, Endianness::Big, false, &m).unwrap();
252 let (decoded, _) = decode_message(&bytes).unwrap();
253 assert_eq!(decoded, m);
254 }
255
256 #[test]
257 fn round_trip_locate_reply_object_here() {
258 let m = Message::LocateReply(LocateReply {
259 request_id: 1,
260 locate_status: crate::locate_reply::LocateStatusType::ObjectHere,
261 body: alloc::vec::Vec::new(),
262 });
263 let bytes = encode_message(Version::V1_0, Endianness::Big, false, &m).unwrap();
264 let (decoded, _) = decode_message(&bytes).unwrap();
265 assert_eq!(decoded, m);
266 }
267}