1use super::*;
2use crate::error::Error;
3
4type Result<T> = std::result::Result<T, util::Error>;
5
6const CHANNEL_TYPE_RELIABLE: u8 = 0x00;
7const CHANNEL_TYPE_RELIABLE_UNORDERED: u8 = 0x80;
8const CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT: u8 = 0x01;
9const CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED: u8 = 0x81;
10const CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED: u8 = 0x02;
11const CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED: u8 = 0x82;
12const CHANNEL_TYPE_LEN: usize = 1;
13
14pub const CHANNEL_PRIORITY_BELOW_NORMAL: u16 = 128;
16pub const CHANNEL_PRIORITY_NORMAL: u16 = 256;
17pub const CHANNEL_PRIORITY_HIGH: u16 = 512;
18pub const CHANNEL_PRIORITY_EXTRA_HIGH: u16 = 1024;
19
20#[derive(Default, Eq, PartialEq, Copy, Clone, Debug)]
21pub enum ChannelType {
22 #[default]
25 Reliable,
26 ReliableUnordered,
29 PartialReliableRexmit,
33 PartialReliableRexmitUnordered,
37 PartialReliableTimed,
43 PartialReliableTimedUnordered,
48}
49
50impl MarshalSize for ChannelType {
51 fn marshal_size(&self) -> usize {
52 CHANNEL_TYPE_LEN
53 }
54}
55
56impl Marshal for ChannelType {
57 fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> {
58 let required_len = self.marshal_size();
59 if buf.remaining_mut() < required_len {
60 return Err(Error::UnexpectedEndOfBuffer {
61 expected: required_len,
62 actual: buf.remaining_mut(),
63 }
64 .into());
65 }
66
67 let byte = match self {
68 Self::Reliable => CHANNEL_TYPE_RELIABLE,
69 Self::ReliableUnordered => CHANNEL_TYPE_RELIABLE_UNORDERED,
70 Self::PartialReliableRexmit => CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT,
71 Self::PartialReliableRexmitUnordered => CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED,
72 Self::PartialReliableTimed => CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED,
73 Self::PartialReliableTimedUnordered => CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED,
74 };
75
76 buf.put_u8(byte);
77
78 Ok(1)
79 }
80}
81
82impl Unmarshal for ChannelType {
83 fn unmarshal<B>(buf: &mut B) -> Result<Self>
84 where
85 Self: Sized,
86 B: Buf,
87 {
88 let required_len = CHANNEL_TYPE_LEN;
89 if buf.remaining() < required_len {
90 return Err(Error::UnexpectedEndOfBuffer {
91 expected: required_len,
92 actual: buf.remaining(),
93 }
94 .into());
95 }
96
97 let b0 = buf.get_u8();
98
99 match b0 {
100 CHANNEL_TYPE_RELIABLE => Ok(Self::Reliable),
101 CHANNEL_TYPE_RELIABLE_UNORDERED => Ok(Self::ReliableUnordered),
102 CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT => Ok(Self::PartialReliableRexmit),
103 CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED => {
104 Ok(Self::PartialReliableRexmitUnordered)
105 }
106 CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED => Ok(Self::PartialReliableTimed),
107 CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED => {
108 Ok(Self::PartialReliableTimedUnordered)
109 }
110 _ => Err(Error::InvalidChannelType(b0).into()),
111 }
112 }
113}
114
115const CHANNEL_OPEN_HEADER_LEN: usize = 11;
116
117#[derive(Eq, PartialEq, Clone, Debug)]
141pub struct DataChannelOpen {
142 pub channel_type: ChannelType,
143 pub priority: u16,
144 pub reliability_parameter: u32,
145 pub label: Vec<u8>,
146 pub protocol: Vec<u8>,
147}
148
149impl MarshalSize for DataChannelOpen {
150 fn marshal_size(&self) -> usize {
151 let label_len = self.label.len();
152 let protocol_len = self.protocol.len();
153
154 CHANNEL_OPEN_HEADER_LEN + label_len + protocol_len
155 }
156}
157
158impl Marshal for DataChannelOpen {
159 fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> {
160 let required_len = self.marshal_size();
161 if buf.remaining_mut() < required_len {
162 return Err(Error::UnexpectedEndOfBuffer {
163 expected: required_len,
164 actual: buf.remaining_mut(),
165 }
166 .into());
167 }
168
169 let n = self.channel_type.marshal_to(buf)?;
170 buf = &mut buf[n..];
171 buf.put_u16(self.priority);
172 buf.put_u32(self.reliability_parameter);
173 buf.put_u16(self.label.len() as u16);
174 buf.put_u16(self.protocol.len() as u16);
175 buf.put_slice(self.label.as_slice());
176 buf.put_slice(self.protocol.as_slice());
177 Ok(self.marshal_size())
178 }
179}
180
181impl Unmarshal for DataChannelOpen {
182 fn unmarshal<B>(buf: &mut B) -> Result<Self>
183 where
184 B: Buf,
185 {
186 let required_len = CHANNEL_OPEN_HEADER_LEN;
187 if buf.remaining() < required_len {
188 return Err(Error::UnexpectedEndOfBuffer {
189 expected: required_len,
190 actual: buf.remaining(),
191 }
192 .into());
193 }
194
195 let channel_type = ChannelType::unmarshal(buf)?;
196 let priority = buf.get_u16();
197 let reliability_parameter = buf.get_u32();
198 let label_len = buf.get_u16() as usize;
199 let protocol_len = buf.get_u16() as usize;
200
201 let required_len = label_len + protocol_len;
202 if buf.remaining() < required_len {
203 return Err(Error::UnexpectedEndOfBuffer {
204 expected: required_len,
205 actual: buf.remaining(),
206 }
207 .into());
208 }
209
210 let mut label = vec![0; label_len];
211 let mut protocol = vec![0; protocol_len];
212
213 buf.copy_to_slice(&mut label[..]);
214 buf.copy_to_slice(&mut protocol[..]);
215
216 Ok(Self {
217 channel_type,
218 priority,
219 reliability_parameter,
220 label,
221 protocol,
222 })
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use bytes::{Bytes, BytesMut};
229
230 use super::*;
231
232 #[test]
233 fn test_channel_type_unmarshal_success() -> Result<()> {
234 let mut bytes = Bytes::from_static(&[0x00]);
235 let channel_type = ChannelType::unmarshal(&mut bytes)?;
236
237 assert_eq!(channel_type, ChannelType::Reliable);
238 Ok(())
239 }
240
241 #[test]
242 fn test_channel_type_unmarshal_invalid() -> Result<()> {
243 let mut bytes = Bytes::from_static(&[0x11]);
244 match ChannelType::unmarshal(&mut bytes) {
245 Ok(_) => panic!("expected Error, but got Ok"),
246 Err(err) => {
247 if let Some(&Error::InvalidChannelType(0x11)) = err.downcast_ref::<Error>() {
248 return Ok(());
249 }
250 panic!(
251 "unexpected err {:?}, want {:?}",
252 err,
253 Error::InvalidMessageType(0x01)
254 );
255 }
256 }
257 }
258
259 #[test]
260 fn test_channel_type_unmarshal_unexpected_end_of_buffer() -> Result<()> {
261 let mut bytes = Bytes::from_static(&[]);
262 match ChannelType::unmarshal(&mut bytes) {
263 Ok(_) => panic!("expected Error, but got Ok"),
264 Err(err) => {
265 if let Some(&Error::UnexpectedEndOfBuffer {
266 expected: 1,
267 actual: 0,
268 }) = err.downcast_ref::<Error>()
269 {
270 return Ok(());
271 }
272 panic!(
273 "unexpected err {:?}, want {:?}",
274 err,
275 Error::InvalidMessageType(0x01)
276 );
277 }
278 }
279 }
280
281 #[test]
282 fn test_channel_type_marshal_size() -> Result<()> {
283 let channel_type = ChannelType::Reliable;
284 let marshal_size = channel_type.marshal_size();
285
286 assert_eq!(marshal_size, 1);
287 Ok(())
288 }
289
290 #[test]
291 fn test_channel_type_marshal() -> Result<()> {
292 let mut buf = BytesMut::with_capacity(1);
293 buf.resize(1, 0u8);
294 let channel_type = ChannelType::Reliable;
295 let bytes_written = channel_type.marshal_to(&mut buf)?;
296 assert_eq!(bytes_written, channel_type.marshal_size());
297
298 let bytes = buf.freeze();
299 assert_eq!(&bytes[..], &[0x00]);
300 Ok(())
301 }
302
303 static MARSHALED_BYTES: [u8; 24] = [
304 0x00, 0x0f, 0x35, 0x00, 0xff, 0x0f, 0x35, 0x00, 0x05, 0x00, 0x08, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, ];
312
313 #[test]
314 fn test_channel_open_unmarshal_success() -> Result<()> {
315 let mut bytes = Bytes::from_static(&MARSHALED_BYTES);
316
317 let channel_open = DataChannelOpen::unmarshal(&mut bytes)?;
318
319 assert_eq!(channel_open.channel_type, ChannelType::Reliable);
320 assert_eq!(channel_open.priority, 3893);
321 assert_eq!(channel_open.reliability_parameter, 16715573);
322 assert_eq!(channel_open.label, b"label");
323 assert_eq!(channel_open.protocol, b"protocol");
324 Ok(())
325 }
326
327 #[test]
328 fn test_channel_open_unmarshal_invalid_channel_type() -> Result<()> {
329 let mut bytes = Bytes::from_static(&[
330 0x11, 0x0f, 0x35, 0x00, 0xff, 0x0f, 0x35, 0x00, 0x05, 0x00, 0x08, ]);
336 match DataChannelOpen::unmarshal(&mut bytes) {
337 Ok(_) => panic!("expected Error, but got Ok"),
338 Err(err) => {
339 if let Some(&Error::InvalidChannelType(0x11)) = err.downcast_ref::<Error>() {
340 return Ok(());
341 }
342 panic!(
343 "unexpected err {:?}, want {:?}",
344 err,
345 Error::InvalidMessageType(0x01)
346 );
347 }
348 }
349 }
350
351 #[test]
352 fn test_channel_open_unmarshal_unexpected_end_of_buffer() -> Result<()> {
353 let mut bytes = Bytes::from_static(&[0x00; 5]);
354 match DataChannelOpen::unmarshal(&mut bytes) {
355 Ok(_) => panic!("expected Error, but got Ok"),
356 Err(err) => {
357 if let Some(&Error::UnexpectedEndOfBuffer {
358 expected: 11,
359 actual: 5,
360 }) = err.downcast_ref::<Error>()
361 {
362 return Ok(());
363 }
364 panic!(
365 "unexpected err {:?}, want {:?}",
366 err,
367 Error::InvalidMessageType(0x01)
368 );
369 }
370 }
371 }
372
373 #[test]
374 fn test_channel_open_unmarshal_unexpected_length_mismatch() -> Result<()> {
375 let mut bytes = Bytes::from_static(&[
376 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x08, ]);
382 match DataChannelOpen::unmarshal(&mut bytes) {
383 Ok(_) => panic!("expected Error, but got Ok"),
384 Err(err) => {
385 if let Some(&Error::UnexpectedEndOfBuffer {
386 expected: 13,
387 actual: 0,
388 }) = err.downcast_ref::<Error>()
389 {
390 return Ok(());
391 }
392 panic!(
393 "unexpected err {:?}, want {:?}",
394 err,
395 Error::InvalidMessageType(0x01)
396 );
397 }
398 }
399 }
400
401 #[test]
402 fn test_channel_open_marshal_size() -> Result<()> {
403 let channel_open = DataChannelOpen {
404 channel_type: ChannelType::Reliable,
405 priority: 3893,
406 reliability_parameter: 16715573,
407 label: b"label".to_vec(),
408 protocol: b"protocol".to_vec(),
409 };
410
411 let marshal_size = channel_open.marshal_size();
412
413 assert_eq!(marshal_size, 11 + 5 + 8);
414 Ok(())
415 }
416
417 #[test]
418 fn test_channel_open_marshal() -> Result<()> {
419 let channel_open = DataChannelOpen {
420 channel_type: ChannelType::Reliable,
421 priority: 3893,
422 reliability_parameter: 16715573,
423 label: b"label".to_vec(),
424 protocol: b"protocol".to_vec(),
425 };
426
427 let mut buf = BytesMut::with_capacity(11 + 5 + 8);
428 buf.resize(11 + 5 + 8, 0u8);
429 let bytes_written = channel_open.marshal_to(&mut buf).unwrap();
430 let bytes = buf.freeze();
431
432 assert_eq!(bytes_written, channel_open.marshal_size());
433 assert_eq!(&bytes[..], &MARSHALED_BYTES);
434 Ok(())
435 }
436}