turn_client_proto/
client.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::collections::VecDeque;
10use std::net::{IpAddr, SocketAddr};
11use std::ops::Range;
12use std::time::{Duration, Instant};
13
14use byteorder::{BigEndian, ByteOrder};
15use stun_proto::agent::{
16    DelayedTransmitBuild, HandleStunReply, StunAgent, StunAgentPollRet, StunError, Transmit,
17    TransmitBuild,
18};
19use stun_proto::types::attribute::{ErrorCode, Nonce, Realm, Username};
20use stun_proto::types::data::Data;
21use stun_proto::types::message::{
22    LongTermCredentials, Message, MessageClass, MessageHeader, MessageIntegrityCredentials,
23    MessageType, TransactionId,
24};
25use stun_proto::types::prelude::AttributeExt;
26use turn_types::channel::ChannelData;
27
28use stun_proto::types::TransportType;
29
30use turn_types::attribute::Data as AData;
31use turn_types::attribute::{
32    ChannelNumber, DontFragment, Lifetime, RequestedTransport, XorPeerAddress, XorRelayedAddress,
33};
34use turn_types::message::*;
35use turn_types::TurnCredentials;
36
37use tracing::{error, info, trace, warn};
38
39/// A set of events that can occur within a TURN client's connection to a TURN server.
40#[derive(Debug)]
41pub enum TurnEvent {
42    /// An allocation was created on the server for the client.  The allocation as the associated
43    /// transport and address.
44    AllocationCreated(TransportType, SocketAddr),
45    /// Allocation failed to be created.
46    AllocationCreateFailed,
47    /// A permission was created for the provided transport and IP address.
48    PermissionCreated(TransportType, IpAddr),
49    /// A permission could not be installed for the provided transport and IP address.
50    PermissionCreateFailed(TransportType, IpAddr),
51}
52
53#[derive(Debug)]
54struct Channel {
55    id: u16,
56    peer_addr: SocketAddr,
57    expires_at: Instant,
58    pending_refresh: Option<TransactionId>,
59}
60
61#[derive(Debug)]
62struct Permission {
63    expired: bool,
64    expires_at: Instant,
65    ip: IpAddr,
66    pending_refresh: Option<TransactionId>,
67}
68
69#[derive(Debug)]
70struct Allocation {
71    relayed_address: SocketAddr,
72    transport: TransportType,
73    expired: bool,
74    lifetime: Duration,
75    expires_at: Instant,
76    permissions: Vec<Permission>,
77    channels: Vec<Channel>,
78
79    pending_permissions: VecDeque<(Permission, TransactionId)>,
80    pending_channels: VecDeque<(Channel, TransactionId)>,
81    pending_refresh: Option<(TransactionId, u32)>,
82
83    expired_channels: Vec<Channel>,
84}
85
86/// A TURN client.
87#[derive(Debug)]
88pub struct TurnClient {
89    stun_agent: StunAgent,
90    credentials: TurnCredentials,
91    state: AuthState,
92    allocations: Vec<Allocation>,
93
94    tcp_buffer: Vec<u8>,
95    pending_transmits: VecDeque<Transmit<Data<'static>>>,
96
97    pending_events: VecDeque<TurnEvent>,
98}
99
100#[derive(Debug)]
101enum AuthState {
102    Initial,
103    InitialSent(TransactionId),
104    Authenticating {
105        credentials: LongTermCredentials,
106        nonce: String,
107        transaction_id: TransactionId,
108    },
109    Authenticated {
110        credentials: LongTermCredentials,
111        nonce: String,
112    },
113    Error,
114}
115
116/// Return value from calling [poll](TurnClient::poll)().
117#[derive(Debug)]
118pub enum TurnPollRet {
119    /// The caller should wait until the provided time. Other events may cause this value to
120    /// modified and poll() should be rechecked.
121    WaitUntil(Instant),
122    /// The connection is closed and no further progress will be made.
123    Closed,
124}
125
126/// Return value from call [recv](TurnClient::recv).
127#[derive(Debug)]
128pub enum TurnRecvRet<T: AsRef<[u8]> + std::fmt::Debug> {
129    /// The data has been handled internally and should not be forwarded any further.
130    Handled,
131    /// The data is not directed at this [TurnClient].
132    Ignored(Transmit<T>),
133    /// Data has been received from a peer of the TURN server.
134    // TODO: try to return existing data without a copy
135    PeerData {
136        /// The data received.
137        data: Vec<u8>,
138        /// The transport the data was received over.
139        transport: TransportType,
140        /// The address of the peer that sent the data.
141        peer: SocketAddr,
142    },
143}
144
145#[derive(Debug)]
146enum InternalHandleStunReply {
147    Handled,
148    Ignored,
149    PeerData {
150        data: Vec<u8>,
151        transport: TransportType,
152        peer: SocketAddr,
153    },
154}
155
156/// Errors produced when attempting to create a permission for a peer address.
157#[derive(Debug)]
158pub enum CreatePermissionError {
159    /// The permission already exists and cannot be recreated.
160    AlreadyExists,
161    /// There is no connection to the TURN server that can handle this channel.
162    NoAllocation,
163}
164
165impl std::fmt::Display for CreatePermissionError {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        write!(f, "{:?}", self)
168    }
169}
170
171/// Errors produced when attempting to bind a channel.
172#[derive(Debug)]
173pub enum BindChannelError {
174    /// The channel identifier already exists and cannot be recreated.
175    AlreadyExists,
176    /// There is no connection to the TURN server that can handle this channel.
177    NoAllocation,
178}
179
180impl std::fmt::Display for BindChannelError {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        write!(f, "{:?}", self)
183    }
184}
185
186impl TurnClient {
187    /// Allocate an address on a TURN server to relay data to and from peers.
188    ///
189    /// # Examples
190    /// ```
191    /// # use turn_types::TurnCredentials;
192    /// # use turn_client_proto::TurnClient;
193    /// # use stun_proto::types::TransportType;
194    /// let credentials = TurnCredentials::new("tuser", "tpass");
195    /// let transport = TransportType::Udp;
196    /// let local_addr = "192.168.0.1:4000".parse().unwrap();
197    /// let remote_addr = "10.0.0.1:3478".parse().unwrap();
198    /// let client = TurnClient::allocate(transport, local_addr, remote_addr, credentials);
199    /// assert_eq!(client.transport(), transport);
200    /// assert_eq!(client.local_addr(), local_addr);
201    /// assert_eq!(client.remote_addr(), remote_addr);
202    /// ```
203    #[tracing::instrument(
204        name = "turn_client_allocate"
205        skip(credentials)
206    )]
207    pub fn allocate(
208        ttype: TransportType,
209        local_addr: SocketAddr,
210        remote_addr: SocketAddr,
211        credentials: TurnCredentials,
212    ) -> Self {
213        turn_types::debug_init();
214        let stun_agent = StunAgent::builder(ttype, local_addr)
215            .remote_addr(remote_addr)
216            .build();
217
218        Self {
219            stun_agent,
220            credentials,
221            state: AuthState::Initial,
222            allocations: vec![],
223            pending_transmits: VecDeque::default(),
224            tcp_buffer: vec![],
225            pending_events: VecDeque::default(),
226        }
227    }
228
229    /// The transport of the connection to the TURN server.
230    pub fn transport(&self) -> TransportType {
231        self.stun_agent.transport()
232    }
233
234    /// The local address of this TURN client.
235    pub fn local_addr(&self) -> SocketAddr {
236        self.stun_agent.local_addr()
237    }
238
239    /// The remote TURN server's address.
240    pub fn remote_addr(&self) -> SocketAddr {
241        self.stun_agent.remote_addr().unwrap()
242    }
243
244    /// Poll the client for further progress.
245    #[tracing::instrument(name = "turn_client_poll", ret, skip(self))]
246    pub fn poll(&mut self, now: Instant) -> TurnPollRet {
247        trace!("polling at {now:?}");
248        if !self.pending_events.is_empty() || !self.pending_transmits.is_empty() {
249            return TurnPollRet::WaitUntil(now);
250        }
251        let mut earliest_wait = now + Duration::from_secs(9999);
252        let cancelled_transaction = match self.stun_agent.poll(now) {
253            StunAgentPollRet::WaitUntil(wait) => {
254                earliest_wait = earliest_wait.min(wait);
255                None
256            }
257            StunAgentPollRet::TransactionTimedOut(transaction)
258            | StunAgentPollRet::TransactionCancelled(transaction) => Some(transaction),
259        };
260        if let Some(transaction) = cancelled_transaction {
261            trace!("STUN transaction {transaction} was cancelled/timed out");
262        }
263        match &mut self.state {
264            AuthState::Error => return TurnPollRet::Closed,
265            AuthState::Initial => {
266                return TurnPollRet::WaitUntil(now);
267            }
268            AuthState::InitialSent(transaction_id) => {
269                if cancelled_transaction.is_some_and(|cancelled| &cancelled == transaction_id) {
270                    self.state = AuthState::Error;
271                }
272                return TurnPollRet::WaitUntil(earliest_wait);
273            }
274            AuthState::Authenticating {
275                credentials: _,
276                nonce: _,
277                transaction_id,
278            } => {
279                if cancelled_transaction.is_some_and(|cancelled| &cancelled == transaction_id) {
280                    self.state = AuthState::Error;
281                }
282                return TurnPollRet::WaitUntil(earliest_wait);
283            }
284            AuthState::Authenticated { credentials, nonce } => {
285                for alloc in self.allocations.iter_mut() {
286                    let mut expires_at = alloc.expires_at
287                        - if alloc.pending_refresh.is_none() {
288                            if alloc.lifetime > Duration::from_secs(120) {
289                                Duration::from_secs(60)
290                            } else {
291                                alloc.lifetime / 2
292                            }
293                        } else {
294                            Duration::ZERO
295                        };
296                    if alloc.pending_refresh.is_none() && expires_at <= now {
297                        let mut refresh = Message::builder_request(REFRESH);
298                        let transaction_id = refresh.transaction_id();
299                        let lifetime = Lifetime::new(1800);
300                        refresh.add_attribute(&lifetime).unwrap();
301                        let username = Username::new(credentials.username()).unwrap();
302                        refresh.add_attribute(&username).unwrap();
303                        let realm = Realm::new(credentials.realm()).unwrap();
304                        refresh.add_attribute(&realm).unwrap();
305                        let nonce = Nonce::new(nonce).unwrap();
306                        refresh.add_attribute(&nonce).unwrap();
307                        refresh
308                            .add_message_integrity(
309                                &MessageIntegrityCredentials::LongTerm(credentials.clone()),
310                                stun_proto::types::message::IntegrityAlgorithm::Sha1,
311                            )
312                            .unwrap();
313                        let remote_addr = self.stun_agent.remote_addr().unwrap();
314                        let transmit = self
315                            .stun_agent
316                            .send_request(refresh, remote_addr, now)
317                            .unwrap();
318                        alloc.pending_refresh = Some((transaction_id, 1800));
319                        self.pending_transmits.push_back(transmit.into_owned());
320                        earliest_wait = now.min(earliest_wait);
321                    }
322                    if let Some((pending, _lifetime)) = alloc.pending_refresh {
323                        if cancelled_transaction.is_some_and(|cancelled| cancelled == pending) {
324                            // TODO: need to eventually fail when the allocation times out.
325                            warn!("Refresh timed out or was cancelled");
326                            expires_at = alloc.expires_at;
327                        } else {
328                            expires_at = earliest_wait;
329                        }
330                    }
331                    for channel in alloc.channels.iter_mut() {
332                        let refresh_time = channel.expires_at - Duration::from_secs(60);
333                        if let Some(pending) = channel.pending_refresh {
334                            if cancelled_transaction.is_some_and(|cancelled| cancelled == pending) {
335                                // TODO: need to eventually fail when the permission times out.
336                                warn!("{} channel {} from {} to {} refresh timed out or was cancelled", alloc.transport, channel.id, alloc.relayed_address, channel.peer_addr);
337                                expires_at = channel.expires_at;
338                            } else if channel.expires_at <= now {
339                                info!(
340                                    "{} channel {} from {} to {} has expired",
341                                    alloc.transport,
342                                    channel.id,
343                                    alloc.relayed_address,
344                                    channel.peer_addr
345                                );
346                            } else {
347                                expires_at = expires_at.min(channel.expires_at);
348                            }
349                        } else if refresh_time <= now {
350                            info!(
351                                "refreshing {} channel {} from {} to {}",
352                                alloc.transport,
353                                channel.id,
354                                alloc.relayed_address,
355                                channel.peer_addr
356                            );
357                            let (transmit, transaction_id) = Self::send_channel_bind_request(
358                                &mut self.stun_agent,
359                                credentials.clone(),
360                                nonce,
361                                channel.id,
362                                channel.peer_addr,
363                                now,
364                            );
365                            channel.pending_refresh = Some(transaction_id);
366                            self.pending_transmits.push_back(transmit);
367                            expires_at = expires_at.min(refresh_time);
368                        } else {
369                            expires_at = expires_at.min(refresh_time);
370                        }
371                    }
372
373                    // refresh permission if necessary
374                    for permission in alloc.permissions.iter_mut() {
375                        let refresh_time = permission.expires_at - Duration::from_secs(60);
376                        if let Some(pending) = permission.pending_refresh {
377                            if cancelled_transaction.is_some_and(|cancelled| cancelled == pending) {
378                                warn!(
379                                    "permission {} from {} to {} refresh timed out or was cancelled",
380                                    alloc.transport, alloc.relayed_address, permission.ip
381                                );
382                                expires_at = permission.expires_at;
383                            } else if permission.expires_at <= now {
384                                info!(
385                                    "permission {} from {} to {} has expired",
386                                    alloc.transport, alloc.relayed_address, permission.ip
387                                );
388                                permission.expired = true;
389                                self.pending_events
390                                    .push_back(TurnEvent::PermissionCreateFailed(
391                                        alloc.transport,
392                                        permission.ip,
393                                    ));
394                            } else {
395                                expires_at = expires_at.min(permission.expires_at);
396                            }
397                        } else if refresh_time <= now {
398                            info!(
399                                "refreshing {} permission from {} to {}",
400                                alloc.transport, alloc.relayed_address, permission.ip
401                            );
402                            let (transmit, transaction_id) = Self::send_create_permission_request(
403                                &mut self.stun_agent,
404                                credentials.clone(),
405                                nonce,
406                                permission.ip,
407                                now,
408                            );
409                            permission.pending_refresh = Some(transaction_id);
410                            self.pending_transmits.push_back(transmit);
411                            expires_at = expires_at.min(refresh_time);
412                        } else {
413                            expires_at = expires_at.min(refresh_time);
414                        }
415                    }
416                    earliest_wait = expires_at.min(earliest_wait)
417                }
418                return TurnPollRet::WaitUntil(earliest_wait.max(now));
419            }
420        }
421    }
422
423    /// The list of allocated relayed addresses on the TURN server.
424    pub fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_ {
425        self.allocations
426            .iter()
427            .filter(|allocation| !allocation.expired)
428            .map(|allocation| (allocation.transport, allocation.relayed_address))
429    }
430
431    /// The list of permissions available for the provided relayed address.
432    pub fn permissions(
433        &self,
434        transport: TransportType,
435        relayed: SocketAddr,
436    ) -> impl Iterator<Item = IpAddr> + '_ {
437        self.allocations
438            .iter()
439            .filter(move |allocation| {
440                !allocation.expired
441                    && allocation.transport == transport
442                    && allocation.relayed_address == relayed
443            })
444            .flat_map(|allocation| {
445                allocation
446                    .permissions
447                    .iter()
448                    .filter(|permission| !permission.expired)
449                    .map(|permission| permission.ip)
450            })
451    }
452
453    /// Poll for a packet to send.
454    #[tracing::instrument(
455        name = "turn_client_poll_transmit"
456        skip(self)
457    )]
458    pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
459        if let Some(transmit) = self.pending_transmits.pop_front() {
460            return Some(transmit);
461        }
462        if let Some(transmit) = self
463            .stun_agent
464            .poll_transmit(now)
465            .map(|transmit| transmit_send(&transmit))
466        {
467            return Some(transmit);
468        }
469        match &mut self.state {
470            AuthState::Error => None,
471            AuthState::Initial => {
472                let (transmit, transaction_id) = self.send_initial_request(now);
473                self.state = AuthState::InitialSent(transaction_id);
474                Some(transmit)
475            }
476            AuthState::InitialSent(_transaction_id) => None,
477            AuthState::Authenticating {
478                credentials: _,
479                nonce: _,
480                transaction_id: _,
481            } => None,
482            AuthState::Authenticated {
483                credentials: _,
484                nonce: _,
485            } => None,
486        }
487    }
488
489    /// Poll for an event that has occurred.
490    #[tracing::instrument(name = "turn_client_poll_event", ret, skip(self))]
491    pub fn poll_event(&mut self) -> Option<TurnEvent> {
492        self.pending_events.pop_back()
493    }
494
495    fn send_initial_request(&mut self, now: Instant) -> (Transmit<Data<'static>>, TransactionId) {
496        let mut msg = Message::builder_request(ALLOCATE);
497        let lifetime = Lifetime::new(3600);
498        msg.add_attribute(&lifetime).unwrap();
499        let requested = RequestedTransport::new(RequestedTransport::UDP);
500        msg.add_attribute(&requested).unwrap();
501        let dont_fragment = DontFragment::new();
502        msg.add_attribute(&dont_fragment).unwrap();
503        let transaction_id = msg.transaction_id();
504
505        let remote_addr = self.stun_agent.remote_addr().unwrap();
506        let transmit = self.stun_agent.send_request(msg, remote_addr, now).unwrap();
507        (transmit.into_owned(), transaction_id)
508    }
509
510    fn send_authenticating_request(
511        &mut self,
512        credentials: LongTermCredentials,
513        nonce: &str,
514        now: Instant,
515    ) -> (Transmit<Data<'static>>, TransactionId) {
516        let mut builder = Message::builder_request(ALLOCATE);
517        let requested_transport = RequestedTransport::new(RequestedTransport::UDP);
518        builder.add_attribute(&requested_transport).unwrap();
519        let username = Username::new(credentials.username()).unwrap();
520        builder.add_attribute(&username).unwrap();
521        let realm = Realm::new(credentials.realm()).unwrap();
522        builder.add_attribute(&realm).unwrap();
523        let nonce = Nonce::new(nonce).unwrap();
524        builder.add_attribute(&nonce).unwrap();
525        builder
526            .add_message_integrity(
527                &stun_proto::types::message::MessageIntegrityCredentials::LongTerm(credentials),
528                stun_proto::types::message::IntegrityAlgorithm::Sha1,
529            )
530            .unwrap();
531        let transaction_id = builder.transaction_id();
532        let transmit = self
533            .stun_agent
534            .send_request(builder, self.stun_agent.remote_addr().unwrap(), now)
535            .unwrap();
536        (transmit.into_owned(), transaction_id)
537    }
538
539    fn update_permission_state(&mut self, msg: Message<'_>, now: Instant) -> bool {
540        if let Some((alloc_idx, pending_idx)) =
541            self.allocations
542                .iter()
543                .enumerate()
544                .find_map(|(idx, allocation)| {
545                    allocation
546                        .pending_permissions
547                        .iter()
548                        .position(|(_permission, transaction_id)| {
549                            transaction_id == &msg.transaction_id()
550                        })
551                        .map(|pending_idx| (idx, pending_idx))
552                })
553        {
554            let (mut permission, _transaction_id) = self.allocations[alloc_idx]
555                .pending_permissions
556                .swap_remove_back(pending_idx)
557                .unwrap();
558            info!("Succesfully created {permission:?}");
559            if msg.has_class(stun_proto::types::message::MessageClass::Error) {
560                warn!(
561                    "Received error response to create permission request for {}",
562                    permission.ip
563                );
564                permission.expired = true;
565                permission.expires_at = now;
566                permission.pending_refresh = None;
567                self.pending_events
568                    .push_back(TurnEvent::PermissionCreateFailed(
569                        self.allocations[alloc_idx].transport,
570                        permission.ip,
571                    ));
572            } else {
573                self.pending_events.push_front(TurnEvent::PermissionCreated(
574                    self.allocations[alloc_idx].transport,
575                    permission.ip,
576                ));
577                permission.expires_at = now + Duration::from_secs(300);
578                permission.expired = false;
579                self.allocations[alloc_idx].permissions.push(permission);
580            }
581            true
582        } else if let Some((alloc_idx, existing_idx)) =
583            self.allocations
584                .iter()
585                .enumerate()
586                .find_map(|(idx, allocation)| {
587                    allocation
588                        .permissions
589                        .iter()
590                        .enumerate()
591                        .find_map(|(idx, existing_permission)| {
592                            if existing_permission.pending_refresh.is_some_and(
593                                |refresh_transaction| refresh_transaction == msg.transaction_id(),
594                            ) {
595                                Some(idx)
596                            } else {
597                                None
598                            }
599                        })
600                        .map(|pending_idx| (idx, pending_idx))
601                })
602        {
603            let transport = self.allocations[alloc_idx].transport;
604            let permission = &mut self.allocations[alloc_idx].permissions[existing_idx];
605            permission.pending_refresh = None;
606            if msg.has_class(stun_proto::types::message::MessageClass::Error) {
607                warn!(
608                    "Received error response to create permission request for {}",
609                    permission.ip
610                );
611                permission.expired = true;
612                permission.expires_at = now;
613                self.pending_events
614                    .push_back(TurnEvent::PermissionCreateFailed(transport, permission.ip));
615            } else {
616                permission.expires_at = now + Duration::from_secs(300);
617            }
618            true
619        } else {
620            false
621        }
622    }
623
624    /// Provide received data to the TURN client for handling.
625    ///
626    /// The returned data outlines what to do with this data.
627    #[tracing::instrument(
628        name = "turn_client_recv",
629        skip(self, transmit),
630        fields(
631            transport = ?transmit.transport,
632            from = ?transmit.from,
633            to = ?transmit.to,
634            data_len = transmit.data.as_ref().len(),
635        )
636    )]
637    pub fn recv<T: AsRef<[u8]> + std::fmt::Debug>(
638        &mut self,
639        transmit: Transmit<T>,
640        now: Instant,
641    ) -> TurnRecvRet<T> {
642        /* is this data for our client? */
643        if transmit.to != self.stun_agent.local_addr()
644            || self.stun_agent.transport() != transmit.transport
645            || transmit.from != self.stun_agent.remote_addr().unwrap()
646        {
647            trace!(
648                "received data not directed at us ({:?}) but for {:?}!",
649                self.stun_agent.local_addr(),
650                transmit.to
651            );
652            return TurnRecvRet::Ignored(transmit);
653        }
654        let (credentials, nonce) = match &mut self.state {
655            AuthState::Error | AuthState::Initial => return TurnRecvRet::Ignored(transmit),
656            AuthState::InitialSent(transaction_id) => {
657                let msg = if self.stun_agent.transport() == TransportType::Tcp {
658                    self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
659                    let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
660                        return TurnRecvRet::Handled;
661                    };
662                    if self.tcp_buffer.len() < MessageHeader::LENGTH + hdr.data_length() as usize {
663                        return TurnRecvRet::Handled;
664                    }
665                    let Ok(ret) = Message::from_bytes(transmit.data.as_ref()) else {
666                        return TurnRecvRet::Ignored(transmit);
667                    };
668                    ret
669                } else {
670                    let Ok(ret) = Message::from_bytes(transmit.data.as_ref()) else {
671                        return TurnRecvRet::Ignored(transmit);
672                    };
673                    ret
674                };
675                trace!("received STUN message {msg}");
676                let msg = match self.stun_agent.handle_stun(msg, transmit.from) {
677                    HandleStunReply::Drop => return TurnRecvRet::Handled,
678                    HandleStunReply::IncomingStun(_) => return TurnRecvRet::Ignored(transmit),
679                    HandleStunReply::StunResponse(msg) => msg,
680                };
681                if !msg.is_response() || &msg.transaction_id() != transaction_id {
682                    return TurnRecvRet::Ignored(transmit);
683                }
684                /* The Initial stun request should result in an unauthorized error as there were
685                 * no credentials in the initial request */
686                if !msg.has_class(stun_proto::types::message::MessageClass::Error) {
687                    self.state = AuthState::Error;
688                    self.pending_events
689                        .push_front(TurnEvent::AllocationCreateFailed);
690                    return TurnRecvRet::Ignored(transmit);
691                }
692                let Ok(error_code) = msg.attribute::<ErrorCode>() else {
693                    self.state = AuthState::Error;
694                    self.pending_events
695                        .push_front(TurnEvent::AllocationCreateFailed);
696                    return TurnRecvRet::Ignored(transmit);
697                };
698                let Ok(realm) = msg.attribute::<Realm>() else {
699                    self.state = AuthState::Error;
700                    self.pending_events
701                        .push_front(TurnEvent::AllocationCreateFailed);
702                    return TurnRecvRet::Ignored(transmit);
703                };
704                let Ok(nonce) = msg.attribute::<Nonce>() else {
705                    self.state = AuthState::Error;
706                    self.pending_events
707                        .push_front(TurnEvent::AllocationCreateFailed);
708                    return TurnRecvRet::Ignored(transmit);
709                };
710                match error_code.code() {
711                    ErrorCode::UNAUTHORIZED => {
712                        /* retry the request with the correct credentials */
713                        let credentials = self
714                            .credentials
715                            .clone()
716                            .into_long_term_credentials(realm.realm());
717                        let (transmit, transaction_id) = self.send_authenticating_request(
718                            credentials.clone(),
719                            nonce.nonce(),
720                            now,
721                        );
722                        self.stun_agent.set_remote_credentials(
723                            MessageIntegrityCredentials::LongTerm(credentials.clone()),
724                        );
725                        self.pending_transmits.push_back(transmit.into_owned());
726                        self.state = AuthState::Authenticating {
727                            credentials,
728                            nonce: nonce.nonce().to_string(),
729                            transaction_id,
730                        };
731                        return TurnRecvRet::Handled;
732                    }
733                    ErrorCode::TRY_ALTERNATE => (), // FIXME: implement
734                    code => {
735                        trace!("Unknown error code returned {code:?}");
736                        self.state = AuthState::Error;
737                        self.pending_events
738                            .push_front(TurnEvent::AllocationCreateFailed);
739                    }
740                }
741                return TurnRecvRet::Ignored(transmit);
742            }
743            AuthState::Authenticating {
744                credentials,
745                nonce,
746                transaction_id,
747            } => {
748                let msg = if self.stun_agent.transport() == TransportType::Tcp {
749                    self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
750                    let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
751                        return TurnRecvRet::Handled;
752                    };
753                    if self.tcp_buffer.len() < MessageHeader::LENGTH + hdr.data_length() as usize {
754                        return TurnRecvRet::Handled;
755                    }
756                    let Ok(ret) = Message::from_bytes(transmit.data.as_ref()) else {
757                        return TurnRecvRet::Ignored(transmit);
758                    };
759                    ret
760                } else {
761                    let Ok(ret) = Message::from_bytes(transmit.data.as_ref()) else {
762                        return TurnRecvRet::Ignored(transmit);
763                    };
764                    ret
765                };
766                trace!("received STUN message {msg}");
767                let msg = match self.stun_agent.handle_stun(msg, transmit.from) {
768                    HandleStunReply::Drop => return TurnRecvRet::Handled,
769                    HandleStunReply::IncomingStun(_) => return TurnRecvRet::Ignored(transmit),
770                    HandleStunReply::StunResponse(msg) => msg,
771                };
772                if !msg.is_response() || &msg.transaction_id() != transaction_id {
773                    return TurnRecvRet::Ignored(transmit);
774                }
775                match msg.class() {
776                    stun_proto::types::message::MessageClass::Error => {
777                        let Ok(error_code) = msg.attribute::<ErrorCode>() else {
778                            self.state = AuthState::Error;
779                            return TurnRecvRet::Ignored(transmit);
780                        };
781                        match error_code.code() {
782                            ErrorCode::STALE_NONCE => {
783                                let Ok(realm) = msg.attribute::<Realm>() else {
784                                    self.state = AuthState::Error;
785                                    return TurnRecvRet::Ignored(transmit);
786                                };
787                                let Ok(nonce) = msg.attribute::<Nonce>() else {
788                                    self.state = AuthState::Error;
789                                    return TurnRecvRet::Ignored(transmit);
790                                };
791                                let credentials = self
792                                    .credentials
793                                    .clone()
794                                    .into_long_term_credentials(realm.realm());
795                                let (transmit, transaction_id) = self.send_authenticating_request(
796                                    credentials.clone(),
797                                    nonce.nonce(),
798                                    now,
799                                );
800                                self.pending_transmits.push_back(transmit.into_owned());
801                                self.state = AuthState::Authenticating {
802                                    credentials,
803                                    nonce: nonce.nonce().to_string(),
804                                    transaction_id,
805                                };
806                                return TurnRecvRet::Handled;
807                            }
808                            code => {
809                                warn!("Unknown error code returned while authenticating: {code:?}");
810                                self.state = AuthState::Error;
811                            }
812                        }
813                    }
814                    stun_proto::types::message::MessageClass::Success => {
815                        let Ok(_) = msg.validate_integrity(&MessageIntegrityCredentials::LongTerm(
816                            credentials.clone(),
817                        )) else {
818                            return TurnRecvRet::Ignored(transmit);
819                        };
820                        let xor_relayed_address = msg.attribute::<XorRelayedAddress>();
821                        let lifetime = msg.attribute::<Lifetime>();
822                        let (Ok(xor_relayed_address), Ok(lifetime)) =
823                            (xor_relayed_address, lifetime)
824                        else {
825                            self.state = AuthState::Error;
826                            return TurnRecvRet::Ignored(transmit);
827                        };
828                        let relayed_address = xor_relayed_address.addr(msg.transaction_id());
829                        let lifetime = Duration::from_secs(lifetime.seconds() as u64);
830                        let expires_at = now + lifetime;
831                        self.state = AuthState::Authenticated {
832                            credentials: credentials.clone(),
833                            nonce: nonce.clone(),
834                        };
835                        info!(relayed = ?relayed_address, transport = ?TransportType::Udp, "New allocation expiring in {}s", lifetime.as_secs());
836                        self.allocations.push(Allocation {
837                            relayed_address,
838                            // TODO support TCP
839                            transport: TransportType::Udp,
840                            expired: false,
841                            lifetime,
842                            expires_at,
843                            permissions: vec![],
844                            channels: vec![],
845                            pending_permissions: VecDeque::default(),
846                            pending_channels: VecDeque::default(),
847                            pending_refresh: None,
848                            expired_channels: vec![],
849                        });
850                        self.pending_events.push_front(TurnEvent::AllocationCreated(
851                            TransportType::Udp,
852                            relayed_address,
853                        ));
854                        return TurnRecvRet::Handled;
855                    }
856                    _ => (),
857                }
858                return TurnRecvRet::Ignored(transmit);
859            }
860            AuthState::Authenticated { credentials, nonce } => (credentials.clone(), nonce),
861        };
862
863        if self.stun_agent.transport() == TransportType::Tcp {
864            // TODO: handle multiple messages/channeldata in a single transmit
865            self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
866            let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
867                let Ok(channel) = ChannelData::parse(&self.tcp_buffer) else {
868                    return TurnRecvRet::Ignored(transmit);
869                };
870                let data = channel.data();
871                for alloc in self.allocations.iter_mut() {
872                    if let Some(chan) = alloc
873                        .channels
874                        .iter_mut()
875                        .find(|chan| chan.id == channel.id())
876                    {
877                        let data = data.to_vec();
878                        self.tcp_buffer = self.tcp_buffer.split_at(data.len() + 2).1.to_vec();
879                        return TurnRecvRet::PeerData {
880                            data,
881                            transport: alloc.transport,
882                            peer: chan.peer_addr,
883                        };
884                    }
885                }
886                self.tcp_buffer = self.tcp_buffer.split_at(data.len() + 2).1.to_vec();
887                return TurnRecvRet::Handled;
888            };
889            if self.tcp_buffer.len() < MessageHeader::LENGTH + hdr.data_length() as usize {
890                return TurnRecvRet::Handled;
891            }
892            let Ok(msg) = Message::from_bytes(transmit.data.as_ref()) else {
893                return TurnRecvRet::Ignored(transmit);
894            };
895
896            // FIXME: dual allocations
897            let transport = self
898                .allocations
899                .iter()
900                .map(|allocation| allocation.transport)
901                .next()
902                .unwrap();
903
904            match self.handle_stun(msg, transport, transmit.from, credentials, now) {
905                InternalHandleStunReply::Handled => TurnRecvRet::Handled,
906                InternalHandleStunReply::Ignored => TurnRecvRet::Ignored(transmit),
907                InternalHandleStunReply::PeerData {
908                    data,
909                    transport,
910                    peer,
911                } => TurnRecvRet::PeerData {
912                    data,
913                    transport,
914                    peer,
915                },
916            }
917        } else {
918            let Ok(msg) = Message::from_bytes(transmit.data.as_ref()) else {
919                let Ok(channel) = ChannelData::parse(transmit.data.as_ref()) else {
920                    return TurnRecvRet::Ignored(transmit);
921                };
922                for alloc in self.allocations.iter_mut() {
923                    if let Some(chan) = alloc
924                        .channels
925                        .iter_mut()
926                        .find(|chan| chan.id == channel.id())
927                    {
928                        return TurnRecvRet::PeerData {
929                            data: channel.data().to_vec(),
930                            transport: alloc.transport,
931                            peer: chan.peer_addr,
932                        };
933                    }
934                }
935                return TurnRecvRet::Ignored(transmit);
936            };
937
938            // FIXME: TCP allocations
939            let transport = self
940                .allocations
941                .iter()
942                .map(|allocation| allocation.transport)
943                .next()
944                .unwrap();
945
946            match self.handle_stun(msg, transport, transmit.from, credentials, now) {
947                InternalHandleStunReply::Handled => TurnRecvRet::Handled,
948                InternalHandleStunReply::Ignored => TurnRecvRet::Ignored(transmit),
949                InternalHandleStunReply::PeerData {
950                    data,
951                    transport,
952                    peer,
953                } => TurnRecvRet::PeerData {
954                    data,
955                    transport,
956                    peer,
957                },
958            }
959        }
960    }
961
962    fn handle_stun(
963        &mut self,
964        msg: Message<'_>,
965        transport: TransportType,
966        from: SocketAddr,
967        credentials: LongTermCredentials,
968        now: Instant,
969    ) -> InternalHandleStunReply {
970        trace!("received STUN message {msg}");
971        let msg = match self.stun_agent.handle_stun(msg, from) {
972            HandleStunReply::Drop => return InternalHandleStunReply::Ignored,
973            HandleStunReply::IncomingStun(msg) => msg,
974            HandleStunReply::StunResponse(msg) => msg,
975        };
976        if msg.is_response() {
977            let Ok(_) = msg.validate_integrity(&MessageIntegrityCredentials::LongTerm(credentials))
978            else {
979                trace!("incoming message failed integrity check");
980                return InternalHandleStunReply::Ignored;
981            };
982
983            match msg.method() {
984                REFRESH => {
985                    let is_success = if msg.has_class(MessageClass::Error) {
986                        msg.attribute::<ErrorCode>()
987                            .is_ok_and(|err| err.code() == ErrorCode::ALLOCATION_MISMATCH)
988                    } else {
989                        msg.has_class(MessageClass::Success)
990                    };
991                    let mut remove_allocations = false;
992                    let mut handled = false;
993                    if is_success {
994                        for alloc in self.allocations.iter_mut() {
995                            let Ok(lifetime) = msg.attribute::<Lifetime>() else {
996                                continue;
997                            };
998                            let (_transaction_id, requested_lifetime) = if alloc
999                                .pending_refresh
1000                                .is_some_and(|(transaction_id, _requested_lifetime)| {
1001                                    transaction_id == msg.transaction_id() && is_success
1002                                }) {
1003                                alloc.pending_refresh.take().unwrap()
1004                            } else {
1005                                continue;
1006                            };
1007                            if requested_lifetime == 0 {
1008                                remove_allocations = true;
1009                            } else {
1010                                alloc.expires_at =
1011                                    now + Duration::from_secs(lifetime.seconds() as u64);
1012                            }
1013                            handled = true;
1014                        }
1015                    }
1016
1017                    if remove_allocations {
1018                        self.allocations.clear();
1019                        self.state = AuthState::Error;
1020                    }
1021                    if handled {
1022                        if remove_allocations {
1023                            info!("Successfully deleted allocation");
1024                        } else {
1025                            info!("Successfully refreshed allocation");
1026                        }
1027                        InternalHandleStunReply::Handled
1028                    } else {
1029                        InternalHandleStunReply::Ignored
1030                    }
1031                }
1032                CREATE_PERMISSION => {
1033                    if self.update_permission_state(msg, now) {
1034                        InternalHandleStunReply::Handled
1035                    } else {
1036                        InternalHandleStunReply::Ignored
1037                    }
1038                }
1039                CHANNEL_BIND => {
1040                    if let Some((alloc_idx, channel_idx)) = self
1041                        .allocations
1042                        .iter()
1043                        .enumerate()
1044                        .find_map(|(idx, allocation)| {
1045                            allocation
1046                                .pending_channels
1047                                .iter()
1048                                .position(|(_channel, transaction_id)| {
1049                                    transaction_id == &msg.transaction_id()
1050                                })
1051                                .map(|perm_idx| (idx, perm_idx))
1052                        })
1053                    {
1054                        let (mut channel, _transaction_id) = self.allocations[alloc_idx]
1055                            .pending_channels
1056                            .swap_remove_back(channel_idx)
1057                            .unwrap();
1058                        if msg.has_class(stun_proto::types::message::MessageClass::Error) {
1059                            error!("Received error response to channel bind request");
1060                            // TODO: handle
1061                            return InternalHandleStunReply::Handled;
1062                        }
1063                        info!("Succesfully created/refreshed {channel:?}");
1064                        self.update_permission_state(msg, now);
1065                        if let Some(existing_idx) = self.allocations[alloc_idx]
1066                            .channels
1067                            .iter()
1068                            .enumerate()
1069                            .find_map(|(idx, existing_channel)| {
1070                                if channel.peer_addr == existing_channel.peer_addr {
1071                                    Some(idx)
1072                                } else {
1073                                    None
1074                                }
1075                            })
1076                        {
1077                            self.allocations[alloc_idx].channels[existing_idx].expires_at =
1078                                now + Duration::from_secs(600);
1079                        } else {
1080                            channel.expires_at = now + Duration::from_secs(600);
1081                            self.allocations[alloc_idx].channels.push(channel);
1082                        }
1083                        return InternalHandleStunReply::Handled;
1084                    }
1085                    InternalHandleStunReply::Ignored
1086                }
1087                _ => InternalHandleStunReply::Ignored, // Other responses are not expected
1088            }
1089        } else if msg.has_class(stun_proto::types::message::MessageClass::Request) {
1090            let Ok(_) = msg.validate_integrity(&MessageIntegrityCredentials::LongTerm(credentials))
1091            else {
1092                trace!("incoming message failed integrity check");
1093                return InternalHandleStunReply::Ignored;
1094            };
1095
1096            // TODO: reply with an error?
1097            InternalHandleStunReply::Ignored
1098        } else {
1099            /* The message is an indication */
1100            match msg.method() {
1101                DATA => {
1102                    let Ok(peer_addr) = msg.attribute::<XorPeerAddress>() else {
1103                        return InternalHandleStunReply::Ignored;
1104                    };
1105                    let Ok(data) = msg.attribute::<AData>() else {
1106                        return InternalHandleStunReply::Ignored;
1107                    };
1108                    InternalHandleStunReply::PeerData {
1109                        data: data.data().to_vec(),
1110                        transport,
1111                        peer: peer_addr.addr(msg.transaction_id()),
1112                    }
1113                }
1114                _ => InternalHandleStunReply::Ignored, // All other indications should be ignored
1115            }
1116        }
1117    }
1118
1119    /// Remove the allocation/s on the server.
1120    pub fn delete(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
1121        let mut builder = Message::builder_request(REFRESH);
1122        let transaction_id = builder.transaction_id();
1123
1124        let AuthState::Authenticated { credentials, nonce } = &self.state else {
1125            return None;
1126        };
1127
1128        let lifetime = Lifetime::new(0);
1129        builder.add_attribute(&lifetime).unwrap();
1130        let username = Username::new(credentials.username()).unwrap();
1131        builder.add_attribute(&username).unwrap();
1132        let realm = Realm::new(credentials.realm()).unwrap();
1133        builder.add_attribute(&realm).unwrap();
1134        let nonce = Nonce::new(nonce).unwrap();
1135        builder.add_attribute(&nonce).unwrap();
1136        builder
1137            .add_message_integrity(
1138                &stun_proto::types::message::MessageIntegrityCredentials::LongTerm(
1139                    credentials.clone(),
1140                ),
1141                stun_proto::types::message::IntegrityAlgorithm::Sha1,
1142            )
1143            .unwrap();
1144
1145        let transmit = self
1146            .stun_agent
1147            .send_request(builder, self.stun_agent.remote_addr().unwrap(), now)
1148            .unwrap();
1149        info!("Deleting allocations");
1150        for alloc in self.allocations.iter_mut() {
1151            alloc.permissions.clear();
1152            alloc.channels.clear();
1153            alloc.expires_at = now;
1154            alloc.expired = true;
1155            alloc.pending_refresh = Some((transaction_id, 0));
1156        }
1157        Some(transmit.into_owned())
1158    }
1159
1160    fn send_create_permission_request(
1161        stun_agent: &mut StunAgent,
1162        credentials: LongTermCredentials,
1163        nonce: &str,
1164        peer_addr: IpAddr,
1165        now: Instant,
1166    ) -> (Transmit<Data<'static>>, TransactionId) {
1167        let mut builder = Message::builder_request(CREATE_PERMISSION);
1168        let transaction_id = builder.transaction_id();
1169
1170        let xor_peer_address = XorPeerAddress::new(SocketAddr::new(peer_addr, 0), transaction_id);
1171        builder.add_attribute(&xor_peer_address).unwrap();
1172        let username = Username::new(credentials.username()).unwrap();
1173        builder.add_attribute(&username).unwrap();
1174        let realm = Realm::new(credentials.realm()).unwrap();
1175        builder.add_attribute(&realm).unwrap();
1176        let nonce = Nonce::new(nonce).unwrap();
1177        builder.add_attribute(&nonce).unwrap();
1178        builder
1179            .add_message_integrity(
1180                &stun_proto::types::message::MessageIntegrityCredentials::LongTerm(
1181                    credentials.clone(),
1182                ),
1183                stun_proto::types::message::IntegrityAlgorithm::Sha1,
1184            )
1185            .unwrap();
1186        let transmit = stun_agent
1187            .send_request(builder, stun_agent.remote_addr().unwrap(), now)
1188            .unwrap();
1189        (transmit.into_owned(), transaction_id)
1190    }
1191
1192    /// Create a permission address to allow sending/receiving data to/from.
1193    #[tracing::instrument(name = "turn_client_create_permission", skip(self, now), err)]
1194    pub fn create_permission(
1195        &mut self,
1196        transport: TransportType,
1197        peer_addr: IpAddr,
1198        now: Instant,
1199    ) -> Result<Transmit<Data<'static>>, CreatePermissionError> {
1200        let Some(allocation) = self.allocations.iter_mut().find(|allocation| {
1201            allocation.transport == transport
1202                && allocation.relayed_address.is_ipv4() == peer_addr.is_ipv4()
1203        }) else {
1204            warn!("No allocation available to create this permission");
1205            return Err(CreatePermissionError::NoAllocation);
1206        };
1207
1208        if now >= allocation.expires_at {
1209            allocation.expired = true;
1210            warn!("Allocation has expired");
1211            return Err(CreatePermissionError::NoAllocation);
1212        }
1213
1214        if allocation
1215            .permissions
1216            .iter()
1217            .any(|permission| permission.ip == peer_addr)
1218        {
1219            return Err(CreatePermissionError::AlreadyExists);
1220        }
1221        if allocation
1222            .pending_permissions
1223            .iter()
1224            .any(|(permission, _transaction_id)| permission.ip == peer_addr)
1225        {
1226            return Err(CreatePermissionError::AlreadyExists);
1227        }
1228        let AuthState::Authenticated { credentials, nonce } = &self.state else {
1229            warn!("Not authenticated yet: {:?}", self.state);
1230            return Err(CreatePermissionError::NoAllocation);
1231        };
1232        let permission = Permission {
1233            expired: false,
1234            expires_at: now,
1235            ip: peer_addr,
1236            pending_refresh: None,
1237        };
1238
1239        let (transmit, transaction_id) = Self::send_create_permission_request(
1240            &mut self.stun_agent,
1241            credentials.clone(),
1242            nonce,
1243            peer_addr,
1244            now,
1245        );
1246        info!("Creating {permission:?}");
1247        allocation
1248            .pending_permissions
1249            .push_back((permission, transaction_id));
1250        Ok(transmit)
1251    }
1252
1253    fn send_channel_bind_request(
1254        stun_agent: &mut StunAgent,
1255        credentials: LongTermCredentials,
1256        nonce: &str,
1257        id: u16,
1258        peer_addr: SocketAddr,
1259        now: Instant,
1260    ) -> (Transmit<Data<'static>>, TransactionId) {
1261        let mut builder = Message::builder_request(CHANNEL_BIND);
1262        let transaction_id = builder.transaction_id();
1263        let channel_no = ChannelNumber::new(id);
1264        builder.add_attribute(&channel_no).unwrap();
1265        let xor_peer_address = XorPeerAddress::new(peer_addr, transaction_id);
1266        builder.add_attribute(&xor_peer_address).unwrap();
1267        let username = Username::new(credentials.username()).unwrap();
1268        builder.add_attribute(&username).unwrap();
1269        let realm = Realm::new(credentials.realm()).unwrap();
1270        builder.add_attribute(&realm).unwrap();
1271        let nonce = Nonce::new(nonce).unwrap();
1272        builder.add_attribute(&nonce).unwrap();
1273        builder
1274            .add_message_integrity(
1275                &stun_proto::types::message::MessageIntegrityCredentials::LongTerm(
1276                    credentials.clone(),
1277                ),
1278                stun_proto::types::message::IntegrityAlgorithm::Sha1,
1279            )
1280            .unwrap();
1281
1282        let transmit = stun_agent
1283            .send_request(builder, stun_agent.remote_addr().unwrap(), now)
1284            .unwrap();
1285        (transmit.into_owned(), transaction_id)
1286    }
1287
1288    /// Bind a channel for sending/receiving data to/from a particular peer.
1289    pub fn bind_channel(
1290        &mut self,
1291        transport: TransportType,
1292        peer_addr: SocketAddr,
1293        now: Instant,
1294    ) -> Result<Transmit<Data<'static>>, BindChannelError> {
1295        let Some(allocation) = self.allocations.iter_mut().find(|allocation| {
1296            allocation.transport == transport
1297                && allocation.relayed_address.is_ipv4() == peer_addr.is_ipv4()
1298        }) else {
1299            warn!("No allocation available to create this permission");
1300            return Err(BindChannelError::NoAllocation);
1301        };
1302
1303        if now >= allocation.expires_at {
1304            allocation.expired = true;
1305            return Err(BindChannelError::NoAllocation);
1306        }
1307
1308        if allocation
1309            .channels
1310            .iter()
1311            .any(|channel| channel.peer_addr == peer_addr)
1312        {
1313            return Err(BindChannelError::AlreadyExists);
1314        }
1315
1316        let AuthState::Authenticated { credentials, nonce } = &self.state else {
1317            return Err(BindChannelError::NoAllocation);
1318        };
1319
1320        let mut channel_id = 0x4000;
1321        for channel in 0x4000..=0x7FFF {
1322            channel_id = channel;
1323            if allocation
1324                .channels
1325                .iter()
1326                .chain(
1327                    allocation
1328                        .pending_channels
1329                        .iter()
1330                        .map(|(channel, _transaction_id)| channel),
1331                )
1332                .chain(allocation.expired_channels.iter())
1333                .any(|channel| {
1334                    channel.expires_at + Duration::from_secs(300) <= now && channel.id == channel_id
1335                })
1336            {
1337                continue;
1338            }
1339            break;
1340        }
1341
1342        let (transmit, transaction_id) = Self::send_channel_bind_request(
1343            &mut self.stun_agent,
1344            credentials.clone(),
1345            nonce,
1346            channel_id,
1347            peer_addr,
1348            now,
1349        );
1350
1351        // FIXME: update any existing permission
1352        let permission = Permission {
1353            expired: false,
1354            expires_at: now,
1355            ip: peer_addr.ip(),
1356            pending_refresh: None,
1357        };
1358        allocation
1359            .pending_permissions
1360            .push_back((permission, transaction_id));
1361        let channel = Channel {
1362            id: channel_id,
1363            expires_at: now,
1364            peer_addr,
1365            pending_refresh: None,
1366        };
1367        info!("Creating channel {channel:?}");
1368        allocation
1369            .pending_channels
1370            .push_back((channel, transaction_id));
1371        Ok(transmit.into_owned())
1372    }
1373
1374    fn have_permission(&self, transport: TransportType, to: SocketAddr, now: Instant) -> bool {
1375        self.allocations.iter().any(|allocation| {
1376            allocation.transport == transport
1377                && allocation.expires_at >= now
1378                && allocation
1379                    .permissions
1380                    .iter()
1381                    .any(|permission| permission.expires_at >= now && permission.ip == to.ip())
1382        })
1383    }
1384
1385    /// Send data to a peer through the TURN server.
1386    ///
1387    /// The provided transport, address and data are the data to send to the peer.
1388    ///
1389    /// The returned value will instruct the caller to send a message to the turn server.
1390    pub fn send_to<T: AsRef<[u8]> + std::fmt::Debug>(
1391        &mut self,
1392        transport: TransportType,
1393        to: SocketAddr,
1394        data: T,
1395        now: Instant,
1396    ) -> Result<TransmitBuild<DelayedMessageOrChannelSend<T>>, StunError> {
1397        if !self.have_permission(transport, to, now) {
1398            return Err(StunError::ResourceNotFound);
1399        }
1400
1401        if let Some(channel) = self.channel(transport, to) {
1402            if channel.expires_at >= now {
1403                return Ok(TransmitBuild::new(
1404                    DelayedMessageOrChannelSend::Channel(DelayedChannelSend {
1405                        data,
1406                        channel_id: channel.id,
1407                    }),
1408                    self.stun_agent.transport(),
1409                    self.stun_agent.local_addr(),
1410                    self.stun_agent.remote_addr().unwrap(),
1411                ));
1412            }
1413        }
1414        Ok(TransmitBuild::new(
1415            DelayedMessageOrChannelSend::Message(DelayedMessageSend {
1416                data,
1417                peer_addr: to,
1418            }),
1419            self.stun_agent.transport(),
1420            self.stun_agent.local_addr(),
1421            self.stun_agent.remote_addr().unwrap(),
1422        ))
1423    }
1424
1425    #[cfg(test)]
1426    fn permission(&self, transport: TransportType, ip: IpAddr) -> Option<&Permission> {
1427        self.allocations
1428            .iter()
1429            .filter(|allocation| allocation.transport == transport)
1430            .find_map(|allocation| {
1431                allocation
1432                    .permissions
1433                    .iter()
1434                    .find(|permission| permission.ip == ip)
1435            })
1436    }
1437
1438    fn channel(&self, transport: TransportType, addr: SocketAddr) -> Option<&Channel> {
1439        self.allocations
1440            .iter()
1441            .filter(|allocation| allocation.transport == transport)
1442            .find_map(|allocation| {
1443                allocation
1444                    .channels
1445                    .iter()
1446                    .find(|channel| channel.peer_addr == addr)
1447            })
1448    }
1449}
1450
1451/// A `Transmit` where the data is some subset of the provided region.
1452#[derive(Debug)]
1453pub struct DelayedTransmit<T: AsRef<[u8]> + std::fmt::Debug> {
1454    data: T,
1455    range: Range<usize>,
1456}
1457
1458impl<T: AsRef<[u8]> + std::fmt::Debug> DelayedTransmit<T> {
1459    fn data(&self) -> &[u8] {
1460        &self.data.as_ref()[self.range.clone()]
1461    }
1462}
1463
1464impl<T: AsRef<[u8]> + std::fmt::Debug> DelayedTransmitBuild for DelayedTransmit<T> {
1465    fn len(&self) -> usize {
1466        self.range.len()
1467    }
1468
1469    fn build(self) -> Vec<u8> {
1470        self.data().to_vec()
1471    }
1472
1473    fn write_into(self, data: &mut [u8]) -> usize {
1474        data.copy_from_slice(self.data());
1475        self.len()
1476    }
1477}
1478
1479/// A `Transmit` that will construct a STUN message towards a client with the relevant data.
1480#[derive(Debug)]
1481pub struct DelayedMessageSend<T: AsRef<[u8]> + std::fmt::Debug> {
1482    data: T,
1483    peer_addr: SocketAddr,
1484}
1485
1486impl<T: AsRef<[u8]> + std::fmt::Debug> DelayedTransmitBuild for DelayedMessageSend<T> {
1487    fn len(&self) -> usize {
1488        let xor_peer_addr = XorPeerAddress::new(self.peer_addr, 0.into());
1489        let data = AData::new(self.data.as_ref());
1490        MessageHeader::LENGTH + xor_peer_addr.padded_len() + data.padded_len()
1491    }
1492
1493    fn build(self) -> Vec<u8> {
1494        let transaction_id = TransactionId::generate();
1495        let mut msg = Message::builder(
1496            MessageType::from_class_method(
1497                stun_proto::types::message::MessageClass::Indication,
1498                SEND,
1499            ),
1500            transaction_id,
1501        );
1502        let xor_peer_address = XorPeerAddress::new(self.peer_addr, transaction_id);
1503        msg.add_attribute(&xor_peer_address).unwrap();
1504        let data = AData::new(self.data.as_ref());
1505        msg.add_attribute(&data).unwrap();
1506        msg.build()
1507    }
1508
1509    fn write_into(self, dest: &mut [u8]) -> usize {
1510        let transaction_id = TransactionId::generate();
1511        let mut msg = Message::builder(
1512            MessageType::from_class_method(
1513                stun_proto::types::message::MessageClass::Indication,
1514                SEND,
1515            ),
1516            transaction_id,
1517        );
1518        let xor_peer_address = XorPeerAddress::new(self.peer_addr, transaction_id);
1519        msg.add_attribute(&xor_peer_address).unwrap();
1520        let data = AData::new(self.data.as_ref());
1521        msg.add_attribute(&data).unwrap();
1522        msg.write_into(dest)
1523    }
1524}
1525
1526/// A `Transmit` that will construct a channel message towards a TURN client.
1527#[derive(Debug)]
1528pub struct DelayedChannelSend<T: AsRef<[u8]> + std::fmt::Debug> {
1529    data: T,
1530    channel_id: u16,
1531}
1532
1533impl<T: AsRef<[u8]> + std::fmt::Debug> DelayedTransmitBuild for DelayedChannelSend<T> {
1534    fn len(&self) -> usize {
1535        self.data.as_ref().len() + 4
1536    }
1537
1538    fn build(self) -> Vec<u8> {
1539        let mut data = vec![0; self.data.as_ref().len() + 4];
1540        self.write_into(&mut data);
1541        data
1542    }
1543
1544    fn write_into(self, dest: &mut [u8]) -> usize {
1545        let data_len = self.data.as_ref().len();
1546        BigEndian::write_u16(&mut dest[..2], self.channel_id);
1547        BigEndian::write_u16(&mut dest[2..4], data_len as u16);
1548        dest[4..].copy_from_slice(self.data.as_ref());
1549        data_len + 4
1550    }
1551}
1552
1553/// A delayed `Transmit` that will produce data for a TURN client.
1554#[derive(Debug)]
1555pub enum DelayedMessageOrChannelSend<T: AsRef<[u8]> + std::fmt::Debug> {
1556    /// A [`DelayedChannelSend`].
1557    Channel(DelayedChannelSend<T>),
1558    /// A [`DelayedMessageSend`].
1559    Message(DelayedMessageSend<T>),
1560}
1561
1562impl<T: AsRef<[u8]> + std::fmt::Debug> DelayedTransmitBuild for DelayedMessageOrChannelSend<T> {
1563    fn len(&self) -> usize {
1564        match self {
1565            Self::Channel(channel) => channel.len(),
1566            Self::Message(msg) => msg.len(),
1567        }
1568    }
1569
1570    fn build(self) -> Vec<u8> {
1571        match self {
1572            Self::Channel(channel) => channel.build(),
1573            Self::Message(msg) => msg.build(),
1574        }
1575    }
1576
1577    fn write_into(self, data: &mut [u8]) -> usize {
1578        match self {
1579            Self::Channel(channel) => channel.write_into(data),
1580            Self::Message(msg) => msg.write_into(data),
1581        }
1582    }
1583}
1584
1585fn transmit_send<T: AsRef<[u8]> + std::fmt::Debug>(
1586    transmit: &Transmit<T>,
1587) -> Transmit<Data<'static>> {
1588    Transmit::new(
1589        Data::from(transmit.data.as_ref()),
1590        transmit.transport,
1591        transmit.from,
1592        transmit.to,
1593    )
1594    .into_owned()
1595}
1596
1597#[cfg(test)]
1598mod tests {
1599    use stun_proto::types::{
1600        attribute::{MessageIntegrity, MessageIntegritySha256, XorMappedAddress},
1601        prelude::AttributeStaticType,
1602    };
1603
1604    use super::*;
1605    use turn_server_proto::{TurnServer, TurnServerPollRet};
1606
1607    #[test]
1608    fn test_turn_client_new_properties() {
1609        let _log = crate::tests::test_init_log();
1610
1611        let local_addr = "192.168.0.1:31234".parse().unwrap();
1612        let remote_addr = "10.0.0.1:3478".parse().unwrap();
1613        let credentials = TurnCredentials::new("tuser", "tpass");
1614
1615        let mut client =
1616            TurnClient::allocate(TransportType::Udp, local_addr, remote_addr, credentials);
1617        assert_eq!(client.transport(), TransportType::Udp);
1618        assert_eq!(client.local_addr(), local_addr);
1619        assert_eq!(client.remote_addr(), remote_addr);
1620
1621        let now = Instant::now();
1622        let TurnPollRet::WaitUntil(new_now) = client.poll(now) else {
1623            unreachable!();
1624        };
1625        assert_eq!(now, new_now);
1626        assert!(client.poll_event().is_none());
1627
1628        assert_eq!(client.relayed_addresses().count(), 0);
1629    }
1630
1631    fn transmit_send_build<T: DelayedTransmitBuild>(
1632        transmit: TransmitBuild<T>,
1633    ) -> Transmit<Data<'static>> {
1634        let data = transmit.data.build().into_boxed_slice();
1635        Transmit::new(
1636            Data::from(data),
1637            transmit.transport,
1638            transmit.from,
1639            transmit.to,
1640        )
1641        .into_owned()
1642    }
1643
1644    #[test]
1645    fn test_delayed_message() {
1646        let data = [5; 5];
1647        let peer_addr = "127.0.0.1:1".parse().unwrap();
1648        let transmit = DelayedMessageOrChannelSend::Message(DelayedMessageSend { data, peer_addr });
1649        let len = transmit.len();
1650        let out = transmit.build();
1651        assert_eq!(len, out.len());
1652        let msg = Message::from_bytes(&out).unwrap();
1653        let addr = msg.attribute::<XorPeerAddress>().unwrap();
1654        assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
1655        let out_data = msg.attribute::<AData>().unwrap();
1656        assert_eq!(out_data.data(), data.as_ref());
1657        let transmit = DelayedMessageOrChannelSend::Message(DelayedMessageSend { data, peer_addr });
1658        let mut out2 = vec![0; len];
1659        transmit.write_into(&mut out2);
1660        let msg = Message::from_bytes(&out2).unwrap();
1661        let addr = msg.attribute::<XorPeerAddress>().unwrap();
1662        assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
1663        let out_data = msg.attribute::<AData>().unwrap();
1664        assert_eq!(out_data.data(), data.as_ref());
1665    }
1666
1667    #[test]
1668    fn test_delayed_channel() {
1669        let data = [5; 5];
1670        let channel_id = 0x4567;
1671        let transmit =
1672            DelayedMessageOrChannelSend::Channel(DelayedChannelSend { data, channel_id });
1673        let len = transmit.len();
1674        let out = transmit.build();
1675        assert_eq!(len, out.len());
1676        let channel = ChannelData::parse(&out).unwrap();
1677        assert_eq!(channel.id(), channel_id);
1678        assert_eq!(channel.data(), data.as_ref());
1679        let transmit =
1680            DelayedMessageOrChannelSend::Channel(DelayedChannelSend { data, channel_id });
1681        let mut out2 = vec![0; len];
1682        transmit.write_into(&mut out2);
1683        assert_eq!(len, out.len());
1684        let channel = ChannelData::parse(&out).unwrap();
1685        assert_eq!(channel.id(), channel_id);
1686        assert_eq!(channel.data(), data.as_ref());
1687    }
1688
1689    struct TurnTestBuilder {
1690        turn_listen_addr: SocketAddr,
1691        credentials: TurnCredentials,
1692        realm: String,
1693        client_addr: SocketAddr,
1694        client_transport: TransportType,
1695        turn_alloc_addr: SocketAddr,
1696        peer_addr: SocketAddr,
1697    }
1698    impl TurnTestBuilder {
1699        fn build(self) -> TurnTest {
1700            let mut server =
1701                TurnServer::new(self.client_transport, self.turn_listen_addr, self.realm);
1702            server.add_user(
1703                self.credentials.username().to_owned(),
1704                self.credentials.password().to_owned(),
1705            );
1706            let client = TurnClient::allocate(
1707                self.client_transport,
1708                self.client_addr,
1709                self.turn_listen_addr,
1710                self.credentials,
1711            );
1712            TurnTest {
1713                client,
1714                server,
1715                turn_alloc_addr: self.turn_alloc_addr,
1716                peer_addr: self.peer_addr,
1717            }
1718        }
1719
1720        fn client_transport(mut self, transport: TransportType) -> Self {
1721            self.client_transport = transport;
1722            self
1723        }
1724    }
1725
1726    struct TurnTest {
1727        client: TurnClient,
1728        server: TurnServer,
1729        turn_alloc_addr: SocketAddr,
1730        peer_addr: SocketAddr,
1731    }
1732
1733    impl TurnTest {
1734        fn builder() -> TurnTestBuilder {
1735            let credentials = TurnCredentials::new("turnuser", "turnpass");
1736            TurnTestBuilder {
1737                turn_listen_addr: "127.0.0.1:3478".parse().unwrap(),
1738                credentials,
1739                realm: String::from("realm"),
1740                client_addr: "127.0.0.1:2000".parse().unwrap(),
1741                client_transport: TransportType::Udp,
1742                turn_alloc_addr: "10.0.0.20:2000".parse().unwrap(),
1743                peer_addr: "10.0.0.3:3000".parse().unwrap(),
1744            }
1745        }
1746
1747        fn allocate(&mut self, now: Instant) {
1748            // initial allocate
1749            let transmit = self.client.poll_transmit(now).unwrap();
1750            let msg = Message::from_bytes(&transmit.data).unwrap();
1751            assert!(msg.has_method(ALLOCATE));
1752            assert!(msg.has_class(stun_proto::types::message::MessageClass::Request));
1753            assert!(msg.has_attribute(RequestedTransport::TYPE));
1754            assert!(!msg.has_attribute(Realm::TYPE));
1755            assert!(!msg.has_attribute(Nonce::TYPE));
1756            assert!(!msg.has_attribute(Username::TYPE));
1757            assert!(!msg.has_attribute(MessageIntegrity::TYPE));
1758            assert!(!msg.has_attribute(MessageIntegritySha256::TYPE));
1759            // error reply
1760            let Ok(Some(transmit)) = self.server.recv(transmit, now) else {
1761                unreachable!();
1762            };
1763            let msg = Message::from_bytes(&transmit.data).unwrap();
1764            assert!(msg.has_method(ALLOCATE));
1765            assert!(msg.has_class(stun_proto::types::message::MessageClass::Error));
1766            assert!(msg.has_attribute(Realm::TYPE));
1767            let err = msg.attribute::<ErrorCode>().unwrap();
1768            assert_eq!(err.code(), ErrorCode::UNAUTHORIZED);
1769            assert!(msg.has_attribute(Nonce::TYPE));
1770            self.client.recv(transmit, now);
1771
1772            // authenticated allocate
1773            let transmit = self.client.poll_transmit(now).unwrap();
1774            let msg = Message::from_bytes(&transmit.data).unwrap();
1775            assert!(msg.has_method(ALLOCATE));
1776            assert!(msg.has_class(stun_proto::types::message::MessageClass::Request));
1777            assert!(msg.has_attribute(RequestedTransport::TYPE));
1778            assert!(msg.has_attribute(Realm::TYPE));
1779            assert!(msg.has_attribute(Nonce::TYPE));
1780            assert!(msg.has_attribute(Username::TYPE));
1781            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1782            let Ok(None) = self.server.recv(transmit, now) else {
1783                unreachable!();
1784            };
1785            let TurnServerPollRet::AllocateSocketUdp {
1786                transport,
1787                local_addr: alloc_local_addr,
1788                remote_addr: alloc_remote_addr,
1789            } = self.server.poll(now)
1790            else {
1791                unreachable!();
1792            };
1793            assert_eq!(transport, self.client.transport());
1794            assert_eq!(alloc_local_addr, self.server.listen_address());
1795            assert_eq!(alloc_remote_addr, self.client.local_addr());
1796            self.server.allocated_udp_socket(
1797                transport,
1798                alloc_local_addr,
1799                alloc_remote_addr,
1800                Ok(self.turn_alloc_addr),
1801                now,
1802            );
1803            // ok reply
1804            let Some(transmit) = self.server.poll_transmit(now) else {
1805                unreachable!();
1806            };
1807            let msg = Message::from_bytes(&transmit.data).unwrap();
1808            assert!(msg.has_method(ALLOCATE));
1809            assert!(msg.has_class(stun_proto::types::message::MessageClass::Success));
1810            assert!(msg.has_attribute(XorRelayedAddress::TYPE));
1811            assert!(msg.has_attribute(Lifetime::TYPE));
1812            assert!(msg.has_attribute(XorMappedAddress::TYPE));
1813            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1814            self.client.recv(transmit, now);
1815            assert!(self
1816                .client
1817                .relayed_addresses()
1818                .any(|(transport, relayed)| transport == TransportType::Udp
1819                    && relayed == self.turn_alloc_addr))
1820        }
1821
1822        fn refresh(&mut self, now: Instant) {
1823            let TurnPollRet::WaitUntil(expiry) = self.client.poll(now) else {
1824                unreachable!()
1825            };
1826            assert_eq!(now, expiry);
1827            let transmit = self.client.poll_transmit(now).unwrap();
1828            trace!("transmit {:?}", transmit.data);
1829            let msg = Message::from_bytes(&transmit.data).unwrap();
1830            assert!(msg.has_method(REFRESH));
1831            assert!(msg.has_class(stun_proto::types::message::MessageClass::Request));
1832            assert!(msg.has_attribute(Realm::TYPE));
1833            assert!(msg.has_attribute(Nonce::TYPE));
1834            assert!(msg.has_attribute(Username::TYPE));
1835            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1836            // ok reply
1837            let Ok(Some(transmit)) = self.server.recv(transmit, now) else {
1838                unreachable!();
1839            };
1840            let msg = Message::from_bytes(&transmit.data).unwrap();
1841            assert!(msg.has_method(REFRESH));
1842            assert!(msg.has_class(stun_proto::types::message::MessageClass::Success));
1843            assert!(msg.has_attribute(Lifetime::TYPE));
1844            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1845            self.client.recv(transmit, now);
1846            assert!(self
1847                .client
1848                .relayed_addresses()
1849                .any(|(transport, relayed)| transport == TransportType::Udp
1850                    && relayed == self.turn_alloc_addr))
1851        }
1852
1853        fn delete_allocation(&mut self, now: Instant) {
1854            let transmit = self.client.delete(now).unwrap();
1855            let msg = Message::from_bytes(&transmit.data).unwrap();
1856            assert!(msg.has_method(REFRESH));
1857            assert!(msg.has_class(stun_proto::types::message::MessageClass::Request));
1858            assert!(msg.has_attribute(Lifetime::TYPE));
1859            assert!(msg.has_attribute(Realm::TYPE));
1860            assert!(msg.has_attribute(Nonce::TYPE));
1861            assert!(msg.has_attribute(Username::TYPE));
1862            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1863            // ok reply
1864            let Ok(Some(transmit)) = self.server.recv(transmit, now) else {
1865                unreachable!();
1866            };
1867            let msg = Message::from_bytes(&transmit.data).unwrap();
1868            assert!(msg.has_method(REFRESH));
1869            assert!(msg.has_class(stun_proto::types::message::MessageClass::Success));
1870            assert!(msg.has_attribute(Lifetime::TYPE));
1871            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1872            self.client.recv(transmit, now);
1873            assert!(!self
1874                .client
1875                .relayed_addresses()
1876                .any(|(transport, relayed)| transport == TransportType::Udp
1877                    && relayed == self.turn_alloc_addr))
1878        }
1879
1880        fn create_permission(&mut self, now: Instant) {
1881            let transmit = self
1882                .client
1883                .create_permission(TransportType::Udp, self.peer_addr.ip(), now)
1884                .unwrap();
1885            let msg = Message::from_bytes(&transmit.data).unwrap();
1886            assert!(msg.has_method(CREATE_PERMISSION));
1887            assert!(msg.has_class(stun_proto::types::message::MessageClass::Request));
1888            assert!(msg.has_attribute(XorPeerAddress::TYPE));
1889            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1890            let Ok(Some(transmit)) = self.server.recv(transmit, now) else {
1891                unreachable!();
1892            };
1893            self.client.recv(transmit, now);
1894            self.validate_client_permission_state(now);
1895        }
1896
1897        fn validate_client_permission_state(&self, now: Instant) {
1898            let Some(permision) = self
1899                .client
1900                .permission(TransportType::Udp, self.peer_addr.ip())
1901            else {
1902                unreachable!();
1903            };
1904            assert_eq!(permision.expires_at, now + Duration::from_secs(300));
1905            assert!(self
1906                .client
1907                .permissions(TransportType::Udp, self.turn_alloc_addr)
1908                .any(|perm_addr| perm_addr == self.peer_addr.ip()));
1909        }
1910
1911        fn bind_channel(&mut self, now: Instant) {
1912            let transmit = self
1913                .client
1914                .bind_channel(TransportType::Udp, self.peer_addr, now)
1915                .unwrap();
1916            let msg = Message::from_bytes(&transmit.data).unwrap();
1917            assert!(msg.has_method(CHANNEL_BIND));
1918            assert!(msg.has_class(stun_proto::types::message::MessageClass::Request));
1919            assert!(msg.has_attribute(XorPeerAddress::TYPE));
1920            assert!(msg.has_attribute(MessageIntegrity::TYPE));
1921            let Ok(Some(transmit)) = self.server.recv(transmit, now) else {
1922                unreachable!();
1923            };
1924            self.client.recv(transmit, now);
1925            let Some(permision) = self
1926                .client
1927                .permission(TransportType::Udp, self.peer_addr.ip())
1928            else {
1929                unreachable!();
1930            };
1931            assert_eq!(permision.expires_at, now + Duration::from_secs(300));
1932            let Some(channel) = self.client.channel(TransportType::Udp, self.peer_addr) else {
1933                unreachable!();
1934            };
1935            assert_eq!(channel.expires_at, now + Duration::from_secs(600));
1936        }
1937
1938        fn sendrecv_data(&mut self, now: Instant) {
1939            // client to peer
1940            let data = [4; 8];
1941            let transmit = self
1942                .client
1943                .send_to(TransportType::Udp, self.peer_addr, data, now)
1944                .unwrap();
1945            assert!(matches!(
1946                transmit.data,
1947                DelayedMessageOrChannelSend::Message(_)
1948            ));
1949            let transmit = transmit_send_build(transmit);
1950            let Ok(Some(transmit)) = self.server.recv(transmit, now) else {
1951                unreachable!();
1952            };
1953            assert_eq!(transmit.transport, TransportType::Udp);
1954            assert_eq!(transmit.from, self.turn_alloc_addr);
1955            assert_eq!(transmit.to, self.peer_addr);
1956
1957            // peer to client
1958            let sent_data = [5; 12];
1959            let Some(transmit) = self
1960                .server
1961                .recv(
1962                    Transmit::new(
1963                        sent_data,
1964                        TransportType::Udp,
1965                        self.peer_addr,
1966                        self.turn_alloc_addr,
1967                    ),
1968                    now,
1969                )
1970                .unwrap()
1971            else {
1972                unreachable!();
1973            };
1974            assert_eq!(transmit.transport, self.client.transport());
1975            assert_eq!(transmit.from, self.server.listen_address());
1976            assert_eq!(transmit.to, self.client.local_addr());
1977            let msg = Message::from_bytes(&transmit.data).unwrap();
1978            assert!(msg.has_class(stun_proto::types::message::MessageClass::Indication));
1979            assert!(msg.has_method(DATA));
1980            let data = msg.attribute::<AData>().unwrap();
1981            assert_eq!(data.data(), sent_data);
1982            let TurnRecvRet::PeerData {
1983                data: recv_data,
1984                transport,
1985                peer,
1986            } = self.client.recv(transmit, now)
1987            else {
1988                unreachable!();
1989            };
1990            assert_eq!(transport, TransportType::Udp);
1991            assert_eq!(peer, self.peer_addr);
1992            assert_eq!(recv_data, sent_data);
1993        }
1994
1995        fn sendrecv_data_channel(&mut self, now: Instant) {
1996            let to_peer = [4; 8];
1997            let from_peer = [5; 12];
1998            self.sendrecv_data_channel_with_data(&to_peer, &from_peer, now);
1999        }
2000
2001        fn sendrecv_data_channel_with_data(
2002            &mut self,
2003            to_peer: &[u8],
2004            from_peer: &[u8],
2005            now: Instant,
2006        ) {
2007            let transmit = self
2008                .client
2009                .send_to(TransportType::Udp, self.peer_addr, to_peer, now)
2010                .unwrap();
2011            assert!(matches!(
2012                transmit.data,
2013                DelayedMessageOrChannelSend::Channel(_)
2014            ));
2015            let transmit = transmit_send_build(transmit);
2016            let Ok(Some(transmit)) = self.server.recv(transmit, now) else {
2017                unreachable!();
2018            };
2019            assert_eq!(transmit.transport, TransportType::Udp);
2020            assert_eq!(transmit.from, self.turn_alloc_addr);
2021            assert_eq!(transmit.to, self.peer_addr);
2022
2023            // peer to client
2024            let Some(transmit) = self
2025                .server
2026                .recv(
2027                    Transmit::new(
2028                        from_peer,
2029                        TransportType::Udp,
2030                        self.peer_addr,
2031                        self.turn_alloc_addr,
2032                    ),
2033                    now,
2034                )
2035                .unwrap()
2036            else {
2037                unreachable!();
2038            };
2039            assert_eq!(transmit.transport, TransportType::Udp);
2040            assert_eq!(transmit.from, self.server.listen_address());
2041            assert_eq!(transmit.to, self.client.local_addr());
2042            let cd = ChannelData::parse(&transmit.data).unwrap();
2043            assert_eq!(cd.data(), from_peer);
2044        }
2045    }
2046
2047    fn turn_allocate_permission(client_transport: TransportType) {
2048        let mut test = TurnTest::builder()
2049            .client_transport(client_transport)
2050            .build();
2051        let now = Instant::now();
2052
2053        test.allocate(now);
2054        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2055            test.client.poll_event()
2056        else {
2057            unreachable!();
2058        };
2059        assert_eq!(relayed_address, test.turn_alloc_addr);
2060        test.create_permission(now);
2061        let Some(TurnEvent::PermissionCreated(TransportType::Udp, permission_ip)) =
2062            test.client.poll_event()
2063        else {
2064            unreachable!();
2065        };
2066        assert_eq!(permission_ip, test.peer_addr.ip());
2067
2068        test.sendrecv_data(now);
2069    }
2070
2071    #[test]
2072    fn test_turn_udp_allocate_udp_permission() {
2073        let _log = crate::tests::test_init_log();
2074
2075        turn_allocate_permission(TransportType::Udp);
2076    }
2077
2078    #[test]
2079    fn test_turn_tcp_allocate_udp_permission() {
2080        let _log = crate::tests::test_init_log();
2081
2082        turn_allocate_permission(TransportType::Tcp);
2083    }
2084
2085    #[test]
2086    fn test_turn_allocate_expire_server() {
2087        let _log = crate::tests::test_init_log();
2088
2089        let mut test = TurnTest::builder().build();
2090        let now = Instant::now();
2091
2092        test.allocate(now);
2093        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2094            test.client.poll_event()
2095        else {
2096            unreachable!();
2097        };
2098        assert_eq!(relayed_address, test.turn_alloc_addr);
2099        let transmit = test
2100            .client
2101            .create_permission(TransportType::Udp, test.peer_addr.ip(), now)
2102            .unwrap();
2103        let now = now + Duration::from_secs(3000);
2104        let Ok(Some(transmit)) = test.server.recv(transmit, now) else {
2105            unreachable!();
2106        };
2107        let msg = Message::from_bytes(&transmit.data).unwrap();
2108        assert!(msg.has_method(CREATE_PERMISSION));
2109        assert!(msg.has_class(stun_proto::types::message::MessageClass::Error));
2110        let err = msg.attribute::<ErrorCode>().unwrap();
2111        assert_eq!(err.code(), ErrorCode::ALLOCATION_MISMATCH);
2112        test.client.recv(transmit, now);
2113    }
2114
2115    #[test]
2116    fn test_turn_allocate_expire_client() {
2117        let _log = crate::tests::test_init_log();
2118
2119        let mut test = TurnTest::builder().build();
2120        let now = Instant::now();
2121
2122        test.allocate(now);
2123        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2124            test.client.poll_event()
2125        else {
2126            unreachable!();
2127        };
2128        assert_eq!(relayed_address, test.turn_alloc_addr);
2129        let now = now + Duration::from_secs(3000);
2130        let Err(CreatePermissionError::NoAllocation) =
2131            test.client
2132                .create_permission(TransportType::Udp, test.peer_addr.ip(), now)
2133        else {
2134            unreachable!();
2135        };
2136    }
2137
2138    #[test]
2139    fn test_turn_allocate_refresh() {
2140        let _log = crate::tests::test_init_log();
2141
2142        let mut test = TurnTest::builder().build();
2143        let now = Instant::now();
2144
2145        test.allocate(now);
2146        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2147            test.client.poll_event()
2148        else {
2149            unreachable!();
2150        };
2151        assert_eq!(relayed_address, test.turn_alloc_addr);
2152
2153        let TurnPollRet::WaitUntil(expiry) = test.client.poll(now) else {
2154            unreachable!()
2155        };
2156        trace!("expiry: {expiry:?}");
2157        assert!(expiry > now + Duration::from_secs(1000));
2158
2159        test.refresh(expiry);
2160        test.create_permission(expiry);
2161        let Some(TurnEvent::PermissionCreated(TransportType::Udp, permission_ip)) =
2162            test.client.poll_event()
2163        else {
2164            unreachable!();
2165        };
2166        assert_eq!(permission_ip, test.peer_addr.ip());
2167        test.sendrecv_data(expiry);
2168    }
2169
2170    #[test]
2171    fn test_turn_allocate_delete() {
2172        let _log = crate::tests::test_init_log();
2173
2174        let mut test = TurnTest::builder().build();
2175        let now = Instant::now();
2176
2177        test.allocate(now);
2178        test.delete_allocation(now);
2179
2180        let Err(CreatePermissionError::NoAllocation) =
2181            test.client
2182                .create_permission(TransportType::Udp, test.peer_addr.ip(), now)
2183        else {
2184            unreachable!();
2185        };
2186    }
2187
2188    #[test]
2189    fn test_turn_channel_bind() {
2190        let _log = crate::tests::test_init_log();
2191
2192        let mut test = TurnTest::builder().build();
2193        let now = Instant::now();
2194
2195        test.allocate(now);
2196        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2197            test.client.poll_event()
2198        else {
2199            unreachable!();
2200        };
2201        assert_eq!(relayed_address, test.turn_alloc_addr);
2202        test.bind_channel(now);
2203        let Some(TurnEvent::PermissionCreated(TransportType::Udp, permission_ip)) =
2204            test.client.poll_event()
2205        else {
2206            unreachable!();
2207        };
2208        assert_eq!(permission_ip, test.peer_addr.ip());
2209        test.sendrecv_data_channel(now);
2210    }
2211
2212    #[test]
2213    fn test_turn_peer_incoming_stun() {
2214        // tests that sending stun messages can be passed through the turn server
2215        let _log = crate::tests::test_init_log();
2216
2217        let mut test = TurnTest::builder().build();
2218        let now = Instant::now();
2219
2220        test.allocate(now);
2221        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2222            test.client.poll_event()
2223        else {
2224            unreachable!();
2225        };
2226        assert_eq!(relayed_address, test.turn_alloc_addr);
2227        test.bind_channel(now);
2228        let Some(TurnEvent::PermissionCreated(TransportType::Udp, permission_ip)) =
2229            test.client.poll_event()
2230        else {
2231            unreachable!();
2232        };
2233        assert_eq!(permission_ip, test.peer_addr.ip());
2234
2235        let mut msg = Message::builder(
2236            MessageType::from_class_method(MessageClass::Indication, 0x1432),
2237            TransactionId::generate(),
2238        );
2239        let realm = Realm::new("realm").unwrap();
2240        msg.add_attribute(&realm).unwrap();
2241        let data = msg.build();
2242        test.sendrecv_data_channel_with_data(&data, &data, now);
2243    }
2244
2245    #[test]
2246    fn test_turn_create_permission_refresh() {
2247        let _log = crate::tests::test_init_log();
2248
2249        let mut test = TurnTest::builder().build();
2250        let now = Instant::now();
2251
2252        test.allocate(now);
2253        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2254            test.client.poll_event()
2255        else {
2256            unreachable!();
2257        };
2258        assert_eq!(relayed_address, test.turn_alloc_addr);
2259
2260        test.create_permission(now);
2261        let Some(TurnEvent::PermissionCreated(TransportType::Udp, permission_ip)) =
2262            test.client.poll_event()
2263        else {
2264            unreachable!();
2265        };
2266        assert_eq!(permission_ip, test.peer_addr.ip());
2267
2268        let TurnPollRet::WaitUntil(expiry) = test.client.poll(now) else {
2269            unreachable!()
2270        };
2271        assert_eq!(expiry, now + Duration::from_secs(240));
2272        let TurnPollRet::WaitUntil(now) = test.client.poll(expiry) else {
2273            unreachable!()
2274        };
2275        assert_eq!(now, expiry);
2276
2277        let transmit = test.client.poll_transmit(expiry).unwrap();
2278        let msg = Message::from_bytes(&transmit.data).unwrap();
2279        assert_eq!(msg.method(), CREATE_PERMISSION);
2280        let Ok(Some(transmit)) = test.server.recv(transmit, now) else {
2281            unreachable!();
2282        };
2283        test.client.recv(transmit, expiry);
2284        test.validate_client_permission_state(expiry);
2285
2286        test.sendrecv_data(expiry);
2287    }
2288
2289    #[test]
2290    fn test_turn_create_permission_timeout() {
2291        let _log = crate::tests::test_init_log();
2292
2293        let mut test = TurnTest::builder().build();
2294        let now = Instant::now();
2295
2296        test.allocate(now);
2297        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2298            test.client.poll_event()
2299        else {
2300            unreachable!();
2301        };
2302        assert_eq!(relayed_address, test.turn_alloc_addr);
2303
2304        test.create_permission(now);
2305        let Some(TurnEvent::PermissionCreated(TransportType::Udp, permission_ip)) =
2306            test.client.poll_event()
2307        else {
2308            unreachable!();
2309        };
2310        assert_eq!(permission_ip, test.peer_addr.ip());
2311
2312        let TurnPollRet::WaitUntil(expiry) = test.client.poll(now) else {
2313            unreachable!()
2314        };
2315        assert_eq!(expiry, now + Duration::from_secs(240));
2316        let TurnPollRet::WaitUntil(now) = test.client.poll(expiry) else {
2317            unreachable!()
2318        };
2319        assert_eq!(now, expiry);
2320
2321        let transmit = test.client.poll_transmit(expiry).unwrap();
2322        let msg = Message::from_bytes(&transmit.data).unwrap();
2323        assert_eq!(msg.method(), CREATE_PERMISSION);
2324        // drop the create permission refresh (and retransmits)
2325        let mut expiry = now;
2326        for _i in 0..8 {
2327            let TurnPollRet::WaitUntil(new_now) = test.client.poll(expiry) else {
2328                unreachable!()
2329            };
2330            let _ = test.client.poll_transmit(new_now);
2331            expiry = new_now;
2332        }
2333        assert_eq!(expiry, now + Duration::from_secs(60));
2334        let TurnPollRet::WaitUntil(now) = test.client.poll(expiry) else {
2335            unreachable!()
2336        };
2337
2338        assert!(!test
2339            .client
2340            .have_permission(TransportType::Udp, test.peer_addr, now));
2341        let Some(TurnEvent::PermissionCreateFailed(_transport, ip)) = test.client.poll_event()
2342        else {
2343            unreachable!();
2344        };
2345        assert_eq!(ip, test.peer_addr.ip());
2346    }
2347
2348    #[test]
2349    fn test_turn_channel_bind_refresh() {
2350        let _log = crate::tests::test_init_log();
2351
2352        let mut test = TurnTest::builder().build();
2353        let now = Instant::now();
2354
2355        test.allocate(now);
2356        let Some(TurnEvent::AllocationCreated(TransportType::Udp, relayed_address)) =
2357            test.client.poll_event()
2358        else {
2359            unreachable!();
2360        };
2361        assert_eq!(relayed_address, test.turn_alloc_addr);
2362
2363        test.bind_channel(now);
2364        let Some(TurnEvent::PermissionCreated(TransportType::Udp, permission_ip)) =
2365            test.client.poll_event()
2366        else {
2367            unreachable!();
2368        };
2369        assert_eq!(permission_ip, test.peer_addr.ip());
2370
2371        // two permission refreshes
2372        let mut permissions_done = now;
2373        for _i in 0..2 {
2374            let now = permissions_done;
2375            let TurnPollRet::WaitUntil(expiry) = test.client.poll(now) else {
2376                unreachable!()
2377            };
2378            assert_eq!(expiry, now + Duration::from_secs(240));
2379            let TurnPollRet::WaitUntil(now) = test.client.poll(expiry) else {
2380                unreachable!()
2381            };
2382            assert_eq!(now, expiry);
2383
2384            let transmit = test.client.poll_transmit(now).unwrap();
2385            let msg = Message::from_bytes(&transmit.data).unwrap();
2386            assert_eq!(msg.method(), CREATE_PERMISSION);
2387            let Ok(Some(transmit)) = test.server.recv(transmit, now) else {
2388                unreachable!();
2389            };
2390            test.client.recv(transmit, now);
2391            test.validate_client_permission_state(now);
2392            permissions_done = now;
2393        }
2394        let now = permissions_done;
2395
2396        let TurnPollRet::WaitUntil(expiry) = test.client.poll(now) else {
2397            unreachable!()
2398        };
2399        assert_eq!(expiry, now + Duration::from_secs(60));
2400        let TurnPollRet::WaitUntil(now) = test.client.poll(expiry) else {
2401            unreachable!()
2402        };
2403        assert_eq!(now, expiry);
2404        let transmit = test.client.poll_transmit(expiry).unwrap();
2405        let msg = Message::from_bytes(&transmit.data).unwrap();
2406        println!("message {msg}");
2407        assert_eq!(msg.method(), CHANNEL_BIND);
2408        let Ok(Some(transmit)) = test.server.recv(transmit, now) else {
2409            unreachable!();
2410        };
2411        test.client.recv(transmit, expiry);
2412
2413        test.sendrecv_data_channel(expiry);
2414    }
2415}