1use core::net::SocketAddr;
21use core::sync::atomic::{AtomicUsize, Ordering};
22
23use alloc::collections::{BTreeMap, BTreeSet};
24use alloc::vec;
25use alloc::vec::Vec;
26use core::time::Duration;
27
28use crate::Instant;
29
30use stun_types::attribute::*;
31use stun_types::data::Data;
32use stun_types::message::*;
33
34use stun_types::TransportType;
35
36use tracing::{debug, trace, warn};
37
38static STUN_AGENT_COUNT: AtomicUsize = AtomicUsize::new(0);
39
40#[derive(Debug)]
42pub struct StunAgent {
43 id: usize,
44 transport: TransportType,
45 local_addr: SocketAddr,
46 remote_addr: Option<SocketAddr>,
47 validated_peers: BTreeSet<SocketAddr>,
48 outstanding_requests: BTreeMap<TransactionId, StunRequestState>,
49}
50
51#[derive(Debug)]
53pub struct StunAgentBuilder {
54 transport: TransportType,
55 local_addr: SocketAddr,
56 remote_addr: Option<SocketAddr>,
57}
58
59impl StunAgentBuilder {
60 pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
62 self.remote_addr = Some(addr);
63 self
64 }
65
66 pub fn build(self) -> StunAgent {
68 let id = STUN_AGENT_COUNT.fetch_add(1, Ordering::SeqCst);
69 StunAgent {
70 id,
71 transport: self.transport,
72 local_addr: self.local_addr,
73 remote_addr: self.remote_addr,
74 validated_peers: Default::default(),
75 outstanding_requests: Default::default(),
76 }
77 }
78}
79
80impl StunAgent {
81 pub fn builder(transport: TransportType, local_addr: SocketAddr) -> StunAgentBuilder {
83 StunAgentBuilder {
84 transport,
85 local_addr,
86 remote_addr: None,
87 }
88 }
89
90 pub fn transport(&self) -> TransportType {
92 self.transport
93 }
94
95 pub fn local_addr(&self) -> SocketAddr {
97 self.local_addr
98 }
99
100 pub fn remote_addr(&self) -> Option<SocketAddr> {
102 self.remote_addr
103 }
104
105 pub fn send_data<T: AsRef<[u8]>>(&self, bytes: T, to: SocketAddr) -> Transmit<T> {
107 send_data(self.transport, bytes, self.local_addr, to)
108 }
109
110 #[tracing::instrument(name = "stun_agent_send",
118 skip(self, msg),
119 fields(
120 transport = %self.transport,
121 from = %self.local_addr,
122 transaction_id,
123 )
124 )]
125 pub fn send<T: AsRef<[u8]>>(
126 &mut self,
127 msg: T,
128 to: SocketAddr,
129 now: Instant,
130 ) -> Result<Transmit<T>, StunError> {
131 let data = msg.as_ref();
132 let hdr = MessageHeader::from_bytes(data)?;
133 tracing::Span::current().record(
134 "transaction_id",
135 tracing::field::display(hdr.transaction_id()),
136 );
137 assert!(!hdr.get_type().has_class(MessageClass::Request));
138 trace!("Sending {} to {to}", hdr.get_type());
139 Ok(Transmit::new(msg, self.transport, self.local_addr, to))
140 }
141
142 #[tracing::instrument(name = "stun_agent_send_request",
150 skip(self, msg),
151 fields(
152 transport = %self.transport,
153 from = %self.local_addr,
154 transaction_id,
155 )
156 )]
157 pub fn send_request<'a, T: AsRef<[u8]>>(
158 &'a mut self,
159 msg: T,
160 to: SocketAddr,
161 now: Instant,
162 ) -> Result<Transmit<Data<'a>>, StunError> {
163 let data = msg.as_ref();
164 let hdr = MessageHeader::from_bytes(data)?;
165 assert!(hdr.get_type().has_class(MessageClass::Request));
166 let transaction_id = hdr.transaction_id();
167 tracing::Span::current().record("transaction_id", tracing::field::display(transaction_id));
168 let state = match self.outstanding_requests.entry(transaction_id) {
169 alloc::collections::btree_map::Entry::Vacant(entry) => {
170 let integrity_algorithm = MessageAttributesIter::new(data)
171 .filter_map(|(_offset, attr)| match attr.get_type() {
172 MessageIntegrity::TYPE => Some(IntegrityAlgorithm::Sha1),
173 MessageIntegritySha256::TYPE => Some(IntegrityAlgorithm::Sha256),
174 _ => None,
175 })
176 .last();
177 trace!("Adding request to {to} with integrity algorithm: {integrity_algorithm:?}");
178 entry.insert(StunRequestState::new(
179 msg,
180 self.transport,
181 self.local_addr,
182 to,
183 transaction_id,
184 integrity_algorithm,
185 ))
186 }
187 alloc::collections::btree_map::Entry::Occupied(_entry) => {
188 return Err(StunError::AlreadyInProgress);
189 }
190 };
191 let Some(transmit) = state.poll_transmit(now) else {
192 unreachable!();
193 };
194 Ok(Transmit::new(
195 Data::from(transmit.data),
196 transmit.transport,
197 transmit.from,
198 transmit.to,
199 ))
200 }
201
202 pub fn is_validated_peer(&self, remote_addr: SocketAddr) -> bool {
209 self.validated_peers.contains(&remote_addr)
210 }
211
212 #[tracing::instrument(
214 name = "stun_validated_peer"
215 skip(self),
216 fields(stun_id = self.id)
217 )]
218 pub fn validated_peer(&mut self, addr: SocketAddr) {
219 if !self.validated_peers.contains(&addr) {
220 debug!("validated peer {:?}", addr);
221 self.validated_peers.insert(addr);
222 }
223 }
224
225 #[tracing::instrument(
234 name = "stun_handle_message"
235 skip(self, msg, from),
236 fields(
237 transaction_id = %msg.transaction_id(),
238 )
239 )]
240 pub fn handle_stun_message(&mut self, msg: &Message<'_>, from: SocketAddr) -> bool {
241 if msg.is_response()
242 && self
243 .take_outstanding_request(&msg.transaction_id())
244 .is_none()
245 {
246 trace!("original request disappeared");
247 return false;
248 }
249 self.validated_peer(from);
250 true
251 }
252
253 #[tracing::instrument(
254 skip(self, transaction_id),
255 fields(transaction_id = %transaction_id)
256 )]
257 fn take_outstanding_request(
258 &mut self,
259 transaction_id: &TransactionId,
260 ) -> Option<StunRequestState> {
261 if let Some(request) = self.outstanding_requests.remove(transaction_id) {
262 trace!("removing request");
263 Some(request)
264 } else {
265 trace!("no outstanding request");
266 None
267 }
268 }
269
270 pub fn request_transaction(&self, transaction_id: TransactionId) -> Option<StunRequest<'_>> {
276 if self.outstanding_requests.contains_key(&transaction_id) {
277 Some(StunRequest {
278 agent: self,
279 transaction_id,
280 })
281 } else {
282 None
283 }
284 }
285
286 pub fn mut_request_transaction(
292 &mut self,
293 transaction_id: TransactionId,
294 ) -> Option<StunRequestMut<'_>> {
295 if self.outstanding_requests.contains_key(&transaction_id) {
296 Some(StunRequestMut {
297 agent: self,
298 transaction_id,
299 })
300 } else {
301 None
302 }
303 }
304
305 fn mut_request_state(
306 &mut self,
307 transaction_id: TransactionId,
308 ) -> Option<&mut StunRequestState> {
309 self.outstanding_requests.get_mut(&transaction_id)
310 }
311
312 fn request_state(&self, transaction_id: TransactionId) -> Option<&StunRequestState> {
313 self.outstanding_requests.get(&transaction_id)
314 }
315
316 #[tracing::instrument(
322 name = "stun_agent_poll"
323 level = "debug",
324 skip(self),
325 )]
326 pub fn poll(&mut self, now: Instant) -> StunAgentPollRet {
327 let mut lowest_wait = now + Duration::from_secs(3600);
328 let mut timeout = None;
329 let mut cancelled = None;
330 for (transaction_id, request) in self.outstanding_requests.iter_mut() {
331 debug_assert_eq!(transaction_id, &request.transaction_id);
332 match request.poll(now) {
333 StunRequestPollRet::Cancelled => {
334 cancelled = Some(*transaction_id);
335 break;
336 }
337 StunRequestPollRet::WaitUntil(wait_until) => {
338 if wait_until < lowest_wait {
339 lowest_wait = wait_until;
340 }
341 }
342 StunRequestPollRet::TimedOut => {
343 timeout = Some(*transaction_id);
344 break;
345 }
346 }
347 }
348 if let Some(transaction) = timeout {
349 if let Some(_state) = self.outstanding_requests.remove(&transaction) {
350 return StunAgentPollRet::TransactionTimedOut(transaction);
351 }
352 }
353 if let Some(transaction) = cancelled {
354 if let Some(_state) = self.outstanding_requests.remove(&transaction) {
355 return StunAgentPollRet::TransactionCancelled(transaction);
356 }
357 }
358 StunAgentPollRet::WaitUntil(lowest_wait)
359 }
360
361 #[tracing::instrument(
363 name = "stun_agent_poll_transmit"
364 level = "debug",
365 skip(self),
366 )]
367 pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
368 self.outstanding_requests
369 .values_mut()
370 .filter_map(|request| request.poll_transmit(now))
371 .next()
372 }
373}
374
375#[derive(Debug)]
377pub enum StunAgentPollRet {
378 TransactionTimedOut(TransactionId),
380 TransactionCancelled(TransactionId),
382 WaitUntil(Instant),
384}
385
386fn send_data<T: AsRef<[u8]>>(
387 transport: TransportType,
388 bytes: T,
389 from: SocketAddr,
390 to: SocketAddr,
391) -> Transmit<T> {
392 Transmit::new(bytes, transport, from, to)
393}
394
395#[derive(Debug)]
397pub struct Transmit<T: AsRef<[u8]>> {
398 pub data: T,
400 pub transport: TransportType,
402 pub from: SocketAddr,
404 pub to: SocketAddr,
406}
407
408impl<T: AsRef<[u8]>> core::fmt::Display for Transmit<T> {
409 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
410 write!(
411 f,
412 "Transmit({}: {} -> {} of {} bytes)",
413 self.transport,
414 self.from,
415 self.to,
416 self.data.as_ref().len()
417 )
418 }
419}
420
421impl<T: AsRef<[u8]>> Transmit<T> {
422 pub fn new(data: T, transport: TransportType, from: SocketAddr, to: SocketAddr) -> Self {
424 Self {
425 data,
426 transport,
427 from,
428 to,
429 }
430 }
431
432 pub fn reinterpret_data<O: AsRef<[u8]>, F: FnOnce(T) -> O>(self, f: F) -> Transmit<O> {
452 Transmit {
453 data: f(self.data),
454 transport: self.transport,
455 from: self.from,
456 to: self.to,
457 }
458 }
459}
460
461impl Transmit<Data<'_>> {
462 pub fn into_owned<'b>(self) -> Transmit<Data<'b>> {
464 self.reinterpret_data(|data| data.into_owned())
465 }
466}
467
468#[derive(Debug)]
470enum StunRequestPollRet {
471 WaitUntil(Instant),
473 Cancelled,
475 TimedOut,
477}
478
479#[derive(Debug)]
480struct StunRequestState {
481 transaction_id: TransactionId,
482 request_integrity: Option<IntegrityAlgorithm>,
483 bytes: Vec<u8>,
484 transport: TransportType,
485 from: SocketAddr,
486 to: SocketAddr,
487 timeouts_ms: Vec<u64>,
488 last_retransmit_timeout_ms: u64,
489 recv_cancelled: bool,
490 send_cancelled: bool,
491 timeout_i: usize,
492 last_send_time: Option<Instant>,
493}
494
495impl StunRequestState {
496 fn new<T: AsRef<[u8]>>(
497 request: T,
498 transport: TransportType,
499 from: SocketAddr,
500 to: SocketAddr,
501 transaction_id: TransactionId,
502 integrity_algorithm: Option<IntegrityAlgorithm>,
503 ) -> Self {
504 let data = request.as_ref();
505 let (timeouts_ms, last_retransmit_timeout_ms) = if transport == TransportType::Tcp {
506 (vec![], 39500)
507 } else {
508 (vec![500, 1000, 2000, 4000, 8000, 16000], 8000)
509 };
510 Self {
511 transaction_id,
512 bytes: data.to_vec(),
513 transport,
514 from,
515 to,
516 request_integrity: integrity_algorithm,
517 timeouts_ms,
518 timeout_i: 0,
519 last_retransmit_timeout_ms,
520 recv_cancelled: false,
521 send_cancelled: false,
522 last_send_time: None,
523 }
524 }
525
526 #[tracing::instrument(skip(self, now), level = "trace")]
527 fn next_send_time(&self, now: Instant) -> Option<Instant> {
528 let Some(last_send) = self.last_send_time else {
529 trace!("not sent yet -> send immediately");
530 return Some(now);
531 };
532 if self.timeout_i >= self.timeouts_ms.len() {
533 let next_send = last_send + Duration::from_millis(self.last_retransmit_timeout_ms);
534 trace!("final retransmission, final timeout ends at {next_send:?}");
535 if next_send > now {
536 return Some(next_send);
537 }
538 return None;
539 }
540 let next_send = last_send + Duration::from_millis(self.timeouts_ms[self.timeout_i]);
541 Some(next_send)
542 }
543
544 #[tracing::instrument(
545 name = "stun_request_poll"
546 level = "debug",
547 ret,
548 skip(self, now),
549 fields(transaction_id = %self.transaction_id),
550 )]
551 fn poll(&mut self, now: Instant) -> StunRequestPollRet {
552 if self.recv_cancelled {
553 return StunRequestPollRet::Cancelled;
554 }
555 let Some(next_send) = self.next_send_time(now) else {
557 return StunRequestPollRet::TimedOut;
558 };
559 if next_send >= now {
560 if self.send_cancelled && self.timeout_i >= self.timeouts_ms.len() {
561 return StunRequestPollRet::Cancelled;
563 }
564 return StunRequestPollRet::WaitUntil(next_send);
565 }
566 StunRequestPollRet::WaitUntil(now)
567 }
568
569 #[tracing::instrument(
570 name = "stun_request_poll_transmit",
571 skip(self, now),
572 fields(transaction_id = %self.transaction_id)
573 )]
574 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
575 if self.recv_cancelled {
576 return None;
577 };
578 let next_send = self.next_send_time(now)?;
579
580 if next_send > now {
581 return None;
582 }
583 if self.last_send_time.is_some() {
584 self.timeout_i += 1;
585 }
586 self.last_send_time = Some(now);
587 if self.send_cancelled {
588 return None;
589 };
590 trace!(
591 "sending {} bytes over {:?} from {:?} to {:?}",
592 self.bytes.len(),
593 self.transport,
594 self.from,
595 self.to
596 );
597 Some(send_data(
598 self.transport,
599 self.bytes.as_slice(),
600 self.from,
601 self.to,
602 ))
603 }
604}
605
606#[derive(Debug, Clone)]
608pub struct StunRequest<'a> {
609 agent: &'a StunAgent,
610 transaction_id: TransactionId,
611}
612
613impl StunRequest<'_> {
614 pub fn peer_address(&self) -> SocketAddr {
616 let state = self.agent.request_state(self.transaction_id).unwrap();
617 state.to
618 }
619
620 pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
622 let state = self.agent.request_state(self.transaction_id).unwrap();
623 state.request_integrity
624 }
625}
626
627#[derive(Debug)]
629pub struct StunRequestMut<'a> {
630 agent: &'a mut StunAgent,
631 transaction_id: TransactionId,
632}
633
634impl StunRequestMut<'_> {
635 pub fn peer_address(&self) -> SocketAddr {
637 let state = self.agent.request_state(self.transaction_id).unwrap();
638 state.to
639 }
640
641 pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
643 let state = self.agent.request_state(self.transaction_id).unwrap();
644 state.request_integrity
645 }
646
647 pub fn cancel_retransmissions(&mut self) {
651 if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
652 state.send_cancelled = true;
653 }
654 }
655
656 pub fn cancel(&mut self) {
659 if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
660 state.send_cancelled = true;
661 state.recv_cancelled = true;
662 }
663 }
664
665 pub fn agent(&self) -> &StunAgent {
667 self.agent
668 }
669
670 pub fn mut_agent(&mut self) -> &mut StunAgent {
672 self.agent
673 }
674
675 pub fn configure_timeout(
682 &mut self,
683 initial_rto: Duration,
684 retransmits: u32,
685 last_retransmit_timeout: Duration,
686 ) {
687 if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
688 match state.transport {
689 TransportType::Udp => {
690 state.timeouts_ms = (0..retransmits)
691 .map(|i| (initial_rto * 2u32.pow(i)).as_millis() as u64)
692 .collect::<Vec<_>>();
693 state.last_retransmit_timeout_ms = last_retransmit_timeout.as_millis() as u64;
694 }
695 TransportType::Tcp => {
696 state.timeouts_ms = vec![];
697 state.last_retransmit_timeout_ms = (last_retransmit_timeout
698 + (0..retransmits)
699 .fold(Duration::ZERO, |acc, i| acc + initial_rto * 2u32.pow(i)))
700 .as_millis() as u64;
701 }
702 }
703 }
704 }
705}
706
707#[derive(Debug, thiserror::Error)]
709#[non_exhaustive]
710pub enum StunError {
711 #[error("The operation is already in progress")]
713 AlreadyInProgress,
714 #[error("A required resource could not be found")]
716 ResourceNotFound,
717 #[error("An operation timed out")]
719 TimedOut,
720 #[error("Unexpected data was received")]
722 ProtocolViolation,
723 #[error("Operation was aborted")]
725 Aborted,
726 #[error("{}", .0)]
728 ParseError(StunParseError),
729 #[error("{}", .0)]
731 WriteError(StunWriteError),
732}
733
734impl From<StunParseError> for StunError {
735 fn from(e: StunParseError) -> Self {
736 StunError::ParseError(e)
737 }
738}
739
740impl From<StunWriteError> for StunError {
741 fn from(e: StunWriteError) -> Self {
742 StunError::WriteError(e)
743 }
744}
745
746#[cfg(test)]
747pub(crate) mod tests {
748 use alloc::string::String;
749 use tracing::error;
750
751 use crate::auth::ShortTermAuth;
752
753 use super::*;
754
755 #[test]
756 fn agent_getters_setters() {
757 let _log = crate::tests::test_init_log();
758 let local_addr = "10.0.0.1:12345".parse().unwrap();
759 let remote_addr = "10.0.0.2:3478".parse().unwrap();
760 let agent = StunAgent::builder(TransportType::Udp, local_addr)
761 .remote_addr(remote_addr)
762 .build();
763
764 assert_eq!(agent.transport(), TransportType::Udp);
765 assert_eq!(agent.local_addr(), local_addr);
766 assert_eq!(agent.remote_addr(), Some(remote_addr));
767 }
768
769 #[test]
770 fn request() {
771 let _log = crate::tests::test_init_log();
772 let local_addr = "127.0.0.1:2000".parse().unwrap();
773 let remote_addr = "127.0.0.1:1000".parse().unwrap();
774 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
775 .remote_addr(remote_addr)
776 .build();
777 let now = Instant::ZERO;
778
779 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
780 let transaction_id = msg.transaction_id();
781 let transmit = agent
782 .send_request(msg.finish(), remote_addr, now)
783 .unwrap()
784 .into_owned();
785 let request = agent.request_transaction(transaction_id).unwrap();
786 assert!(request.integrity().is_none());
787 assert_eq!(transmit.transport, TransportType::Udp);
788 assert_eq!(transmit.from, local_addr);
789 assert_eq!(transmit.to, remote_addr);
790 let request = Message::from_bytes(&transmit.data).unwrap();
791 let response = Message::builder_error(&request, MessageWriteVec::new());
792 let resp_data = response.finish();
793 let response = Message::from_bytes(&resp_data).unwrap();
794 assert!(agent.handle_stun_message(&response, remote_addr));
795 assert!(agent.request_transaction(transaction_id).is_none());
796 assert!(agent.mut_request_transaction(transaction_id).is_none());
797
798 let ret = agent.poll(now);
799 assert!(matches!(ret, StunAgentPollRet::WaitUntil(_)));
800 }
801
802 #[test]
803 fn indication_with_invalid_response() {
804 let _log = crate::tests::test_init_log();
805 let local_addr = "127.0.0.1:2000".parse().unwrap();
806 let remote_addr = "127.0.0.1:1000".parse().unwrap();
807 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
808 .remote_addr(remote_addr)
809 .build();
810 let transaction_id = TransactionId::generate();
811 let msg = Message::builder(
812 MessageType::from_class_method(MessageClass::Indication, BINDING),
813 transaction_id,
814 MessageWriteVec::new(),
815 );
816 let transmit = agent
817 .send(msg.finish(), remote_addr, Instant::ZERO)
818 .unwrap();
819 assert_eq!(transmit.transport, TransportType::Udp);
820 assert_eq!(transmit.from, local_addr);
821 assert_eq!(transmit.to, remote_addr);
822 let _indication = Message::from_bytes(&transmit.data).unwrap();
823 assert!(agent.request_transaction(transaction_id).is_none());
824 assert!(agent.mut_request_transaction(transaction_id).is_none());
825 let response = Message::builder(
827 MessageType::from_class_method(MessageClass::Error, BINDING),
828 transaction_id,
829 MessageWriteVec::new(),
830 );
831 let resp_data = response.finish();
832 let response = Message::from_bytes(&resp_data).unwrap();
833 assert!(!agent.handle_stun_message(&response, remote_addr))
835 }
836
837 #[test]
838 fn request_with_credentials() {
839 let _log = crate::tests::test_init_log();
840 let local_addr = "10.0.0.1:12345".parse().unwrap();
841 let remote_addr = "10.0.0.2:3478".parse().unwrap();
842
843 let mut auth = ShortTermAuth::new();
844 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
845 let credentials = ShortTermCredentials::new(String::from("local_password"));
846 auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
847
848 assert!(!agent.is_validated_peer(remote_addr));
850
851 let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
852 let transaction_id = msg.transaction_id();
853 msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
854 .unwrap();
855 error!("send");
856 let transmit = agent
857 .send_request(msg.finish(), remote_addr, Instant::ZERO)
858 .unwrap();
859 error!("sent");
860
861 let request = Message::from_bytes(&transmit.data).unwrap();
862
863 error!("generate response");
864 let mut response = Message::builder_success(&request, MessageWriteVec::new());
865 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
866 response.add_attribute(&xor_addr).unwrap();
867 response
868 .add_message_integrity(&credentials.into(), IntegrityAlgorithm::Sha1)
869 .unwrap();
870 error!("{response:?}");
871
872 let data = response.finish();
873 error!("{data:?}");
874 let response = Message::from_bytes(&data).unwrap();
875 error!("{response}");
876 assert_eq!(
877 auth.validate_incoming_message(&response).unwrap(),
878 Some(IntegrityAlgorithm::Sha1)
879 );
880 let request = agent
881 .request_transaction(response.transaction_id())
882 .unwrap();
883 assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
884 assert!(agent.handle_stun_message(&response, remote_addr));
885
886 assert_eq!(response.transaction_id(), transaction_id);
887 assert!(agent.request_transaction(transaction_id).is_none());
888 assert!(agent.mut_request_transaction(transaction_id).is_none());
889 assert!(agent.is_validated_peer(remote_addr));
890 }
891
892 #[test]
893 fn request_unanswered() {
894 let _log = crate::tests::test_init_log();
895 let local_addr = "127.0.0.1:2000".parse().unwrap();
896 let remote_addr = "127.0.0.1:1000".parse().unwrap();
897 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
898 .remote_addr(remote_addr)
899 .build();
900 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
901 let transaction_id = msg.transaction_id();
902 agent
903 .send_request(msg.finish(), remote_addr, Instant::ZERO)
904 .unwrap();
905 let mut now = Instant::ZERO;
906 loop {
907 let _ = agent.poll_transmit(now);
908 match agent.poll(now) {
909 StunAgentPollRet::WaitUntil(new_now) => {
910 now = new_now;
911 }
912 StunAgentPollRet::TransactionTimedOut(_) => break,
913 _ => unreachable!(),
914 }
915 }
916 assert!(agent.request_transaction(transaction_id).is_none());
917 assert!(agent.mut_request_transaction(transaction_id).is_none());
918
919 assert!(!agent.is_validated_peer(remote_addr));
921 }
922
923 #[test]
924 fn request_custom_timeout() {
925 let _log = crate::tests::test_init_log();
926 let local_addr = "127.0.0.1:2000".parse().unwrap();
927 let remote_addr = "127.0.0.1:1000".parse().unwrap();
928 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
929 .remote_addr(remote_addr)
930 .build();
931 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
932 let transaction_id = msg.transaction_id();
933 let mut now = Instant::ZERO;
934 agent.send_request(msg.finish(), remote_addr, now).unwrap();
935 let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
936 transaction.configure_timeout(Duration::from_secs(1), 2, Duration::from_secs(10));
937 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
938 unreachable!();
939 };
940 assert_eq!(wait - now, Duration::from_secs(1));
941 now = wait;
942 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
944 unreachable!();
945 };
946 assert_eq!(wait, now);
947 let Some(_) = agent.poll_transmit(now) else {
948 unreachable!();
949 };
950 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
951 unreachable!();
952 };
953 assert_eq!(wait - now, Duration::from_secs(2));
954 now = wait;
955 let Some(_) = agent.poll_transmit(now) else {
956 unreachable!();
957 };
958 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
959 unreachable!();
960 };
961 assert_eq!(wait - now, Duration::from_secs(10));
962 now = wait;
963 let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
964 unreachable!();
965 };
966 assert_eq!(timed_out, transaction_id);
967
968 assert!(agent.request_transaction(transaction_id).is_none());
969 assert!(agent.mut_request_transaction(transaction_id).is_none());
970
971 assert!(!agent.is_validated_peer(remote_addr));
973 }
974
975 #[test]
976 fn request_no_retransmit() {
977 let _log = crate::tests::test_init_log();
978 let local_addr = "127.0.0.1:2000".parse().unwrap();
979 let remote_addr = "127.0.0.1:1000".parse().unwrap();
980 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
981 .remote_addr(remote_addr)
982 .build();
983 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
984 let transaction_id = msg.transaction_id();
985 let mut now = Instant::ZERO;
986 agent.send_request(msg.finish(), remote_addr, now).unwrap();
987 let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
988 transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10));
989 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
990 unreachable!();
991 };
992 assert_eq!(wait - now, Duration::from_secs(10));
993 now = wait;
994 let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
995 unreachable!();
996 };
997 assert_eq!(timed_out, transaction_id);
998
999 assert!(agent.request_transaction(transaction_id).is_none());
1000 assert!(agent.mut_request_transaction(transaction_id).is_none());
1001
1002 assert!(!agent.is_validated_peer(remote_addr));
1004 }
1005
1006 #[test]
1007 fn request_tcp_custom_timeout() {
1008 let _log = crate::tests::test_init_log();
1009 let local_addr = "127.0.0.1:2000".parse().unwrap();
1010 let remote_addr = "127.0.0.1:1000".parse().unwrap();
1011 let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
1012 .remote_addr(remote_addr)
1013 .build();
1014 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1015 let transaction_id = msg.transaction_id();
1016 let mut now = Instant::ZERO;
1017 agent.send_request(msg.finish(), remote_addr, now).unwrap();
1018 let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
1019 transaction.configure_timeout(Duration::from_secs(1), 3, Duration::from_secs(3));
1020 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1021 unreachable!();
1022 };
1023 assert_eq!(wait - now, Duration::from_secs(1 + 2 + 4 + 3));
1024 now = wait;
1025 let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
1026 unreachable!();
1027 };
1028 assert_eq!(timed_out, transaction_id);
1029
1030 assert!(agent.request_transaction(transaction_id).is_none());
1031 assert!(agent.mut_request_transaction(transaction_id).is_none());
1032
1033 assert!(!agent.is_validated_peer(remote_addr));
1035 }
1036
1037 #[test]
1038 fn request_without_credentials() {
1039 let _log = crate::tests::test_init_log();
1040 let local_addr = "10.0.0.1:12345".parse().unwrap();
1041 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1042
1043 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1044
1045 assert!(!agent.is_validated_peer(remote_addr));
1047
1048 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1049 let transaction_id = msg.transaction_id();
1050 let transmit = agent
1051 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1052 .unwrap();
1053
1054 let request = Message::from_bytes(&transmit.data).unwrap();
1055
1056 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1057 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1058 response.add_attribute(&xor_addr).unwrap();
1059
1060 let data = response.finish();
1061 let to = transmit.to;
1062 trace!("data: {data:?}");
1063 let response = Message::from_bytes(&data).unwrap();
1064 let request = agent
1065 .request_transaction(response.transaction_id())
1066 .unwrap();
1067 assert_eq!(request.integrity(), None);
1068 assert!(agent.handle_stun_message(&response, to));
1069 assert_eq!(response.transaction_id(), transaction_id);
1070 assert!(agent.request_transaction(transaction_id).is_none());
1071 assert!(agent.mut_request_transaction(transaction_id).is_none());
1072 assert!(agent.is_validated_peer(remote_addr));
1073 }
1074
1075 #[test]
1076 fn response_with_incorrect_credentials() {
1077 let _log = crate::tests::test_init_log();
1078 let local_addr = "10.0.0.1:12345".parse().unwrap();
1079 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1080
1081 let mut auth = ShortTermAuth::new();
1082 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1083 let credentials = ShortTermCredentials::new(String::from("local_password"));
1084 let wrong_credentials = ShortTermCredentials::new(String::from("wrong_password"));
1085 auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
1086
1087 let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
1088 msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
1089 .unwrap();
1090 let transmit = agent
1091 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1092 .unwrap();
1093 let data = transmit.data;
1094
1095 let request = Message::from_bytes(&data).unwrap();
1096
1097 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1098 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1099 response.add_attribute(&xor_addr).unwrap();
1100 response
1102 .add_message_integrity(&wrong_credentials.into(), IntegrityAlgorithm::Sha1)
1103 .unwrap();
1104
1105 let data = response.finish();
1106 let response = Message::from_bytes(&data).unwrap();
1107 let request = agent
1109 .request_transaction(response.transaction_id())
1110 .unwrap();
1111 assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
1112 assert!(matches!(
1113 auth.validate_incoming_message(&response),
1114 Err(ValidateError::IntegrityFailed)
1115 ));
1116
1117 assert!(!agent.is_validated_peer(remote_addr));
1119
1120 assert!(agent.handle_stun_message(&response, remote_addr));
1122 assert!(!agent.handle_stun_message(&response, remote_addr));
1123 assert!(agent.is_validated_peer(remote_addr));
1124 }
1125
1126 #[test]
1127 fn duplicate_response_ignored() {
1128 let _log = crate::tests::test_init_log();
1129 let local_addr = "10.0.0.1:12345".parse().unwrap();
1130 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1131
1132 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1133 assert!(!agent.is_validated_peer(remote_addr));
1134
1135 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1136 let transmit = agent
1137 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1138 .unwrap();
1139 let data = transmit.data;
1140
1141 let request = Message::from_bytes(&data).unwrap();
1142
1143 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1144 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1145 response.add_attribute(&xor_addr).unwrap();
1146
1147 let data = response.finish();
1148 let to = transmit.to;
1149 let response = Message::from_bytes(&data).unwrap();
1150 assert!(agent.handle_stun_message(&response, to));
1151
1152 let response = Message::from_bytes(&data).unwrap();
1153 assert!(!agent.handle_stun_message(&response, to));
1154 }
1155
1156 #[test]
1157 fn request_cancel() {
1158 let _log = crate::tests::test_init_log();
1159 let local_addr = "10.0.0.1:12345".parse().unwrap();
1160 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1161
1162 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1163
1164 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1165 let transaction_id = msg.transaction_id();
1166 let _transmit = agent
1167 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1168 .unwrap();
1169
1170 let mut request = agent.mut_request_transaction(transaction_id).unwrap();
1171 assert_eq!(request.integrity(), None);
1172 assert_eq!(request.agent().local_addr(), local_addr);
1173 assert_eq!(request.mut_agent().local_addr(), local_addr);
1174 assert_eq!(request.peer_address(), remote_addr);
1175 request.cancel();
1176
1177 let ret = agent.poll(Instant::ZERO);
1178 let StunAgentPollRet::TransactionCancelled(_request) = ret else {
1179 unreachable!();
1180 };
1181 assert_eq!(transaction_id, transaction_id);
1182 assert!(agent.request_transaction(transaction_id).is_none());
1183 assert!(agent.mut_request_transaction(transaction_id).is_none());
1184 assert!(!agent.is_validated_peer(remote_addr));
1185 }
1186
1187 #[test]
1188 fn request_cancel_send() {
1189 let _log = crate::tests::test_init_log();
1190 let local_addr = "10.0.0.1:12345".parse().unwrap();
1191 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1192
1193 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1194
1195 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1196 let transaction_id = msg.transaction_id();
1197 let _transmit = agent
1198 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1199 .unwrap();
1200
1201 let mut request = agent.mut_request_transaction(transaction_id).unwrap();
1202 assert_eq!(request.integrity(), None);
1203 assert_eq!(request.agent().local_addr(), local_addr);
1204 assert_eq!(request.mut_agent().local_addr(), local_addr);
1205 assert_eq!(request.peer_address(), remote_addr);
1206 request.cancel_retransmissions();
1207
1208 let mut now = Instant::ZERO;
1209 let start = now;
1210 loop {
1211 match agent.poll(now) {
1212 StunAgentPollRet::WaitUntil(new_now) => {
1213 assert_ne!(new_now, now);
1214 now = new_now;
1215 }
1216 StunAgentPollRet::TransactionCancelled(_) => break,
1217 _ => unreachable!(),
1218 }
1219 let _ = agent.poll_transmit(now);
1220 }
1221 assert!(now - start > Duration::from_secs(20));
1222 assert!(agent.request_transaction(transaction_id).is_none());
1223 assert!(agent.mut_request_transaction(transaction_id).is_none());
1224 assert!(!agent.is_validated_peer(remote_addr));
1225 }
1226
1227 #[test]
1228 fn request_duplicate() {
1229 let _log = crate::tests::test_init_log();
1230 let local_addr = "10.0.0.1:12345".parse().unwrap();
1231 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1232
1233 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1234
1235 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1236 let transaction_id = msg.transaction_id();
1237 let msg = msg.finish();
1238 let transmit = agent
1239 .send_request(msg.clone(), remote_addr, Instant::ZERO)
1240 .unwrap();
1241 let to = transmit.to;
1242 let request = Message::from_bytes(&transmit.data).unwrap();
1243
1244 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1245 let xor_addr = XorMappedAddress::new(transmit.from, transaction_id);
1246 response.add_attribute(&xor_addr).unwrap();
1247
1248 assert!(matches!(
1249 agent.send_request(msg, remote_addr, Instant::ZERO),
1250 Err(StunError::AlreadyInProgress)
1251 ));
1252
1253 let request = agent.request_transaction(transaction_id).unwrap();
1255 assert_eq!(request.peer_address(), remote_addr);
1256
1257 let data = response.finish();
1258 let response = Message::from_bytes(&data).unwrap();
1259 assert!(agent.handle_stun_message(&response, to));
1260
1261 assert!(agent.is_validated_peer(to));
1262 }
1263
1264 #[test]
1265 fn incoming_request() {
1266 let _log = crate::tests::test_init_log();
1267 let local_addr = "10.0.0.1:12345".parse().unwrap();
1268 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1269
1270 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1271
1272 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1273 let data = msg.finish();
1274 let stun = Message::from_bytes(&data).unwrap();
1275 error!("{stun:?}");
1276 assert!(agent.handle_stun_message(&stun, remote_addr));
1277 agent.validated_peer(remote_addr);
1278 assert!(agent.is_validated_peer(remote_addr));
1279 }
1280
1281 #[test]
1282 fn tcp_request() {
1283 let _log = crate::tests::test_init_log();
1284 let local_addr = "127.0.0.1:2000".parse().unwrap();
1285 let remote_addr = "127.0.0.1:1000".parse().unwrap();
1286 let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
1287 .remote_addr(remote_addr)
1288 .build();
1289
1290 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1291 let transaction_id = msg.transaction_id();
1292 let transmit = agent
1293 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1294 .unwrap();
1295 assert_eq!(transmit.transport, TransportType::Tcp);
1296 assert_eq!(transmit.from, local_addr);
1297 assert_eq!(transmit.to, remote_addr);
1298
1299 let request = Message::from_bytes(&transmit.data).unwrap();
1300 assert_eq!(request.transaction_id(), transaction_id);
1301 }
1302
1303 #[test]
1304 fn transmit_into_owned() {
1305 let data = [0x10, 0x20];
1306 let transport = TransportType::Udp;
1307 let from = "127.0.0.1:1000".parse().unwrap();
1308 let to = "127.0.0.1:2000".parse().unwrap();
1309 let transmit = Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to);
1310 let owned = transmit.into_owned();
1311 assert_eq!(owned.data.as_ref(), data.as_ref());
1312 assert_eq!(owned.transport, transport);
1313 assert_eq!(owned.from, from);
1314 assert_eq!(owned.to, to);
1315 error!("{owned}");
1316 }
1317
1318 #[test]
1319 fn transmit_display() {
1320 let data = [0x10, 0x20];
1321 let from = "127.0.0.1:1000".parse().unwrap();
1322 let to = "127.0.0.1:2000".parse().unwrap();
1323 assert_eq!(
1324 alloc::format!(
1325 "{}",
1326 Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to)
1327 ),
1328 String::from("Transmit(UDP: 127.0.0.1:1000 -> 127.0.0.1:2000 of 2 bytes)")
1329 );
1330 }
1331}