1#[cfg(test)]
2mod relay_conn_test;
3
4use super::binding::*;
6use super::periodic_timer::*;
7use super::permission::*;
8use super::transaction::*;
9use crate::proto;
10
11use crate::errors::*;
12
13use stun::agent::*;
14use stun::attributes::*;
15use stun::error_code::*;
16use stun::fingerprint::*;
17use stun::integrity::*;
18use stun::message::*;
19use stun::textattrs::*;
20
21use util::{Conn, Error};
22
23use std::io;
24use std::net::SocketAddr;
25use std::sync::Arc;
26
27use tokio::sync::{mpsc, Mutex};
28use tokio::time::{Duration, Instant};
29
30use async_trait::async_trait;
31
32const PERM_REFRESH_INTERVAL: Duration = Duration::from_secs(120);
33const MAX_RETRY_ATTEMPTS: u16 = 3;
34
35pub(crate) struct InboundData {
36 pub(crate) data: Vec<u8>,
37 pub(crate) from: SocketAddr,
38}
39
40#[async_trait]
42pub trait RelayConnObserver {
43 fn turn_server_addr(&self) -> String;
44 fn username(&self) -> Username;
45 fn realm(&self) -> Realm;
46 async fn write_to(&self, data: &[u8], to: &str) -> Result<usize, Error>;
47 async fn perform_transaction(
48 &mut self,
49 msg: &Message,
50 to: &str,
51 ignore_result: bool,
52 ) -> Result<TransactionResult, Error>;
53}
54
55pub(crate) struct RelayConnConfig {
57 pub(crate) relayed_addr: SocketAddr,
58 pub(crate) integrity: MessageIntegrity,
59 pub(crate) nonce: Nonce,
60 pub(crate) lifetime: Duration,
61 pub(crate) binding_mgr: Arc<Mutex<BindingManager>>,
62 pub(crate) read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
63}
64
65pub struct RelayConnInternal<T: 'static + RelayConnObserver + Send + Sync> {
66 obs: Arc<Mutex<T>>,
67 relayed_addr: SocketAddr,
68 perm_map: PermissionMap,
69 binding_mgr: Arc<Mutex<BindingManager>>,
70 integrity: MessageIntegrity,
71 nonce: Nonce,
72 lifetime: Duration,
73}
74
75pub struct RelayConn<T: 'static + RelayConnObserver + Send + Sync> {
77 relayed_addr: SocketAddr,
78 read_ch_rx: Arc<Mutex<mpsc::Receiver<InboundData>>>,
79 relay_conn: Arc<Mutex<RelayConnInternal<T>>>,
80 refresh_alloc_timer: PeriodicTimer,
81 refresh_perms_timer: PeriodicTimer,
82}
83
84impl<T: 'static + RelayConnObserver + Send + Sync> RelayConn<T> {
85 pub(crate) fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
87 log::debug!("initial lifetime: {} seconds", config.lifetime.as_secs());
88
89 let mut c = RelayConn {
90 refresh_alloc_timer: PeriodicTimer::new(TimerIdRefresh::Alloc, config.lifetime / 2),
91 refresh_perms_timer: PeriodicTimer::new(TimerIdRefresh::Perms, PERM_REFRESH_INTERVAL),
92 relayed_addr: config.relayed_addr,
93 read_ch_rx: Arc::clone(&config.read_ch_rx),
94 relay_conn: Arc::new(Mutex::new(RelayConnInternal::new(obs, config))),
95 };
96
97 let rci1 = Arc::clone(&c.relay_conn);
98 let rci2 = Arc::clone(&c.relay_conn);
99
100 if c.refresh_alloc_timer.start(rci1) {
101 log::debug!("refresh_alloc_timer started");
102 }
103 if c.refresh_perms_timer.start(rci2) {
104 log::debug!("refresh_perms_timer started");
105 }
106
107 c
108 }
109
110 pub async fn close(&mut self) -> Result<(), Error> {
113 self.refresh_alloc_timer.stop();
114 self.refresh_perms_timer.stop();
115
116 let mut relay_conn = self.relay_conn.lock().await;
117 relay_conn.close().await
118 }
119}
120
121#[async_trait]
122impl<T: RelayConnObserver + Send + Sync> Conn for RelayConn<T> {
123 async fn connect(&self, _addr: SocketAddr) -> io::Result<()> {
124 Err(io::Error::new(io::ErrorKind::Other, "Not applicable"))
125 }
126
127 async fn recv(&self, _buf: &mut [u8]) -> io::Result<usize> {
128 Err(io::Error::new(io::ErrorKind::Other, "Not applicable"))
129 }
130
131 async fn recv_from(&self, p: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
142 let mut read_ch_rx = self.read_ch_rx.lock().await;
143
144 if let Some(ib_data) = read_ch_rx.recv().await {
145 let n = ib_data.data.len();
146 if p.len() < n {
147 return Err(io::Error::new(
148 io::ErrorKind::InvalidInput,
149 ERR_SHORT_BUFFER.to_string(),
150 ));
151 }
152 p[..n].copy_from_slice(&ib_data.data);
153 Ok((n, ib_data.from))
154 } else {
155 Err(io::Error::new(
156 io::ErrorKind::ConnectionAborted,
157 ERR_ALREADY_CLOSED.to_string(),
158 ))
159 }
160 }
161
162 async fn send(&self, _buf: &[u8]) -> io::Result<usize> {
163 Err(io::Error::new(io::ErrorKind::Other, "Not applicable"))
164 }
165
166 async fn send_to(&self, p: &[u8], addr: SocketAddr) -> io::Result<usize> {
172 let mut relay_conn = self.relay_conn.lock().await;
173 match relay_conn.send_to(p, addr).await {
174 Ok(n) => Ok(n),
175 Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string())),
176 }
177 }
178
179 async fn local_addr(&self) -> io::Result<SocketAddr> {
181 Ok(self.relayed_addr)
182 }
183}
184
185impl<T: RelayConnObserver + Send + Sync> RelayConnInternal<T> {
186 fn new(obs: Arc<Mutex<T>>, config: RelayConnConfig) -> Self {
188 RelayConnInternal {
189 obs,
190 relayed_addr: config.relayed_addr,
191 perm_map: PermissionMap::new(),
192 binding_mgr: config.binding_mgr,
193 integrity: config.integrity,
194 nonce: config.nonce,
195 lifetime: config.lifetime,
196 }
197 }
198
199 async fn send_to(&mut self, p: &[u8], addr: SocketAddr) -> Result<usize, Error> {
205 let mut perm = if let Some(perm) = self.perm_map.find(&addr) {
207 *perm
208 } else {
209 let perm = Permission::default();
210 self.perm_map.insert(&addr, perm);
211 perm
212 };
213
214 let mut result = Ok(());
215 for _ in 0..MAX_RETRY_ATTEMPTS {
216 result = self.create_perm(&mut perm, addr).await;
217 if let Err(err) = &result {
218 if *err != *ERR_TRY_AGAIN {
219 break;
220 }
221 }
222 }
223 if let Err(err) = result {
224 return Err(err);
225 }
226
227 let number = {
228 let (bind_st, bind_at, bind_number, bind_addr) = {
229 let mut binding_mgr = self.binding_mgr.lock().await;
230 let b = if let Some(b) = binding_mgr.find_by_addr(&addr) {
231 b
232 } else {
233 binding_mgr
234 .create(addr)
235 .ok_or_else(|| Error::new("Addr not found".to_owned()))?
236 };
237 (b.state(), b.refreshed_at(), b.number, b.addr)
238 };
239
240 if bind_st == BindingState::Idle
241 || bind_st == BindingState::Request
242 || bind_st == BindingState::Failed
243 {
244 if bind_st == BindingState::Idle {
248 let binding_mgr = Arc::clone(&self.binding_mgr);
249 let rc_obs = Arc::clone(&self.obs);
250 let nonce = self.nonce.clone();
251 let integrity = self.integrity.clone();
252 tokio::spawn(async move {
253 {
254 let mut bm = binding_mgr.lock().await;
255 if let Some(b) = bm.get_by_addr(&bind_addr) {
256 b.set_state(BindingState::Request);
257 }
258 }
259
260 let result = RelayConnInternal::bind(
261 rc_obs,
262 bind_addr,
263 bind_number,
264 nonce,
265 integrity,
266 )
267 .await;
268
269 {
270 let mut bm = binding_mgr.lock().await;
271 if let Err(err) = result {
272 if err != *ERR_UNEXPECTED_RESPONSE {
273 bm.delete_by_addr(&bind_addr);
274 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
275 b.set_state(BindingState::Failed);
276 }
277
278 log::warn!("bind() failed: {}", err);
280 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
281 b.set_state(BindingState::Ready);
282 }
283 }
284 });
285 }
286
287 let peer_addr = socket_addr2peer_address(&addr);
289 let mut msg = Message::new();
290 msg.build(&[
291 Box::new(TransactionId::new()),
292 Box::new(MessageType::new(METHOD_SEND, CLASS_INDICATION)),
293 Box::new(proto::data::Data(p.to_vec())),
294 Box::new(peer_addr),
295 Box::new(FINGERPRINT),
296 ])?;
297
298 let obs = self.obs.lock().await;
300 let turn_server_addr = obs.turn_server_addr();
301 return obs.write_to(&msg.raw, &turn_server_addr).await;
302 }
303
304 if bind_st == BindingState::Ready
308 && Instant::now().duration_since(bind_at) > Duration::from_secs(5 * 60)
309 {
310 let binding_mgr = Arc::clone(&self.binding_mgr);
311 let rc_obs = Arc::clone(&self.obs);
312 let nonce = self.nonce.clone();
313 let integrity = self.integrity.clone();
314 tokio::spawn(async move {
315 {
316 let mut bm = binding_mgr.lock().await;
317 if let Some(b) = bm.get_by_addr(&bind_addr) {
318 b.set_state(BindingState::Refresh);
319 }
320 }
321
322 let result =
323 RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity)
324 .await;
325
326 {
327 let mut bm = binding_mgr.lock().await;
328 if let Err(err) = result {
329 if err != *ERR_UNEXPECTED_RESPONSE {
330 bm.delete_by_addr(&bind_addr);
331 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
332 b.set_state(BindingState::Failed);
333 }
334
335 log::warn!("bind() for refresh failed: {}", err);
337 } else if let Some(b) = bm.get_by_addr(&bind_addr) {
338 b.set_refreshed_at(Instant::now());
339 b.set_state(BindingState::Ready);
340 }
341 }
342 });
343 }
344
345 bind_number
346 };
347
348 self.send_channel_data(p, number).await
350 }
351
352 async fn create_perm(&mut self, perm: &mut Permission, addr: SocketAddr) -> Result<(), Error> {
360 if perm.state() == PermState::Idle {
361 if let Err(err) = self.create_permissions(&[addr]).await {
363 self.perm_map.delete(&addr);
364 return Err(err);
365 }
366 perm.set_state(PermState::Permitted);
367 }
368 Ok(())
369 }
370
371 async fn send_channel_data(&self, data: &[u8], ch_num: u16) -> Result<usize, Error> {
372 let mut ch_data = proto::chandata::ChannelData {
373 data: data.to_vec(),
374 number: proto::channum::ChannelNumber(ch_num),
375 ..Default::default()
376 };
377 ch_data.encode();
378
379 let obs = self.obs.lock().await;
380 obs.write_to(&ch_data.raw, &obs.turn_server_addr()).await
381 }
382
383 async fn create_permissions(&mut self, addrs: &[SocketAddr]) -> Result<(), Error> {
384 let res = {
385 let msg = {
386 let obs = self.obs.lock().await;
387 let mut setters: Vec<Box<dyn Setter>> = vec![
388 Box::new(TransactionId::new()),
389 Box::new(MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST)),
390 ];
391
392 for addr in addrs {
393 setters.push(Box::new(socket_addr2peer_address(addr)));
394 }
395
396 setters.push(Box::new(obs.username()));
397 setters.push(Box::new(obs.realm()));
398 setters.push(Box::new(self.nonce.clone()));
399 setters.push(Box::new(self.integrity.clone()));
400 setters.push(Box::new(FINGERPRINT));
401
402 let mut msg = Message::new();
403 msg.build(&setters)?;
404 msg
405 };
406
407 let mut obs = self.obs.lock().await;
408 let turn_server_addr = obs.turn_server_addr();
409
410 log::debug!("UDPConn.createPermissions call PerformTransaction 1");
411 let tr_res = obs
412 .perform_transaction(&msg, &turn_server_addr, false)
413 .await?;
414
415 tr_res.msg
416 };
417
418 if res.typ.class == CLASS_ERROR_RESPONSE {
419 let mut code = ErrorCodeAttribute::default();
420 let result = code.get_from(&res);
421 if result.is_err() {
422 return Err(Error::new(format!("{}", res.typ)));
423 } else if code.code == CODE_STALE_NONCE {
424 self.set_nonce_from_msg(&res);
425 return Err(ERR_TRY_AGAIN.to_owned());
426 } else {
427 return Err(Error::new(format!("{} (error {})", res.typ, code)));
428 }
429 }
430
431 Ok(())
432 }
433
434 pub fn set_nonce_from_msg(&mut self, msg: &Message) {
435 match Nonce::get_from_as(msg, ATTR_NONCE) {
437 Ok(nonce) => {
438 self.nonce = nonce;
439 log::debug!("refresh allocation: 438, got new nonce.");
440 }
441 Err(_) => log::warn!("refresh allocation: 438 but no nonce."),
442 }
443 }
444
445 pub async fn close(&mut self) -> Result<(), Error> {
448 self.refresh_allocation(Duration::from_secs(0), true )
449 .await
450 }
451
452 async fn refresh_allocation(
453 &mut self,
454 lifetime: Duration,
455 dont_wait: bool,
456 ) -> Result<(), Error> {
457 let res = {
458 let mut obs = self.obs.lock().await;
459
460 let mut msg = Message::new();
461 msg.build(&[
462 Box::new(TransactionId::new()),
463 Box::new(MessageType::new(METHOD_REFRESH, CLASS_REQUEST)),
464 Box::new(proto::lifetime::Lifetime(lifetime)),
465 Box::new(obs.username()),
466 Box::new(obs.realm()),
467 Box::new(self.nonce.clone()),
468 Box::new(self.integrity.clone()),
469 Box::new(FINGERPRINT),
470 ])?;
471
472 log::debug!("send refresh request (dont_wait={})", dont_wait);
473 let turn_server_addr = obs.turn_server_addr();
474 let tr_res = obs
475 .perform_transaction(&msg, &turn_server_addr, dont_wait)
476 .await?;
477
478 if dont_wait {
479 log::debug!("refresh request sent");
480 return Ok(());
481 }
482
483 log::debug!("refresh request sent, and waiting response");
484
485 tr_res.msg
486 };
487
488 if res.typ.class == CLASS_ERROR_RESPONSE {
489 let mut code = ErrorCodeAttribute::default();
490 let result = code.get_from(&res);
491 if result.is_err() {
492 return Err(Error::new(format!("{}", res.typ)));
493 } else if code.code == CODE_STALE_NONCE {
494 self.set_nonce_from_msg(&res);
495 return Err(ERR_TRY_AGAIN.to_owned());
496 } else {
497 return Ok(());
498 }
499 }
500
501 let mut updated_lifetime = proto::lifetime::Lifetime::default();
503 updated_lifetime.get_from(&res)?;
504
505 self.lifetime = updated_lifetime.0;
506 log::debug!("updated lifetime: {} seconds", self.lifetime.as_secs());
507 Ok(())
508 }
509
510 async fn refresh_permissions(&mut self) -> Result<(), Error> {
511 let addrs = self.perm_map.addrs();
512 if addrs.is_empty() {
513 log::debug!("no permission to refresh");
514 return Ok(());
515 }
516
517 if let Err(err) = self.create_permissions(&addrs).await {
518 if err != *ERR_TRY_AGAIN {
519 log::error!("fail to refresh permissions: {}", err);
520 }
521 return Err(err);
522 }
523
524 log::debug!("refresh permissions successful");
525 Ok(())
526 }
527
528 async fn bind(
529 rc_obs: Arc<Mutex<T>>,
530 bind_addr: SocketAddr,
531 bind_number: u16,
532 nonce: Nonce,
533 integrity: MessageIntegrity,
534 ) -> Result<(), Error> {
535 let (msg, turn_server_addr) = {
536 let obs = rc_obs.lock().await;
537
538 let setters: Vec<Box<dyn Setter>> = vec![
539 Box::new(TransactionId::new()),
540 Box::new(MessageType::new(METHOD_CHANNEL_BIND, CLASS_REQUEST)),
541 Box::new(socket_addr2peer_address(&bind_addr)),
542 Box::new(proto::channum::ChannelNumber(bind_number)),
543 Box::new(obs.username()),
544 Box::new(obs.realm()),
545 Box::new(nonce),
546 Box::new(integrity),
547 Box::new(FINGERPRINT),
548 ];
549
550 let mut msg = Message::new();
551 msg.build(&setters)?;
552
553 (msg, obs.turn_server_addr())
554 };
555
556 log::debug!("UDPConn.bind call PerformTransaction 1");
557 let tr_res = {
558 let mut obs = rc_obs.lock().await;
559 obs.perform_transaction(&msg, &turn_server_addr, false)
560 .await?
561 };
562
563 let res = tr_res.msg;
564
565 if res.typ != MessageType::new(METHOD_CHANNEL_BIND, CLASS_SUCCESS_RESPONSE) {
566 return Err(ERR_UNEXPECTED_RESPONSE.to_owned());
567 }
568
569 log::debug!("channel binding successful: {} {}", bind_addr, bind_number);
570
571 Ok(())
573 }
574}
575
576#[async_trait]
577impl<T: RelayConnObserver + Send + Sync> PeriodicTimerTimeoutHandler for RelayConnInternal<T> {
578 async fn on_timeout(&mut self, id: TimerIdRefresh) {
579 log::debug!("refresh timer {:?} expired", id);
580 match id {
581 TimerIdRefresh::Alloc => {
582 let lifetime = self.lifetime;
583 let mut result = Ok(());
586 for _ in 0..MAX_RETRY_ATTEMPTS {
587 result = self.refresh_allocation(lifetime, false).await;
588 if let Err(err) = &result {
589 if *err != *ERR_TRY_AGAIN {
590 break;
591 }
592 }
593 }
594 if result.is_err() {
595 log::warn!("refresh allocation failed");
596 }
597 }
598 TimerIdRefresh::Perms => {
599 let mut result = Ok(());
600 for _ in 0..MAX_RETRY_ATTEMPTS {
601 result = self.refresh_permissions().await;
602 if let Err(err) = &result {
603 if *err != *ERR_TRY_AGAIN {
604 break;
605 }
606 }
607 }
608 if result.is_err() {
609 log::warn!("refresh permissions failed");
610 }
611 }
612 }
613 }
614}
615
616fn socket_addr2peer_address(addr: &SocketAddr) -> proto::peeraddr::PeerAddress {
617 proto::peeraddr::PeerAddress {
618 ip: addr.ip(),
619 port: addr.port(),
620 }
621}