Skip to main content

saorsa_webrtc_core/
protocol_handler.rs

1//! WebRTC protocol handler for SharedTransport integration
2//!
3//! Implements the `ProtocolHandler` trait from saorsa-transport to handle
4//! WebRTC-specific stream types over the shared transport layer.
5
6use ant_quic::{
7    LinkError as TransportError, LinkResult as TransportResult, PeerId, ProtocolHandler, StreamType,
8};
9use async_trait::async_trait;
10use bytes::Bytes;
11use std::collections::HashMap;
12use thiserror::Error;
13use tokio::sync::{mpsc, RwLock};
14use tracing::{debug, error, trace, warn};
15
16use crate::quic_bridge::RtpPacket;
17use crate::signaling::SignalingMessage;
18
19/// Errors specific to WebRTC protocol handling.
20#[derive(Debug, Error)]
21pub enum WebRtcHandlerError {
22    /// Failed to deserialize signaling message.
23    #[error("failed to deserialize signaling message: {0}")]
24    SignalingDeserialize(String),
25
26    /// Failed to deserialize media packet.
27    #[error("failed to deserialize media packet: {0}")]
28    MediaDeserialize(String),
29
30    /// Failed to serialize response.
31    #[error("failed to serialize response: {0}")]
32    Serialize(String),
33
34    /// Channel send error.
35    #[error("failed to send to channel: {0}")]
36    ChannelSend(String),
37}
38
39/// Incoming WebRTC message types.
40#[derive(Debug, Clone)]
41pub enum WebRtcIncoming {
42    /// Signaling message (SDP offers, answers, ICE candidates).
43    Signal {
44        /// Remote peer ID.
45        peer: PeerId,
46        /// The signaling message.
47        message: SignalingMessage,
48    },
49    /// Media packet (RTP audio/video).
50    Media {
51        /// Remote peer ID.
52        peer: PeerId,
53        /// The RTP packet.
54        packet: RtpPacket,
55    },
56    /// Data channel message.
57    Data {
58        /// Remote peer ID.
59        peer: PeerId,
60        /// Channel ID.
61        channel_id: u32,
62        /// The data payload.
63        data: Bytes,
64    },
65}
66
67/// Configuration for the WebRTC protocol handler.
68#[derive(Debug, Clone)]
69pub struct WebRtcHandlerConfig {
70    /// Buffer size for incoming signal messages.
71    pub signal_buffer_size: usize,
72    /// Buffer size for incoming media packets.
73    pub media_buffer_size: usize,
74    /// Buffer size for incoming data channel messages.
75    pub data_buffer_size: usize,
76}
77
78impl Default for WebRtcHandlerConfig {
79    fn default() -> Self {
80        Self {
81            signal_buffer_size: 256,
82            media_buffer_size: 1024,
83            data_buffer_size: 512,
84        }
85    }
86}
87
88/// WebRTC protocol handler for SharedTransport.
89///
90/// Routes incoming streams to the appropriate WebRTC subsystem based on
91/// stream type:
92/// - `WebRtcSignal` (0x20): SDP offers/answers, ICE candidates
93/// - `WebRtcMedia` (0x21): RTP packets for audio/video
94/// - `WebRtcData` (0x22): Data channel messages
95pub struct WebRtcProtocolHandler {
96    /// Channel for incoming signaling messages.
97    signal_tx: mpsc::Sender<WebRtcIncoming>,
98    /// Channel for incoming media packets.
99    media_tx: mpsc::Sender<WebRtcIncoming>,
100    /// Channel for incoming data channel messages.
101    data_tx: mpsc::Sender<WebRtcIncoming>,
102
103    /// Per-peer session state.
104    sessions: RwLock<HashMap<PeerId, PeerSession>>,
105
106    /// Shutdown flag.
107    shutdown: RwLock<bool>,
108}
109
110/// State for a peer's WebRTC session.
111#[derive(Debug, Default)]
112struct PeerSession {
113    /// Active data channel IDs.
114    data_channels: Vec<u32>,
115    /// Messages received count.
116    messages_received: u64,
117    /// Last activity timestamp.
118    last_activity: Option<std::time::Instant>,
119}
120
121impl WebRtcProtocolHandler {
122    /// Create a new WebRTC protocol handler.
123    ///
124    /// Returns the handler and receivers for each message type.
125    pub fn new(
126        config: WebRtcHandlerConfig,
127    ) -> (
128        Self,
129        mpsc::Receiver<WebRtcIncoming>,
130        mpsc::Receiver<WebRtcIncoming>,
131        mpsc::Receiver<WebRtcIncoming>,
132    ) {
133        let (signal_tx, signal_rx) = mpsc::channel(config.signal_buffer_size);
134        let (media_tx, media_rx) = mpsc::channel(config.media_buffer_size);
135        let (data_tx, data_rx) = mpsc::channel(config.data_buffer_size);
136
137        let handler = Self {
138            signal_tx,
139            media_tx,
140            data_tx,
141            sessions: RwLock::new(HashMap::new()),
142            shutdown: RwLock::new(false),
143        };
144
145        (handler, signal_rx, media_rx, data_rx)
146    }
147
148    /// Create with default configuration.
149    pub fn with_defaults() -> (
150        Self,
151        mpsc::Receiver<WebRtcIncoming>,
152        mpsc::Receiver<WebRtcIncoming>,
153        mpsc::Receiver<WebRtcIncoming>,
154    ) {
155        Self::new(WebRtcHandlerConfig::default())
156    }
157
158    /// Handle incoming signaling message.
159    async fn handle_signal(&self, peer: PeerId, data: Bytes) -> TransportResult<Option<Bytes>> {
160        trace!(peer = ?peer, size = data.len(), "Processing WebRTC signal");
161
162        // Deserialize the signaling message
163        let message: SignalingMessage = serde_json::from_slice(&data).map_err(|e| {
164            TransportError::Internal(format!("Failed to deserialize signaling message: {}", e))
165        })?;
166
167        debug!(
168            peer = ?peer,
169            session_id = %message.session_id(),
170            "Received signaling message"
171        );
172
173        // Update session state
174        {
175            let mut sessions = self.sessions.write().await;
176            let session = sessions.entry(peer).or_default();
177            session.messages_received += 1;
178            session.last_activity = Some(std::time::Instant::now());
179        }
180
181        // Send to signal channel
182        self.signal_tx
183            .send(WebRtcIncoming::Signal { peer, message })
184            .await
185            .map_err(|e| {
186                TransportError::Internal(format!("Failed to send to signal channel: {}", e))
187            })?;
188
189        // Signaling typically expects a response, but we handle that asynchronously
190        Ok(None)
191    }
192
193    /// Handle incoming media packet.
194    async fn handle_media(&self, peer: PeerId, data: Bytes) -> TransportResult<Option<Bytes>> {
195        trace!(peer = ?peer, size = data.len(), "Processing WebRTC media");
196
197        // Deserialize the RTP packet
198        let packet = RtpPacket::from_bytes(&data).map_err(|e| {
199            TransportError::Internal(format!("Failed to deserialize RTP packet: {}", e))
200        })?;
201
202        trace!(
203            peer = ?peer,
204            stream_type = ?packet.stream_type,
205            seq = packet.sequence_number,
206            "Received media packet"
207        );
208
209        // Update session state
210        {
211            let mut sessions = self.sessions.write().await;
212            let session = sessions.entry(peer).or_default();
213            session.messages_received += 1;
214            session.last_activity = Some(std::time::Instant::now());
215        }
216
217        // Send to media channel - use try_send for non-blocking media
218        match self
219            .media_tx
220            .try_send(WebRtcIncoming::Media { peer, packet })
221        {
222            Ok(()) => {}
223            Err(mpsc::error::TrySendError::Full(_)) => {
224                warn!(peer = ?peer, "Media channel full, dropping packet");
225            }
226            Err(mpsc::error::TrySendError::Closed(_)) => {
227                return Err(TransportError::Shutdown);
228            }
229        }
230
231        // Media packets do not require responses
232        Ok(None)
233    }
234
235    /// Handle incoming data channel message.
236    async fn handle_data(&self, peer: PeerId, data: Bytes) -> TransportResult<Option<Bytes>> {
237        trace!(peer = ?peer, size = data.len(), "Processing WebRTC data");
238
239        // Data channel format: 4-byte channel ID + payload
240        if data.len() < 4 {
241            return Err(TransportError::Internal(
242                "Data channel message too short".into(),
243            ));
244        }
245
246        let channel_id = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
247        let payload = data.slice(4..);
248
249        debug!(
250            peer = ?peer,
251            channel_id = channel_id,
252            payload_size = payload.len(),
253            "Received data channel message"
254        );
255
256        // Update session state
257        {
258            let mut sessions = self.sessions.write().await;
259            let session = sessions.entry(peer).or_default();
260            session.messages_received += 1;
261            session.last_activity = Some(std::time::Instant::now());
262            if !session.data_channels.contains(&channel_id) {
263                session.data_channels.push(channel_id);
264            }
265        }
266
267        // Send to data channel
268        self.data_tx
269            .send(WebRtcIncoming::Data {
270                peer,
271                channel_id,
272                data: payload,
273            })
274            .await
275            .map_err(|e| {
276                TransportError::Internal(format!("Failed to send to data channel: {}", e))
277            })?;
278
279        // Data channel messages may or may not require responses
280        Ok(None)
281    }
282
283    /// Get number of active sessions.
284    pub async fn session_count(&self) -> usize {
285        self.sessions.read().await.len()
286    }
287
288    /// Get session info for a peer.
289    pub async fn get_session_stats(&self, peer: &PeerId) -> Option<(u64, Vec<u32>)> {
290        let sessions = self.sessions.read().await;
291        sessions
292            .get(peer)
293            .map(|s| (s.messages_received, s.data_channels.clone()))
294    }
295
296    /// Remove a peer session.
297    pub async fn remove_session(&self, peer: &PeerId) {
298        let mut sessions = self.sessions.write().await;
299        if sessions.remove(peer).is_some() {
300            debug!(peer = ?peer, "Removed WebRTC session");
301        }
302    }
303}
304
305#[async_trait]
306impl ProtocolHandler for WebRtcProtocolHandler {
307    fn stream_types(&self) -> &[StreamType] {
308        StreamType::webrtc_types()
309    }
310
311    async fn handle_stream(
312        &self,
313        peer: PeerId,
314        stream_type: StreamType,
315        data: Bytes,
316    ) -> TransportResult<Option<Bytes>> {
317        // Check shutdown flag
318        if *self.shutdown.read().await {
319            return Err(TransportError::Shutdown);
320        }
321
322        match stream_type {
323            StreamType::WebRtcSignal => self.handle_signal(peer, data).await,
324            StreamType::WebRtcMedia => self.handle_media(peer, data).await,
325            StreamType::WebRtcData => self.handle_data(peer, data).await,
326            _ => {
327                error!(stream_type = %stream_type, "Unexpected stream type in WebRTC handler");
328                Err(TransportError::Internal(format!(
329                    "Unknown stream type: {}",
330                    stream_type
331                )))
332            }
333        }
334    }
335
336    async fn handle_datagram(
337        &self,
338        peer: PeerId,
339        stream_type: StreamType,
340        data: Bytes,
341    ) -> TransportResult<()> {
342        // Datagrams are used for unreliable media (e.g., low-priority video frames)
343        if stream_type == StreamType::WebRtcMedia {
344            trace!(peer = ?peer, size = data.len(), "Received media datagram");
345
346            // Try to deserialize and forward, but do not fail on errors for datagrams
347            if let Ok(packet) = RtpPacket::from_bytes(&data) {
348                let _ = self
349                    .media_tx
350                    .try_send(WebRtcIncoming::Media { peer, packet });
351            }
352        }
353        Ok(())
354    }
355
356    async fn shutdown(&self) -> TransportResult<()> {
357        debug!("Shutting down WebRTC protocol handler");
358
359        let mut shutdown = self.shutdown.write().await;
360        *shutdown = true;
361
362        // Clear sessions
363        self.sessions.write().await.clear();
364
365        Ok(())
366    }
367
368    fn name(&self) -> &str {
369        "WebRtcProtocolHandler"
370    }
371}
372
373/// Builder for creating WebRtcProtocolHandler with custom configuration.
374pub struct WebRtcProtocolHandlerBuilder {
375    config: WebRtcHandlerConfig,
376}
377
378impl WebRtcProtocolHandlerBuilder {
379    /// Create a new builder with default configuration.
380    pub fn new() -> Self {
381        Self {
382            config: WebRtcHandlerConfig::default(),
383        }
384    }
385
386    /// Set signal buffer size.
387    pub fn signal_buffer_size(mut self, size: usize) -> Self {
388        self.config.signal_buffer_size = size;
389        self
390    }
391
392    /// Set media buffer size.
393    pub fn media_buffer_size(mut self, size: usize) -> Self {
394        self.config.media_buffer_size = size;
395        self
396    }
397
398    /// Set data buffer size.
399    pub fn data_buffer_size(mut self, size: usize) -> Self {
400        self.config.data_buffer_size = size;
401        self
402    }
403
404    /// Build the handler and return receivers.
405    pub fn build(
406        self,
407    ) -> (
408        WebRtcProtocolHandler,
409        mpsc::Receiver<WebRtcIncoming>,
410        mpsc::Receiver<WebRtcIncoming>,
411        mpsc::Receiver<WebRtcIncoming>,
412    ) {
413        WebRtcProtocolHandler::new(self.config)
414    }
415}
416
417impl Default for WebRtcProtocolHandlerBuilder {
418    fn default() -> Self {
419        Self::new()
420    }
421}
422
423#[cfg(test)]
424#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
425mod tests {
426    use super::*;
427
428    #[tokio::test]
429    async fn test_handler_stream_types() {
430        let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
431
432        let types = handler.stream_types();
433        assert!(types.contains(&StreamType::WebRtcSignal));
434        assert!(types.contains(&StreamType::WebRtcMedia));
435        assert!(types.contains(&StreamType::WebRtcData));
436        assert_eq!(types.len(), 3);
437    }
438
439    #[tokio::test]
440    async fn test_handler_name() {
441        let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
442        assert_eq!(handler.name(), "WebRtcProtocolHandler");
443    }
444
445    #[tokio::test]
446    async fn test_handle_signal_message() {
447        let (handler, mut signal_rx, _, _) = WebRtcProtocolHandler::with_defaults();
448
449        let peer = PeerId::from([1u8; 32]);
450        let message = SignalingMessage::Offer {
451            session_id: "test-session".to_string(),
452            sdp: "v=0\r\no=- 123 1 IN IP4 127.0.0.1\r\n".to_string(),
453            quic_endpoint: None,
454        };
455
456        let data = Bytes::from(serde_json::to_vec(&message).unwrap());
457
458        let result = handler
459            .handle_stream(peer, StreamType::WebRtcSignal, data)
460            .await;
461        assert!(result.is_ok());
462
463        // Check message was forwarded
464        let received = signal_rx.try_recv();
465        assert!(received.is_ok());
466
467        if let WebRtcIncoming::Signal {
468            peer: p,
469            message: m,
470        } = received.unwrap()
471        {
472            assert_eq!(p, peer);
473            assert_eq!(m.session_id(), "test-session");
474        } else {
475            panic!("Expected Signal message");
476        }
477    }
478
479    #[tokio::test]
480    async fn test_handle_media_packet() {
481        let (handler, _, mut media_rx, _) = WebRtcProtocolHandler::with_defaults();
482
483        let peer = PeerId::from([2u8; 32]);
484        let packet = RtpPacket::new(
485            96,         // payload type
486            1000,       // sequence number
487            12345,      // timestamp
488            0xDEADBEEF, // SSRC
489            vec![1, 2, 3, 4],
490            crate::quic_bridge::StreamType::Audio,
491        )
492        .unwrap();
493
494        let data = Bytes::from(packet.to_bytes().unwrap());
495
496        let result = handler
497            .handle_stream(peer, StreamType::WebRtcMedia, data)
498            .await;
499        assert!(result.is_ok());
500
501        // Check packet was forwarded
502        let received = media_rx.try_recv();
503        assert!(received.is_ok());
504
505        if let WebRtcIncoming::Media {
506            peer: p,
507            packet: pkt,
508        } = received.unwrap()
509        {
510            assert_eq!(p, peer);
511            assert_eq!(pkt.sequence_number, 1000);
512        } else {
513            panic!("Expected Media message");
514        }
515    }
516
517    #[tokio::test]
518    async fn test_handle_data_channel() {
519        let (handler, _, _, mut data_rx) = WebRtcProtocolHandler::with_defaults();
520
521        let peer = PeerId::from([3u8; 32]);
522
523        // Build data channel message: 4-byte channel ID + payload
524        let channel_id: u32 = 42;
525        let payload = b"hello world";
526        let mut data = channel_id.to_be_bytes().to_vec();
527        data.extend_from_slice(payload);
528
529        let result = handler
530            .handle_stream(peer, StreamType::WebRtcData, Bytes::from(data))
531            .await;
532        assert!(result.is_ok());
533
534        // Check message was forwarded
535        let received = data_rx.try_recv();
536        assert!(received.is_ok());
537
538        if let WebRtcIncoming::Data {
539            peer: p,
540            channel_id: ch,
541            data: d,
542        } = received.unwrap()
543        {
544            assert_eq!(p, peer);
545            assert_eq!(ch, 42);
546            assert_eq!(&d[..], payload);
547        } else {
548            panic!("Expected Data message");
549        }
550    }
551
552    #[tokio::test]
553    async fn test_data_channel_too_short() {
554        let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
555
556        let peer = PeerId::from([4u8; 32]);
557        let data = Bytes::from_static(&[1, 2, 3]); // Only 3 bytes, need 4 for channel ID
558
559        let result = handler
560            .handle_stream(peer, StreamType::WebRtcData, data)
561            .await;
562
563        assert!(result.is_err());
564    }
565
566    #[tokio::test]
567    async fn test_session_tracking() {
568        let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
569
570        let peer = PeerId::from([5u8; 32]);
571
572        // Initially no sessions
573        assert_eq!(handler.session_count().await, 0);
574
575        // Send a data channel message
576        let mut data = 1u32.to_be_bytes().to_vec();
577        data.extend_from_slice(b"test");
578
579        let _ = handler
580            .handle_stream(peer, StreamType::WebRtcData, Bytes::from(data))
581            .await;
582
583        // Now we have a session
584        assert_eq!(handler.session_count().await, 1);
585
586        let stats = handler.get_session_stats(&peer).await;
587        assert!(stats.is_some());
588        let (msgs, channels) = stats.unwrap();
589        assert_eq!(msgs, 1);
590        assert!(channels.contains(&1));
591
592        // Remove session
593        handler.remove_session(&peer).await;
594        assert_eq!(handler.session_count().await, 0);
595    }
596
597    #[tokio::test]
598    async fn test_shutdown() {
599        let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
600
601        // Shutdown
602        let result = handler.shutdown().await;
603        assert!(result.is_ok());
604
605        // After shutdown, handle_stream should fail
606        let peer = PeerId::from([6u8; 32]);
607        let result = handler
608            .handle_stream(peer, StreamType::WebRtcSignal, Bytes::new())
609            .await;
610        assert!(result.is_err());
611    }
612
613    #[tokio::test]
614    async fn test_builder() {
615        let (handler, _, _, _) = WebRtcProtocolHandlerBuilder::new()
616            .signal_buffer_size(128)
617            .media_buffer_size(512)
618            .data_buffer_size(256)
619            .build();
620
621        assert_eq!(handler.name(), "WebRtcProtocolHandler");
622    }
623
624    #[tokio::test]
625    async fn test_invalid_signal_message() {
626        let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
627
628        let peer = PeerId::from([7u8; 32]);
629        let data = Bytes::from_static(b"not valid json");
630
631        let result = handler
632            .handle_stream(peer, StreamType::WebRtcSignal, data)
633            .await;
634        assert!(result.is_err());
635    }
636
637    #[tokio::test]
638    async fn test_unexpected_stream_type() {
639        let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
640
641        let peer = PeerId::from([8u8; 32]);
642
643        // Try with a non-WebRTC stream type
644        let result = handler
645            .handle_stream(peer, StreamType::Membership, Bytes::new())
646            .await;
647        assert!(result.is_err());
648    }
649}
650
651/// Stream routing for WebRTC media types
652///
653/// Provides mapping between packet types and QUIC stream types
654pub mod stream_routing {
655    use crate::link_transport::StreamType;
656
657    /// RTP payload types for audio codecs (96-127)
658    pub const AUDIO_PAYLOAD_TYPE_RANGE: (u8, u8) = (96, 127);
659
660    /// RTP payload types for video codecs (96-127)
661    pub const VIDEO_PAYLOAD_TYPE_RANGE: (u8, u8) = (96, 127);
662
663    /// RTCP payload types (200-211)
664    pub const RTCP_PAYLOAD_TYPE_RANGE: (u8, u8) = (200, 211);
665
666    /// Detect if a packet is RTP based on payload type
667    ///
668    /// # Arguments
669    ///
670    /// * `payload_type` - The RTP payload type (first 2 bits of first byte)
671    ///
672    /// # Returns
673    ///
674    /// `true` if this appears to be an RTP packet
675    pub fn is_rtp(payload_type: u8) -> bool {
676        payload_type < 128 || (96..=127).contains(&payload_type)
677    }
678
679    /// Detect if a packet is RTCP based on payload type
680    ///
681    /// # Arguments
682    ///
683    /// * `payload_type` - The RTCP payload type
684    ///
685    /// # Returns
686    ///
687    /// `true` if this appears to be an RTCP packet
688    pub fn is_rtcp(payload_type: u8) -> bool {
689        (200..=211).contains(&payload_type)
690    }
691
692    /// Detect if packet is audio based on codec hint
693    ///
694    /// # Arguments
695    ///
696    /// * `payload_type` - The RTP payload type
697    ///
698    /// # Returns
699    ///
700    /// `true` if this is likely an audio codec
701    pub fn is_audio_codec(payload_type: u8) -> bool {
702        // Dynamic payload types 96-127 need external SDP mapping
703        // But common audio PTs: 0-23 are static
704        matches!(
705            payload_type,
706            0 | 1
707                | 3
708                | 4
709                | 5
710                | 6
711                | 7
712                | 8
713                | 9
714                | 10
715                | 11
716                | 12
717                | 13
718                | 14
719                | 15
720                | 16
721                | 17
722                | 18
723                | 19
724                | 25
725                | 97
726        )
727    }
728
729    /// Detect if packet is video based on codec hint
730    ///
731    /// # Arguments
732    ///
733    /// * `payload_type` - The RTP payload type
734    ///
735    /// # Returns
736    ///
737    /// `true` if this is likely a video codec
738    pub fn is_video_codec(payload_type: u8) -> bool {
739        // Common video PTs: 26, 32-34, 96-127 (dynamic), etc.
740        matches!(
741            payload_type,
742            26 | 32 | 33 | 34 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105
743        )
744    }
745
746    /// Route media to appropriate stream based on type
747    ///
748    /// # Arguments
749    ///
750    /// * `payload_type` - The RTP/RTCP payload type
751    ///
752    /// # Returns
753    ///
754    /// The target `StreamType` for this packet
755    pub fn route_to_stream(payload_type: u8) -> StreamType {
756        if is_rtcp(payload_type) {
757            StreamType::RtcpFeedback
758        } else if is_audio_codec(payload_type) {
759            StreamType::Audio
760        } else if is_video_codec(payload_type) {
761            StreamType::Video
762        } else {
763            // Default to video for unknown dynamic types
764            StreamType::Video
765        }
766    }
767
768    /// Get all RTCP packet types
769    ///
770    /// # Returns
771    ///
772    /// A vector of all valid RTCP payload types
773    pub fn rtcp_feedback_types() -> Vec<u8> {
774        (200..=211).collect()
775    }
776
777    #[cfg(test)]
778    mod routing_tests {
779        use super::*;
780
781        #[test]
782        fn test_is_rtp() {
783            assert!(is_rtp(0));
784            assert!(is_rtp(96));
785            assert!(is_rtp(127));
786            assert!(!is_rtp(200));
787        }
788
789        #[test]
790        fn test_is_rtcp() {
791            assert!(is_rtcp(200));
792            assert!(is_rtcp(205));
793            assert!(is_rtcp(211));
794            assert!(!is_rtcp(199));
795            assert!(!is_rtcp(212));
796        }
797
798        #[test]
799        fn test_is_audio_codec() {
800            assert!(is_audio_codec(0)); // PCMU
801            assert!(is_audio_codec(8)); // PCMA
802            assert!(is_audio_codec(97)); // iLBC
803            assert!(!is_audio_codec(26)); // Video
804        }
805
806        #[test]
807        fn test_is_video_codec() {
808            assert!(is_video_codec(26)); // Motion JPEG
809            assert!(is_video_codec(32)); // MPV
810            assert!(is_video_codec(96)); // Dynamic
811            assert!(!is_video_codec(0)); // Audio
812        }
813
814        #[test]
815        fn test_route_to_stream_audio() {
816            let stream = route_to_stream(0); // PCMU
817            assert_eq!(stream, StreamType::Audio);
818        }
819
820        #[test]
821        fn test_route_to_stream_video() {
822            let stream = route_to_stream(26); // Motion JPEG
823            assert_eq!(stream, StreamType::Video);
824        }
825
826        #[test]
827        fn test_route_to_stream_rtcp() {
828            let stream = route_to_stream(200); // SR
829            assert_eq!(stream, StreamType::RtcpFeedback);
830        }
831
832        #[test]
833        fn test_rtcp_feedback_types() {
834            let types = rtcp_feedback_types();
835            assert_eq!(types.len(), 12);
836            assert_eq!(types[0], 200);
837            assert_eq!(types[11], 211);
838        }
839    }
840}
841
842impl WebRtcProtocolHandler {
843    /// Route an incoming packet to the correct stream type
844    ///
845    /// # Arguments
846    ///
847    /// * `payload` - The packet payload
848    ///
849    /// # Returns
850    ///
851    /// The target `StreamType` for this packet
852    pub fn route_packet_to_stream(payload: &[u8]) -> crate::link_transport::StreamType {
853        if payload.is_empty() {
854            return crate::link_transport::StreamType::Data;
855        }
856
857        // Extract payload type from RTP/RTCP header
858        // First check if this is RTCP (byte[1] >= 200)
859        if payload[1] >= 200 {
860            return crate::link_transport::StreamType::RtcpFeedback;
861        }
862
863        // For RTP, extract payload type from bits 1-7 of second byte
864        let pt = payload[1] & 0x7F;
865        stream_routing::route_to_stream(pt)
866    }
867
868    /// Get the media type for a stream
869    ///
870    /// # Arguments
871    ///
872    /// * `stream_type` - The stream type
873    ///
874    /// # Returns
875    ///
876    /// A description of the media type
877    pub fn stream_media_type(stream_type: crate::link_transport::StreamType) -> &'static str {
878        match stream_type {
879            crate::link_transport::StreamType::Audio => "Audio RTP",
880            crate::link_transport::StreamType::Video => "Video RTP",
881            crate::link_transport::StreamType::Screen => "Screen Share RTP",
882            crate::link_transport::StreamType::RtcpFeedback => "RTCP Feedback",
883            crate::link_transport::StreamType::Data => "Data Channel",
884        }
885    }
886}
887
888#[cfg(test)]
889mod routing_integration_tests {
890    use super::*;
891    use crate::link_transport::StreamType;
892
893    #[test]
894    fn test_route_packet_audio_rtp() {
895        // RTP packet with PCMU (PT=0)
896        let payload = vec![0x80, 0x00, 0x00, 0x01];
897        let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
898        assert_eq!(stream, StreamType::Audio);
899    }
900
901    #[test]
902    fn test_route_packet_video_rtp() {
903        // RTP packet with Motion JPEG (PT=26)
904        let payload = vec![0x80, 0x1A, 0x00, 0x01];
905        let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
906        assert_eq!(stream, StreamType::Video);
907    }
908
909    #[test]
910    fn test_route_packet_rtcp() {
911        // RTCP packet with SR (PT=200)
912        let payload = vec![0x80, 0xC8, 0x00, 0x01];
913        let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
914        assert_eq!(stream, StreamType::RtcpFeedback);
915    }
916
917    #[test]
918    fn test_route_packet_empty() {
919        let payload: Vec<u8> = vec![];
920        let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
921        assert_eq!(stream, StreamType::Data);
922    }
923
924    #[test]
925    fn test_stream_media_type_descriptions() {
926        assert_eq!(
927            WebRtcProtocolHandler::stream_media_type(StreamType::Audio),
928            "Audio RTP"
929        );
930        assert_eq!(
931            WebRtcProtocolHandler::stream_media_type(StreamType::Video),
932            "Video RTP"
933        );
934        assert_eq!(
935            WebRtcProtocolHandler::stream_media_type(StreamType::Screen),
936            "Screen Share RTP"
937        );
938        assert_eq!(
939            WebRtcProtocolHandler::stream_media_type(StreamType::RtcpFeedback),
940            "RTCP Feedback"
941        );
942        assert_eq!(
943            WebRtcProtocolHandler::stream_media_type(StreamType::Data),
944            "Data Channel"
945        );
946    }
947}