turn_server_proto/
server.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::collections::{HashMap, VecDeque};
10use std::net::{IpAddr, SocketAddr};
11use std::time::{Duration, Instant};
12
13use stun_proto::agent::{StunAgent, StunError, Transmit, TransmitBuild};
14use stun_proto::prelude::*;
15use stun_proto::types::attribute::{
16    ErrorCode, Fingerprint, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress,
17};
18use stun_proto::types::data::Data;
19use stun_proto::types::message::{
20    LongTermCredentials, Message, MessageBuilder, MessageClass, MessageIntegrityCredentials,
21    MessageType, TransactionId, BINDING,
22};
23use stun_proto::types::prelude::{Attribute, AttributeFromRaw, AttributeStaticType};
24use stun_proto::types::TransportType;
25use turn_types::channel::ChannelData;
26
27use turn_types::message::CREATE_PERMISSION;
28
29use turn_types::attribute::Data as AData;
30use turn_types::attribute::{
31    ChannelNumber, Lifetime, RequestedTransport, XorPeerAddress, XorRelayedAddress,
32};
33use turn_types::message::{ALLOCATE, CHANNEL_BIND, DATA, REFRESH, SEND};
34use turn_types::TurnCredentials;
35
36use tracing::{debug, error, info, trace, warn};
37
38/// A TURN server.
39#[derive(Debug)]
40pub struct TurnServer {
41    realm: String,
42    // FIXME: remove
43    stun: StunAgent,
44
45    clients: Vec<Client>,
46    nonces: Vec<NonceData>,
47    pending_transmits: VecDeque<Transmit<Data<'static>>>,
48    pending_allocates: VecDeque<PendingClient>,
49
50    // username -> password mapping.
51    users: HashMap<String, String>,
52}
53
54#[derive(Debug)]
55struct PendingClient {
56    client: Client,
57    asked: bool,
58    transaction_id: TransactionId,
59}
60
61#[derive(Debug)]
62struct NonceData {
63    nonce: String,
64    expires_at: Instant,
65
66    transport: TransportType,
67    remote_addr: SocketAddr,
68    local_addr: SocketAddr,
69}
70
71/// Return value for [poll](TurnServer::poll).
72#[derive(Debug)]
73pub enum TurnServerPollRet {
74    /// Wait until the specified time before calling poll() again.
75    WaitUntil(Instant),
76    /// Allocate a UDP socket for a client specified by the client's network 5-tuple.
77    AllocateSocketUdp {
78        /// The transport of the client asking for an allocation.
79        transport: TransportType,
80        /// The TURN server address of the client asking for an allocation.
81        local_addr: SocketAddr,
82        /// The client local address of the client asking for an allocation.
83        remote_addr: SocketAddr,
84    },
85}
86
87impl TurnServer {
88    /// Construct a new [`TurnServer`]
89    ///
90    /// # Examples
91    /// ```
92    /// # use turn_server_proto::TurnServer;
93    /// # use stun_proto::types::TransportType;
94    /// let realm = String::from("realm");
95    /// let listen_addr = "10.0.0.1:3478".parse().unwrap();
96    /// let server = TurnServer::new(TransportType::Udp, listen_addr, realm);
97    /// assert_eq!(server.listen_address(), listen_addr);
98    /// ```
99    pub fn new(ttype: TransportType, listen_addr: SocketAddr, realm: String) -> Self {
100        let stun = StunAgent::builder(ttype, listen_addr).build();
101        Self {
102            realm,
103            stun,
104            clients: vec![],
105            nonces: vec![],
106            pending_transmits: VecDeque::default(),
107            pending_allocates: VecDeque::default(),
108            users: HashMap::default(),
109        }
110    }
111
112    /// Add a user credentials that would be accepted by this [`TurnServer`].
113    pub fn add_user(&mut self, username: String, password: String) {
114        self.users.insert(username, password);
115    }
116
117    /// The address that the [`TurnServer`] is listening on for incoming client connections.
118    pub fn listen_address(&self) -> SocketAddr {
119        self.stun.local_addr()
120    }
121
122    /// Provide received data to the [`TurnServer`].
123    ///
124    /// Any returned Transmit should be forwarded to the appropriate socket.
125    #[tracing::instrument(
126        name = "turn_server_recv",
127        skip(self, transmit),
128        fields(
129            transport = %transmit.transport,
130            remote_addr = %transmit.from,
131            local_addr = %transmit.to,
132            data_len = transmit.data.as_ref().len(),
133        )
134        err,
135        ret,
136    )]
137    pub fn recv<T: AsRef<[u8]>>(
138        &mut self,
139        transmit: Transmit<T>,
140        now: Instant,
141    ) -> Result<Option<Transmit<Data<'static>>>, StunError> {
142        if let Some((client, allocation)) =
143            self.allocation_from_public_5tuple(transmit.transport, transmit.to, transmit.from)
144        {
145            // A packet from the relayed address needs to be sent to the client that set up
146            // the allocation.
147            let Some(_permission) =
148                allocation.permissions_from_5tuple(transmit.transport, transmit.to, transmit.from)
149            else {
150                warn!(
151                    "no permission for {:?} for this allocation {:?}",
152                    transmit.from, allocation.addr
153                );
154                return Ok(None);
155            };
156
157            if let Some(existing) =
158                allocation.channel_from_5tuple(transmit.transport, transmit.to, transmit.from)
159            {
160                debug!(
161                    "found existing channel {} for {:?} for this allocation {:?}",
162                    existing.id, transmit.from, allocation.addr
163                );
164                let mut data = vec![0; 4];
165                data[0..2].copy_from_slice(&existing.id.to_be_bytes());
166                data[2..4].copy_from_slice(&(transmit.data.as_ref().len() as u16).to_be_bytes());
167                // XXX: try to avoid copy?
168                data.extend_from_slice(transmit.data.as_ref());
169                Ok(Some(Transmit::new(
170                    data.into_boxed_slice().into(),
171                    client.transport,
172                    client.local_addr,
173                    client.remote_addr,
174                )))
175            } else {
176                // no channel with that id
177                debug!(
178                    "no channel for {:?} for this allocation {:?}, using DATA indication",
179                    transmit.from, allocation.addr
180                );
181                let transaction_id = TransactionId::generate();
182                let mut builder = Message::builder(
183                    MessageType::from_class_method(MessageClass::Indication, DATA),
184                    transaction_id,
185                );
186                let peer_address = XorPeerAddress::new(transmit.from, transaction_id);
187                builder.add_attribute(&peer_address).unwrap();
188                let data = AData::new(transmit.data.as_ref());
189                builder.add_attribute(&data).unwrap();
190                // XXX: try to avoid copy?
191                let msg_data = builder.build();
192
193                Ok(Some(Transmit::new(
194                    msg_data.into_boxed_slice().into(),
195                    client.transport,
196                    client.local_addr,
197                    client.remote_addr,
198                )))
199            }
200        } else {
201            // TODO: TCP buffering requirements
202            match Message::from_bytes(transmit.data.as_ref()) {
203                Ok(msg) => {
204                    trace!("received {} from {:?}", msg, transmit.from);
205                    match self.handle_stun(
206                        &msg,
207                        transmit.transport,
208                        transmit.from,
209                        transmit.to,
210                        now,
211                    ) {
212                        Err(builder) => {
213                            let data = builder.build();
214                            return Ok(Some(Transmit::new(
215                                data.into_boxed_slice().into(),
216                                transmit.transport,
217                                transmit.to,
218                                transmit.from,
219                            )));
220                        }
221                        Ok(Some(transmit)) => Ok(Some(transmit.into_owned())),
222                        Ok(None) => Ok(None),
223                    }
224                }
225                Err(_) => {
226                    if let Some(client) =
227                        self.client_from_5tuple(transmit.transport, transmit.to, transmit.from)
228                    {
229                        trace!(
230                            "received {} bytes from {:?}",
231                            transmit.data.as_ref().len(),
232                            transmit.from
233                        );
234                        let Ok(channel) = ChannelData::parse(transmit.data.as_ref()) else {
235                            return Ok(None);
236                        };
237                        trace!(
238                            "parsed channel data with id {} and data length {}",
239                            channel.id(),
240                            channel.data().len()
241                        );
242                        let Some((allocation, existing)) =
243                            client.allocations.iter().find_map(|allocation| {
244                                allocation
245                                    .channel_from_id(channel.id())
246                                    .map(|perm| (allocation, perm))
247                            })
248                        else {
249                            warn!(
250                                "no channel id {} for this client {:?}",
251                                channel.id(),
252                                client.remote_addr
253                            );
254                            // no channel with that id
255                            return Ok(None);
256                        };
257
258                        // A packet from the client needs to be sent to the peer referenced by the
259                        // configured channel.
260                        let Some(_permission) = allocation.permissions_from_5tuple(
261                            transmit.transport,
262                            allocation.addr,
263                            existing.peer_addr,
264                        ) else {
265                            warn!(
266                                "no permission for {:?} for this allocation {:?}",
267                                existing.peer_addr, allocation.addr
268                            );
269                            return Ok(None);
270                        };
271                        Ok(Some(
272                            Transmit::new(
273                                Data::from(channel.data()),
274                                allocation.ttype,
275                                allocation.addr,
276                                existing.peer_addr,
277                            )
278                            .into_owned(),
279                        ))
280                    } else {
281                        trace!(
282                            "No handler for {} bytes over {:?} from {:?}, to {:?}. Ignoring",
283                            transmit.data.as_ref().len(),
284                            transmit.transport,
285                            transmit.from,
286                            transmit.to
287                        );
288                        Ok(None)
289                    }
290                }
291            }
292        }
293    }
294
295    /// Poll the [`TurnServer`] in order to make further progress.
296    ///
297    /// The returned value indicates what the caller should do.
298    #[tracing::instrument(name = "turn_server_poll", skip(self), ret)]
299    pub fn poll(&mut self, now: Instant) -> TurnServerPollRet {
300        for pending in self.pending_allocates.iter_mut() {
301            if pending.asked {
302                continue;
303            }
304
305            // TODO: TCP
306            return TurnServerPollRet::AllocateSocketUdp {
307                transport: pending.client.transport,
308                local_addr: pending.client.local_addr,
309                remote_addr: pending.client.remote_addr,
310            };
311        }
312
313        for client in self.clients.iter_mut() {
314            client.allocations.retain_mut(|allocation| {
315                if allocation.expires_at <= now {
316                    allocation
317                        .permissions
318                        .retain_mut(|permission| permission.expires_at <= now);
319                    allocation
320                        .channels
321                        .retain_mut(|channel| channel.expires_at <= now);
322                    true
323                } else {
324                    false
325                }
326            });
327        }
328
329        TurnServerPollRet::WaitUntil(now + Duration::from_secs(60))
330    }
331
332    /// Poll for a new Transmit to send over a socket.
333    #[tracing::instrument(name = "turn_server_poll_transmit", skip(self))]
334    pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
335        if let Some(transmit) = self.pending_transmits.pop_back() {
336            return Some(transmit);
337        }
338        None
339    }
340
341    /// Notify the [`TurnServer`] that a UDP socket has been allocated (or an error) in response to
342    /// [TurnServerPollRet::AllocateSocketUdp].
343    #[tracing::instrument(name = "turn_server_allocated_udp_socket", skip(self))]
344    pub fn allocated_udp_socket(
345        &mut self,
346        transport: TransportType,
347        local_addr: SocketAddr,
348        remote_addr: SocketAddr,
349        socket_addr: Result<SocketAddr, ()>,
350        now: Instant,
351    ) {
352        let Some(position) = self.pending_allocates.iter().position(|pending| {
353            pending.client.transport == transport
354                && pending.client.local_addr == local_addr
355                && pending.client.remote_addr == remote_addr
356        }) else {
357            warn!("No pending allocation for transport: Udp, local: {local_addr:?}, remote {remote_addr:?}");
358            return;
359        };
360        info!("pending allocation for transport: Udp, local: {local_addr:?}, remote {remote_addr:?} resulted in {socket_addr:?}");
361        let mut pending = self.pending_allocates.remove(position).unwrap();
362        let transaction_id = pending.transaction_id;
363        let to = pending.client.remote_addr;
364
365        let mut builder = if let Ok(socket_addr) = socket_addr {
366            pending.client.allocations.push(Allocation {
367                addr: socket_addr,
368                ttype: TransportType::Udp,
369                expires_at: now + Duration::from_secs(1800),
370                permissions: vec![],
371                channels: vec![],
372            });
373
374            let mut builder = Message::builder(
375                MessageType::from_class_method(MessageClass::Success, ALLOCATE),
376                transaction_id,
377            );
378            let relayed_address = XorRelayedAddress::new(socket_addr, transaction_id);
379            builder.add_attribute(&relayed_address).unwrap();
380            let lifetime = Lifetime::new(1800);
381            builder.add_attribute(&lifetime).unwrap();
382            // TODO RESERVATION-TOKEN
383            let mapped_address = XorMappedAddress::new(pending.client.remote_addr, transaction_id);
384            builder.add_attribute(&mapped_address).unwrap();
385
386            builder.into_owned()
387        } else {
388            let mut builder = Message::builder(
389                MessageType::from_class_method(MessageClass::Error, ALLOCATE),
390                transaction_id,
391            );
392            let error = ErrorCode::builder(ErrorCode::INSUFFICIENT_CAPACITY)
393                .build()
394                .unwrap();
395            builder.add_attribute(&error).unwrap();
396            builder.into_owned()
397        };
398        builder
399            .add_message_integrity(
400                &MessageIntegrityCredentials::LongTerm(pending.client.credentials.clone()),
401                stun_proto::types::message::IntegrityAlgorithm::Sha1,
402            )
403            .unwrap();
404
405        let Ok(transmit) = self.stun.send(builder, to, now) else {
406            return;
407        };
408        if socket_addr.is_ok() {
409            self.clients.push(pending.client);
410        }
411        self.pending_transmits
412            .push_back(transmit_send_build(transmit));
413    }
414
415    fn validate_stun<'a>(
416        &mut self,
417        msg: &Message<'_>,
418        ttype: TransportType,
419        from: SocketAddr,
420        to: SocketAddr,
421        now: Instant,
422    ) -> Result<LongTermCredentials, MessageBuilder<'a>> {
423        let integrity = msg.attribute::<MessageIntegrity>().ok();
424        // TODO: check for SHA256 integrity
425        if integrity.is_none() {
426            //   o  If the message does not contain a MESSAGE-INTEGRITY attribute, the
427            //      server MUST generate an error response with an error code of 401
428            //      (Unauthorized).  This response MUST include a REALM value.  It is
429            //      RECOMMENDED that the REALM value be the domain name of the
430            //      provider of the STUN server.  The response MUST include a NONCE,
431            //      selected by the server.  The response SHOULD NOT contain a
432            //      USERNAME or MESSAGE-INTEGRITY attribute.
433            let nonce = if let Some(nonce) = self.nonce_from_5tuple(ttype, to, from) {
434                nonce
435            } else {
436                self.nonces.push(NonceData {
437                    transport: ttype,
438                    remote_addr: from,
439                    local_addr: to,
440                    // FIXME: use an actual random source.
441                    nonce: String::from("random"),
442                    expires_at: now + Duration::from_secs(3600),
443                });
444                self.nonces.last().unwrap()
445            };
446            trace!(
447                "no message-integrity, returning unauthorized with nonce: {}",
448                nonce.nonce
449            );
450            let mut builder = Message::builder_error(msg);
451            let nonce = Nonce::new(&nonce.nonce).unwrap();
452            builder.add_attribute(&nonce).unwrap();
453            let realm = Realm::new(&self.realm).unwrap();
454            builder.add_attribute(&realm).unwrap();
455            let error = ErrorCode::builder(ErrorCode::UNAUTHORIZED).build().unwrap();
456            builder.add_attribute(&error).unwrap();
457            return Err(builder.into_owned());
458        }
459
460        //  o  If the message contains a MESSAGE-INTEGRITY attribute, but is
461        //      missing the USERNAME, REALM, or NONCE attribute, the server MUST
462        //      generate an error response with an error code of 400 (Bad
463        //      Request).  This response SHOULD NOT include a USERNAME, NONCE,
464        //      REALM, or MESSAGE-INTEGRITY attribute.
465        let username = msg.attribute::<Username>().ok();
466        let realm = msg.attribute::<Realm>().ok();
467        let nonce = msg.attribute::<Nonce>().ok();
468        let Some(((username, _realm), nonce)) = username.zip(realm).zip(nonce) else {
469            trace!("bad request due to missing username, realm, nonce");
470            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
471            let mut builder = Message::builder_error(msg);
472            builder.add_attribute(&error).unwrap();
473            return Err(builder.into_owned());
474        };
475
476        //   o  If the NONCE is no longer valid, the server MUST generate an error
477        //      response with an error code of 438 (Stale Nonce).  This response
478        //      MUST include NONCE and REALM attributes and SHOULD NOT include the
479        //      USERNAME or MESSAGE-INTEGRITY attribute.  Servers can invalidate
480        //      nonces in order to provide additional security.  See Section 4.3
481        //      of [RFC2617] for guidelines.
482        let nonce_data = self.mut_nonce_from_5tuple(ttype, to, from);
483        let mut stale_nonce = false;
484        let nonce_value = if let Some(nonce_data) = nonce_data {
485            if nonce_data.expires_at < now {
486                nonce_data.nonce = String::from("random");
487                nonce_data.expires_at = now + Duration::from_secs(3600);
488                stale_nonce = true;
489            } else if nonce_data.nonce != nonce.nonce() {
490                stale_nonce = true;
491            }
492            nonce_data.nonce.clone()
493        } else {
494            let nonce_value = String::from("randome");
495            self.nonces.push(NonceData {
496                transport: ttype,
497                remote_addr: from,
498                local_addr: to,
499                // FIXME: use an actual random source.
500                nonce: nonce_value.clone(),
501                expires_at: now + Duration::from_secs(3600),
502            });
503            stale_nonce = true;
504            nonce_value
505        };
506
507        if stale_nonce {
508            let mut builder = Message::builder_error(msg);
509            let error = ErrorCode::builder(ErrorCode::STALE_NONCE).build().unwrap();
510            builder.add_attribute(&error).unwrap();
511            let realm = Realm::new(&self.realm).unwrap();
512            builder.add_attribute(&realm).unwrap();
513            let nonce = Nonce::new(&nonce_value).unwrap();
514            builder.add_attribute(&nonce).unwrap();
515
516            return Err(builder.into_owned());
517        };
518
519        //   o  Using the password associated with the username in the USERNAME
520        //      attribute, compute the value for the message integrity as
521        //      described in Section 15.4.  If the resulting value does not match
522        //      the contents of the MESSAGE-INTEGRITY attribute, the server MUST
523        //      reject the request with an error response.  This response MUST use
524        //      an error code of 401 (Unauthorized).  It MUST include REALM and
525        //      NONCE attributes and SHOULD NOT include the USERNAME or MESSAGE-
526        //      INTEGRITY attribute.
527        let password = self.users.get(username.username());
528        let credentials = TurnCredentials::new(
529            username.username(),
530            password.map_or("", |pass| pass.as_str()),
531        )
532        .into_long_term_credentials(&self.realm);
533        if password.map_or(true, |_password| {
534            msg.validate_integrity(&MessageIntegrityCredentials::LongTerm(credentials.clone()))
535                .is_err()
536        }) {
537            let mut builder = Message::builder_error(msg);
538            let error = ErrorCode::builder(ErrorCode::UNAUTHORIZED).build().unwrap();
539            builder.add_attribute(&error).unwrap();
540            let realm = Realm::new(&self.realm).unwrap();
541            builder.add_attribute(&realm).unwrap();
542            let nonce = Nonce::new(&nonce_value).unwrap();
543            builder.add_attribute(&nonce).unwrap();
544            return Err(builder.into_owned());
545        }
546
547        if let Some(client) = self.client_from_5tuple(ttype, to, from) {
548            if client.credentials.username() != username.username() {
549                let mut builder = Message::builder_error(msg);
550                let error = ErrorCode::builder(ErrorCode::WRONG_CREDENTIALS)
551                    .build()
552                    .unwrap();
553                builder.add_attribute(&error).unwrap();
554                builder
555                    .add_message_integrity(
556                        &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
557                        stun_proto::types::message::IntegrityAlgorithm::Sha1,
558                    )
559                    .unwrap();
560                return Err(builder.into_owned());
561            }
562        }
563
564        Ok(credentials)
565    }
566
567    fn handle_stun_binding<'a>(
568        &mut self,
569        msg: &Message<'_>,
570        _ttype: TransportType,
571        from: SocketAddr,
572        to: SocketAddr,
573        now: Instant,
574    ) -> Result<Transmit<Data<'a>>, MessageBuilder<'a>> {
575        let response = if let Some(error_msg) =
576            Message::check_attribute_types(msg, &[Fingerprint::TYPE], &[])
577        {
578            error_msg
579        } else {
580            let mut response = Message::builder_success(msg);
581            let xor_addr = XorMappedAddress::new(from, msg.transaction_id());
582            response.add_attribute(&xor_addr).unwrap();
583            response.add_fingerprint().unwrap();
584            response.into_owned()
585        };
586
587        let Ok(transmit) = self.stun.send(response, to, now) else {
588            error!("Failed to send");
589            let mut response = Message::builder_error(msg);
590            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
591            response.add_attribute(&error).unwrap();
592            response.add_fingerprint().unwrap();
593            return Err(response.into_owned());
594        };
595
596        Ok(transmit_send_build(transmit))
597    }
598
599    fn handle_stun_allocate<'a>(
600        &mut self,
601        msg: &Message<'_>,
602        ttype: TransportType,
603        from: SocketAddr,
604        to: SocketAddr,
605        now: Instant,
606    ) -> Result<(), MessageBuilder<'a>> {
607        let credentials = self.validate_stun(msg, ttype, from, to, now)?;
608
609        if let Some(_client) = self.mut_client_from_5tuple(ttype, to, from) {
610            let mut builder = Message::builder_error(msg);
611            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
612                .build()
613                .unwrap();
614            builder.add_attribute(&error).unwrap();
615            return Err(builder.into_owned());
616        };
617
618        let Ok(requested_transport) = msg.attribute::<RequestedTransport>() else {
619            let mut builder = Message::builder_error(msg);
620            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
621            builder.add_attribute(&error).unwrap();
622            builder
623                .add_message_integrity(
624                    &MessageIntegrityCredentials::LongTerm(credentials),
625                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
626                )
627                .unwrap();
628            return Err(builder.into_owned());
629        };
630
631        if requested_transport.protocol() != RequestedTransport::UDP {
632            let mut builder = Message::builder_error(msg);
633            let error = ErrorCode::builder(ErrorCode::UNSUPPORTED_TRANSPORT_PROTOCOL)
634                .build()
635                .unwrap();
636            builder.add_attribute(&error).unwrap();
637            builder
638                .add_message_integrity(
639                    &MessageIntegrityCredentials::LongTerm(credentials),
640                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
641                )
642                .unwrap();
643            return Err(builder.into_owned());
644        }
645
646        // TODO: DONT-FRAGMENT
647        // TODO: EVEN-PORT
648        // TODO: RESERVATION-TOKEN
649        // TODO: allocation quota
650        // XXX: TRY-ALTERNATE
651
652        let client = Client {
653            transport: ttype,
654            remote_addr: from,
655            local_addr: to,
656            allocations: vec![],
657            credentials,
658        };
659
660        self.pending_allocates.push_front(PendingClient {
661            client,
662            asked: false,
663            transaction_id: msg.transaction_id(),
664        });
665
666        Ok(())
667    }
668
669    fn handle_stun_refresh<'a>(
670        &mut self,
671        msg: &Message<'_>,
672        ttype: TransportType,
673        from: SocketAddr,
674        to: SocketAddr,
675        now: Instant,
676    ) -> Result<Transmit<Data<'static>>, MessageBuilder<'a>> {
677        let _credentials = self.validate_stun(msg, ttype, from, to, now)?;
678
679        let Some(client) = self.mut_client_from_5tuple(ttype, to, from) else {
680            let mut builder = Message::builder_error(msg);
681            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
682                .build()
683                .unwrap();
684            builder.add_attribute(&error).unwrap();
685            return Err(builder.into_owned());
686        };
687
688        // TODO: proper lifetime handling
689        let request_lifetime = msg
690            .attribute::<Lifetime>()
691            .map(|lt| lt.seconds())
692            .unwrap_or(600);
693        let credentials = if request_lifetime == 0 {
694            // TODO: handle dual IPv4/6 allocations.
695            let credentials = client.credentials.clone();
696            self.remove_client_by_5tuple(ttype, to, from);
697            credentials
698        } else {
699            for allocation in client.allocations.iter_mut() {
700                allocation.expires_at = now + Duration::from_secs(request_lifetime as u64)
701            }
702            client.credentials.clone()
703        };
704
705        let mut builder = Message::builder_success(msg);
706        let lifetime = Lifetime::new(request_lifetime);
707        builder.add_attribute(&lifetime).unwrap();
708        builder
709            .add_message_integrity(
710                &MessageIntegrityCredentials::LongTerm(credentials),
711                stun_proto::types::message::IntegrityAlgorithm::Sha1,
712            )
713            .unwrap();
714        let Ok(transmit) = self.stun.send(builder, from, now) else {
715            error!("Failed to send");
716            let mut response = Message::builder_error(msg);
717            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
718            response.add_attribute(&error).unwrap();
719            response.add_fingerprint().unwrap();
720            return Err(response.into_owned());
721        };
722
723        Ok(transmit_send_build(transmit))
724    }
725
726    fn handle_stun_create_permission<'a>(
727        &mut self,
728        msg: &Message<'_>,
729        ttype: TransportType,
730        from: SocketAddr,
731        to: SocketAddr,
732        now: Instant,
733    ) -> Result<Transmit<Data<'static>>, MessageBuilder<'a>> {
734        let credentials = self.validate_stun(msg, ttype, from, to, now)?;
735
736        let Some(client) = self.mut_client_from_5tuple(ttype, to, from) else {
737            let mut builder = Message::builder_error(msg);
738            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
739                .build()
740                .unwrap();
741            builder.add_attribute(&error).unwrap();
742            return Err(builder.into_owned());
743        };
744
745        let mut at_least_one_peer_addr = false;
746        for peer_addr in msg
747            .iter_attributes()
748            .filter(|a| a.get_type() == XorPeerAddress::TYPE)
749        {
750            let Ok(peer_addr) = XorPeerAddress::from_raw(peer_addr) else {
751                let mut builder = Message::builder_error(msg);
752                let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
753                builder.add_attribute(&error).unwrap();
754                builder
755                    .add_message_integrity(
756                        &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
757                        stun_proto::types::message::IntegrityAlgorithm::Sha1,
758                    )
759                    .unwrap();
760                return Err(builder.into_owned());
761            };
762            at_least_one_peer_addr = true;
763            let peer_addr = peer_addr.addr(msg.transaction_id());
764
765            let Some(alloc) = client
766                .allocations
767                .iter_mut()
768                .find(|a| a.addr.is_ipv4() == peer_addr.is_ipv4())
769            else {
770                // XXX: Should always be an allocation available.
771                // TODO: support IPv6
772                unreachable!();
773            };
774
775            if now > alloc.expires_at {
776                trace!("allocation has expired");
777                // allocation has expired
778                let mut builder = Message::builder_error(msg);
779                let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
780                    .build()
781                    .unwrap();
782                builder.add_attribute(&error).unwrap();
783                builder
784                    .add_message_integrity(
785                        &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
786                        stun_proto::types::message::IntegrityAlgorithm::Sha1,
787                    )
788                    .unwrap();
789                return Err(builder.into_owned());
790            }
791
792            // TODO: support TCP allocations
793            if let Some(position) = alloc
794                .permissions
795                .iter()
796                .position(|perm| perm.ttype == TransportType::Udp && perm.addr == peer_addr.ip())
797            {
798                alloc.permissions[position].expires_at = now + Duration::from_secs(300);
799            } else {
800                alloc.permissions.push(Permission {
801                    addr: peer_addr.ip(),
802                    ttype: TransportType::Udp,
803                    expires_at: now + Duration::from_secs(300),
804                });
805            }
806        }
807
808        if !at_least_one_peer_addr {
809            let mut builder = Message::builder_error(msg);
810            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
811            builder.add_attribute(&error).unwrap();
812            builder
813                .add_message_integrity(
814                    &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
815                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
816                )
817                .unwrap();
818            return Err(builder.into_owned());
819        }
820
821        let mut builder = Message::builder_success(msg);
822        builder
823            .add_message_integrity(
824                &MessageIntegrityCredentials::LongTerm(credentials),
825                stun_proto::types::message::IntegrityAlgorithm::Sha1,
826            )
827            .unwrap();
828
829        let Ok(transmit) = self.stun.send(builder, from, now) else {
830            error!("Failed to send");
831            let mut response = Message::builder_error(msg);
832            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
833            response.add_attribute(&error).unwrap();
834            response.add_fingerprint().unwrap();
835            return Err(response.into_owned());
836        };
837
838        Ok(transmit_send_build(transmit))
839    }
840
841    fn handle_stun_channel_bind<'a>(
842        &mut self,
843        msg: &Message<'_>,
844        ttype: TransportType,
845        from: SocketAddr,
846        to: SocketAddr,
847        now: Instant,
848    ) -> Result<Transmit<Data<'static>>, MessageBuilder<'a>> {
849        let credentials = self.validate_stun(msg, ttype, from, to, now)?;
850
851        let Some(client) = self.mut_client_from_5tuple(ttype, to, from) else {
852            let mut builder = Message::builder_error(msg);
853            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
854                .build()
855                .unwrap();
856            builder.add_attribute(&error).unwrap();
857            return Err(builder.into_owned());
858        };
859
860        let bad_request = move |msg: &Message<'_>, credentials: LongTermCredentials| {
861            let mut builder = Message::builder_error(msg);
862            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
863            builder.add_attribute(&error).unwrap();
864            builder
865                .add_message_integrity(
866                    &MessageIntegrityCredentials::LongTerm(credentials),
867                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
868                )
869                .unwrap();
870            builder.into_owned()
871        };
872
873        let peer_addr = msg
874            .attribute::<XorPeerAddress>()
875            .ok()
876            .map(|peer_addr| peer_addr.addr(msg.transaction_id()));
877        let Some(peer_addr) = peer_addr else {
878            trace!("No peer address");
879            return Err(bad_request(msg, credentials));
880        };
881
882        let Some(alloc) = client
883            .allocations
884            .iter_mut()
885            .find(|allocation| allocation.addr.is_ipv4() == peer_addr.is_ipv4())
886        else {
887            let mut builder = Message::builder_error(msg);
888            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
889                .build()
890                .unwrap();
891            builder.add_attribute(&error).unwrap();
892            return Err(builder.into_owned());
893        };
894
895        if now > alloc.expires_at {
896            trace!("allocation has expired");
897            // allocation has expired
898            let mut builder = Message::builder_error(msg);
899            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
900                .build()
901                .unwrap();
902            builder.add_attribute(&error).unwrap();
903            builder
904                .add_message_integrity(
905                    &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
906                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
907                )
908                .unwrap();
909            return Err(builder.into_owned());
910        }
911
912        let mut existing = alloc.channels.iter_mut().find(|channel| {
913            channel.peer_addr == peer_addr && channel.peer_transport == TransportType::Udp
914        });
915
916        let channel_no = msg
917            .attribute::<ChannelNumber>()
918            .ok()
919            .map(|channel| channel.channel());
920        if let Some(channel_no) = channel_no {
921            if !(0x4000..=0x7fff).contains(&channel_no) {
922                trace!("Channel id out of range");
923                return Err(bad_request(msg, credentials));
924            }
925            if existing
926                .as_ref()
927                .is_some_and(|existing| existing.id != channel_no)
928            {
929                trace!("channel peer address does not match channel ID");
930                return Err(bad_request(msg, credentials));
931            }
932        } else {
933            debug!("Bad request: no requested channel id");
934            return Err(bad_request(msg, credentials));
935        }
936
937        if let Some(existing) = existing.as_mut() {
938            existing.expires_at = now + Duration::from_secs(600);
939        } else {
940            alloc.channels.push(Channel {
941                id: channel_no.unwrap(),
942                peer_addr,
943                peer_transport: TransportType::Udp,
944                expires_at: now + Duration::from_secs(600),
945            });
946        }
947
948        if let Some(existing) = alloc
949            .permissions
950            .iter_mut()
951            .find(|perm| perm.ttype == TransportType::Udp && perm.addr == peer_addr.ip())
952        {
953            existing.expires_at = now + Duration::from_secs(300);
954        } else {
955            alloc.permissions.push(Permission {
956                addr: peer_addr.ip(),
957                ttype: TransportType::Udp,
958                expires_at: now + Duration::from_secs(300),
959            });
960        }
961
962        let mut builder = Message::builder_success(msg);
963        builder
964            .add_message_integrity(
965                &MessageIntegrityCredentials::LongTerm(credentials),
966                stun_proto::types::message::IntegrityAlgorithm::Sha1,
967            )
968            .unwrap();
969
970        let Ok(transmit) = self.stun.send(builder, from, now) else {
971            error!("Failed to send");
972            let mut response = Message::builder_error(msg);
973            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
974            response.add_attribute(&error).unwrap();
975            response.add_fingerprint().unwrap();
976            return Err(response.into_owned());
977        };
978
979        Ok(transmit_send_build(transmit))
980    }
981
982    fn handle_stun_send_indication<'a>(
983        &mut self,
984        msg: &'a Message<'a>,
985        ttype: TransportType,
986        from: SocketAddr,
987        to: SocketAddr,
988        now: Instant,
989    ) -> Result<Transmit<Data<'a>>, ()> {
990        let peer_address = msg.attribute::<XorPeerAddress>().map_err(|_| ())?;
991        let peer_address = peer_address.addr(msg.transaction_id());
992
993        let Some(client) = self.client_from_5tuple(ttype, to, from) else {
994            trace!("no client for transport {ttype:?} from {from:?}, to {to:?}");
995            trace!("clients: {:?}", self.clients);
996            return Err(());
997        };
998
999        let Some(alloc) = client
1000            .allocations
1001            .iter()
1002            .find(|allocation| allocation.addr.ip().is_ipv4() == peer_address.is_ipv4())
1003        else {
1004            trace!("no allocation for transport {ttype:?} from {from:?}, to {to:?}");
1005            trace!("allocations: {:?}", client.allocations);
1006            return Err(());
1007        };
1008        if now > alloc.expires_at {
1009            trace!("allocation has expired");
1010            // allocation has expired
1011            return Err(());
1012        }
1013
1014        let Some(permission) = alloc
1015            .permissions
1016            .iter()
1017            .find(|permission| permission.addr == peer_address.ip())
1018        else {
1019            trace!("permission not installed");
1020            // no permission installed for this peer, ignoring
1021            return Err(());
1022        };
1023        if now > permission.expires_at {
1024            trace!("permission has expired");
1025            // permission has expired
1026            return Err(());
1027        }
1028
1029        let data = msg.attribute::<AData>().map_err(|_| ())?;
1030        trace!("have {} to send to {:?}", data.data().len(), peer_address);
1031        Ok(Transmit::new(
1032            Data::from(data.data()),
1033            permission.ttype,
1034            alloc.addr,
1035            peer_address,
1036        )
1037        .into_owned())
1038        // XXX: copies the data.  Try to figure out a way to not do this
1039        /*
1040        self.pending_transmits.push_back(Transmit::new_owned(
1041            data.data(),
1042            permission.ttype,
1043            alloc.addr,
1044            peer_address,
1045        ));
1046        Ok(())*/
1047    }
1048
1049    #[tracing::instrument(name = "turn_server_handle_stun", skip(self, msg, from, to, now))]
1050    fn handle_stun<'a>(
1051        &mut self,
1052        msg: &'a Message<'a>,
1053        ttype: TransportType,
1054        from: SocketAddr,
1055        to: SocketAddr,
1056        now: Instant,
1057    ) -> Result<Option<Transmit<Data<'a>>>, MessageBuilder<'a>> {
1058        trace!("received STUN message {msg}");
1059        let ret = if msg.has_class(stun_proto::types::message::MessageClass::Request) {
1060            match msg.method() {
1061                BINDING => self
1062                    .handle_stun_binding(msg, ttype, from, to, now)
1063                    .map(Some),
1064                ALLOCATE => self
1065                    .handle_stun_allocate(msg, ttype, from, to, now)
1066                    .map(|_| None),
1067                REFRESH => self
1068                    .handle_stun_refresh(msg, ttype, from, to, now)
1069                    .map(Some),
1070                CREATE_PERMISSION => self
1071                    .handle_stun_create_permission(msg, ttype, from, to, now)
1072                    .map(Some),
1073                CHANNEL_BIND => self
1074                    .handle_stun_channel_bind(msg, ttype, from, to, now)
1075                    .map(Some),
1076                _ => {
1077                    let mut builder = Message::builder_error(msg);
1078                    let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
1079                    builder.add_attribute(&error).unwrap();
1080                    Err(builder.into_owned())
1081                }
1082            }
1083        } else if msg.has_class(stun_proto::types::message::MessageClass::Indication) {
1084            match msg.method() {
1085                SEND => Ok(self
1086                    .handle_stun_send_indication(msg, ttype, from, to, now)
1087                    .ok()),
1088                _ => Ok(None),
1089            }
1090        } else {
1091            Ok(None)
1092        };
1093        debug!("result: {ret:?}");
1094        ret
1095    }
1096
1097    fn nonce_from_5tuple(
1098        &self,
1099        ttype: TransportType,
1100        local_addr: SocketAddr,
1101        remote_addr: SocketAddr,
1102    ) -> Option<&NonceData> {
1103        self.nonces.iter().find(|nonce| {
1104            nonce.transport == ttype
1105                && nonce.remote_addr == remote_addr
1106                && nonce.local_addr == local_addr
1107        })
1108    }
1109
1110    fn mut_nonce_from_5tuple(
1111        &mut self,
1112        ttype: TransportType,
1113        local_addr: SocketAddr,
1114        remote_addr: SocketAddr,
1115    ) -> Option<&mut NonceData> {
1116        self.nonces.iter_mut().find(|nonce| {
1117            nonce.transport == ttype
1118                && nonce.remote_addr == remote_addr
1119                && nonce.local_addr == local_addr
1120        })
1121    }
1122
1123    fn client_from_5tuple(
1124        &self,
1125        ttype: TransportType,
1126        local_addr: SocketAddr,
1127        remote_addr: SocketAddr,
1128    ) -> Option<&Client> {
1129        self.clients.iter().find(|client| {
1130            client.transport == ttype
1131                && client.remote_addr == remote_addr
1132                && client.local_addr == local_addr
1133        })
1134    }
1135
1136    fn mut_client_from_5tuple(
1137        &mut self,
1138        ttype: TransportType,
1139        local_addr: SocketAddr,
1140        remote_addr: SocketAddr,
1141    ) -> Option<&mut Client> {
1142        self.clients.iter_mut().find(|client| {
1143            client.transport == ttype
1144                && client.remote_addr == remote_addr
1145                && client.local_addr == local_addr
1146        })
1147    }
1148
1149    fn remove_client_by_5tuple(
1150        &mut self,
1151        ttype: TransportType,
1152        local_addr: SocketAddr,
1153        remote_addr: SocketAddr,
1154    ) {
1155        self.clients.retain(|client| {
1156            client.transport != ttype
1157                && client.remote_addr != remote_addr
1158                && client.local_addr == local_addr
1159        })
1160    }
1161
1162    fn allocation_from_public_5tuple(
1163        &self,
1164        ttype: TransportType,
1165        local_addr: SocketAddr,
1166        remote_addr: SocketAddr,
1167    ) -> Option<(&Client, &Allocation)> {
1168        self.clients.iter().find_map(|client| {
1169            client
1170                .allocations
1171                .iter()
1172                .find(|allocation| {
1173                    allocation.ttype == ttype
1174                        && allocation.addr == local_addr
1175                        && allocation
1176                            .permissions
1177                            .iter()
1178                            .any(|permission| permission.addr == remote_addr.ip())
1179                })
1180                .map(|allocation| (client, allocation))
1181        })
1182    }
1183}
1184
1185#[derive(Debug)]
1186struct Client {
1187    transport: TransportType,
1188    local_addr: SocketAddr,
1189    remote_addr: SocketAddr,
1190
1191    allocations: Vec<Allocation>,
1192    credentials: LongTermCredentials,
1193}
1194
1195#[derive(Debug)]
1196struct Allocation {
1197    // the peer-side address of this allocation
1198    addr: SocketAddr,
1199    ttype: TransportType,
1200
1201    expires_at: Instant,
1202
1203    permissions: Vec<Permission>,
1204    channels: Vec<Channel>,
1205}
1206
1207impl Allocation {
1208    fn permissions_from_5tuple(
1209        &self,
1210        ttype: TransportType,
1211        _local_addr: SocketAddr,
1212        remote_addr: SocketAddr,
1213    ) -> Option<&Permission> {
1214        self.permissions
1215            .iter()
1216            .find(|permission| permission.ttype == ttype && remote_addr.ip() == permission.addr)
1217    }
1218
1219    fn channel_from_id(&self, id: u16) -> Option<&Channel> {
1220        self.channels.iter().find(|channel| channel.id == id)
1221    }
1222
1223    fn channel_from_5tuple(
1224        &self,
1225        transport: TransportType,
1226        local_addr: SocketAddr,
1227        remote_addr: SocketAddr,
1228    ) -> Option<&Channel> {
1229        if self.addr != local_addr {
1230            return None;
1231        }
1232        self.channels
1233            .iter()
1234            .find(|channel| transport == channel.peer_transport && remote_addr == channel.peer_addr)
1235    }
1236}
1237
1238#[derive(Debug)]
1239struct Permission {
1240    addr: IpAddr,
1241    ttype: TransportType,
1242
1243    expires_at: Instant,
1244}
1245
1246#[derive(Debug)]
1247struct Channel {
1248    id: u16,
1249    peer_addr: SocketAddr,
1250    peer_transport: TransportType,
1251
1252    expires_at: Instant,
1253}
1254
1255fn transmit_send_build<T: DelayedTransmitBuild>(
1256    transmit: TransmitBuild<T>,
1257) -> Transmit<Data<'static>> {
1258    let data = transmit.data.build().into_boxed_slice();
1259    Transmit::new(
1260        Data::from(data),
1261        transmit.transport,
1262        transmit.from,
1263        transmit.to,
1264    )
1265    .into_owned()
1266}