stun_client/
client.rs

1//! This module is a thread-safe async-std-based asynchronous STUN client.
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6
7use async_macros::select;
8use async_std::future;
9use async_std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
10use async_std::task;
11use futures::channel::mpsc;
12use futures::stream::StreamExt;
13use futures::SinkExt;
14
15use super::error::*;
16use super::message::*;
17
18const DEFAULT_RECV_TIMEOUT_MS: u64 = 3000;
19const DEFAULT_RECV_BUF_SIZE: usize = 1024;
20
21/// STUN client options.
22#[derive(Clone, Debug)]
23pub struct Options {
24    pub recv_timeout_ms: u64,
25    pub recv_buf_size: usize,
26}
27
28/// STUN client.
29/// The transport protocol is UDP only and only supports simple STUN Binding requests.
30pub struct Client {
31    socket: Arc<UdpSocket>,
32    recv_timeout_ms: u64,
33    transactions: Arc<Mutex<HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>>>,
34    running: Arc<AtomicBool>,
35    stop_tx: mpsc::Sender<bool>,
36}
37
38impl Client {
39    /// Create a Client.
40    pub async fn new<A: ToSocketAddrs>(
41        local_bind_addr: A,
42        opts: Option<Options>,
43    ) -> Result<Client, STUNClientError> {
44        let socket = UdpSocket::bind(local_bind_addr)
45            .await
46            .map_err(|e| STUNClientError::IOError(e))?;
47        let socket = Arc::new(socket);
48        let transactions = Arc::new(Mutex::new(HashMap::new()));
49        let running = Arc::new(AtomicBool::new(true));
50        let (tx, rx) = mpsc::channel(1);
51        let recv_timeout_ms = opts
52            .clone()
53            .map(|o| o.recv_timeout_ms)
54            .unwrap_or_else(|| DEFAULT_RECV_TIMEOUT_MS);
55        let client = Client {
56            socket: socket.clone(),
57            recv_timeout_ms: recv_timeout_ms,
58            transactions: transactions.clone(),
59            running: running.clone(),
60            stop_tx: tx,
61        };
62
63        let recv_buf_size = opts
64            .map(|o| o.recv_buf_size)
65            .unwrap_or_else(|| DEFAULT_RECV_BUF_SIZE);
66        task::spawn(async move {
67            Self::run_message_receiver(socket, recv_buf_size, running, rx, transactions).await
68        });
69        Ok(client)
70    }
71
72    /// Create a Client from async_std::net::UdpSocket.
73    pub fn from_socket(socket: Arc<UdpSocket>, opts: Option<Options>) -> Client {
74        let transactions = Arc::new(Mutex::new(HashMap::new()));
75        let running = Arc::new(AtomicBool::new(true));
76        let (tx, rx) = mpsc::channel(1);
77        let recv_timeout_ms = opts
78            .clone()
79            .map(|o| o.recv_timeout_ms)
80            .unwrap_or_else(|| DEFAULT_RECV_TIMEOUT_MS);
81        let client = Client {
82            socket: socket.clone(),
83            recv_timeout_ms: recv_timeout_ms,
84            transactions: transactions.clone(),
85            running: running.clone(),
86            stop_tx: tx,
87        };
88
89        let recv_buf_size = opts
90            .map(|o| o.recv_buf_size)
91            .unwrap_or_else(|| DEFAULT_RECV_BUF_SIZE);
92        task::spawn(async move {
93            Self::run_message_receiver(socket, recv_buf_size, running, rx, transactions).await
94        });
95        client
96    }
97
98    /// Send STUN Binding request asynchronously.
99    pub async fn binding_request<A: ToSocketAddrs>(
100        &mut self,
101        stun_addr: A,
102        attrs: Option<HashMap<Attribute, Vec<u8>>>,
103    ) -> Result<Message, STUNClientError> {
104        let msg = Message::new(Method::Binding, Class::Request, attrs);
105        let (tx, mut rx) = mpsc::channel(1);
106        {
107            let mut m = self.transactions.lock().unwrap();
108            m.insert(msg.get_transaction_id(), tx);
109        }
110        let raw_msg = msg.to_raw();
111        self.socket
112            .send_to(&raw_msg, stun_addr)
113            .await
114            .map_err(|e| STUNClientError::IOError(e))?;
115
116        let fut = rx.next();
117        let res = future::timeout(Duration::from_millis(self.recv_timeout_ms), fut)
118            .await
119            .map_err(|_| STUNClientError::TimeoutError())?
120            .ok_or(STUNClientError::Unknown(String::from(
121                "Receive stream terminated unintentionally",
122            )))?;
123
124        {
125            let mut m = self.transactions.lock().unwrap();
126            m.remove(&msg.get_transaction_id());
127        }
128
129        res
130    }
131
132    async fn run_message_receiver(
133        socket: Arc<UdpSocket>,
134        recv_buf_size: usize,
135        running: Arc<AtomicBool>,
136        rx: mpsc::Receiver<bool>,
137        transactions: Arc<Mutex<HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>>>,
138    ) {
139        let mut rx = rx;
140        while running.load(Ordering::Relaxed) {
141            let mut buf = vec![0u8; recv_buf_size];
142            let sock_fut = Self::socket_recv(socket.clone(), &mut buf);
143            let stop_fut = Self::stop_recv(&mut rx);
144            let result = select!(sock_fut, stop_fut).await;
145
146            let socket_recv_result;
147            match result {
148                Event::Stop(_) => return,
149                Event::Socket(ev) => {
150                    socket_recv_result = ev;
151                }
152            }
153
154            let result = socket_recv_result.map_err(|e| STUNClientError::IOError(e));
155            match result {
156                Ok(result) => {
157                    let msg = Message::from_raw(&buf[..result.0]);
158                    match msg {
159                        Ok(msg) => {
160                            let tx: Option<mpsc::Sender<Result<Message, STUNClientError>>>;
161                            {
162                                // It's a bug if you panic with this unwrap
163                                let transactions = transactions.lock().unwrap();
164                                tx = transactions
165                                    .get(&msg.get_transaction_id())
166                                    .map(|v| v.clone());
167                            }
168                            if let Some(mut tx) = tx {
169                                tx.send(Ok(msg)).await.ok();
170                            }
171                        }
172                        Err(e) => {
173                            let transactions_unlocked: Option<
174                                HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>,
175                            >;
176                            {
177                                // It's a bug if you panic with this unwrap
178                                let t = transactions.lock().unwrap();
179                                transactions_unlocked = Some(t.clone());
180                            }
181                            if let Some(transactions_unlocked) = transactions_unlocked {
182                                for (_, transaction) in transactions_unlocked.iter() {
183                                    let mut transaction = transaction.clone();
184                                    transaction.send(Err(e.clone())).await.ok();
185                                }
186                            }
187                        }
188                    }
189                }
190                Err(e) => {
191                    let transactions_unlocked: Option<
192                        HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>,
193                    >;
194                    {
195                        // It's a bug if you panic with this unwrap
196                        let t = transactions.lock().unwrap();
197                        transactions_unlocked = Some(t.clone());
198                    }
199                    if let Some(transactions_unlocked) = transactions_unlocked {
200                        for transaction in transactions_unlocked.iter() {
201                            let mut transaction = transaction.1.clone();
202                            transaction.send(Err(e.clone())).await.ok();
203                        }
204                    }
205                }
206            }
207        }
208    }
209
210    async fn socket_recv(socket: Arc<UdpSocket>, buf: &mut [u8]) -> Event {
211        let result = socket.recv_from(buf).await;
212        Event::Socket(result)
213    }
214
215    async fn stop_recv(rx: &mut mpsc::Receiver<bool>) -> Event {
216        Event::Stop(rx.next().await.unwrap_or_else(|| true))
217    }
218}
219
220impl Drop for Client {
221    fn drop(&mut self) {
222        self.running.store(false, Ordering::Relaxed);
223        let mut tx = self.stop_tx.clone();
224        task::spawn(async move {
225            tx.send(true).await.ok();
226        });
227    }
228}
229
230enum Event {
231    Socket(Result<(usize, SocketAddr), std::io::Error>),
232    Stop(bool),
233}