1use std::collections::HashMap;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6
7use async_macros::select;
8use async_std::future;
9use async_std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
10use async_std::task;
11use futures::channel::mpsc;
12use futures::stream::StreamExt;
13use futures::SinkExt;
14
15use super::error::*;
16use super::message::*;
17
18const DEFAULT_RECV_TIMEOUT_MS: u64 = 3000;
19const DEFAULT_RECV_BUF_SIZE: usize = 1024;
20
21#[derive(Clone, Debug)]
23pub struct Options {
24 pub recv_timeout_ms: u64,
25 pub recv_buf_size: usize,
26}
27
28pub struct Client {
31 socket: Arc<UdpSocket>,
32 recv_timeout_ms: u64,
33 transactions: Arc<Mutex<HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>>>,
34 running: Arc<AtomicBool>,
35 stop_tx: mpsc::Sender<bool>,
36}
37
38impl Client {
39 pub async fn new<A: ToSocketAddrs>(
41 local_bind_addr: A,
42 opts: Option<Options>,
43 ) -> Result<Client, STUNClientError> {
44 let socket = UdpSocket::bind(local_bind_addr)
45 .await
46 .map_err(|e| STUNClientError::IOError(e))?;
47 let socket = Arc::new(socket);
48 let transactions = Arc::new(Mutex::new(HashMap::new()));
49 let running = Arc::new(AtomicBool::new(true));
50 let (tx, rx) = mpsc::channel(1);
51 let recv_timeout_ms = opts
52 .clone()
53 .map(|o| o.recv_timeout_ms)
54 .unwrap_or_else(|| DEFAULT_RECV_TIMEOUT_MS);
55 let client = Client {
56 socket: socket.clone(),
57 recv_timeout_ms: recv_timeout_ms,
58 transactions: transactions.clone(),
59 running: running.clone(),
60 stop_tx: tx,
61 };
62
63 let recv_buf_size = opts
64 .map(|o| o.recv_buf_size)
65 .unwrap_or_else(|| DEFAULT_RECV_BUF_SIZE);
66 task::spawn(async move {
67 Self::run_message_receiver(socket, recv_buf_size, running, rx, transactions).await
68 });
69 Ok(client)
70 }
71
72 pub fn from_socket(socket: Arc<UdpSocket>, opts: Option<Options>) -> Client {
74 let transactions = Arc::new(Mutex::new(HashMap::new()));
75 let running = Arc::new(AtomicBool::new(true));
76 let (tx, rx) = mpsc::channel(1);
77 let recv_timeout_ms = opts
78 .clone()
79 .map(|o| o.recv_timeout_ms)
80 .unwrap_or_else(|| DEFAULT_RECV_TIMEOUT_MS);
81 let client = Client {
82 socket: socket.clone(),
83 recv_timeout_ms: recv_timeout_ms,
84 transactions: transactions.clone(),
85 running: running.clone(),
86 stop_tx: tx,
87 };
88
89 let recv_buf_size = opts
90 .map(|o| o.recv_buf_size)
91 .unwrap_or_else(|| DEFAULT_RECV_BUF_SIZE);
92 task::spawn(async move {
93 Self::run_message_receiver(socket, recv_buf_size, running, rx, transactions).await
94 });
95 client
96 }
97
98 pub async fn binding_request<A: ToSocketAddrs>(
100 &mut self,
101 stun_addr: A,
102 attrs: Option<HashMap<Attribute, Vec<u8>>>,
103 ) -> Result<Message, STUNClientError> {
104 let msg = Message::new(Method::Binding, Class::Request, attrs);
105 let (tx, mut rx) = mpsc::channel(1);
106 {
107 let mut m = self.transactions.lock().unwrap();
108 m.insert(msg.get_transaction_id(), tx);
109 }
110 let raw_msg = msg.to_raw();
111 self.socket
112 .send_to(&raw_msg, stun_addr)
113 .await
114 .map_err(|e| STUNClientError::IOError(e))?;
115
116 let fut = rx.next();
117 let res = future::timeout(Duration::from_millis(self.recv_timeout_ms), fut)
118 .await
119 .map_err(|_| STUNClientError::TimeoutError())?
120 .ok_or(STUNClientError::Unknown(String::from(
121 "Receive stream terminated unintentionally",
122 )))?;
123
124 {
125 let mut m = self.transactions.lock().unwrap();
126 m.remove(&msg.get_transaction_id());
127 }
128
129 res
130 }
131
132 async fn run_message_receiver(
133 socket: Arc<UdpSocket>,
134 recv_buf_size: usize,
135 running: Arc<AtomicBool>,
136 rx: mpsc::Receiver<bool>,
137 transactions: Arc<Mutex<HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>>>,
138 ) {
139 let mut rx = rx;
140 while running.load(Ordering::Relaxed) {
141 let mut buf = vec![0u8; recv_buf_size];
142 let sock_fut = Self::socket_recv(socket.clone(), &mut buf);
143 let stop_fut = Self::stop_recv(&mut rx);
144 let result = select!(sock_fut, stop_fut).await;
145
146 let socket_recv_result;
147 match result {
148 Event::Stop(_) => return,
149 Event::Socket(ev) => {
150 socket_recv_result = ev;
151 }
152 }
153
154 let result = socket_recv_result.map_err(|e| STUNClientError::IOError(e));
155 match result {
156 Ok(result) => {
157 let msg = Message::from_raw(&buf[..result.0]);
158 match msg {
159 Ok(msg) => {
160 let tx: Option<mpsc::Sender<Result<Message, STUNClientError>>>;
161 {
162 let transactions = transactions.lock().unwrap();
164 tx = transactions
165 .get(&msg.get_transaction_id())
166 .map(|v| v.clone());
167 }
168 if let Some(mut tx) = tx {
169 tx.send(Ok(msg)).await.ok();
170 }
171 }
172 Err(e) => {
173 let transactions_unlocked: Option<
174 HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>,
175 >;
176 {
177 let t = transactions.lock().unwrap();
179 transactions_unlocked = Some(t.clone());
180 }
181 if let Some(transactions_unlocked) = transactions_unlocked {
182 for (_, transaction) in transactions_unlocked.iter() {
183 let mut transaction = transaction.clone();
184 transaction.send(Err(e.clone())).await.ok();
185 }
186 }
187 }
188 }
189 }
190 Err(e) => {
191 let transactions_unlocked: Option<
192 HashMap<Vec<u8>, mpsc::Sender<Result<Message, STUNClientError>>>,
193 >;
194 {
195 let t = transactions.lock().unwrap();
197 transactions_unlocked = Some(t.clone());
198 }
199 if let Some(transactions_unlocked) = transactions_unlocked {
200 for transaction in transactions_unlocked.iter() {
201 let mut transaction = transaction.1.clone();
202 transaction.send(Err(e.clone())).await.ok();
203 }
204 }
205 }
206 }
207 }
208 }
209
210 async fn socket_recv(socket: Arc<UdpSocket>, buf: &mut [u8]) -> Event {
211 let result = socket.recv_from(buf).await;
212 Event::Socket(result)
213 }
214
215 async fn stop_recv(rx: &mut mpsc::Receiver<bool>) -> Event {
216 Event::Stop(rx.next().await.unwrap_or_else(|| true))
217 }
218}
219
220impl Drop for Client {
221 fn drop(&mut self) {
222 self.running.store(false, Ordering::Relaxed);
223 let mut tx = self.stop_tx.clone();
224 task::spawn(async move {
225 tx.send(true).await.ok();
226 });
227 }
228}
229
230enum Event {
231 Socket(Result<(usize, SocketAddr), std::io::Error>),
232 Stop(bool),
233}