Skip to main content

turn/client/
relay_conn.rs

1#[cfg(test)]
2mod relay_conn_test;
3
4// client implements the API for a TURN client
5use std::io;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use stun::agent::*;
11use stun::attributes::*;
12use stun::error_code::*;
13use stun::fingerprint::*;
14use stun::integrity::*;
15use stun::message::*;
16use stun::textattrs::*;
17use tokio::sync::{mpsc, Mutex};
18use tokio::time::{Duration, Instant};
19use util::Conn;
20
21use super::binding::*;
22use super::periodic_timer::*;
23use super::permission::*;
24use super::transaction::*;
25use crate::{proto, Error};
26
27const PERM_REFRESH_INTERVAL: Duration = Duration::from_secs(120);
28const MAX_RETRY_ATTEMPTS: u16 = 3;
29
30pub(crate) struct InboundData {
31    pub(crate) data: Vec<u8>,
32    pub(crate) from: SocketAddr,
33}
34
35/// `RelayConnObserver` is an interface to [`RelayConn`] observer.
36#[async_trait]
37pub trait RelayConnObserver {
38    fn turn_server_addr(&self) -> String;
39    fn username(&self) -> Username;
40    fn realm(&self) -> Realm;
41    async fn write_to(&self, data: &[u8], to: &str) -> Result<usize, util::Error>;
42    async fn perform_transaction(
43        &mut self,
44        msg: &Message,
45        to: &str,
46        ignore_result: bool,
47    ) -> Result<TransactionResult, Error>;
48}
49
50/// `RelayConnConfig` is a set of configuration params used by [`RelayConn::new()`].
51pub(crate) struct RelayConnConfig {
52    pub(crate) relayed_addr: SocketAddr,
53    pub(crate) integrity: MessageIntegrity,
54    pub(crate) nonce: Nonce,
55    pub(crate) lifetime: Duration,
56    pub(crate) binding_mgr: Arc<Mutex<BindingManager>>,
57    pub(crate) read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
58}
59
60pub struct RelayConnInternal<T: 'static + RelayConnObserver + Send + Sync> {
61    obs: Arc<Mutex<T>>,
62    relayed_addr: SocketAddr,
63    perm_map: PermissionMap,
64    binding_mgr: Arc<Mutex<BindingManager>>,
65    integrity: MessageIntegrity,
66    nonce: Nonce,
67    lifetime: Duration,
68}
69
70/// `RelayConn` is the implementation of the Conn interfaces for UDP Relayed network connections.
71pub struct RelayConn<T: 'static + RelayConnObserver + Send + Sync> {
72    relayed_addr: SocketAddr,
73    read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
74    relay_conn: Arc<Mutex<RelayConnInternal<T>>>,
75    refresh_alloc_timer: PeriodicTimer,
76    refresh_perms_timer: PeriodicTimer,
77}
78
79impl<T: 'static + RelayConnObserver + Send + Sync> RelayConn<T> {
80    /// Creates a new [`RelayConn`].
81    pub(crate) async fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
82        log::debug!("initial lifetime: {} seconds", config.lifetime.as_secs());
83
84        let c = RelayConn {
85            refresh_alloc_timer: PeriodicTimer::new(TimerIdRefresh::Alloc, config.lifetime / 2),
86            refresh_perms_timer: PeriodicTimer::new(TimerIdRefresh::Perms, PERM_REFRESH_INTERVAL),
87            relayed_addr: config.relayed_addr,
88            read_ch_rx: Arc::clone(&config.read_ch_rx),
89            relay_conn: Arc::new(Mutex::new(RelayConnInternal::new(obs, config))),
90        };
91
92        let rci1 = Arc::clone(&c.relay_conn);
93        let rci2 = Arc::clone(&c.relay_conn);
94
95        if c.refresh_alloc_timer.start(rci1).await {
96            log::debug!("refresh_alloc_timer started");
97        }
98        if c.refresh_perms_timer.start(rci2).await {
99            log::debug!("refresh_perms_timer started");
100        }
101
102        c
103    }
104}
105
106#[async_trait]
107impl<T: RelayConnObserver + Send + Sync> Conn for RelayConn<T> {
108    async fn connect(&self, _addr: SocketAddr) -> Result<(), util::Error> {
109        Err(io::Error::other("Not applicable").into())
110    }
111
112    async fn recv(&self, _buf: &mut [u8]) -> Result<usize, util::Error> {
113        Err(io::Error::other("Not applicable").into())
114    }
115
116    /// Reads a packet from the connection,
117    /// copying the payload into `p`. It returns the number of
118    /// bytes copied into `p` and the return address that
119    /// was on the packet.
120    /// It returns the number of bytes read `(0 <= n <= len(p))`
121    /// and any error encountered. Callers should always process
122    /// the `n > 0` bytes returned before considering the error.
123    /// It can be made to time out and return
124    /// an Error with Timeout() == true after a fixed time limit;
125    /// see SetDeadline and SetReadDeadline.
126    async fn recv_from(&self, p: &mut [u8]) -> Result<(usize, SocketAddr), util::Error> {
127        let mut read_ch_rx = self.read_ch_rx.lock().await;
128
129        if let Some(ib_data) = read_ch_rx.recv().await {
130            let n = ib_data.data.len();
131            if p.len() < n {
132                return Err(io::Error::new(
133                    io::ErrorKind::InvalidInput,
134                    Error::ErrShortBuffer.to_string(),
135                )
136                .into());
137            }
138            p[..n].copy_from_slice(&ib_data.data);
139            Ok((n, ib_data.from))
140        } else {
141            Err(io::Error::new(
142                io::ErrorKind::ConnectionAborted,
143                Error::ErrAlreadyClosed.to_string(),
144            )
145            .into())
146        }
147    }
148
149    async fn send(&self, _buf: &[u8]) -> Result<usize, util::Error> {
150        Err(io::Error::other("Not applicable").into())
151    }
152
153    /// Writes a packet with payload `p` to `addr`.
154    /// It can be made to time out and return
155    /// an Error with Timeout() == true after a fixed time limit;
156    /// see SetDeadline and SetWriteDeadline.
157    /// On packet-oriented connections, write timeouts are rare.
158    async fn send_to(&self, p: &[u8], addr: SocketAddr) -> Result<usize, util::Error> {
159        let mut relay_conn = self.relay_conn.lock().await;
160        match relay_conn.send_to(p, addr).await {
161            Ok(n) => Ok(n),
162            Err(err) => Err(io::Error::other(err.to_string()).into()),
163        }
164    }
165
166    /// Returns the local network address.
167    fn local_addr(&self) -> Result<SocketAddr, util::Error> {
168        Ok(self.relayed_addr)
169    }
170
171    fn remote_addr(&self) -> Option<SocketAddr> {
172        None
173    }
174
175    /// Closes the connection.
176    /// Any blocked [`Self::recv_from()`] or [`Self::send_to()`] operations
177    /// will be unblocked and return errors.
178    async fn close(&self) -> Result<(), util::Error> {
179        self.refresh_alloc_timer.stop().await;
180        self.refresh_perms_timer.stop().await;
181
182        let mut relay_conn = self.relay_conn.lock().await;
183        let _ = relay_conn
184            .close()
185            .await
186            .map_err(|err| util::Error::Other(format!("{err}")));
187        Ok(())
188    }
189
190    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
191        self
192    }
193}
194
195impl<T: RelayConnObserver + Send + Sync> RelayConnInternal<T> {
196    /// Creates a new [`RelayConnInternal`].
197    fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
198        RelayConnInternal {
199            obs,
200            relayed_addr: config.relayed_addr,
201            perm_map: PermissionMap::new(),
202            binding_mgr: config.binding_mgr,
203            integrity: config.integrity,
204            nonce: config.nonce,
205            lifetime: config.lifetime,
206        }
207    }
208
209    /// Writes a packet with payload `p` to `addr`.
210    /// It can be made to time out and return
211    /// an Error with Timeout() == true after a fixed time limit;
212    /// see SetDeadline and SetWriteDeadline.
213    /// On packet-oriented connections, write timeouts are rare.
214    async fn send_to(&mut self, p: &[u8], addr: SocketAddr) -> Result<usize, Error> {
215        // check if we have a permission for the destination IP addr
216        let perm = if let Some(perm) = self.perm_map.find(&addr) {
217            Arc::clone(perm)
218        } else {
219            let perm = Arc::new(Permission::default());
220            self.perm_map.insert(&addr, Arc::clone(&perm));
221            perm
222        };
223
224        let mut result = Ok(());
225        for _ in 0..MAX_RETRY_ATTEMPTS {
226            result = self.create_perm(&perm, addr).await;
227            if let Err(err) = &result {
228                if Error::ErrTryAgain != *err {
229                    break;
230                }
231            }
232        }
233        result?;
234
235        let number = {
236            let (bind_st, bind_at, bind_number, bind_addr) = {
237                let mut binding_mgr = self.binding_mgr.lock().await;
238                let b = if let Some(b) = binding_mgr.find_by_addr(&addr) {
239                    b
240                } else {
241                    binding_mgr
242                        .create(addr)
243                        .ok_or_else(|| Error::Other("Addr not found".to_owned()))?
244                };
245                (b.state(), b.refreshed_at(), b.number, b.addr)
246            };
247
248            if bind_st == BindingState::Idle
249                || bind_st == BindingState::Request
250                || bind_st == BindingState::Failed
251            {
252                // block only callers with the same binding until
253                // the binding transaction has been complete
254                // binding state may have been changed while waiting. check again.
255                if bind_st == BindingState::Idle {
256                    let binding_mgr = Arc::clone(&self.binding_mgr);
257                    let rc_obs = Arc::clone(&self.obs);
258                    let nonce = self.nonce.clone();
259                    let integrity = self.integrity.clone();
260                    {
261                        let mut bm = binding_mgr.lock().await;
262                        if let Some(b) = bm.get_by_addr(&bind_addr) {
263                            b.set_state(BindingState::Request);
264                        }
265                    }
266                    tokio::spawn(async move {
267                        let result = RelayConnInternal::bind(
268                            rc_obs,
269                            bind_addr,
270                            bind_number,
271                            nonce,
272                            integrity,
273                        )
274                        .await;
275
276                        {
277                            let mut bm = binding_mgr.lock().await;
278                            if let Err(err) = result {
279                                if Error::ErrUnexpectedResponse != err {
280                                    bm.delete_by_addr(&bind_addr);
281                                } else if let Some(b) = bm.get_by_addr(&bind_addr) {
282                                    b.set_state(BindingState::Failed);
283                                }
284
285                                // keep going...
286                                log::warn!("bind() failed: {err}");
287                            } else if let Some(b) = bm.get_by_addr(&bind_addr) {
288                                b.set_state(BindingState::Ready);
289                            }
290                        }
291                    });
292                }
293
294                // send data using SendIndication
295                let peer_addr = socket_addr2peer_address(&addr);
296                let mut msg = Message::new();
297                msg.build(&[
298                    Box::new(TransactionId::new()),
299                    Box::new(MessageType::new(METHOD_SEND, CLASS_INDICATION)),
300                    Box::new(proto::data::Data(p.to_vec())),
301                    Box::new(peer_addr),
302                    Box::new(FINGERPRINT),
303                ])?;
304
305                // indication has no transaction (fire-and-forget)
306                let obs = self.obs.lock().await;
307                let turn_server_addr = obs.turn_server_addr();
308                return Ok(obs.write_to(&msg.raw, &turn_server_addr).await?);
309            }
310
311            // binding is either ready
312
313            // check if the binding needs a refresh
314            if bind_st == BindingState::Ready
315                && Instant::now()
316                    .checked_duration_since(bind_at)
317                    .unwrap_or_else(|| Duration::from_secs(0))
318                    > Duration::from_secs(5 * 60)
319            {
320                let binding_mgr = Arc::clone(&self.binding_mgr);
321                let rc_obs = Arc::clone(&self.obs);
322                let nonce = self.nonce.clone();
323                let integrity = self.integrity.clone();
324                {
325                    let mut bm = binding_mgr.lock().await;
326                    if let Some(b) = bm.get_by_addr(&bind_addr) {
327                        b.set_state(BindingState::Refresh);
328                    }
329                }
330                tokio::spawn(async move {
331                    let result =
332                        RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity)
333                            .await;
334
335                    {
336                        let mut bm = binding_mgr.lock().await;
337                        if let Err(err) = result {
338                            if Error::ErrUnexpectedResponse != err {
339                                bm.delete_by_addr(&bind_addr);
340                            } else if let Some(b) = bm.get_by_addr(&bind_addr) {
341                                b.set_state(BindingState::Failed);
342                            }
343
344                            // keep going...
345                            log::warn!("bind() for refresh failed: {err}");
346                        } else if let Some(b) = bm.get_by_addr(&bind_addr) {
347                            b.set_refreshed_at(Instant::now());
348                            b.set_state(BindingState::Ready);
349                        }
350                    }
351                });
352            }
353
354            bind_number
355        };
356
357        // send via ChannelData
358        self.send_channel_data(p, number).await
359    }
360
361    /// This func-block would block, per destination IP (, or perm), until
362    /// the perm state becomes "requested". Purpose of this is to guarantee
363    /// the order of packets (within the same perm).
364    /// Note that CreatePermission transaction may not be complete before
365    /// all the data transmission. This is done assuming that the request
366    /// will be mostly likely successful and we can tolerate some loss of
367    /// UDP packet (or reorder), inorder to minimize the latency in most cases.
368    async fn create_perm(&mut self, perm: &Arc<Permission>, addr: SocketAddr) -> Result<(), Error> {
369        if perm.state() == PermState::Idle {
370            // punch a hole! (this would block a bit..)
371            if let Err(err) = self.create_permissions(&[addr]).await {
372                self.perm_map.delete(&addr);
373                return Err(err);
374            }
375            perm.set_state(PermState::Permitted);
376        }
377        Ok(())
378    }
379
380    async fn send_channel_data(&self, data: &[u8], ch_num: u16) -> Result<usize, Error> {
381        let mut ch_data = proto::chandata::ChannelData {
382            data: data.to_vec(),
383            number: proto::channum::ChannelNumber(ch_num),
384            ..Default::default()
385        };
386        ch_data.encode();
387
388        let obs = self.obs.lock().await;
389        Ok(obs.write_to(&ch_data.raw, &obs.turn_server_addr()).await?)
390    }
391
392    async fn create_permissions(&mut self, addrs: &[SocketAddr]) -> Result<(), Error> {
393        let res = {
394            let msg = {
395                let obs = self.obs.lock().await;
396                let mut setters: Vec<Box<dyn Setter>> = vec![
397                    Box::new(TransactionId::new()),
398                    Box::new(MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST)),
399                ];
400
401                for addr in addrs {
402                    setters.push(Box::new(socket_addr2peer_address(addr)));
403                }
404
405                setters.push(Box::new(obs.username()));
406                setters.push(Box::new(obs.realm()));
407                setters.push(Box::new(self.nonce.clone()));
408                setters.push(Box::new(self.integrity.clone()));
409                setters.push(Box::new(FINGERPRINT));
410
411                let mut msg = Message::new();
412                msg.build(&setters)?;
413                msg
414            };
415
416            let mut obs = self.obs.lock().await;
417            let turn_server_addr = obs.turn_server_addr();
418
419            log::debug!("UDPConn.createPermissions call PerformTransaction 1");
420            let tr_res = obs
421                .perform_transaction(&msg, &turn_server_addr, false)
422                .await?;
423
424            tr_res.msg
425        };
426
427        if res.typ.class == CLASS_ERROR_RESPONSE {
428            let mut code = ErrorCodeAttribute::default();
429            let result = code.get_from(&res);
430            if result.is_err() {
431                return Err(Error::Other(format!("{}", res.typ)));
432            } else if code.code == CODE_STALE_NONCE {
433                self.set_nonce_from_msg(&res);
434                return Err(Error::ErrTryAgain);
435            } else {
436                return Err(Error::Other(format!("{} (error {})", res.typ, code)));
437            }
438        }
439
440        Ok(())
441    }
442
443    pub fn set_nonce_from_msg(&mut self, msg: &Message) {
444        // Update nonce
445        match Nonce::get_from_as(msg, ATTR_NONCE) {
446            Ok(nonce) => {
447                self.nonce = nonce;
448                log::debug!("refresh allocation: 438, got new nonce.");
449            }
450            Err(_) => log::warn!("refresh allocation: 438 but no nonce."),
451        }
452    }
453
454    /// Closes the connection.
455    /// Any blocked `recv_from` or `send_to` operations will be unblocked and return errors.
456    pub async fn close(&mut self) -> Result<(), Error> {
457        self.refresh_allocation(Duration::from_secs(0), true /* dontWait=true */)
458            .await
459    }
460
461    async fn refresh_allocation(
462        &mut self,
463        lifetime: Duration,
464        dont_wait: bool,
465    ) -> Result<(), Error> {
466        let res = {
467            let mut obs = self.obs.lock().await;
468
469            let mut msg = Message::new();
470            msg.build(&[
471                Box::new(TransactionId::new()),
472                Box::new(MessageType::new(METHOD_REFRESH, CLASS_REQUEST)),
473                Box::new(proto::lifetime::Lifetime(lifetime)),
474                Box::new(obs.username()),
475                Box::new(obs.realm()),
476                Box::new(self.nonce.clone()),
477                Box::new(self.integrity.clone()),
478                Box::new(FINGERPRINT),
479            ])?;
480
481            log::debug!("send refresh request (dont_wait={dont_wait})");
482            let turn_server_addr = obs.turn_server_addr();
483            let tr_res = obs
484                .perform_transaction(&msg, &turn_server_addr, dont_wait)
485                .await?;
486
487            if dont_wait {
488                log::debug!("refresh request sent");
489                return Ok(());
490            }
491
492            log::debug!("refresh request sent, and waiting response");
493
494            tr_res.msg
495        };
496
497        if res.typ.class == CLASS_ERROR_RESPONSE {
498            let mut code = ErrorCodeAttribute::default();
499            let result = code.get_from(&res);
500            if result.is_err() {
501                return Err(Error::Other(format!("{}", res.typ)));
502            } else if code.code == CODE_STALE_NONCE {
503                self.set_nonce_from_msg(&res);
504                return Err(Error::ErrTryAgain);
505            } else {
506                return Ok(());
507            }
508        }
509
510        // Getting lifetime from response
511        let mut updated_lifetime = proto::lifetime::Lifetime::default();
512        updated_lifetime.get_from(&res)?;
513
514        self.lifetime = updated_lifetime.0;
515        log::debug!("updated lifetime: {} seconds", self.lifetime.as_secs());
516        Ok(())
517    }
518
519    async fn refresh_permissions(&mut self) -> Result<(), Error> {
520        let addrs = self.perm_map.addrs();
521        if addrs.is_empty() {
522            log::debug!("no permission to refresh");
523            return Ok(());
524        }
525
526        if let Err(err) = self.create_permissions(&addrs).await {
527            if Error::ErrTryAgain != err {
528                log::error!("fail to refresh permissions: {err}");
529            }
530            return Err(err);
531        }
532
533        log::debug!("refresh permissions successful");
534        Ok(())
535    }
536
537    async fn bind(
538        rc_obs: Arc<Mutex<T>>,
539        bind_addr: SocketAddr,
540        bind_number: u16,
541        nonce: Nonce,
542        integrity: MessageIntegrity,
543    ) -> Result<(), Error> {
544        let (msg, turn_server_addr) = {
545            let obs = rc_obs.lock().await;
546
547            let setters: Vec<Box<dyn Setter>> = vec![
548                Box::new(TransactionId::new()),
549                Box::new(MessageType::new(METHOD_CHANNEL_BIND, CLASS_REQUEST)),
550                Box::new(socket_addr2peer_address(&bind_addr)),
551                Box::new(proto::channum::ChannelNumber(bind_number)),
552                Box::new(obs.username()),
553                Box::new(obs.realm()),
554                Box::new(nonce),
555                Box::new(integrity),
556                Box::new(FINGERPRINT),
557            ];
558
559            let mut msg = Message::new();
560            msg.build(&setters)?;
561
562            (msg, obs.turn_server_addr())
563        };
564
565        log::debug!("UDPConn.bind call PerformTransaction 1");
566        let tr_res = {
567            let mut obs = rc_obs.lock().await;
568            obs.perform_transaction(&msg, &turn_server_addr, false)
569                .await?
570        };
571
572        let res = tr_res.msg;
573
574        if res.typ != MessageType::new(METHOD_CHANNEL_BIND, CLASS_SUCCESS_RESPONSE) {
575            return Err(Error::ErrUnexpectedResponse);
576        }
577
578        log::debug!("channel binding successful: {bind_addr} {bind_number}");
579
580        // Success.
581        Ok(())
582    }
583}
584
585#[async_trait]
586impl<T: RelayConnObserver + Send + Sync> PeriodicTimerTimeoutHandler for RelayConnInternal<T> {
587    async fn on_timeout(&mut self, id: TimerIdRefresh) {
588        log::debug!("refresh timer {id:?} expired");
589        match id {
590            TimerIdRefresh::Alloc => {
591                let lifetime = self.lifetime;
592                // limit the max retries on errTryAgain to 3
593                // when stale nonce returns, second retry should succeed
594                let mut result = Ok(());
595                for _ in 0..MAX_RETRY_ATTEMPTS {
596                    result = self.refresh_allocation(lifetime, false).await;
597                    if let Err(err) = &result {
598                        if Error::ErrTryAgain != *err {
599                            break;
600                        }
601                    }
602                }
603                if result.is_err() {
604                    log::warn!("refresh allocation failed");
605                }
606            }
607            TimerIdRefresh::Perms => {
608                let mut result = Ok(());
609                for _ in 0..MAX_RETRY_ATTEMPTS {
610                    result = self.refresh_permissions().await;
611                    if let Err(err) = &result {
612                        if Error::ErrTryAgain != *err {
613                            break;
614                        }
615                    }
616                }
617                if result.is_err() {
618                    log::warn!("refresh permissions failed");
619                }
620            }
621        }
622    }
623}
624
625fn socket_addr2peer_address(addr: &SocketAddr) -> proto::peeraddr::PeerAddress {
626    proto::peeraddr::PeerAddress {
627        ip: addr.ip(),
628        port: addr.port(),
629    }
630}