1#[cfg(test)]
2mod relay_conn_test;
3
4use std::io;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use stun::agent::*;
11use stun::attributes::*;
12use stun::error_code::*;
13use stun::fingerprint::*;
14use stun::integrity::*;
15use stun::message::*;
16use stun::textattrs::*;
17use tokio::sync::{mpsc, Mutex};
18use tokio::time::{Duration, Instant};
19use util::Conn;
20
21use super::binding::*;
22use super::periodic_timer::*;
23use super::permission::*;
24use super::transaction::*;
25use crate::{proto, Error};
26
27const PERM_REFRESH_INTERVAL: Duration = Duration::from_secs(120);
28const MAX_RETRY_ATTEMPTS: u16 = 3;
29
30pub(crate) struct InboundData {
31 pub(crate) data: Vec<u8>,
32 pub(crate) from: SocketAddr,
33}
34
35#[async_trait]
37pub trait RelayConnObserver {
38 fn turn_server_addr(&self) -> String;
39 fn username(&self) -> Username;
40 fn realm(&self) -> Realm;
41 async fn write_to(&self, data: &[u8], to: &str) -> Result<usize, util::Error>;
42 async fn perform_transaction(
43 &mut self,
44 msg: &Message,
45 to: &str,
46 ignore_result: bool,
47 ) -> Result<TransactionResult, Error>;
48}
49
50pub(crate) struct RelayConnConfig {
52 pub(crate) relayed_addr: SocketAddr,
53 pub(crate) integrity: MessageIntegrity,
54 pub(crate) nonce: Nonce,
55 pub(crate) lifetime: Duration,
56 pub(crate) binding_mgr: Arc<Mutex<BindingManager>>,
57 pub(crate) read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
58}
59
60pub struct RelayConnInternal<T: 'static + RelayConnObserver + Send + Sync> {
61 obs: Arc<Mutex<T>>,
62 relayed_addr: SocketAddr,
63 perm_map: PermissionMap,
64 binding_mgr: Arc<Mutex<BindingManager>>,
65 integrity: MessageIntegrity,
66 nonce: Nonce,
67 lifetime: Duration,
68}
69
70pub struct RelayConn<T: 'static + RelayConnObserver + Send + Sync> {
72 relayed_addr: SocketAddr,
73 read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
74 relay_conn: Arc<Mutex<RelayConnInternal<T>>>,
75 refresh_alloc_timer: PeriodicTimer,
76 refresh_perms_timer: PeriodicTimer,
77}
78
79impl<T: 'static + RelayConnObserver + Send + Sync> RelayConn<T> {
80 pub(crate) async fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
82 log::debug!("initial lifetime: {} seconds", config.lifetime.as_secs());
83
84 let c = RelayConn {
85 refresh_alloc_timer: PeriodicTimer::new(TimerIdRefresh::Alloc, config.lifetime / 2),
86 refresh_perms_timer: PeriodicTimer::new(TimerIdRefresh::Perms, PERM_REFRESH_INTERVAL),
87 relayed_addr: config.relayed_addr,
88 read_ch_rx: Arc::clone(&config.read_ch_rx),
89 relay_conn: Arc::new(Mutex::new(RelayConnInternal::new(obs, config))),
90 };
91
92 let rci1 = Arc::clone(&c.relay_conn);
93 let rci2 = Arc::clone(&c.relay_conn);
94
95 if c.refresh_alloc_timer.start(rci1).await {
96 log::debug!("refresh_alloc_timer started");
97 }
98 if c.refresh_perms_timer.start(rci2).await {
99 log::debug!("refresh_perms_timer started");
100 }
101
102 c
103 }
104}
105
106#[async_trait]
107impl<T: RelayConnObserver + Send + Sync> Conn for RelayConn<T> {
108 async fn connect(&self, _addr: SocketAddr) -> Result<(), util::Error> {
109 Err(io::Error::other("Not applicable").into())
110 }
111
112 async fn recv(&self, _buf: &mut [u8]) -> Result<usize, util::Error> {
113 Err(io::Error::other("Not applicable").into())
114 }
115
116 async fn recv_from(&self, p: &mut [u8]) -> Result<(usize, SocketAddr), util::Error> {
127 let mut read_ch_rx = self.read_ch_rx.lock().await;
128
129 if let Some(ib_data) = read_ch_rx.recv().await {
130 let n = ib_data.data.len();
131 if p.len() < n {
132 return Err(io::Error::new(
133 io::ErrorKind::InvalidInput,
134 Error::ErrShortBuffer.to_string(),
135 )
136 .into());
137 }
138 p[..n].copy_from_slice(&ib_data.data);
139 Ok((n, ib_data.from))
140 } else {
141 Err(io::Error::new(
142 io::ErrorKind::ConnectionAborted,
143 Error::ErrAlreadyClosed.to_string(),
144 )
145 .into())
146 }
147 }
148
149 async fn send(&self, _buf: &[u8]) -> Result<usize, util::Error> {
150 Err(io::Error::other("Not applicable").into())
151 }
152
153 async fn send_to(&self, p: &[u8], addr: SocketAddr) -> Result<usize, util::Error> {
159 let mut relay_conn = self.relay_conn.lock().await;
160 match relay_conn.send_to(p, addr).await {
161 Ok(n) => Ok(n),
162 Err(err) => Err(io::Error::other(err.to_string()).into()),
163 }
164 }
165
166 fn local_addr(&self) -> Result<SocketAddr, util::Error> {
168 Ok(self.relayed_addr)
169 }
170
171 fn remote_addr(&self) -> Option<SocketAddr> {
172 None
173 }
174
175 async fn close(&self) -> Result<(), util::Error> {
179 self.refresh_alloc_timer.stop().await;
180 self.refresh_perms_timer.stop().await;
181
182 let mut relay_conn = self.relay_conn.lock().await;
183 let _ = relay_conn
184 .close()
185 .await
186 .map_err(|err| util::Error::Other(format!("{err}")));
187 Ok(())
188 }
189
190 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
191 self
192 }
193}
194
195impl<T: RelayConnObserver + Send + Sync> RelayConnInternal<T> {
196 fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
198 RelayConnInternal {
199 obs,
200 relayed_addr: config.relayed_addr,
201 perm_map: PermissionMap::new(),
202 binding_mgr: config.binding_mgr,
203 integrity: config.integrity,
204 nonce: config.nonce,
205 lifetime: config.lifetime,
206 }
207 }
208
209 async fn send_to(&mut self, p: &[u8], addr: SocketAddr) -> Result<usize, Error> {
215 let perm = if let Some(perm) = self.perm_map.find(&addr) {
217 Arc::clone(perm)
218 } else {
219 let perm = Arc::new(Permission::default());
220 self.perm_map.insert(&addr, Arc::clone(&perm));
221 perm
222 };
223
224 let mut result = Ok(());
225 for _ in 0..MAX_RETRY_ATTEMPTS {
226 result = self.create_perm(&perm, addr).await;
227 if let Err(err) = &result {
228 if Error::ErrTryAgain != *err {
229 break;
230 }
231 }
232 }
233 result?;
234
235 let number = {
236 let (bind_st, bind_at, bind_number, bind_addr) = {
237 let mut binding_mgr = self.binding_mgr.lock().await;
238 let b = if let Some(b) = binding_mgr.find_by_addr(&addr) {
239 b
240 } else {
241 binding_mgr
242 .create(addr)
243 .ok_or_else(|| Error::Other("Addr not found".to_owned()))?
244 };
245 (b.state(), b.refreshed_at(), b.number, b.addr)
246 };
247
248 if bind_st == BindingState::Idle
249 || bind_st == BindingState::Request
250 || bind_st == BindingState::Failed
251 {
252 if bind_st == BindingState::Idle {
256 let binding_mgr = Arc::clone(&self.binding_mgr);
257 let rc_obs = Arc::clone(&self.obs);
258 let nonce = self.nonce.clone();
259 let integrity = self.integrity.clone();
260 {
261 let mut bm = binding_mgr.lock().await;
262 if let Some(b) = bm.get_by_addr(&bind_addr) {
263 b.set_state(BindingState::Request);
264 }
265 }
266 tokio::spawn(async move {
267 let result = RelayConnInternal::bind(
268 rc_obs,
269 bind_addr,
270 bind_number,
271 nonce,
272 integrity,
273 )
274 .await;
275
276 {
277 let mut bm = binding_mgr.lock().await;
278 if let Err(err) = result {
279 if Error::ErrUnexpectedResponse != err {
280 bm.delete_by_addr(&bind_addr);
281 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
282 b.set_state(BindingState::Failed);
283 }
284
285 log::warn!("bind() failed: {err}");
287 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
288 b.set_state(BindingState::Ready);
289 }
290 }
291 });
292 }
293
294 let peer_addr = socket_addr2peer_address(&addr);
296 let mut msg = Message::new();
297 msg.build(&[
298 Box::new(TransactionId::new()),
299 Box::new(MessageType::new(METHOD_SEND, CLASS_INDICATION)),
300 Box::new(proto::data::Data(p.to_vec())),
301 Box::new(peer_addr),
302 Box::new(FINGERPRINT),
303 ])?;
304
305 let obs = self.obs.lock().await;
307 let turn_server_addr = obs.turn_server_addr();
308 return Ok(obs.write_to(&msg.raw, &turn_server_addr).await?);
309 }
310
311 if bind_st == BindingState::Ready
315 && Instant::now()
316 .checked_duration_since(bind_at)
317 .unwrap_or_else(|| Duration::from_secs(0))
318 > Duration::from_secs(5 * 60)
319 {
320 let binding_mgr = Arc::clone(&self.binding_mgr);
321 let rc_obs = Arc::clone(&self.obs);
322 let nonce = self.nonce.clone();
323 let integrity = self.integrity.clone();
324 {
325 let mut bm = binding_mgr.lock().await;
326 if let Some(b) = bm.get_by_addr(&bind_addr) {
327 b.set_state(BindingState::Refresh);
328 }
329 }
330 tokio::spawn(async move {
331 let result =
332 RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity)
333 .await;
334
335 {
336 let mut bm = binding_mgr.lock().await;
337 if let Err(err) = result {
338 if Error::ErrUnexpectedResponse != err {
339 bm.delete_by_addr(&bind_addr);
340 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
341 b.set_state(BindingState::Failed);
342 }
343
344 log::warn!("bind() for refresh failed: {err}");
346 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
347 b.set_refreshed_at(Instant::now());
348 b.set_state(BindingState::Ready);
349 }
350 }
351 });
352 }
353
354 bind_number
355 };
356
357 self.send_channel_data(p, number).await
359 }
360
361 async fn create_perm(&mut self, perm: &Arc<Permission>, addr: SocketAddr) -> Result<(), Error> {
369 if perm.state() == PermState::Idle {
370 if let Err(err) = self.create_permissions(&[addr]).await {
372 self.perm_map.delete(&addr);
373 return Err(err);
374 }
375 perm.set_state(PermState::Permitted);
376 }
377 Ok(())
378 }
379
380 async fn send_channel_data(&self, data: &[u8], ch_num: u16) -> Result<usize, Error> {
381 let mut ch_data = proto::chandata::ChannelData {
382 data: data.to_vec(),
383 number: proto::channum::ChannelNumber(ch_num),
384 ..Default::default()
385 };
386 ch_data.encode();
387
388 let obs = self.obs.lock().await;
389 Ok(obs.write_to(&ch_data.raw, &obs.turn_server_addr()).await?)
390 }
391
392 async fn create_permissions(&mut self, addrs: &[SocketAddr]) -> Result<(), Error> {
393 let res = {
394 let msg = {
395 let obs = self.obs.lock().await;
396 let mut setters: Vec<Box<dyn Setter>> = vec![
397 Box::new(TransactionId::new()),
398 Box::new(MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST)),
399 ];
400
401 for addr in addrs {
402 setters.push(Box::new(socket_addr2peer_address(addr)));
403 }
404
405 setters.push(Box::new(obs.username()));
406 setters.push(Box::new(obs.realm()));
407 setters.push(Box::new(self.nonce.clone()));
408 setters.push(Box::new(self.integrity.clone()));
409 setters.push(Box::new(FINGERPRINT));
410
411 let mut msg = Message::new();
412 msg.build(&setters)?;
413 msg
414 };
415
416 let mut obs = self.obs.lock().await;
417 let turn_server_addr = obs.turn_server_addr();
418
419 log::debug!("UDPConn.createPermissions call PerformTransaction 1");
420 let tr_res = obs
421 .perform_transaction(&msg, &turn_server_addr, false)
422 .await?;
423
424 tr_res.msg
425 };
426
427 if res.typ.class == CLASS_ERROR_RESPONSE {
428 let mut code = ErrorCodeAttribute::default();
429 let result = code.get_from(&res);
430 if result.is_err() {
431 return Err(Error::Other(format!("{}", res.typ)));
432 } else if code.code == CODE_STALE_NONCE {
433 self.set_nonce_from_msg(&res);
434 return Err(Error::ErrTryAgain);
435 } else {
436 return Err(Error::Other(format!("{} (error {})", res.typ, code)));
437 }
438 }
439
440 Ok(())
441 }
442
443 pub fn set_nonce_from_msg(&mut self, msg: &Message) {
444 match Nonce::get_from_as(msg, ATTR_NONCE) {
446 Ok(nonce) => {
447 self.nonce = nonce;
448 log::debug!("refresh allocation: 438, got new nonce.");
449 }
450 Err(_) => log::warn!("refresh allocation: 438 but no nonce."),
451 }
452 }
453
454 pub async fn close(&mut self) -> Result<(), Error> {
457 self.refresh_allocation(Duration::from_secs(0), true )
458 .await
459 }
460
461 async fn refresh_allocation(
462 &mut self,
463 lifetime: Duration,
464 dont_wait: bool,
465 ) -> Result<(), Error> {
466 let res = {
467 let mut obs = self.obs.lock().await;
468
469 let mut msg = Message::new();
470 msg.build(&[
471 Box::new(TransactionId::new()),
472 Box::new(MessageType::new(METHOD_REFRESH, CLASS_REQUEST)),
473 Box::new(proto::lifetime::Lifetime(lifetime)),
474 Box::new(obs.username()),
475 Box::new(obs.realm()),
476 Box::new(self.nonce.clone()),
477 Box::new(self.integrity.clone()),
478 Box::new(FINGERPRINT),
479 ])?;
480
481 log::debug!("send refresh request (dont_wait={dont_wait})");
482 let turn_server_addr = obs.turn_server_addr();
483 let tr_res = obs
484 .perform_transaction(&msg, &turn_server_addr, dont_wait)
485 .await?;
486
487 if dont_wait {
488 log::debug!("refresh request sent");
489 return Ok(());
490 }
491
492 log::debug!("refresh request sent, and waiting response");
493
494 tr_res.msg
495 };
496
497 if res.typ.class == CLASS_ERROR_RESPONSE {
498 let mut code = ErrorCodeAttribute::default();
499 let result = code.get_from(&res);
500 if result.is_err() {
501 return Err(Error::Other(format!("{}", res.typ)));
502 } else if code.code == CODE_STALE_NONCE {
503 self.set_nonce_from_msg(&res);
504 return Err(Error::ErrTryAgain);
505 } else {
506 return Ok(());
507 }
508 }
509
510 let mut updated_lifetime = proto::lifetime::Lifetime::default();
512 updated_lifetime.get_from(&res)?;
513
514 self.lifetime = updated_lifetime.0;
515 log::debug!("updated lifetime: {} seconds", self.lifetime.as_secs());
516 Ok(())
517 }
518
519 async fn refresh_permissions(&mut self) -> Result<(), Error> {
520 let addrs = self.perm_map.addrs();
521 if addrs.is_empty() {
522 log::debug!("no permission to refresh");
523 return Ok(());
524 }
525
526 if let Err(err) = self.create_permissions(&addrs).await {
527 if Error::ErrTryAgain != err {
528 log::error!("fail to refresh permissions: {err}");
529 }
530 return Err(err);
531 }
532
533 log::debug!("refresh permissions successful");
534 Ok(())
535 }
536
537 async fn bind(
538 rc_obs: Arc<Mutex<T>>,
539 bind_addr: SocketAddr,
540 bind_number: u16,
541 nonce: Nonce,
542 integrity: MessageIntegrity,
543 ) -> Result<(), Error> {
544 let (msg, turn_server_addr) = {
545 let obs = rc_obs.lock().await;
546
547 let setters: Vec<Box<dyn Setter>> = vec![
548 Box::new(TransactionId::new()),
549 Box::new(MessageType::new(METHOD_CHANNEL_BIND, CLASS_REQUEST)),
550 Box::new(socket_addr2peer_address(&bind_addr)),
551 Box::new(proto::channum::ChannelNumber(bind_number)),
552 Box::new(obs.username()),
553 Box::new(obs.realm()),
554 Box::new(nonce),
555 Box::new(integrity),
556 Box::new(FINGERPRINT),
557 ];
558
559 let mut msg = Message::new();
560 msg.build(&setters)?;
561
562 (msg, obs.turn_server_addr())
563 };
564
565 log::debug!("UDPConn.bind call PerformTransaction 1");
566 let tr_res = {
567 let mut obs = rc_obs.lock().await;
568 obs.perform_transaction(&msg, &turn_server_addr, false)
569 .await?
570 };
571
572 let res = tr_res.msg;
573
574 if res.typ != MessageType::new(METHOD_CHANNEL_BIND, CLASS_SUCCESS_RESPONSE) {
575 return Err(Error::ErrUnexpectedResponse);
576 }
577
578 log::debug!("channel binding successful: {bind_addr} {bind_number}");
579
580 Ok(())
582 }
583}
584
585#[async_trait]
586impl<T: RelayConnObserver + Send + Sync> PeriodicTimerTimeoutHandler for RelayConnInternal<T> {
587 async fn on_timeout(&mut self, id: TimerIdRefresh) {
588 log::debug!("refresh timer {id:?} expired");
589 match id {
590 TimerIdRefresh::Alloc => {
591 let lifetime = self.lifetime;
592 let mut result = Ok(());
595 for _ in 0..MAX_RETRY_ATTEMPTS {
596 result = self.refresh_allocation(lifetime, false).await;
597 if let Err(err) = &result {
598 if Error::ErrTryAgain != *err {
599 break;
600 }
601 }
602 }
603 if result.is_err() {
604 log::warn!("refresh allocation failed");
605 }
606 }
607 TimerIdRefresh::Perms => {
608 let mut result = Ok(());
609 for _ in 0..MAX_RETRY_ATTEMPTS {
610 result = self.refresh_permissions().await;
611 if let Err(err) = &result {
612 if Error::ErrTryAgain != *err {
613 break;
614 }
615 }
616 }
617 if result.is_err() {
618 log::warn!("refresh permissions failed");
619 }
620 }
621 }
622 }
623}
624
625fn socket_addr2peer_address(addr: &SocketAddr) -> proto::peeraddr::PeerAddress {
626 proto::peeraddr::PeerAddress {
627 ip: addr.ip(),
628 port: addr.port(),
629 }
630}