turn_server/codec/message/
mod.rs

1pub mod attributes;
2pub mod methods;
3
4use super::{
5    Attributes, Error,
6    crypto::{Password, fingerprint, hmac_sha1},
7    message::{
8        attributes::{Attribute, AttributeType, MessageIntegrity, MessageIntegritySha256},
9        methods::Method,
10    },
11};
12
13use bytes::{BufMut, BytesMut};
14
15static MAGIC_NUMBER: u32 = 0x2112A442;
16
17pub struct MessageEncoder<'a> {
18    transaction_id: &'a [u8],
19    bytes: &'a mut BytesMut,
20}
21
22impl<'a> MessageEncoder<'a> {
23    pub fn new(method: Method, transaction_id: &'a [u8; 12], bytes: &'a mut BytesMut) -> Self {
24        bytes.clear();
25        bytes.put_u16(method.into());
26        bytes.put_u16(0);
27        bytes.put_u32(MAGIC_NUMBER);
28        bytes.put(transaction_id.as_slice());
29
30        Self {
31            bytes,
32            transaction_id,
33        }
34    }
35
36    /// rely on old message to create new message.
37    ///
38    /// # Test
39    ///
40    /// ```
41    /// use bytes::BytesMut;
42    /// use std::convert::TryFrom;
43    /// use turn_server::codec::message::methods::*;
44    /// use turn_server::codec::message::*;
45    /// use turn_server::codec::*;
46    ///
47    /// let buffer = [
48    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
49    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
50    /// ];
51    ///
52    /// let mut attributes = Attributes::default();
53    /// let mut buf = BytesMut::new();
54    /// let old = Message::decode(&buffer[..], &mut attributes).unwrap();
55    /// MessageEncoder::extend(Method::Binding(MethodType::Request), &old, &mut buf);
56    ///
57    /// assert_eq!(&buf[..], &buffer[..]);
58    /// ```
59    pub fn extend(method: Method, reader: &Message<'a>, bytes: &'a mut BytesMut) -> Self {
60        let transaction_id = reader.transaction_id();
61
62        bytes.clear();
63        bytes.put_u16(method.into());
64        bytes.put_u16(0);
65        bytes.put_u32(MAGIC_NUMBER);
66        bytes.put(transaction_id);
67        Self {
68            bytes,
69            transaction_id,
70        }
71    }
72
73    /// append attribute.
74    ///
75    /// append attribute to message attribute list.
76    ///
77    /// # Test
78    ///
79    /// ```
80    /// use bytes::BytesMut;
81    /// use std::convert::TryFrom;
82    /// use turn_server::codec::message::attributes::*;
83    /// use turn_server::codec::message::methods::*;
84    /// use turn_server::codec::message::*;
85    /// use turn_server::codec::*;
86    ///
87    /// let buffer = [
88    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
89    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
90    /// ];
91    ///
92    /// let new_buf = [
93    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
94    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b, 0x00, 0x06, 0x00,
95    ///     0x05, 0x70, 0x61, 0x6e, 0x64, 0x61, 0x00, 0x00, 0x00,
96    /// ];
97    ///
98    /// let mut buf = BytesMut::new();
99    /// let mut attributes = Attributes::default();
100    /// let old = Message::decode(&buffer[..], &mut attributes).unwrap();
101    /// let mut message =
102    ///     MessageEncoder::extend(Method::Binding(MethodType::Request), &old, &mut buf);
103    ///
104    /// message.append::<UserName>("panda");
105    ///
106    /// assert_eq!(&new_buf[..], &buf[..]);
107    /// ```
108    pub fn append<'c, T: Attribute<'c>>(&'c mut self, value: T::Item) {
109        self.bytes.put_u16(T::TYPE as u16);
110
111        // record the current position,
112        // and then advance the internal cursor 2 bytes,
113        // here is to reserve the position.
114        let os = self.bytes.len();
115        unsafe { self.bytes.advance_mut(2) }
116        T::serialize(value, self.bytes, self.transaction_id);
117
118        // compute write index,
119        // back to source index write size.
120        let size = self.bytes.len() - os - 2;
121        let size_buf = (size as u16).to_be_bytes();
122        self.bytes[os] = size_buf[0];
123        self.bytes[os + 1] = size_buf[1];
124
125        // if you need to padding,
126        // padding in the zero bytes.
127        let psize = alignment_32(size);
128        if psize > 0 {
129            self.bytes.put(&[0u8; 10][0..psize]);
130        }
131    }
132
133    /// try decoder bytes as message.
134    ///
135    /// # Test
136    ///
137    /// ```
138    /// use bytes::BytesMut;
139    /// use std::convert::TryFrom;
140    /// use turn_server::codec::message::methods::*;
141    /// use turn_server::codec::message::*;
142    /// use turn_server::codec::*;
143    ///
144    /// let buffer = [
145    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
146    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
147    /// ];
148    ///
149    /// let result = [
150    ///     0, 1, 0, 32, 33, 18, 164, 66, 114, 109, 73, 66, 114, 82, 100, 72, 87,
151    ///     98, 75, 43, 0, 8, 0, 20, 69, 14, 110, 68, 82, 30, 232, 222, 44, 240,
152    ///     250, 182, 156, 92, 25, 23, 152, 198, 217, 222, 128, 40, 0, 4, 74, 165,
153    ///     171, 86,
154    /// ];
155    ///
156    /// let mut attributes = Attributes::default();
157    /// let mut buf = BytesMut::with_capacity(1280);
158    /// let old = Message::decode(&buffer[..], &mut attributes).unwrap();
159    /// let mut message =
160    ///     MessageEncoder::extend(Method::Binding(MethodType::Request), &old, &mut buf);
161    ///
162    /// message
163    ///     .flush(Some(&turn_server::codec::crypto::generate_password(
164    ///         "panda",
165    ///         "panda",
166    ///         "raspberry",
167    ///         turn_server::codec::message::attributes::PasswordAlgorithm::Md5,
168    ///     )))
169    ///     .unwrap();
170    ///
171    /// assert_eq!(&buf[..], &result);
172    /// ```
173    pub fn flush(&mut self, password: Option<&Password>) -> Result<(), Error> {
174        // write attribute list size.
175        self.set_len(self.bytes.len() - 20);
176
177        // if need message integrity?
178        if let Some(it) = password {
179            self.verify(it)?;
180        }
181
182        Ok(())
183    }
184
185    /// append MessageIntegrity attribute.
186    ///
187    /// add the `MessageIntegrity` attribute to the stun message
188    /// and serialize the message into a buffer.
189    ///
190    /// # Test
191    ///
192    /// ```
193    /// use bytes::BytesMut;
194    /// use std::convert::TryFrom;
195    /// use turn_server::codec::message::methods::*;
196    /// use turn_server::codec::message::*;
197    /// use turn_server::codec::*;
198    ///
199    /// let buffer = [
200    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
201    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
202    /// ];
203    ///
204    /// let result = [
205    ///     0, 1, 0, 32, 33, 18, 164, 66, 114, 109, 73, 66, 114, 82, 100, 72, 87,
206    ///     98, 75, 43, 0, 8, 0, 20, 69, 14, 110, 68, 82, 30, 232, 222, 44, 240,
207    ///     250, 182, 156, 92, 25, 23, 152, 198, 217, 222, 128, 40, 0, 4, 74, 165,
208    ///     171, 86,
209    /// ];
210    ///
211    /// let mut attributes = Attributes::default();
212    /// let mut buf = BytesMut::from(&buffer[..]);
213    /// let old = Message::decode(&buffer[..], &mut attributes).unwrap();
214    /// let mut message =
215    ///     MessageEncoder::extend(Method::Binding(MethodType::Request), &old, &mut buf);
216    ///
217    /// message
218    ///     .flush(Some(&turn_server::codec::crypto::generate_password(
219    ///         "panda",
220    ///         "panda",
221    ///         "raspberry",
222    ///         turn_server::codec::message::attributes::PasswordAlgorithm::Md5,
223    ///     )))
224    ///     .unwrap();
225    ///
226    /// assert_eq!(&buf[..], &result);
227    /// ```
228    fn verify(&mut self, passwrd: &Password) -> Result<(), Error> {
229        assert!(self.bytes.len() >= 20);
230        let len = self.bytes.len();
231
232        // compute new size,
233        // new size include the MessageIntegrity attribute size.
234        self.set_len(len + 4);
235
236        // write MessageIntegrity attribute.
237        {
238            let hmac = hmac_sha1(passwrd, &[self.bytes]);
239            self.bytes.put_u16(match passwrd {
240                Password::Md5(_) => AttributeType::MessageIntegrity as u16,
241                Password::Sha256(_) => AttributeType::MessageIntegritySha256 as u16,
242            });
243
244            self.bytes.put_u16(20);
245            self.bytes.put(hmac.as_slice());
246        }
247
248        // compute new size,
249        // new size include the Fingerprint attribute size.
250        self.set_len(len + 4 + 8);
251
252        // CRC Fingerprint
253        let fingerprint = fingerprint(self.bytes);
254        self.bytes.put_u16(AttributeType::Fingerprint as u16);
255        self.bytes.put_u16(4);
256        self.bytes.put_u32(fingerprint);
257
258        Ok(())
259    }
260
261    // set stun message header size.
262    fn set_len(&mut self, len: usize) {
263        self.bytes[2..4].copy_from_slice((len as u16).to_be_bytes().as_slice());
264    }
265}
266
267pub struct Message<'a> {
268    /// message method.
269    method: Method,
270    /// message source bytes.
271    bytes: &'a [u8],
272    /// message payload size.
273    size: u16,
274    // message attribute list.
275    attributes: &'a Attributes,
276}
277
278impl<'a> Message<'a> {
279    /// message method.
280    #[inline]
281    pub fn method(&self) -> Method {
282        self.method
283    }
284
285    /// message transaction id.
286    #[inline]
287    pub fn transaction_id(&self) -> &'a [u8] {
288        &self.bytes[8..20]
289    }
290
291    /// get attribute.
292    ///
293    /// get attribute from message attribute list.
294    ///
295    /// # Test
296    ///
297    /// ```
298    /// use std::convert::TryFrom;
299    /// use turn_server::codec::message::attributes::*;
300    /// use turn_server::codec::message::methods::*;
301    /// use turn_server::codec::message::*;
302    /// use turn_server::codec::*;
303    ///
304    /// let buffer = [
305    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
306    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
307    /// ];
308    ///
309    /// let mut attributes = Attributes::default();
310    /// let message = Message::decode(&buffer[..], &mut attributes).unwrap();
311    ///
312    /// assert!(message.get::<UserName>().is_none());
313    /// ```
314    pub fn get<T: Attribute<'a>>(&self) -> Option<T::Item> {
315        let range = self.attributes.get(&T::TYPE)?;
316        T::deserialize(&self.bytes[range], self.transaction_id()).ok()
317    }
318
319    /// get attribute for type.
320    ///
321    /// get attribute from message attribute list.
322    ///
323    /// # Test
324    ///
325    /// ```
326    /// use std::convert::TryFrom;
327    /// use turn_server::codec::message::attributes::*;
328    /// use turn_server::codec::message::methods::*;
329    /// use turn_server::codec::message::*;
330    /// use turn_server::codec::*;
331    ///
332    /// let buffer = [
333    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
334    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
335    /// ];
336    ///
337    /// let mut attributes = Attributes::default();
338    /// let message = Message::decode(&buffer[..], &mut attributes).unwrap();
339    ///
340    /// assert!(message.get_for_type(AttributeType::UserName).is_none());
341    /// ```
342    pub fn get_for_type(&self, attr_type: AttributeType) -> Option<&'a [u8]> {
343        let range = self.attributes.get(&attr_type)?;
344        Some(&self.bytes[range])
345    }
346
347    /// Gets all the values of an attribute from a list.
348    ///
349    /// Normally a stun message can have multiple attributes with the same name,
350    /// and this function will all the values of the current attribute.
351    ///
352    /// # Test
353    ///
354    /// ```
355    /// use std::convert::TryFrom;
356    /// use turn_server::codec::message::attributes::*;
357    /// use turn_server::codec::message::methods::*;
358    /// use turn_server::codec::message::*;
359    /// use turn_server::codec::*;
360    ///
361    /// let buffer = [
362    ///     0x00u8, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49,
363    ///     0x42, 0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
364    /// ];
365    ///
366    /// let mut attributes = Attributes::default();
367    /// let message = Message::decode(&buffer[..], &mut attributes).unwrap();
368    ///
369    /// assert_eq!(message.get_all::<UserName>().next(), None);
370    /// ```
371    pub fn get_all<T: Attribute<'a>>(&self) -> impl Iterator<Item = T::Item> {
372        self.attributes
373            .get_all(&T::TYPE)
374            .map(|it| T::deserialize(&self.bytes[it.clone()], self.transaction_id()))
375            .filter(|it| it.is_ok())
376            .flatten()
377    }
378
379    /// check MessageRefIntegrity attribute.
380    ///
381    /// return whether the `MessageRefIntegrity` attribute contained in the message
382    /// can pass the check.
383    ///
384    ///
385    /// # Test
386    ///
387    /// ```
388    /// use std::convert::TryFrom;
389    /// use turn_server::codec::message::methods::*;
390    /// use turn_server::codec::message::*;
391    /// use turn_server::codec::*;
392    ///
393    /// let buffer = [
394    ///     0x00u8, 0x03, 0x00, 0x50, 0x21, 0x12, 0xa4, 0x42, 0x64, 0x4f, 0x5a,
395    ///     0x78, 0x6a, 0x56, 0x33, 0x62, 0x4b, 0x52, 0x33, 0x31, 0x00, 0x19, 0x00,
396    ///     0x04, 0x11, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x05, 0x70, 0x61, 0x6e,
397    ///     0x64, 0x61, 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x09, 0x72, 0x61, 0x73,
398    ///     0x70, 0x62, 0x65, 0x72, 0x72, 0x79, 0x00, 0x00, 0x00, 0x00, 0x15, 0x00,
399    ///     0x10, 0x31, 0x63, 0x31, 0x33, 0x64, 0x32, 0x62, 0x32, 0x34, 0x35, 0x62,
400    ///     0x33, 0x61, 0x37, 0x33, 0x34, 0x00, 0x08, 0x00, 0x14, 0xd6, 0x78, 0x26,
401    ///     0x99, 0x0e, 0x15, 0x56, 0x15, 0xe5, 0xf4, 0x24, 0x74, 0xe2, 0x3c, 0x26,
402    ///     0xc5, 0xb1, 0x03, 0xb2, 0x6d,
403    /// ];
404    ///
405    /// let mut attributes = Attributes::default();
406    /// let message = Message::decode(&buffer[..], &mut attributes).unwrap();
407    /// let result = message
408    ///     .verify(&turn_server::codec::crypto::generate_password(
409    ///         "panda",
410    ///         "panda",
411    ///         "raspberry",
412    ///         turn_server::codec::message::attributes::PasswordAlgorithm::Md5,
413    ///     ))
414    ///     .is_ok();
415    ///
416    /// assert!(result);
417    /// ```
418    pub fn verify(&self, password: &Password) -> Result<(), Error> {
419        if self.bytes.is_empty() || self.size < 20 {
420            return Err(Error::InvalidInput);
421        }
422
423        // unwrap MessageIntegrity attribute,
424        // an error occurs if not found.
425        let integrity = match password {
426            Password::Md5(_) => self.get::<MessageIntegrity>(),
427            Password::Sha256(_) => self.get::<MessageIntegritySha256>(),
428        }
429        .ok_or(Error::NotFoundIntegrity)?;
430
431        // create multiple submit.
432        let size_buf = (self.size + 4).to_be_bytes();
433        let body = [
434            &self.bytes[0..2],
435            &size_buf,
436            &self.bytes[4..self.size as usize],
437        ];
438
439        // digest the message buffer.
440        {
441            // Compare local and original attribute.
442            if integrity != hmac_sha1(password, &body).as_slice() {
443                return Err(Error::IntegrityFailed);
444            }
445        }
446
447        Ok(())
448    }
449
450    /// # Test
451    ///
452    /// ```
453    /// use std::convert::TryFrom;
454    /// use turn_server::codec::message::attributes::*;
455    /// use turn_server::codec::message::methods::*;
456    /// use turn_server::codec::message::*;
457    /// use turn_server::codec::*;
458    ///
459    /// let buffer: [u8; 20] = [
460    ///     0x00, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49, 0x42,
461    ///     0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
462    /// ];
463    ///
464    /// let mut attributes = Attributes::default();
465    /// let message = Message::decode(&buffer[..], &mut attributes).unwrap();
466    ///
467    /// assert_eq!(
468    ///     message.method(),
469    ///     Method::Binding(MethodType::Request)
470    /// );
471    ///
472    /// assert!(message.get::<UserName>().is_none());
473    /// ```
474    pub fn decode(bytes: &'a [u8], attributes: &'a mut Attributes) -> Result<Self, Error> {
475        let len = bytes.len();
476
477        // There must be at least a complete header.
478        if len < 20 {
479            return Err(Error::InvalidInput);
480        }
481
482        let method = Method::try_from(u16::from_be_bytes(bytes[..2].try_into()?))?;
483
484        // First check whether the message length is valid. Here, the length needs
485        // to add the 20 bytes of the header, because the length field here does
486        // not include the header length.
487        {
488            let size = u16::from_be_bytes(bytes[2..4].try_into()?) as usize + 20;
489            if len < size {
490                return Err(Error::InvalidInput);
491            }
492        }
493
494        // Check whether the magic number is the same.
495        if bytes[4..8] != MAGIC_NUMBER.to_be_bytes() {
496            return Err(Error::NotFoundMagicNumber);
497        }
498
499        let mut find_integrity = false;
500        let mut content_len = 0;
501        let mut offset = 20;
502
503        loop {
504            // if the buf length is not long enough to continue,
505            // jump out of the loop.
506            if len - offset < 4 {
507                break;
508            }
509
510            // get attribute type
511            let key = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]);
512
513            // whether the MessageIntegrity attribute has been found,
514            // if found, record the current offset position.
515            if !find_integrity {
516                content_len = offset as u16;
517            }
518
519            // get attribute size
520            let size = u16::from_be_bytes([bytes[offset + 2], bytes[offset + 3]]) as usize;
521
522            // check if the attribute length has overflowed.
523            offset += 4;
524            if len - offset < size {
525                break;
526            }
527
528            // body range.
529            let range = offset..(offset + size);
530
531            // if there are padding bytes,
532            // skip padding size.
533            if size > 0 {
534                offset += size + alignment_32(size);
535            }
536
537            // skip the attributes that are not supported.
538            let attrkind = if let Ok(kind) = AttributeType::try_from(key) {
539                // check whether the current attribute is MessageIntegrity,
540                // if it is, mark this attribute has been found.
541                if kind == AttributeType::MessageIntegrity {
542                    find_integrity = true;
543                }
544
545                kind
546            } else {
547                continue;
548            };
549
550            // get attribute body
551            // insert attribute to attributes list.
552            attributes.append(attrkind, range);
553        }
554
555        Ok(Self {
556            size: content_len,
557            attributes,
558            method,
559            bytes,
560        })
561    }
562
563    /// # Test
564    ///
565    /// ```
566    /// use turn_server::codec::message::*;
567    ///
568    /// let buffer: [u8; 20] = [
569    ///     0x00, 0x01, 0x00, 0x00, 0x21, 0x12, 0xa4, 0x42, 0x72, 0x6d, 0x49, 0x42,
570    ///     0x72, 0x52, 0x64, 0x48, 0x57, 0x62, 0x4b, 0x2b,
571    /// ];
572    ///
573    /// let size = Message::message_size(&buffer[..]).unwrap();
574    ///
575    /// assert_eq!(size, 20);
576    /// ```
577    pub fn message_size(buffer: &[u8]) -> Result<usize, Error> {
578        if buffer[0] >> 6 != 0 || buffer.len() < 20 {
579            return Err(Error::InvalidInput);
580        }
581
582        Ok((u16::from_be_bytes(buffer[2..4].try_into()?) + 20) as usize)
583    }
584}
585
586/// compute padding size.
587///
588/// RFC5766 stipulates that the attribute content is a multiple of 4.
589///
590/// # Test
591///
592/// ```
593/// use turn_server::codec::message::alignment_32;
594///
595/// assert_eq!(alignment_32(4), 0);
596/// assert_eq!(alignment_32(0), 0);
597/// assert_eq!(alignment_32(5), 3);
598/// ```
599#[inline(always)]
600pub fn alignment_32(size: usize) -> usize {
601    let range = size % 4;
602    if size == 0 || range == 0 {
603        return 0;
604    }
605
606    4 - range
607}