stun_client/
message.rs

1//! This module implements some of the STUN protocol message processing based on RFC 8489 and RFC 5780.
2use std::collections::HashMap;
3use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
4
5use rand::{thread_rng, Rng};
6
7use super::error::*;
8
9/// Magic cookie
10pub const MAGIC_COOKIE: u32 = 0x2112A442;
11
12// Methods
13/// Binding method
14pub const METHOD_BINDING: u16 = 0x0001;
15
16// Classes
17/// A constant that represents a class request
18pub const CLASS_REQUEST: u16 = 0x0000;
19/// A constant that represents a class indication
20pub const CLASS_INDICATION: u16 = 0x0010;
21/// A constant that represents a class success response
22pub const CLASS_SUCCESS_RESPONSE: u16 = 0x0100;
23/// A constant that represents a class error response
24pub const CLASS_ERROR_RESPONSE: u16 = 0x0110;
25
26/// STUN header size
27pub const HEADER_BYTE_SIZE: usize = 20;
28
29// STUN Attributes
30/// MAPPED-ADDRESS attribute
31pub const ATTR_MAPPED_ADDRESS: u16 = 0x0001;
32/// XOR-MAPPED-ADDRESS attribute
33pub const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
34/// ERROR-CODE attribute
35pub const ATTR_ERROR_CODE: u16 = 0x0009;
36/// SOFTWARE attribute
37pub const ATTR_SOFTWARE: u16 = 0x8022;
38
39// RFC 5780 NAT Behavior Discovery
40/// OTHER-ADDRESS attribute
41pub const ATTR_OTHER_ADDRESS: u16 = 0x802c;
42/// CHANGE-REQUEST attribute
43pub const ATTR_CHANGE_REQUEST: u16 = 0x0003;
44/// RESPONSE-ORIGIN attribute
45pub const ATTR_RESPONSE_ORIGIN: u16 = 0x802b;
46
47/// The "change IP" flag for the CHANGE-REQUEST attribute.
48pub const CHANGE_REQUEST_IP_FLAG: u32 = 0x00000004;
49/// The "change port" flag for the CHANGE-REQUEST attribute.
50pub const CHANGE_REQUEST_PORT_FLAG: u32 = 0x00000002;
51
52pub const FAMILY_IPV4: u8 = 0x01;
53pub const FAMILY_IPV6: u8 = 0x02;
54
55/// Enum representing STUN method
56#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
57pub enum Method {
58    Binding,
59    Unknown(u16),
60}
61
62impl Method {
63    /// Convert from u16 to Method.
64    pub fn from_u16(method: u16) -> Self {
65        match method {
66            METHOD_BINDING => Self::Binding,
67            _ => Self::Unknown(method),
68        }
69    }
70
71    /// Convert from Method to u16.
72    pub fn to_u16(&self) -> u16 {
73        match self {
74            Self::Binding => METHOD_BINDING,
75            Self::Unknown(method) => method.clone(),
76        }
77    }
78}
79
80/// Enum representing STUN class
81#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
82pub enum Class {
83    Request,
84    Indication,
85    SuccessResponse,
86    ErrorResponse,
87    Unknown(u16),
88}
89
90impl Class {
91    /// Convert from u16 to Class.
92    pub fn from_u16(class: u16) -> Self {
93        match class {
94            CLASS_REQUEST => Self::Request,
95            CLASS_INDICATION => Self::Indication,
96            CLASS_SUCCESS_RESPONSE => Self::SuccessResponse,
97            CLASS_ERROR_RESPONSE => Self::ErrorResponse,
98            _ => Self::Unknown(class),
99        }
100    }
101
102    /// Convert from u16 to Class.
103    pub fn to_u16(&self) -> u16 {
104        match self {
105            Self::Request => CLASS_REQUEST,
106            Self::Indication => CLASS_INDICATION,
107            Self::SuccessResponse => CLASS_SUCCESS_RESPONSE,
108            Self::ErrorResponse => CLASS_ERROR_RESPONSE,
109            Self::Unknown(class) => class.clone(),
110        }
111    }
112}
113
114/// Enum representing STUN attribute
115#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
116pub enum Attribute {
117    MappedAddress,
118    XORMappedAddress,
119    Software,
120    OtherAddress,
121    ChangeRequest,
122    ResponseOrigin,
123    ErrorCode,
124    Unknown(u16),
125}
126
127impl Attribute {
128    /// Convert from u16 to Attribute.
129    pub fn from_u16(attribute: u16) -> Self {
130        match attribute {
131            ATTR_MAPPED_ADDRESS => Self::MappedAddress,
132            ATTR_XOR_MAPPED_ADDRESS => Self::XORMappedAddress,
133            ATTR_SOFTWARE => Self::Software,
134            ATTR_OTHER_ADDRESS => Self::OtherAddress,
135            ATTR_CHANGE_REQUEST => Self::ChangeRequest,
136            ATTR_RESPONSE_ORIGIN => Self::ResponseOrigin,
137            ATTR_ERROR_CODE => Self::ErrorCode,
138            _ => Self::Unknown(attribute),
139        }
140    }
141
142    /// Convert from u16 to Attribute.
143    pub fn to_u16(&self) -> u16 {
144        match self {
145            Self::MappedAddress => ATTR_MAPPED_ADDRESS,
146            Self::XORMappedAddress => ATTR_XOR_MAPPED_ADDRESS,
147            Self::Software => ATTR_SOFTWARE,
148            Self::OtherAddress => ATTR_OTHER_ADDRESS,
149            Self::ChangeRequest => ATTR_CHANGE_REQUEST,
150            Self::ResponseOrigin => ATTR_RESPONSE_ORIGIN,
151            Self::ErrorCode => ATTR_ERROR_CODE,
152            Self::Unknown(attribute) => attribute.clone(),
153        }
154    }
155
156    /// Gets the value of the MAPPED-ADDRESS attribute from Message.
157    pub fn get_mapped_address(message: &Message) -> Option<SocketAddr> {
158        Self::decode_simple_address_attribute(message, Self::MappedAddress)
159    }
160
161    /// Gets the value of the XOR-MAPPED-ADDRESS attribute from Message.
162    pub fn get_xor_mapped_address(message: &Message) -> Option<SocketAddr> {
163        let attr_value = message.get_raw_attr_value(Self::XORMappedAddress)?;
164        let family = attr_value[1];
165        // RFC8489: X-Port is computed by XOR'ing the mapped port with the most significant 16 bits of the magic cookie.
166        let mc_bytes = MAGIC_COOKIE.to_be_bytes();
167        let port = u16::from_be_bytes([attr_value[2], attr_value[3]])
168            ^ u16::from_be_bytes([mc_bytes[0], mc_bytes[1]]);
169        match family {
170            FAMILY_IPV4 => {
171                // RFC8489: If the IP address family is IPv4, X-Address is computed by XOR'ing the mapped IP address with the magic cookie.
172                let encoded_ip = &attr_value[4..];
173                let b: Vec<u8> = encoded_ip
174                    .iter()
175                    .zip(&MAGIC_COOKIE.to_be_bytes())
176                    .map(|(b, m)| b ^ m)
177                    .collect();
178                let ip_addr = bytes_to_ip_addr(family, b)?;
179                Some(SocketAddr::new(ip_addr, port))
180            }
181            FAMILY_IPV6 => {
182                // RFC8489: If the IP address family is IPv6, X-Address is computed by XOR'ing the mapped IP address with the concatenation of the magic cookie and the 96-bit transaction ID.
183                let encoded_ip = &attr_value[4..];
184                let mut mc_ti: Vec<u8> = vec![];
185                mc_ti.extend(&MAGIC_COOKIE.to_be_bytes());
186                mc_ti.extend(&message.header.transaction_id);
187                let b: Vec<u8> = encoded_ip.iter().zip(&mc_ti).map(|(b, m)| b ^ m).collect();
188                let ip_addr = bytes_to_ip_addr(family, b)?;
189                Some(SocketAddr::new(ip_addr, port))
190            }
191            _ => None,
192        }
193    }
194
195    /// Gets the value of the SOFTWARE attribute from message.
196    pub fn get_software(message: &Message) -> Option<String> {
197        let attr_value = message.get_raw_attr_value(Self::Software)?;
198        String::from_utf8(attr_value).ok()
199    }
200
201    /// Gets the value of the ERROR-CODE attribute from Message.
202    pub fn get_error_code(message: &Message) -> Option<ErrorCode> {
203        let attr_value = message.get_raw_attr_value(Self::ErrorCode)?;
204        let class = (attr_value[2] as u16) * 100;
205        let number = attr_value[3] as u16;
206        let code = class + number;
207        let reason = String::from_utf8(attr_value[4..].to_vec())
208            .unwrap_or(String::from("cannot parse error reason"));
209        Some(ErrorCode::from(code, reason))
210    }
211
212    /// Gets the value of the OTHER-ADDRESS attribute from Message.
213    pub fn get_other_address(message: &Message) -> Option<SocketAddr> {
214        // RFC5780: it is simply a new name with the same semantics as CHANGED-ADDRESS.
215        // RCF3489: Its syntax is identical to MAPPED-ADDRESS.
216        Self::decode_simple_address_attribute(message, Self::OtherAddress)
217    }
218
219    /// Gets the value of the RESPONSE-ORIGIN attribute from Message.
220    pub fn get_response_origin(message: &Message) -> Option<SocketAddr> {
221        Self::decode_simple_address_attribute(message, Self::ResponseOrigin)
222    }
223
224    /// Generates a value for the CHANGE-REQUEST attribute.
225    pub fn generate_change_request_value(change_ip: bool, change_port: bool) -> Vec<u8> {
226        let mut value: u32 = 0;
227        if change_ip {
228            value |= CHANGE_REQUEST_IP_FLAG;
229        }
230
231        if change_port {
232            value |= CHANGE_REQUEST_PORT_FLAG;
233        }
234
235        value.to_be_bytes().to_vec()
236    }
237
238    pub fn decode_simple_address_attribute(message: &Message, attr: Self) -> Option<SocketAddr> {
239        let attr_value = message.get_raw_attr_value(attr)?;
240        let family = attr_value[1];
241        let port = u16::from_be_bytes([attr_value[2], attr_value[3]]);
242        let ip_addr = bytes_to_ip_addr(family, attr_value[4..].to_vec())?;
243        Some(SocketAddr::new(ip_addr, port))
244    }
245}
246
247/// Struct representing STUN message
248#[derive(Debug, Eq, PartialEq)]
249pub struct Message {
250    header: Header,
251    attributes: Option<HashMap<Attribute, Vec<u8>>>,
252}
253
254impl Message {
255    /// Create a STUN Message.
256    pub fn new(
257        method: Method,
258        class: Class,
259        attributes: Option<HashMap<Attribute, Vec<u8>>>,
260    ) -> Message {
261        let attr_type_byte_size = 2;
262        let attr_length_byte_size = 2;
263        let length: u16 = if let Some(attributes) = &attributes {
264            attributes
265                .iter()
266                .map(|e| attr_type_byte_size + attr_length_byte_size + e.1.len() as u16)
267                .sum()
268        } else {
269            0
270        };
271
272        let transaction_id: Vec<u8> = thread_rng().gen::<[u8; 12]>().to_vec();
273
274        Message {
275            header: Header::new(method, class, length, transaction_id),
276            attributes: attributes,
277        }
278    }
279
280    /// Create a STUN message from raw bytes.
281    pub fn from_raw(buf: &[u8]) -> Result<Message, STUNClientError> {
282        if buf.len() < HEADER_BYTE_SIZE {
283            return Err(STUNClientError::ParseError());
284        }
285
286        let header = Header::from_raw(&buf[..HEADER_BYTE_SIZE])?;
287        let mut attrs = None;
288        if buf.len() > HEADER_BYTE_SIZE {
289            attrs = Some(Message::decode_attrs(&buf[HEADER_BYTE_SIZE..])?);
290        }
291
292        Ok(Message {
293            header: header,
294            attributes: attrs,
295        })
296    }
297
298    /// Converts a Message to a STUN protocol message raw bytes.
299    pub fn to_raw(&self) -> Vec<u8> {
300        let mut bytes = self.header.to_raw();
301        if let Some(attributes) = &self.attributes {
302            for (k, v) in attributes.iter() {
303                bytes.extend(&k.to_u16().to_be_bytes());
304                bytes.extend(&(v.len() as u16).to_be_bytes());
305                bytes.extend(v);
306            }
307        }
308
309        bytes
310    }
311
312    /// Get the method from Message.
313    pub fn get_method(&self) -> Method {
314        self.header.method
315    }
316
317    /// Get the class from Message.
318    pub fn get_class(&self) -> Class {
319        self.header.class
320    }
321
322    /// Get the raw attribute bytes from Message.
323    pub fn get_raw_attr_value(&self, attr: Attribute) -> Option<Vec<u8>> {
324        self.attributes
325            .as_ref()?
326            .get(&attr)
327            .and_then(|v| Some(v.clone()))
328    }
329
330    /// Get the transaction id from Message.
331    pub fn get_transaction_id(&self) -> Vec<u8> {
332        self.header.transaction_id.clone()
333    }
334
335    fn decode_attrs(attrs_buf: &[u8]) -> Result<HashMap<Attribute, Vec<u8>>, STUNClientError> {
336        let mut attrs_buf = attrs_buf.to_vec();
337        let mut attributes = HashMap::new();
338
339        if attrs_buf.is_empty() {
340            return Err(STUNClientError::ParseError());
341        }
342
343        while !attrs_buf.is_empty() {
344            if attrs_buf.len() < 4 {
345                break;
346            }
347
348            let attribute_type = Attribute::from_u16(u16::from_be_bytes([
349                attrs_buf.remove(0),
350                attrs_buf.remove(0),
351            ]));
352            let length =
353                u16::from_be_bytes([attrs_buf.remove(0), attrs_buf.remove(0)]) as usize;
354            if attrs_buf.len() < length {
355                return Err(STUNClientError::ParseError());
356            }
357
358            let value: Vec<u8> = attrs_buf.drain(..length).collect();
359            attributes.insert(attribute_type, value);
360        }
361
362        Ok(attributes)
363    }
364}
365
366/// Struct representing STUN header
367#[derive(Debug, Eq, PartialEq)]
368pub struct Header {
369    method: Method,
370    class: Class,
371    length: u16,
372    transaction_id: Vec<u8>,
373}
374
375impl Header {
376    /// Create a STUN header.
377    pub fn new(method: Method, class: Class, length: u16, transaction_id: Vec<u8>) -> Header {
378        Header {
379            class: class,
380            method: method,
381            length: length,
382            transaction_id: transaction_id,
383        }
384    }
385
386    /// Create a STUN header from raw bytes.
387    pub fn from_raw(buf: &[u8]) -> Result<Header, STUNClientError> {
388        let mut buf = buf.to_vec();
389        if buf.len() < HEADER_BYTE_SIZE {
390            return Err(STUNClientError::ParseError());
391        }
392
393        let message_type = u16::from_be_bytes([buf.remove(0), buf.remove(0)]);
394        let class = Header::decode_class(message_type);
395        let method = Header::decode_method(message_type);
396        let length = u16::from_be_bytes([buf.remove(0), buf.remove(0)]);
397
398        Ok(Header {
399            class: class,
400            method: method,
401            length: length,
402            // 0..3 is Magic Cookie
403            transaction_id: buf[4..].to_vec(),
404        })
405    }
406
407    /// Converts a Header to a STUN protocol header raw bytes.
408    pub fn to_raw(&self) -> Vec<u8> {
409        let message_type = self.message_type();
410        let mut bytes = vec![];
411        bytes.extend(&message_type.to_be_bytes());
412        bytes.extend(&self.length.to_be_bytes());
413        bytes.extend(&MAGIC_COOKIE.to_be_bytes());
414        bytes.extend(&self.transaction_id);
415        bytes
416    }
417
418    fn message_type(&self) -> u16 {
419        self.class.to_u16() | self.method.to_u16()
420    }
421
422    fn decode_method(message_type: u16) -> Method {
423        // RFC8489: M11 through M0 represent a 12-bit encoding of the method
424        Method::from_u16(message_type & 0x3EEF)
425    }
426
427    fn decode_class(message_type: u16) -> Class {
428        // RFC8489: C1 and C0 represent a 2-bit encoding of the class
429        Class::from_u16(message_type & 0x0110)
430    }
431}
432
433fn bytes_to_ip_addr(family: u8, b: Vec<u8>) -> Option<IpAddr> {
434    match family {
435        FAMILY_IPV4 => Some(IpAddr::V4(Ipv4Addr::from([b[0], b[1], b[2], b[3]]))),
436        FAMILY_IPV6 => Some(IpAddr::V6(Ipv6Addr::from([
437            b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13],
438            b[14], b[15],
439        ]))),
440        _ => None,
441    }
442}
443
444/// An enum that defines the type of STUN error code.
445#[derive(Clone, Debug, Eq, Hash, PartialEq)]
446pub enum ErrorCode {
447    TryAlternate(String),
448    BadRequest(String),
449    Unauthorized(String),
450    UnknownAttribute(String),
451    StaleNonce(String),
452    ServerError(String),
453    Unknown(String),
454}
455
456impl ErrorCode {
457    pub fn from(code: u16, reason: String) -> Self {
458        match code {
459            300 => Self::TryAlternate(reason),
460            400 => Self::BadRequest(reason),
461            401 => Self::Unauthorized(reason),
462            420 => Self::UnknownAttribute(reason),
463            438 => Self::StaleNonce(reason),
464            500 => Self::ServerError(reason),
465            _ => Self::Unknown(reason),
466        }
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn message_new_and_message_from_raw_are_equivalent() {
476        let mut attrs = HashMap::new();
477        attrs.insert(
478            Attribute::ChangeRequest,
479            Attribute::generate_change_request_value(true, false),
480        );
481        let msg = Message::new(Method::Binding, Class::Request, Some(attrs));
482        let re_built_msg = Message::from_raw(&msg.to_raw()).unwrap();
483        assert_eq!(msg, re_built_msg);
484    }
485}