turn_server_proto/
server.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::collections::{HashMap, VecDeque};
10use std::net::{IpAddr, SocketAddr};
11use std::time::{Duration, Instant};
12
13use stun_proto::agent::{StunAgent, StunError, Transmit, TransmitBuild};
14use stun_proto::prelude::*;
15use stun_proto::types::attribute::{
16    ErrorCode, Fingerprint, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress,
17};
18use stun_proto::types::data::Data as SData;
19use stun_proto::types::message::{
20    LongTermCredentials, Message, MessageBuilder, MessageClass, MessageIntegrityCredentials,
21    MessageType, TransactionId, BINDING,
22};
23use stun_proto::types::prelude::{Attribute, AttributeFromRaw, AttributeStaticType};
24use stun_proto::types::TransportType;
25use turn_types::channel::ChannelData;
26
27use turn_types::message::CREATE_PERMISSION;
28
29use turn_types::attribute::{
30    ChannelNumber, Data, Lifetime, RequestedTransport, XorPeerAddress, XorRelayedAddress,
31};
32use turn_types::message::{ALLOCATE, CHANNEL_BIND, DATA, REFRESH, SEND};
33use turn_types::TurnCredentials;
34
35use tracing::{debug, error, info, trace, warn};
36
37#[derive(Debug)]
38pub struct TurnServer {
39    realm: String,
40    // FIXME: remove
41    credentials: TurnCredentials,
42    stun: StunAgent,
43
44    clients: Vec<Client>,
45    nonces: Vec<NonceData>,
46    pending_transmits: VecDeque<Transmit<SData<'static>>>,
47    pending_allocates: VecDeque<PendingClient>,
48
49    // username -> password mapping.
50    users: HashMap<String, String>,
51}
52
53#[derive(Debug)]
54struct PendingClient {
55    client: Client,
56    asked: bool,
57    transaction_id: TransactionId,
58}
59
60#[derive(Debug)]
61struct NonceData {
62    nonce: String,
63    expires_at: Instant,
64
65    transport: TransportType,
66    remote_addr: SocketAddr,
67    local_addr: SocketAddr,
68}
69
70#[derive(Debug)]
71pub enum TurnServerPollRet {
72    WaitUntil(Instant),
73    AllocateSocketUdp {
74        transport: TransportType,
75        local_addr: SocketAddr,
76        remote_addr: SocketAddr,
77    },
78}
79
80impl TurnServer {
81    pub fn new(
82        ttype: TransportType,
83        listen_addr: SocketAddr,
84        credentials: TurnCredentials,
85        realm: String,
86    ) -> Self {
87        let stun = StunAgent::builder(ttype, listen_addr).build();
88        Self {
89            realm,
90            credentials,
91            stun,
92            clients: vec![],
93            nonces: vec![],
94            pending_transmits: VecDeque::default(),
95            pending_allocates: VecDeque::default(),
96            users: HashMap::default(),
97        }
98    }
99
100    pub fn add_user(&mut self, username: String, password: String) {
101        self.users.insert(username, password);
102    }
103
104    pub fn listen_address(&self) -> SocketAddr {
105        self.stun.local_addr()
106    }
107
108    #[tracing::instrument(
109        name = "turn_server_recv",
110        skip(self, transmit),
111        fields(
112            transport = %transmit.transport,
113            remote_addr = %transmit.from,
114            local_addr = %transmit.to,
115            data_len = transmit.data.as_ref().len(),
116        )
117        err,
118        ret,
119    )]
120    pub fn recv<T: AsRef<[u8]>>(
121        &mut self,
122        transmit: Transmit<T>,
123        now: Instant,
124    ) -> Result<Option<Transmit<SData<'static>>>, StunError> {
125        match Message::from_bytes(transmit.data.as_ref()) {
126            Ok(msg) => {
127                match self.handle_stun(&msg, transmit.transport, transmit.from, transmit.to, now) {
128                    Err(builder) => {
129                        let data = builder.build();
130                        return Ok(Some(Transmit::new(
131                            data.into_boxed_slice().into(),
132                            transmit.transport,
133                            transmit.to,
134                            transmit.from,
135                        )));
136                    }
137                    Ok(Some(transmit)) => Ok(Some(transmit.into_owned())),
138                    Ok(None) => Ok(None),
139                }
140            }
141            Err(_) => {
142                // TODO: TCP buffering requirements
143                if let Some(client) =
144                    self.client_from_5tuple(transmit.transport, transmit.to, transmit.from)
145                {
146                    let Ok(channel) = ChannelData::parse(transmit.data.as_ref()) else {
147                        return Ok(None);
148                    };
149                    let Some((allocation, existing)) =
150                        client.allocations.iter().find_map(|allocation| {
151                            allocation
152                                .channel_from_id(channel.id())
153                                .map(|perm| (allocation, perm))
154                        })
155                    else {
156                        // no channel with that id
157                        return Ok(None);
158                    };
159
160                    // A packet from the client needs to be sent to the peer referenced by the
161                    // configured channel.
162                    let Some(_permission) = allocation.permissions_from_5tuple(
163                        transmit.transport,
164                        allocation.addr,
165                        existing.peer_addr,
166                    ) else {
167                        return Ok(None);
168                    };
169                    Ok(Some(
170                        Transmit::new(
171                            SData::from(channel.data()),
172                            allocation.ttype,
173                            allocation.addr,
174                            existing.peer_addr,
175                        )
176                        .into_owned(),
177                    ))
178                } else if let Some((client, allocation)) = self.allocation_from_public_5tuple(
179                    transmit.transport,
180                    transmit.to,
181                    transmit.from,
182                ) {
183                    // A packet from the relayed address needs to be sent to the client that set up
184                    // the allocation.
185                    let Some(_permission) = allocation.permissions_from_5tuple(
186                        transmit.transport,
187                        transmit.to,
188                        transmit.from,
189                    ) else {
190                        return Ok(None);
191                    };
192
193                    if let Some(existing) = allocation.channel_from_5tuple(
194                        transmit.transport,
195                        transmit.from,
196                        transmit.to,
197                    ) {
198                        // no channel with that id
199                        let mut data = vec![0; 4];
200                        data[0..2].copy_from_slice(&existing.id.to_be_bytes());
201                        data[2..4]
202                            .copy_from_slice(&(transmit.data.as_ref().len() as u16).to_be_bytes());
203                        // XXX: try to avoid copy?
204                        data.extend_from_slice(transmit.data.as_ref());
205                        Ok(Some(Transmit::new(
206                            data.into_boxed_slice().into(),
207                            client.transport,
208                            client.local_addr,
209                            client.remote_addr,
210                        )))
211                    } else {
212                        let transaction_id = TransactionId::generate();
213                        let mut builder = Message::builder(
214                            MessageType::from_class_method(MessageClass::Indication, DATA),
215                            transaction_id,
216                        );
217                        let peer_address = XorPeerAddress::new(transmit.from, transaction_id);
218                        builder.add_attribute(&peer_address).unwrap();
219                        let data = Data::new(transmit.data.as_ref());
220                        builder.add_attribute(&data).unwrap();
221                        let msg_data = builder.build();
222
223                        Ok(Some(Transmit::new(
224                            msg_data.into_boxed_slice().into(),
225                            client.transport,
226                            client.local_addr,
227                            client.remote_addr,
228                        )))
229                    }
230                } else {
231                    Ok(None)
232                }
233            }
234        }
235    }
236
237    #[tracing::instrument(name = "turn_server_poll", skip(self), ret)]
238    pub fn poll(&mut self, now: Instant) -> TurnServerPollRet {
239        for pending in self.pending_allocates.iter_mut() {
240            if pending.asked {
241                continue;
242            }
243
244            // TODO: TCP
245            return TurnServerPollRet::AllocateSocketUdp {
246                transport: pending.client.transport,
247                local_addr: pending.client.local_addr,
248                remote_addr: pending.client.remote_addr,
249            };
250        }
251
252        for client in self.clients.iter_mut() {
253            client.allocations.retain_mut(|allocation| {
254                if allocation.expires_at <= now {
255                    allocation
256                        .permissions
257                        .retain_mut(|permission| permission.expires_at <= now);
258                    allocation
259                        .channels
260                        .retain_mut(|channel| channel.expires_at <= now);
261                    true
262                } else {
263                    false
264                }
265            });
266        }
267
268        TurnServerPollRet::WaitUntil(now + Duration::from_secs(60))
269    }
270
271    #[tracing::instrument(name = "turn_server_poll_transmit", skip(self))]
272    pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<SData<'static>>> {
273        if let Some(transmit) = self.pending_transmits.pop_back() {
274            return Some(transmit);
275        }
276        None
277    }
278
279    #[tracing::instrument(name = "turn_server_allocated_udp_socket", skip(self))]
280    pub fn allocated_udp_socket(
281        &mut self,
282        local_addr: SocketAddr,
283        remote_addr: SocketAddr,
284        socket_addr: Result<SocketAddr, ()>,
285        now: Instant,
286    ) {
287        let Some(position) = self.pending_allocates.iter().position(|pending| {
288            pending.client.transport == TransportType::Udp
289                && pending.client.local_addr == local_addr
290                && pending.client.remote_addr == remote_addr
291        }) else {
292            warn!("No pending allocation for transport: Udp, local: {local_addr:?}, remote {remote_addr:?}");
293            return;
294        };
295        info!("pending allocation for transport: Udp, local: {local_addr:?}, remote {remote_addr:?} resulted in {socket_addr:?}");
296        let mut pending = self.pending_allocates.remove(position).unwrap();
297        let transaction_id = pending.transaction_id;
298        let to = pending.client.remote_addr;
299
300        let mut builder = if let Ok(socket_addr) = socket_addr {
301            pending.client.allocations.push(Allocation {
302                addr: socket_addr,
303                ttype: TransportType::Udp,
304                expires_at: now + Duration::from_secs(600),
305                permissions: vec![],
306                channels: vec![],
307            });
308
309            let mut builder = Message::builder(
310                MessageType::from_class_method(MessageClass::Success, ALLOCATE),
311                transaction_id,
312            );
313            let relayed_address = XorRelayedAddress::new(socket_addr, transaction_id);
314            builder.add_attribute(&relayed_address).unwrap();
315            let lifetime = Lifetime::new(600);
316            builder.add_attribute(&lifetime).unwrap();
317            // TODO RESERVATION-TOKEN
318            let mapped_address = XorMappedAddress::new(pending.client.remote_addr, transaction_id);
319            builder.add_attribute(&mapped_address).unwrap();
320
321            builder.into_owned()
322        } else {
323            let mut builder = Message::builder(
324                MessageType::from_class_method(MessageClass::Error, ALLOCATE),
325                transaction_id,
326            );
327            let error = ErrorCode::builder(ErrorCode::INSUFFICIENT_CAPACITY)
328                .build()
329                .unwrap();
330            builder.add_attribute(&error).unwrap();
331            builder.into_owned()
332        };
333        builder
334            .add_message_integrity(
335                &MessageIntegrityCredentials::LongTerm(pending.client.credentials.clone()),
336                stun_proto::types::message::IntegrityAlgorithm::Sha1,
337            )
338            .unwrap();
339
340        let Ok(transmit) = self.stun.send(builder, to, now) else {
341            return;
342        };
343        if socket_addr.is_ok() {
344            self.clients.push(pending.client);
345        }
346        self.pending_transmits
347            .push_back(transmit_send_build(transmit));
348    }
349
350    fn validate_stun<'a>(
351        &mut self,
352        msg: &Message<'_>,
353        ttype: TransportType,
354        from: SocketAddr,
355        to: SocketAddr,
356        now: Instant,
357    ) -> Result<LongTermCredentials, MessageBuilder<'a>> {
358        let integrity = msg.attribute::<MessageIntegrity>().ok();
359        // TODO: check for SHA256 integrity
360        if integrity.is_none() {
361            //   o  If the message does not contain a MESSAGE-INTEGRITY attribute, the
362            //      server MUST generate an error response with an error code of 401
363            //      (Unauthorized).  This response MUST include a REALM value.  It is
364            //      RECOMMENDED that the REALM value be the domain name of the
365            //      provider of the STUN server.  The response MUST include a NONCE,
366            //      selected by the server.  The response SHOULD NOT contain a
367            //      USERNAME or MESSAGE-INTEGRITY attribute.
368            let nonce = if let Some(nonce) = self.nonce_from_5tuple(ttype, to, from) {
369                nonce
370            } else {
371                self.nonces.push(NonceData {
372                    transport: ttype,
373                    remote_addr: from,
374                    local_addr: to,
375                    // FIXME: use an actual random source.
376                    nonce: String::from("random"),
377                    expires_at: now + Duration::from_secs(3600),
378                });
379                self.nonces.last().unwrap()
380            };
381            trace!(
382                "no message-integrity, returning unauthorized with nonce: {}",
383                nonce.nonce
384            );
385            let mut builder = Message::builder_error(msg);
386            let nonce = Nonce::new(&nonce.nonce).unwrap();
387            builder.add_attribute(&nonce).unwrap();
388            let realm = Realm::new(&self.realm).unwrap();
389            builder.add_attribute(&realm).unwrap();
390            let error = ErrorCode::builder(ErrorCode::UNAUTHORIZED).build().unwrap();
391            builder.add_attribute(&error).unwrap();
392            return Err(builder.into_owned());
393        }
394
395        //  o  If the message contains a MESSAGE-INTEGRITY attribute, but is
396        //      missing the USERNAME, REALM, or NONCE attribute, the server MUST
397        //      generate an error response with an error code of 400 (Bad
398        //      Request).  This response SHOULD NOT include a USERNAME, NONCE,
399        //      REALM, or MESSAGE-INTEGRITY attribute.
400        let username = msg.attribute::<Username>().ok();
401        let realm = msg.attribute::<Realm>().ok();
402        let nonce = msg.attribute::<Nonce>().ok();
403        let Some(((username, _realm), nonce)) = username.zip(realm).zip(nonce) else {
404            trace!("bad request due to missing username, realm, nonce");
405            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
406            let mut builder = Message::builder_error(msg);
407            builder.add_attribute(&error).unwrap();
408            return Err(builder.into_owned());
409        };
410
411        //   o  If the NONCE is no longer valid, the server MUST generate an error
412        //      response with an error code of 438 (Stale Nonce).  This response
413        //      MUST include NONCE and REALM attributes and SHOULD NOT include the
414        //      USERNAME or MESSAGE-INTEGRITY attribute.  Servers can invalidate
415        //      nonces in order to provide additional security.  See Section 4.3
416        //      of [RFC2617] for guidelines.
417        let nonce_data = self.mut_nonce_from_5tuple(ttype, to, from);
418        let mut stale_nonce = false;
419        let nonce_value = if let Some(nonce_data) = nonce_data {
420            if nonce_data.expires_at < now {
421                nonce_data.nonce = String::from("random");
422                nonce_data.expires_at = now + Duration::from_secs(3600);
423                stale_nonce = true;
424            } else if nonce_data.nonce != nonce.nonce() {
425                stale_nonce = true;
426            }
427            nonce_data.nonce.clone()
428        } else {
429            let nonce_value = String::from("randome");
430            self.nonces.push(NonceData {
431                transport: ttype,
432                remote_addr: from,
433                local_addr: to,
434                // FIXME: use an actual random source.
435                nonce: nonce_value.clone(),
436                expires_at: now + Duration::from_secs(3600),
437            });
438            stale_nonce = true;
439            nonce_value
440        };
441
442        if stale_nonce {
443            let mut builder = Message::builder_error(msg);
444            let error = ErrorCode::builder(ErrorCode::STALE_NONCE).build().unwrap();
445            builder.add_attribute(&error).unwrap();
446            let realm = Realm::new(&self.realm).unwrap();
447            builder.add_attribute(&realm).unwrap();
448            let nonce = Nonce::new(&nonce_value).unwrap();
449            builder.add_attribute(&nonce).unwrap();
450
451            return Err(builder.into_owned());
452        };
453
454        //   o  Using the password associated with the username in the USERNAME
455        //      attribute, compute the value for the message integrity as
456        //      described in Section 15.4.  If the resulting value does not match
457        //      the contents of the MESSAGE-INTEGRITY attribute, the server MUST
458        //      reject the request with an error response.  This response MUST use
459        //      an error code of 401 (Unauthorized).  It MUST include REALM and
460        //      NONCE attributes and SHOULD NOT include the USERNAME or MESSAGE-
461        //      INTEGRITY attribute.
462        let password = self.users.get(username.username());
463        let credentials = TurnCredentials::new(
464            username.username(),
465            password.map_or("", |pass| pass.as_str()),
466        )
467        .into_long_term_credentials(&self.realm);
468        if password.map_or(true, |_password| {
469            msg.validate_integrity(&MessageIntegrityCredentials::LongTerm(credentials.clone()))
470                .is_err()
471        }) {
472            let mut builder = Message::builder_error(msg);
473            let error = ErrorCode::builder(ErrorCode::UNAUTHORIZED).build().unwrap();
474            builder.add_attribute(&error).unwrap();
475            let realm = Realm::new(&self.realm).unwrap();
476            builder.add_attribute(&realm).unwrap();
477            let nonce = Nonce::new(&nonce_value).unwrap();
478            builder.add_attribute(&nonce).unwrap();
479            return Err(builder.into_owned());
480        }
481
482        if let Some(client) = self.client_from_5tuple(ttype, to, from) {
483            if client.credentials.username() != username.username() {
484                let mut builder = Message::builder_error(msg);
485                let error = ErrorCode::builder(ErrorCode::WRONG_CREDENTIALS)
486                    .build()
487                    .unwrap();
488                builder.add_attribute(&error).unwrap();
489                builder
490                    .add_message_integrity(
491                        &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
492                        stun_proto::types::message::IntegrityAlgorithm::Sha1,
493                    )
494                    .unwrap();
495                return Err(builder.into_owned());
496            }
497        }
498
499        Ok(credentials)
500    }
501
502    fn handle_stun_binding<'a>(
503        &mut self,
504        msg: &Message<'_>,
505        _ttype: TransportType,
506        from: SocketAddr,
507        to: SocketAddr,
508        now: Instant,
509    ) -> Result<Transmit<SData<'a>>, MessageBuilder<'a>> {
510        let response = if let Some(error_msg) =
511            Message::check_attribute_types(msg, &[Fingerprint::TYPE], &[])
512        {
513            error_msg
514        } else {
515            let mut response = Message::builder_success(msg);
516            let xor_addr = XorMappedAddress::new(from, msg.transaction_id());
517            response.add_attribute(&xor_addr).unwrap();
518            response.add_fingerprint().unwrap();
519            response.into_owned()
520        };
521
522        let Ok(transmit) = self.stun.send(response, to, now) else {
523            error!("Failed to send");
524            let mut response = Message::builder_error(msg);
525            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
526            response.add_attribute(&error).unwrap();
527            response.add_fingerprint().unwrap();
528            return Err(response.into_owned());
529        };
530
531        Ok(transmit_send_build(transmit))
532    }
533
534    fn handle_stun_allocate<'a>(
535        &mut self,
536        msg: &Message<'_>,
537        ttype: TransportType,
538        from: SocketAddr,
539        to: SocketAddr,
540        now: Instant,
541    ) -> Result<(), MessageBuilder<'a>> {
542        let credentials = self.validate_stun(msg, ttype, from, to, now)?;
543
544        if let Some(_client) = self.mut_client_from_5tuple(ttype, to, from) {
545            let mut builder = Message::builder_error(msg);
546            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
547                .build()
548                .unwrap();
549            builder.add_attribute(&error).unwrap();
550            return Err(builder.into_owned());
551        };
552
553        let Ok(requested_transport) = msg.attribute::<RequestedTransport>() else {
554            let mut builder = Message::builder_error(msg);
555            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
556            builder.add_attribute(&error).unwrap();
557            builder
558                .add_message_integrity(
559                    &MessageIntegrityCredentials::LongTerm(credentials),
560                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
561                )
562                .unwrap();
563            return Err(builder.into_owned());
564        };
565
566        if requested_transport.protocol() != RequestedTransport::UDP {
567            let mut builder = Message::builder_error(msg);
568            let error = ErrorCode::builder(ErrorCode::UNSUPPORTED_TRANSPORT_PROTOCOL)
569                .build()
570                .unwrap();
571            builder.add_attribute(&error).unwrap();
572            builder
573                .add_message_integrity(
574                    &MessageIntegrityCredentials::LongTerm(credentials),
575                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
576                )
577                .unwrap();
578            return Err(builder.into_owned());
579        }
580
581        // TODO: DONT-FRAGMENT
582        // TODO: EVEN-PORT
583        // TODO: RESERVATION-TOKEN
584        // TODO: allocation quota
585        // XXX: TRY-ALTERNATE
586
587        let client = Client {
588            transport: ttype,
589            remote_addr: from,
590            local_addr: to,
591            allocations: vec![],
592            credentials,
593        };
594
595        self.pending_allocates.push_front(PendingClient {
596            client,
597            asked: false,
598            transaction_id: msg.transaction_id(),
599        });
600
601        Ok(())
602    }
603
604    fn handle_stun_refresh<'a>(
605        &mut self,
606        msg: &Message<'_>,
607        ttype: TransportType,
608        from: SocketAddr,
609        to: SocketAddr,
610        now: Instant,
611    ) -> Result<Transmit<SData<'static>>, MessageBuilder<'a>> {
612        let _credentials = self.validate_stun(msg, ttype, from, to, now)?;
613
614        let Some(client) = self.mut_client_from_5tuple(ttype, to, from) else {
615            let mut builder = Message::builder_error(msg);
616            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
617                .build()
618                .unwrap();
619            builder.add_attribute(&error).unwrap();
620            return Err(builder.into_owned());
621        };
622
623        // TODO: proper lifetime handling
624        let request_lifetime = msg
625            .attribute::<Lifetime>()
626            .map(|lt| lt.seconds())
627            .unwrap_or(600);
628        let credentials = if request_lifetime == 0 {
629            // TODO: handle dual IPv4/6 allocations.
630            let credentials = client.credentials.clone();
631            self.remove_client_by_5tuple(ttype, to, from);
632            credentials
633        } else {
634            for allocation in client.allocations.iter_mut() {
635                allocation.expires_at = now + Duration::from_secs(request_lifetime as u64)
636            }
637            client.credentials.clone()
638        };
639
640        let mut builder = Message::builder_success(msg);
641        let lifetime = Lifetime::new(request_lifetime);
642        builder.add_attribute(&lifetime).unwrap();
643        builder
644            .add_message_integrity(
645                &MessageIntegrityCredentials::LongTerm(credentials),
646                stun_proto::types::message::IntegrityAlgorithm::Sha1,
647            )
648            .unwrap();
649        let Ok(transmit) = self.stun.send(builder, from, now) else {
650            error!("Failed to send");
651            let mut response = Message::builder_error(msg);
652            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
653            response.add_attribute(&error).unwrap();
654            response.add_fingerprint().unwrap();
655            return Err(response.into_owned());
656        };
657
658        Ok(transmit_send_build(transmit))
659    }
660
661    fn handle_stun_create_permission<'a>(
662        &mut self,
663        msg: &Message<'_>,
664        ttype: TransportType,
665        from: SocketAddr,
666        to: SocketAddr,
667        now: Instant,
668    ) -> Result<Transmit<SData<'static>>, MessageBuilder<'a>> {
669        let credentials = self.validate_stun(msg, ttype, from, to, now)?;
670
671        let Some(client) = self.mut_client_from_5tuple(ttype, to, from) else {
672            let mut builder = Message::builder_error(msg);
673            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
674                .build()
675                .unwrap();
676            builder.add_attribute(&error).unwrap();
677            return Err(builder.into_owned());
678        };
679
680        let mut at_least_one_peer_addr = false;
681        for peer_addr in msg
682            .iter_attributes()
683            .filter(|a| a.get_type() == XorPeerAddress::TYPE)
684        {
685            let Ok(peer_addr) = XorPeerAddress::from_raw(peer_addr) else {
686                let mut builder = Message::builder_error(msg);
687                let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
688                builder.add_attribute(&error).unwrap();
689                builder
690                    .add_message_integrity(
691                        &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
692                        stun_proto::types::message::IntegrityAlgorithm::Sha1,
693                    )
694                    .unwrap();
695                return Err(builder.into_owned());
696            };
697            at_least_one_peer_addr = true;
698            let peer_addr = peer_addr.addr(msg.transaction_id());
699
700            let Some(alloc) = client
701                .allocations
702                .iter_mut()
703                .find(|a| a.addr.is_ipv4() == peer_addr.is_ipv4())
704            else {
705                // XXX: Should always be an allocation available.
706                // TODO: support IPv6
707                unreachable!();
708            };
709
710            if now > alloc.expires_at {
711                trace!("allocation has expired");
712                // allocation has expired
713                let mut builder = Message::builder_error(msg);
714                let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
715                    .build()
716                    .unwrap();
717                builder.add_attribute(&error).unwrap();
718                builder
719                    .add_message_integrity(
720                        &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
721                        stun_proto::types::message::IntegrityAlgorithm::Sha1,
722                    )
723                    .unwrap();
724                return Err(builder.into_owned());
725            }
726
727            // TODO: support TCP allocations
728            if let Some(position) = alloc
729                .permissions
730                .iter()
731                .position(|perm| perm.ttype == TransportType::Udp && perm.addr == peer_addr.ip())
732            {
733                alloc.permissions[position].expires_at = now + Duration::from_secs(300);
734            } else {
735                alloc.permissions.push(Permission {
736                    addr: peer_addr.ip(),
737                    ttype: TransportType::Udp,
738                    expires_at: now + Duration::from_secs(300),
739                });
740            }
741        }
742
743        if !at_least_one_peer_addr {
744            let mut builder = Message::builder_error(msg);
745            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
746            builder.add_attribute(&error).unwrap();
747            builder
748                .add_message_integrity(
749                    &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
750                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
751                )
752                .unwrap();
753            return Err(builder.into_owned());
754        }
755
756        let mut builder = Message::builder_success(msg);
757        builder
758            .add_message_integrity(
759                &MessageIntegrityCredentials::LongTerm(credentials),
760                stun_proto::types::message::IntegrityAlgorithm::Sha1,
761            )
762            .unwrap();
763
764        let Ok(transmit) = self.stun.send(builder, from, now) else {
765            error!("Failed to send");
766            let mut response = Message::builder_error(msg);
767            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
768            response.add_attribute(&error).unwrap();
769            response.add_fingerprint().unwrap();
770            return Err(response.into_owned());
771        };
772
773        Ok(transmit_send_build(transmit))
774    }
775
776    fn handle_stun_channel_bind<'a>(
777        &mut self,
778        msg: &Message<'_>,
779        ttype: TransportType,
780        from: SocketAddr,
781        to: SocketAddr,
782        now: Instant,
783    ) -> Result<Transmit<SData<'static>>, MessageBuilder<'a>> {
784        let credentials = self.validate_stun(msg, ttype, from, to, now)?;
785
786        let Some(client) = self.mut_client_from_5tuple(ttype, to, from) else {
787            let mut builder = Message::builder_error(msg);
788            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
789                .build()
790                .unwrap();
791            builder.add_attribute(&error).unwrap();
792            return Err(builder.into_owned());
793        };
794
795        let bad_request = move |msg: &Message<'_>, credentials: LongTermCredentials| {
796            let mut builder = Message::builder_error(msg);
797            let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
798            builder.add_attribute(&error).unwrap();
799            builder
800                .add_message_integrity(
801                    &MessageIntegrityCredentials::LongTerm(credentials),
802                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
803                )
804                .unwrap();
805            builder.into_owned()
806        };
807
808        let peer_addr = msg
809            .attribute::<XorPeerAddress>()
810            .ok()
811            .map(|peer_addr| peer_addr.addr(msg.transaction_id()));
812        let Some(peer_addr) = peer_addr else {
813            trace!("No peer address");
814            return Err(bad_request(msg, credentials));
815        };
816
817        let Some(alloc) = client
818            .allocations
819            .iter_mut()
820            .find(|allocation| allocation.addr.is_ipv4() == peer_addr.is_ipv4())
821        else {
822            let mut builder = Message::builder_error(msg);
823            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
824                .build()
825                .unwrap();
826            builder.add_attribute(&error).unwrap();
827            return Err(builder.into_owned());
828        };
829
830        if now > alloc.expires_at {
831            trace!("allocation has expired");
832            // allocation has expired
833            let mut builder = Message::builder_error(msg);
834            let error = ErrorCode::builder(ErrorCode::ALLOCATION_MISMATCH)
835                .build()
836                .unwrap();
837            builder.add_attribute(&error).unwrap();
838            builder
839                .add_message_integrity(
840                    &MessageIntegrityCredentials::LongTerm(client.credentials.clone()),
841                    stun_proto::types::message::IntegrityAlgorithm::Sha1,
842                )
843                .unwrap();
844            return Err(builder.into_owned());
845        }
846
847        let mut existing = alloc.channels.iter_mut().find(|channel| {
848            channel.peer_addr == peer_addr && channel.peer_transport == TransportType::Udp
849        });
850
851        let channel_no = msg
852            .attribute::<ChannelNumber>()
853            .ok()
854            .map(|channel| channel.channel());
855        if let Some(channel_no) = channel_no {
856            if !(0x4000..=0x7fff).contains(&channel_no) {
857                trace!("Channel id out of range");
858                return Err(bad_request(msg, credentials));
859            }
860            if existing
861                .as_ref()
862                .map_or(false, |existing| existing.id != channel_no)
863            {
864                trace!("channel peer address does not match channel ID");
865                return Err(bad_request(msg, credentials));
866            }
867        } else {
868            trace!("No requested channel id");
869            return Err(bad_request(msg, credentials));
870        }
871
872        if let Some(existing) = existing.as_mut() {
873            existing.expires_at = now + Duration::from_secs(600);
874        } else {
875            alloc.channels.push(Channel {
876                id: channel_no.unwrap(),
877                peer_addr,
878                peer_transport: TransportType::Udp,
879                expires_at: now + Duration::from_secs(600),
880            });
881        }
882
883        if let Some(existing) = alloc
884            .permissions
885            .iter_mut()
886            .find(|perm| perm.ttype == TransportType::Udp && perm.addr == peer_addr.ip())
887        {
888            existing.expires_at = now + Duration::from_secs(300);
889        } else {
890            alloc.permissions.push(Permission {
891                addr: peer_addr.ip(),
892                ttype: TransportType::Udp,
893                expires_at: now + Duration::from_secs(300),
894            });
895        }
896
897        let mut builder = Message::builder_success(msg);
898        builder
899            .add_message_integrity(
900                &MessageIntegrityCredentials::LongTerm(credentials),
901                stun_proto::types::message::IntegrityAlgorithm::Sha1,
902            )
903            .unwrap();
904
905        let Ok(transmit) = self.stun.send(builder, from, now) else {
906            error!("Failed to send");
907            let mut response = Message::builder_error(msg);
908            let error = ErrorCode::builder(ErrorCode::SERVER_ERROR).build().unwrap();
909            response.add_attribute(&error).unwrap();
910            response.add_fingerprint().unwrap();
911            return Err(response.into_owned());
912        };
913
914        Ok(transmit_send_build(transmit))
915    }
916
917    fn handle_stun_send_indication<'a>(
918        &mut self,
919        msg: &'a Message<'a>,
920        ttype: TransportType,
921        from: SocketAddr,
922        to: SocketAddr,
923        now: Instant,
924    ) -> Result<Transmit<SData<'a>>, ()> {
925        let peer_address = msg.attribute::<XorPeerAddress>().map_err(|_| ())?;
926        let peer_address = peer_address.addr(msg.transaction_id());
927
928        let Some(client) = self.client_from_5tuple(ttype, to, from) else {
929            trace!("no client for transport {ttype:?} from {from:?}, to {to:?}");
930            trace!("clients: {:?}", self.clients);
931            return Err(());
932        };
933
934        let Some(alloc) = client
935            .allocations
936            .iter()
937            .find(|allocation| allocation.addr.ip().is_ipv4() == peer_address.is_ipv4())
938        else {
939            trace!("no allocation for transport {ttype:?} from {from:?}, to {to:?}");
940            trace!("allocations: {:?}", client.allocations);
941            return Err(());
942        };
943        if now > alloc.expires_at {
944            trace!("allocation has expired");
945            // allocation has expired
946            return Err(());
947        }
948
949        let Some(permission) = alloc
950            .permissions
951            .iter()
952            .find(|permission| permission.addr == peer_address.ip())
953        else {
954            trace!("permission not installed");
955            // no permission installed for this peer, ignoring
956            return Err(());
957        };
958        if now > permission.expires_at {
959            trace!("permission has expired");
960            // permission has expired
961            return Err(());
962        }
963
964        let data = msg.attribute::<Data>().map_err(|_| ())?;
965        trace!("have {} to send to {:?}", data.data().len(), peer_address);
966        Ok(Transmit::new(
967            SData::from(data.data()),
968            permission.ttype,
969            alloc.addr,
970            peer_address,
971        )
972        .into_owned())
973        // XXX: copies the data.  Try to figure out a way to not do this
974        /*
975        self.pending_transmits.push_back(Transmit::new_owned(
976            data.data(),
977            permission.ttype,
978            alloc.addr,
979            peer_address,
980        ));
981        Ok(())*/
982    }
983
984    #[tracing::instrument(name = "turn_server_handle_stun", skip(self, msg, from, to, now))]
985    fn handle_stun<'a>(
986        &mut self,
987        msg: &'a Message<'a>,
988        ttype: TransportType,
989        from: SocketAddr,
990        to: SocketAddr,
991        now: Instant,
992    ) -> Result<Option<Transmit<SData<'a>>>, MessageBuilder<'a>> {
993        trace!("received STUN message {msg}");
994        let ret = if msg.has_class(stun_proto::types::message::MessageClass::Request) {
995            match msg.method() {
996                BINDING => self
997                    .handle_stun_binding(msg, ttype, from, to, now)
998                    .map(Some),
999                ALLOCATE => self
1000                    .handle_stun_allocate(msg, ttype, from, to, now)
1001                    .map(|_| None),
1002                REFRESH => self
1003                    .handle_stun_refresh(msg, ttype, from, to, now)
1004                    .map(Some),
1005                CREATE_PERMISSION => self
1006                    .handle_stun_create_permission(msg, ttype, from, to, now)
1007                    .map(Some),
1008                CHANNEL_BIND => self
1009                    .handle_stun_channel_bind(msg, ttype, from, to, now)
1010                    .map(Some),
1011                _ => {
1012                    let mut builder = Message::builder_error(msg);
1013                    let error = ErrorCode::builder(ErrorCode::BAD_REQUEST).build().unwrap();
1014                    builder.add_attribute(&error).unwrap();
1015                    Err(builder.into_owned())
1016                }
1017            }
1018        } else if msg.has_class(stun_proto::types::message::MessageClass::Indication) {
1019            match msg.method() {
1020                SEND => Ok(self
1021                    .handle_stun_send_indication(msg, ttype, from, to, now)
1022                    .ok()),
1023                _ => Ok(None),
1024            }
1025        } else {
1026            Ok(None)
1027        };
1028        debug!("result: {ret:?}");
1029        ret
1030    }
1031
1032    fn nonce_from_5tuple(
1033        &self,
1034        ttype: TransportType,
1035        local_addr: SocketAddr,
1036        remote_addr: SocketAddr,
1037    ) -> Option<&NonceData> {
1038        self.nonces.iter().find(|nonce| {
1039            nonce.transport == ttype
1040                && nonce.remote_addr == remote_addr
1041                && nonce.local_addr == local_addr
1042        })
1043    }
1044
1045    fn mut_nonce_from_5tuple(
1046        &mut self,
1047        ttype: TransportType,
1048        local_addr: SocketAddr,
1049        remote_addr: SocketAddr,
1050    ) -> Option<&mut NonceData> {
1051        self.nonces.iter_mut().find(|nonce| {
1052            nonce.transport == ttype
1053                && nonce.remote_addr == remote_addr
1054                && nonce.local_addr == local_addr
1055        })
1056    }
1057
1058    fn client_from_5tuple(
1059        &self,
1060        ttype: TransportType,
1061        local_addr: SocketAddr,
1062        remote_addr: SocketAddr,
1063    ) -> Option<&Client> {
1064        self.clients.iter().find(|client| {
1065            client.transport == ttype
1066                && client.remote_addr == remote_addr
1067                && client.local_addr == local_addr
1068        })
1069    }
1070
1071    fn mut_client_from_5tuple(
1072        &mut self,
1073        ttype: TransportType,
1074        local_addr: SocketAddr,
1075        remote_addr: SocketAddr,
1076    ) -> Option<&mut Client> {
1077        self.clients.iter_mut().find(|client| {
1078            client.transport == ttype
1079                && client.remote_addr == remote_addr
1080                && client.local_addr == local_addr
1081        })
1082    }
1083
1084    fn remove_client_by_5tuple(
1085        &mut self,
1086        ttype: TransportType,
1087        local_addr: SocketAddr,
1088        remote_addr: SocketAddr,
1089    ) {
1090        self.clients.retain(|client| {
1091            client.transport != ttype
1092                && client.remote_addr != remote_addr
1093                && client.local_addr == local_addr
1094        })
1095    }
1096
1097    fn allocation_from_public_5tuple(
1098        &self,
1099        ttype: TransportType,
1100        local_addr: SocketAddr,
1101        remote_addr: SocketAddr,
1102    ) -> Option<(&Client, &Allocation)> {
1103        self.clients.iter().find_map(|client| {
1104            client
1105                .allocations
1106                .iter()
1107                .find(|allocation| {
1108                    allocation.ttype == ttype
1109                        && allocation.addr == local_addr
1110                        && allocation
1111                            .permissions
1112                            .iter()
1113                            .any(|permission| permission.addr == remote_addr.ip())
1114                })
1115                .map(|allocation| (client, allocation))
1116        })
1117    }
1118}
1119
1120#[derive(Debug)]
1121struct Client {
1122    transport: TransportType,
1123    local_addr: SocketAddr,
1124    remote_addr: SocketAddr,
1125
1126    allocations: Vec<Allocation>,
1127    credentials: LongTermCredentials,
1128}
1129
1130#[derive(Debug)]
1131struct Allocation {
1132    // the peer-side address of this allocation
1133    addr: SocketAddr,
1134    ttype: TransportType,
1135
1136    expires_at: Instant,
1137
1138    permissions: Vec<Permission>,
1139    channels: Vec<Channel>,
1140}
1141
1142impl Allocation {
1143    fn permissions_from_5tuple(
1144        &self,
1145        ttype: TransportType,
1146        _local_addr: SocketAddr,
1147        remote_addr: SocketAddr,
1148    ) -> Option<&Permission> {
1149        self.permissions
1150            .iter()
1151            .find(|permission| permission.ttype == ttype && remote_addr.ip() == permission.addr)
1152    }
1153
1154    fn channel_from_id(&self, id: u16) -> Option<&Channel> {
1155        self.channels.iter().find(|channel| channel.id == id)
1156    }
1157
1158    fn channel_from_5tuple(
1159        &self,
1160        transport: TransportType,
1161        local_addr: SocketAddr,
1162        remote_addr: SocketAddr,
1163    ) -> Option<&Channel> {
1164        if self.addr != local_addr {
1165            return None;
1166        }
1167        self.channels
1168            .iter()
1169            .find(|channel| transport == channel.peer_transport && remote_addr == channel.peer_addr)
1170    }
1171}
1172
1173#[derive(Debug)]
1174struct Permission {
1175    addr: IpAddr,
1176    ttype: TransportType,
1177
1178    expires_at: Instant,
1179}
1180
1181#[derive(Debug)]
1182struct Channel {
1183    id: u16,
1184    peer_addr: SocketAddr,
1185    peer_transport: TransportType,
1186
1187    expires_at: Instant,
1188}
1189
1190fn transmit_send_build<T: DelayedTransmitBuild>(
1191    transmit: TransmitBuild<T>,
1192) -> Transmit<SData<'static>> {
1193    let data = transmit.data.build().into_boxed_slice();
1194    Transmit::new(
1195        SData::from(data),
1196        transmit.transport,
1197        transmit.from,
1198        transmit.to,
1199    )
1200    .into_owned()
1201}