rust_p2p_core/stun/
mod.rs

1use std::collections::HashSet;
2use std::io;
3use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
4use std::time::Duration;
5
6use crate::nat::NatType;
7use crate::socket::{bind_udp, LocalInterface};
8use rand::RngCore;
9use stun_format::Attr;
10use tokio::net::UdpSocket;
11
12/// Obtain nat information with the option specified interface from the stun servers
13pub async fn stun_test_nat(
14    stun_servers: Vec<String>,
15    default_interface: Option<&LocalInterface>,
16) -> io::Result<(NatType, Vec<Ipv4Addr>, u16)> {
17    let mut nat_type = NatType::Cone;
18    let mut port_range = 0;
19    let mut hash_set = HashSet::new();
20    for _ in 0..2 {
21        let stun_servers = stun_servers.clone();
22        match stun_test_nat0(stun_servers, default_interface).await {
23            Ok((nat_type_t, ip_list_t, port_range_t)) => {
24                if nat_type_t == NatType::Symmetric {
25                    nat_type = NatType::Symmetric;
26                }
27                for x in ip_list_t {
28                    hash_set.insert(x);
29                }
30                if port_range < port_range_t {
31                    port_range = port_range_t;
32                }
33            }
34            Err(e) => {
35                log::warn!("{:?}", e);
36            }
37        }
38    }
39    Ok((nat_type, hash_set.into_iter().collect(), port_range))
40}
41
42pub(crate) async fn stun_test_nat0(
43    stun_servers: Vec<String>,
44    default_interface: Option<&LocalInterface>,
45) -> io::Result<(NatType, Vec<Ipv4Addr>, u16)> {
46    let udp = bind_udp("0.0.0.0:0".parse().unwrap(), default_interface)?;
47    let udp = UdpSocket::from_std(udp.into())?;
48    let mut nat_type = NatType::Cone;
49    let mut min_port = u16::MAX;
50    let mut max_port = 0;
51    let mut hash_set = HashSet::new();
52    let mut pub_addrs = HashSet::new();
53    for x in &stun_servers {
54        match test_nat(&udp, x).await {
55            Ok(addr) => {
56                pub_addrs.extend(addr);
57            }
58            Err(e) => {
59                log::warn!("stun {} error {:?} ", x, e);
60            }
61        }
62    }
63    if pub_addrs.len() > 1 {
64        nat_type = NatType::Symmetric;
65    }
66    for addr in &pub_addrs {
67        if let SocketAddr::V4(addr) = addr {
68            hash_set.insert(*addr.ip());
69            if min_port > addr.port() {
70                min_port = addr.port()
71            }
72            if max_port < addr.port() {
73                max_port = addr.port()
74            }
75        }
76    }
77    if hash_set.is_empty() {
78        Ok((nat_type, vec![], 0))
79    } else {
80        Ok((
81            nat_type,
82            hash_set.into_iter().collect(),
83            max_port - min_port,
84        ))
85    }
86}
87
88async fn test_nat(udp: &UdpSocket, stun_server: &String) -> io::Result<HashSet<SocketAddr>> {
89    udp.connect(stun_server).await?;
90    let tid = rand::rng().next_u64() as u128;
91    let mut addr = HashSet::new();
92    let (mapped_addr1, changed_addr1) = test_nat_(udp, stun_server, true, true, tid).await?;
93    if mapped_addr1.is_ipv4() {
94        addr.insert(mapped_addr1);
95    }
96    if let Some(changed_addr1) = changed_addr1 {
97        if udp.connect(changed_addr1).await.is_ok() {
98            match test_nat_(udp, stun_server, false, false, tid + 1).await {
99                Ok((mapped_addr2, _)) => {
100                    if mapped_addr2.is_ipv4() {
101                        addr.insert(mapped_addr1);
102                    }
103                }
104                Err(e) => {
105                    log::warn!("stun {} error {:?} ", stun_server, e);
106                }
107            }
108        }
109    }
110    log::info!(
111        "stun {} mapped_addr {:?}  changed_addr {:?}",
112        stun_server,
113        addr,
114        changed_addr1,
115    );
116
117    Ok(addr)
118}
119
120async fn test_nat_(
121    udp: &UdpSocket,
122    stun_server: &String,
123    change_ip: bool,
124    change_port: bool,
125    tid: u128,
126) -> io::Result<(SocketAddr, Option<SocketAddr>)> {
127    for _ in 0..2 {
128        let mut buf = [0u8; 28];
129        let mut msg = stun_format::MsgBuilder::from(buf.as_mut_slice());
130        msg.typ(stun_format::MsgType::BindingRequest);
131        msg.tid(tid);
132        msg.add_attr(Attr::ChangeRequest {
133            change_ip,
134            change_port,
135        });
136        udp.send(msg.as_bytes()).await?;
137        let mut buf = [0; 10240];
138        let (len, _addr) =
139            match tokio::time::timeout(Duration::from_secs(3), udp.recv_from(&mut buf)).await {
140                Ok(rs) => rs?,
141                Err(e) => {
142                    log::warn!("stun {} error {:?}", stun_server, e);
143                    continue;
144                }
145            };
146        let msg = stun_format::Msg::from(&buf[..len]);
147        let mut mapped_addr = None;
148        let mut changed_addr = None;
149        for x in msg.attrs_iter() {
150            match x {
151                Attr::MappedAddress(addr) => {
152                    if mapped_addr.is_none() {
153                        let _ = mapped_addr.insert(stun_addr(addr));
154                    }
155                }
156                Attr::ChangedAddress(addr) => {
157                    if changed_addr.is_none() {
158                        let _ = changed_addr.insert(stun_addr(addr));
159                    }
160                }
161                Attr::XorMappedAddress(addr) => {
162                    if mapped_addr.is_none() {
163                        let _ = mapped_addr.insert(stun_addr(addr));
164                    }
165                }
166                _ => {}
167            }
168            if let Some(mapped_addr) = mapped_addr {
169                if changed_addr.is_some() {
170                    return Ok((mapped_addr, changed_addr));
171                }
172            }
173        }
174        if let Some(addr) = mapped_addr {
175            return Ok((addr, changed_addr));
176        }
177    }
178    Err(io::Error::new(io::ErrorKind::Other, "stun response err"))
179}
180
181fn stun_addr(addr: stun_format::SocketAddr) -> SocketAddr {
182    match addr {
183        stun_format::SocketAddr::V4(ip, port) => {
184            SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port))
185        }
186        stun_format::SocketAddr::V6(ip, port) => {
187            SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0))
188        }
189    }
190}
191
192const TAG: u128 = 1827549368 << 64;
193
194pub fn send_stun_request() -> Vec<u8> {
195    let mut buf = [0u8; 28];
196    let mut msg = stun_format::MsgBuilder::from(buf.as_mut_slice());
197    msg.typ(stun_format::MsgType::BindingRequest);
198    let id = rand::rng().next_u64() as u128;
199    msg.tid(id | TAG);
200    msg.add_attr(Attr::ChangeRequest {
201        change_ip: false,
202        change_port: false,
203    });
204    msg.as_bytes().to_vec()
205}
206pub fn is_stun_response(buf: &[u8]) -> bool {
207    buf[0] == 0x01
208}
209pub fn recv_stun_response(buf: &[u8]) -> Option<SocketAddr> {
210    let msg = stun_format::Msg::from(buf);
211    if let Some(tid) = msg.tid() {
212        if tid & TAG != TAG {
213            return None;
214        }
215    }
216    for x in msg.attrs_iter() {
217        match x {
218            Attr::MappedAddress(addr) => {
219                return Some(stun_addr(addr));
220            }
221            Attr::XorMappedAddress(addr) => {
222                return Some(stun_addr(addr));
223            }
224            _ => {}
225        }
226    }
227    None
228}