webrtc_turn/client/
relay_conn.rs

1#[cfg(test)]
2mod relay_conn_test;
3
4// client implements the API for a TURN client
5use super::binding::*;
6use super::periodic_timer::*;
7use super::permission::*;
8use super::transaction::*;
9use crate::proto;
10
11use crate::errors::*;
12
13use stun::agent::*;
14use stun::attributes::*;
15use stun::error_code::*;
16use stun::fingerprint::*;
17use stun::integrity::*;
18use stun::message::*;
19use stun::textattrs::*;
20
21use util::{Conn, Error};
22
23use std::io;
24use std::net::SocketAddr;
25use std::sync::Arc;
26
27use tokio::sync::{mpsc, Mutex};
28use tokio::time::{Duration, Instant};
29
30use async_trait::async_trait;
31
32const PERM_REFRESH_INTERVAL: Duration = Duration::from_secs(120);
33const MAX_RETRY_ATTEMPTS: u16 = 3;
34
35pub(crate) struct InboundData {
36    pub(crate) data: Vec<u8>,
37    pub(crate) from: SocketAddr,
38}
39
40// UDPConnObserver is an interface to UDPConn observer
41#[async_trait]
42pub trait RelayConnObserver {
43    fn turn_server_addr(&self) -> String;
44    fn username(&self) -> Username;
45    fn realm(&self) -> Realm;
46    async fn write_to(&self, data: &[u8], to: &str) -> Result<usize, Error>;
47    async fn perform_transaction(
48        &mut self,
49        msg: &Message,
50        to: &str,
51        ignore_result: bool,
52    ) -> Result<TransactionResult, Error>;
53}
54
55// RelayConnConfig is a set of configuration params use by NewUDPConn
56pub(crate) struct RelayConnConfig {
57    pub(crate) relayed_addr: SocketAddr,
58    pub(crate) integrity: MessageIntegrity,
59    pub(crate) nonce: Nonce,
60    pub(crate) lifetime: Duration,
61    pub(crate) binding_mgr: Arc<Mutex<BindingManager>>,
62    pub(crate) read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
63}
64
65pub struct RelayConnInternal<T: 'static + RelayConnObserver + Send + Sync> {
66    obs: Arc<Mutex<T>>,
67    relayed_addr: SocketAddr,
68    perm_map: PermissionMap,
69    binding_mgr: Arc<Mutex<BindingManager>>,
70    integrity: MessageIntegrity,
71    nonce: Nonce,
72    lifetime: Duration,
73}
74
75// RelayConn is the implementation of the Conn interfaces for UDP Relayed network connections.
76pub struct RelayConn<T: 'static + RelayConnObserver + Send + Sync> {
77    relayed_addr: SocketAddr,
78    read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
79    relay_conn: Arc<Mutex<RelayConnInternal<T>>>,
80    refresh_alloc_timer: PeriodicTimer,
81    refresh_perms_timer: PeriodicTimer,
82}
83
84impl<T: 'static + RelayConnObserver + Send + Sync> RelayConn<T> {
85    // new creates a new instance of UDPConn
86    pub(crate) fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
87        log::debug!("initial lifetime: {} seconds", config.lifetime.as_secs());
88
89        let mut c = RelayConn {
90            refresh_alloc_timer: PeriodicTimer::new(TimerIdRefresh::Alloc, config.lifetime / 2),
91            refresh_perms_timer: PeriodicTimer::new(TimerIdRefresh::Perms, PERM_REFRESH_INTERVAL),
92            relayed_addr: config.relayed_addr,
93            read_ch_rx: Arc::clone(&config.read_ch_rx),
94            relay_conn: Arc::new(Mutex::new(RelayConnInternal::new(obs, config))),
95        };
96
97        let rci1 = Arc::clone(&c.relay_conn);
98        let rci2 = Arc::clone(&c.relay_conn);
99
100        if c.refresh_alloc_timer.start(rci1) {
101            log::debug!("refresh_alloc_timer started");
102        }
103        if c.refresh_perms_timer.start(rci2) {
104            log::debug!("refresh_perms_timer started");
105        }
106
107        c
108    }
109
110    // Close closes the connection.
111    // Any blocked ReadFrom or write_to operations will be unblocked and return errors.
112    pub async fn close(&mut self) -> Result<(), Error> {
113        self.refresh_alloc_timer.stop();
114        self.refresh_perms_timer.stop();
115
116        let mut relay_conn = self.relay_conn.lock().await;
117        relay_conn.close().await
118    }
119}
120
121#[async_trait]
122impl<T: RelayConnObserver + Send + Sync> Conn for RelayConn<T> {
123    async fn connect(&self, _addr: SocketAddr) -> io::Result<()> {
124        Err(io::Error::new(io::ErrorKind::Other, "Not applicable"))
125    }
126
127    async fn recv(&self, _buf: &mut [u8]) -> io::Result<usize> {
128        Err(io::Error::new(io::ErrorKind::Other, "Not applicable"))
129    }
130
131    // ReadFrom reads a packet from the connection,
132    // copying the payload into p. It returns the number of
133    // bytes copied into p and the return address that
134    // was on the packet.
135    // It returns the number of bytes read (0 <= n <= len(p))
136    // and any error encountered. Callers should always process
137    // the n > 0 bytes returned before considering the error err.
138    // ReadFrom can be made to time out and return
139    // an Error with Timeout() == true after a fixed time limit;
140    // see SetDeadline and SetReadDeadline.
141    async fn recv_from(&self, p: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
142        let mut read_ch_rx = self.read_ch_rx.lock().await;
143
144        if let Some(ib_data) = read_ch_rx.recv().await {
145            let n = ib_data.data.len();
146            if p.len() < n {
147                return Err(io::Error::new(
148                    io::ErrorKind::InvalidInput,
149                    ERR_SHORT_BUFFER.to_string(),
150                ));
151            }
152            p[..n].copy_from_slice(&ib_data.data);
153            Ok((n, ib_data.from))
154        } else {
155            Err(io::Error::new(
156                io::ErrorKind::ConnectionAborted,
157                ERR_ALREADY_CLOSED.to_string(),
158            ))
159        }
160    }
161
162    async fn send(&self, _buf: &[u8]) -> io::Result<usize> {
163        Err(io::Error::new(io::ErrorKind::Other, "Not applicable"))
164    }
165
166    // write_to writes a packet with payload p to addr.
167    // write_to can be made to time out and return
168    // an Error with Timeout() == true after a fixed time limit;
169    // see SetDeadline and SetWriteDeadline.
170    // On packet-oriented connections, write timeouts are rare.
171    async fn send_to(&self, p: &[u8], addr: SocketAddr) -> io::Result<usize> {
172        let mut relay_conn = self.relay_conn.lock().await;
173        match relay_conn.send_to(p, addr).await {
174            Ok(n) => Ok(n),
175            Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string())),
176        }
177    }
178
179    // LocalAddr returns the local network address.
180    async fn local_addr(&self) -> io::Result<SocketAddr> {
181        Ok(self.relayed_addr)
182    }
183}
184
185impl<T: RelayConnObserver + Send + Sync> RelayConnInternal<T> {
186    // new creates a new instance of UDPConn
187    fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
188        RelayConnInternal {
189            obs,
190            relayed_addr: config.relayed_addr,
191            perm_map: PermissionMap::new(),
192            binding_mgr: config.binding_mgr,
193            integrity: config.integrity,
194            nonce: config.nonce,
195            lifetime: config.lifetime,
196        }
197    }
198
199    // write_to writes a packet with payload p to addr.
200    // write_to can be made to time out and return
201    // an Error with Timeout() == true after a fixed time limit;
202    // see SetDeadline and SetWriteDeadline.
203    // On packet-oriented connections, write timeouts are rare.
204    async fn send_to(&mut self, p: &[u8], addr: SocketAddr) -> Result<usize, Error> {
205        // check if we have a permission for the destination IP addr
206        let mut perm = if let Some(perm) = self.perm_map.find(&addr) {
207            *perm
208        } else {
209            let perm = Permission::default();
210            self.perm_map.insert(&addr, perm);
211            perm
212        };
213
214        let mut result = Ok(());
215        for _ in 0..MAX_RETRY_ATTEMPTS {
216            result = self.create_perm(&mut perm, addr).await;
217            if let Err(err) = &result {
218                if *err != *ERR_TRY_AGAIN {
219                    break;
220                }
221            }
222        }
223        if let Err(err) = result {
224            return Err(err);
225        }
226
227        let number = {
228            let (bind_st, bind_at, bind_number, bind_addr) = {
229                let mut binding_mgr = self.binding_mgr.lock().await;
230                let b = if let Some(b) = binding_mgr.find_by_addr(&addr) {
231                    b
232                } else {
233                    binding_mgr
234                        .create(addr)
235                        .ok_or_else(|| Error::new("Addr not found".to_owned()))?
236                };
237                (b.state(), b.refreshed_at(), b.number, b.addr)
238            };
239
240            if bind_st == BindingState::Idle
241                || bind_st == BindingState::Request
242                || bind_st == BindingState::Failed
243            {
244                // block only callers with the same binding until
245                // the binding transaction has been complete
246                // binding state may have been changed while waiting. check again.
247                if bind_st == BindingState::Idle {
248                    let binding_mgr = Arc::clone(&self.binding_mgr);
249                    let rc_obs = Arc::clone(&self.obs);
250                    let nonce = self.nonce.clone();
251                    let integrity = self.integrity.clone();
252                    tokio::spawn(async move {
253                        {
254                            let mut bm = binding_mgr.lock().await;
255                            if let Some(b) = bm.get_by_addr(&bind_addr) {
256                                b.set_state(BindingState::Request);
257                            }
258                        }
259
260                        let result = RelayConnInternal::bind(
261                            rc_obs,
262                            bind_addr,
263                            bind_number,
264                            nonce,
265                            integrity,
266                        )
267                        .await;
268
269                        {
270                            let mut bm = binding_mgr.lock().await;
271                            if let Err(err) = result {
272                                if err != *ERR_UNEXPECTED_RESPONSE {
273                                    bm.delete_by_addr(&bind_addr);
274                                } else if let Some(b) = bm.get_by_addr(&bind_addr) {
275                                    b.set_state(BindingState::Failed);
276                                }
277
278                                // keep going...
279                                log::warn!("bind() failed: {}", err);
280                            } else if let Some(b) = bm.get_by_addr(&bind_addr) {
281                                b.set_state(BindingState::Ready);
282                            }
283                        }
284                    });
285                }
286
287                // send data using SendIndication
288                let peer_addr = socket_addr2peer_address(&addr);
289                let mut msg = Message::new();
290                msg.build(&[
291                    Box::new(TransactionId::new()),
292                    Box::new(MessageType::new(METHOD_SEND, CLASS_INDICATION)),
293                    Box::new(proto::data::Data(p.to_vec())),
294                    Box::new(peer_addr),
295                    Box::new(FINGERPRINT),
296                ])?;
297
298                // indication has no transaction (fire-and-forget)
299                let obs = self.obs.lock().await;
300                let turn_server_addr = obs.turn_server_addr();
301                return obs.write_to(&msg.raw, &turn_server_addr).await;
302            }
303
304            // binding is either ready
305
306            // check if the binding needs a refresh
307            if bind_st == BindingState::Ready
308                && Instant::now().duration_since(bind_at) > Duration::from_secs(5 * 60)
309            {
310                let binding_mgr = Arc::clone(&self.binding_mgr);
311                let rc_obs = Arc::clone(&self.obs);
312                let nonce = self.nonce.clone();
313                let integrity = self.integrity.clone();
314                tokio::spawn(async move {
315                    {
316                        let mut bm = binding_mgr.lock().await;
317                        if let Some(b) = bm.get_by_addr(&bind_addr) {
318                            b.set_state(BindingState::Refresh);
319                        }
320                    }
321
322                    let result =
323                        RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity)
324                            .await;
325
326                    {
327                        let mut bm = binding_mgr.lock().await;
328                        if let Err(err) = result {
329                            if err != *ERR_UNEXPECTED_RESPONSE {
330                                bm.delete_by_addr(&bind_addr);
331                            } else if let Some(b) = bm.get_by_addr(&bind_addr) {
332                                b.set_state(BindingState::Failed);
333                            }
334
335                            // keep going...
336                            log::warn!("bind() for refresh failed: {}", err);
337                        } else if let Some(b) = bm.get_by_addr(&bind_addr) {
338                            b.set_refreshed_at(Instant::now());
339                            b.set_state(BindingState::Ready);
340                        }
341                    }
342                });
343            }
344
345            bind_number
346        };
347
348        // send via ChannelData
349        self.send_channel_data(p, number).await
350    }
351
352    // This func-block would block, per destination IP (, or perm), until
353    // the perm state becomes "requested". Purpose of this is to guarantee
354    // the order of packets (within the same perm).
355    // Note that CreatePermission transaction may not be complete before
356    // all the data transmission. This is done assuming that the request
357    // will be mostly likely successful and we can tolerate some loss of
358    // UDP packet (or reorder), inorder to minimize the latency in most cases.
359    async fn create_perm(&mut self, perm: &mut Permission, addr: SocketAddr) -> Result<(), Error> {
360        if perm.state() == PermState::Idle {
361            // punch a hole! (this would block a bit..)
362            if let Err(err) = self.create_permissions(&[addr]).await {
363                self.perm_map.delete(&addr);
364                return Err(err);
365            }
366            perm.set_state(PermState::Permitted);
367        }
368        Ok(())
369    }
370
371    async fn send_channel_data(&self, data: &[u8], ch_num: u16) -> Result<usize, Error> {
372        let mut ch_data = proto::chandata::ChannelData {
373            data: data.to_vec(),
374            number: proto::channum::ChannelNumber(ch_num),
375            ..Default::default()
376        };
377        ch_data.encode();
378
379        let obs = self.obs.lock().await;
380        obs.write_to(&ch_data.raw, &obs.turn_server_addr()).await
381    }
382
383    async fn create_permissions(&mut self, addrs: &[SocketAddr]) -> Result<(), Error> {
384        let res = {
385            let msg = {
386                let obs = self.obs.lock().await;
387                let mut setters: Vec<Box<dyn Setter>> = vec![
388                    Box::new(TransactionId::new()),
389                    Box::new(MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST)),
390                ];
391
392                for addr in addrs {
393                    setters.push(Box::new(socket_addr2peer_address(addr)));
394                }
395
396                setters.push(Box::new(obs.username()));
397                setters.push(Box::new(obs.realm()));
398                setters.push(Box::new(self.nonce.clone()));
399                setters.push(Box::new(self.integrity.clone()));
400                setters.push(Box::new(FINGERPRINT));
401
402                let mut msg = Message::new();
403                msg.build(&setters)?;
404                msg
405            };
406
407            let mut obs = self.obs.lock().await;
408            let turn_server_addr = obs.turn_server_addr();
409
410            log::debug!("UDPConn.createPermissions call PerformTransaction 1");
411            let tr_res = obs
412                .perform_transaction(&msg, &turn_server_addr, false)
413                .await?;
414
415            tr_res.msg
416        };
417
418        if res.typ.class == CLASS_ERROR_RESPONSE {
419            let mut code = ErrorCodeAttribute::default();
420            let result = code.get_from(&res);
421            if result.is_err() {
422                return Err(Error::new(format!("{}", res.typ)));
423            } else if code.code == CODE_STALE_NONCE {
424                self.set_nonce_from_msg(&res);
425                return Err(ERR_TRY_AGAIN.to_owned());
426            } else {
427                return Err(Error::new(format!("{} (error {})", res.typ, code)));
428            }
429        }
430
431        Ok(())
432    }
433
434    pub fn set_nonce_from_msg(&mut self, msg: &Message) {
435        // Update nonce
436        match Nonce::get_from_as(msg, ATTR_NONCE) {
437            Ok(nonce) => {
438                self.nonce = nonce;
439                log::debug!("refresh allocation: 438, got new nonce.");
440            }
441            Err(_) => log::warn!("refresh allocation: 438 but no nonce."),
442        }
443    }
444
445    // Close closes the connection.
446    // Any blocked ReadFrom or write_to operations will be unblocked and return errors.
447    pub async fn close(&mut self) -> Result<(), Error> {
448        self.refresh_allocation(Duration::from_secs(0), true /* dontWait=true */)
449            .await
450    }
451
452    async fn refresh_allocation(
453        &mut self,
454        lifetime: Duration,
455        dont_wait: bool,
456    ) -> Result<(), Error> {
457        let res = {
458            let mut obs = self.obs.lock().await;
459
460            let mut msg = Message::new();
461            msg.build(&[
462                Box::new(TransactionId::new()),
463                Box::new(MessageType::new(METHOD_REFRESH, CLASS_REQUEST)),
464                Box::new(proto::lifetime::Lifetime(lifetime)),
465                Box::new(obs.username()),
466                Box::new(obs.realm()),
467                Box::new(self.nonce.clone()),
468                Box::new(self.integrity.clone()),
469                Box::new(FINGERPRINT),
470            ])?;
471
472            log::debug!("send refresh request (dont_wait={})", dont_wait);
473            let turn_server_addr = obs.turn_server_addr();
474            let tr_res = obs
475                .perform_transaction(&msg, &turn_server_addr, dont_wait)
476                .await?;
477
478            if dont_wait {
479                log::debug!("refresh request sent");
480                return Ok(());
481            }
482
483            log::debug!("refresh request sent, and waiting response");
484
485            tr_res.msg
486        };
487
488        if res.typ.class == CLASS_ERROR_RESPONSE {
489            let mut code = ErrorCodeAttribute::default();
490            let result = code.get_from(&res);
491            if result.is_err() {
492                return Err(Error::new(format!("{}", res.typ)));
493            } else if code.code == CODE_STALE_NONCE {
494                self.set_nonce_from_msg(&res);
495                return Err(ERR_TRY_AGAIN.to_owned());
496            } else {
497                return Ok(());
498            }
499        }
500
501        // Getting lifetime from response
502        let mut updated_lifetime = proto::lifetime::Lifetime::default();
503        updated_lifetime.get_from(&res)?;
504
505        self.lifetime = updated_lifetime.0;
506        log::debug!("updated lifetime: {} seconds", self.lifetime.as_secs());
507        Ok(())
508    }
509
510    async fn refresh_permissions(&mut self) -> Result<(), Error> {
511        let addrs = self.perm_map.addrs();
512        if addrs.is_empty() {
513            log::debug!("no permission to refresh");
514            return Ok(());
515        }
516
517        if let Err(err) = self.create_permissions(&addrs).await {
518            if err != *ERR_TRY_AGAIN {
519                log::error!("fail to refresh permissions: {}", err);
520            }
521            return Err(err);
522        }
523
524        log::debug!("refresh permissions successful");
525        Ok(())
526    }
527
528    async fn bind(
529        rc_obs: Arc<Mutex<T>>,
530        bind_addr: SocketAddr,
531        bind_number: u16,
532        nonce: Nonce,
533        integrity: MessageIntegrity,
534    ) -> Result<(), Error> {
535        let (msg, turn_server_addr) = {
536            let obs = rc_obs.lock().await;
537
538            let setters: Vec<Box<dyn Setter>> = vec![
539                Box::new(TransactionId::new()),
540                Box::new(MessageType::new(METHOD_CHANNEL_BIND, CLASS_REQUEST)),
541                Box::new(socket_addr2peer_address(&bind_addr)),
542                Box::new(proto::channum::ChannelNumber(bind_number)),
543                Box::new(obs.username()),
544                Box::new(obs.realm()),
545                Box::new(nonce),
546                Box::new(integrity),
547                Box::new(FINGERPRINT),
548            ];
549
550            let mut msg = Message::new();
551            msg.build(&setters)?;
552
553            (msg, obs.turn_server_addr())
554        };
555
556        log::debug!("UDPConn.bind call PerformTransaction 1");
557        let tr_res = {
558            let mut obs = rc_obs.lock().await;
559            obs.perform_transaction(&msg, &turn_server_addr, false)
560                .await?
561        };
562
563        let res = tr_res.msg;
564
565        if res.typ != MessageType::new(METHOD_CHANNEL_BIND, CLASS_SUCCESS_RESPONSE) {
566            return Err(ERR_UNEXPECTED_RESPONSE.to_owned());
567        }
568
569        log::debug!("channel binding successful: {} {}", bind_addr, bind_number);
570
571        // Success.
572        Ok(())
573    }
574}
575
576#[async_trait]
577impl<T: RelayConnObserver + Send + Sync> PeriodicTimerTimeoutHandler for RelayConnInternal<T> {
578    async fn on_timeout(&mut self, id: TimerIdRefresh) {
579        log::debug!("refresh timer {:?} expired", id);
580        match id {
581            TimerIdRefresh::Alloc => {
582                let lifetime = self.lifetime;
583                // limit the max retries on errTryAgain to 3
584                // when stale nonce returns, sencond retry should succeed
585                let mut result = Ok(());
586                for _ in 0..MAX_RETRY_ATTEMPTS {
587                    result = self.refresh_allocation(lifetime, false).await;
588                    if let Err(err) = &result {
589                        if *err != *ERR_TRY_AGAIN {
590                            break;
591                        }
592                    }
593                }
594                if result.is_err() {
595                    log::warn!("refresh allocation failed");
596                }
597            }
598            TimerIdRefresh::Perms => {
599                let mut result = Ok(());
600                for _ in 0..MAX_RETRY_ATTEMPTS {
601                    result = self.refresh_permissions().await;
602                    if let Err(err) = &result {
603                        if *err != *ERR_TRY_AGAIN {
604                            break;
605                        }
606                    }
607                }
608                if result.is_err() {
609                    log::warn!("refresh permissions failed");
610                }
611            }
612        }
613    }
614}
615
616fn socket_addr2peer_address(addr: &SocketAddr) -> proto::peeraddr::PeerAddress {
617    proto::peeraddr::PeerAddress {
618        ip: addr.ip(),
619        port: addr.port(),
620    }
621}