1use ant_quic::{
7 LinkError as TransportError, LinkResult as TransportResult, PeerId, ProtocolHandler, StreamType,
8};
9use async_trait::async_trait;
10use bytes::Bytes;
11use std::collections::HashMap;
12use thiserror::Error;
13use tokio::sync::{mpsc, RwLock};
14use tracing::{debug, error, trace, warn};
15
16use crate::quic_bridge::RtpPacket;
17use crate::signaling::SignalingMessage;
18
19#[derive(Debug, Error)]
21pub enum WebRtcHandlerError {
22 #[error("failed to deserialize signaling message: {0}")]
24 SignalingDeserialize(String),
25
26 #[error("failed to deserialize media packet: {0}")]
28 MediaDeserialize(String),
29
30 #[error("failed to serialize response: {0}")]
32 Serialize(String),
33
34 #[error("failed to send to channel: {0}")]
36 ChannelSend(String),
37}
38
39#[derive(Debug, Clone)]
41pub enum WebRtcIncoming {
42 Signal {
44 peer: PeerId,
46 message: SignalingMessage,
48 },
49 Media {
51 peer: PeerId,
53 packet: RtpPacket,
55 },
56 Data {
58 peer: PeerId,
60 channel_id: u32,
62 data: Bytes,
64 },
65}
66
67#[derive(Debug, Clone)]
69pub struct WebRtcHandlerConfig {
70 pub signal_buffer_size: usize,
72 pub media_buffer_size: usize,
74 pub data_buffer_size: usize,
76}
77
78impl Default for WebRtcHandlerConfig {
79 fn default() -> Self {
80 Self {
81 signal_buffer_size: 256,
82 media_buffer_size: 1024,
83 data_buffer_size: 512,
84 }
85 }
86}
87
88pub struct WebRtcProtocolHandler {
96 signal_tx: mpsc::Sender<WebRtcIncoming>,
98 media_tx: mpsc::Sender<WebRtcIncoming>,
100 data_tx: mpsc::Sender<WebRtcIncoming>,
102
103 sessions: RwLock<HashMap<PeerId, PeerSession>>,
105
106 shutdown: RwLock<bool>,
108}
109
110#[derive(Debug, Default)]
112struct PeerSession {
113 data_channels: Vec<u32>,
115 messages_received: u64,
117 last_activity: Option<std::time::Instant>,
119}
120
121impl WebRtcProtocolHandler {
122 pub fn new(
126 config: WebRtcHandlerConfig,
127 ) -> (
128 Self,
129 mpsc::Receiver<WebRtcIncoming>,
130 mpsc::Receiver<WebRtcIncoming>,
131 mpsc::Receiver<WebRtcIncoming>,
132 ) {
133 let (signal_tx, signal_rx) = mpsc::channel(config.signal_buffer_size);
134 let (media_tx, media_rx) = mpsc::channel(config.media_buffer_size);
135 let (data_tx, data_rx) = mpsc::channel(config.data_buffer_size);
136
137 let handler = Self {
138 signal_tx,
139 media_tx,
140 data_tx,
141 sessions: RwLock::new(HashMap::new()),
142 shutdown: RwLock::new(false),
143 };
144
145 (handler, signal_rx, media_rx, data_rx)
146 }
147
148 pub fn with_defaults() -> (
150 Self,
151 mpsc::Receiver<WebRtcIncoming>,
152 mpsc::Receiver<WebRtcIncoming>,
153 mpsc::Receiver<WebRtcIncoming>,
154 ) {
155 Self::new(WebRtcHandlerConfig::default())
156 }
157
158 async fn handle_signal(&self, peer: PeerId, data: Bytes) -> TransportResult<Option<Bytes>> {
160 trace!(peer = ?peer, size = data.len(), "Processing WebRTC signal");
161
162 let message: SignalingMessage = serde_json::from_slice(&data).map_err(|e| {
164 TransportError::Internal(format!("Failed to deserialize signaling message: {}", e))
165 })?;
166
167 debug!(
168 peer = ?peer,
169 session_id = %message.session_id(),
170 "Received signaling message"
171 );
172
173 {
175 let mut sessions = self.sessions.write().await;
176 let session = sessions.entry(peer).or_default();
177 session.messages_received += 1;
178 session.last_activity = Some(std::time::Instant::now());
179 }
180
181 self.signal_tx
183 .send(WebRtcIncoming::Signal { peer, message })
184 .await
185 .map_err(|e| {
186 TransportError::Internal(format!("Failed to send to signal channel: {}", e))
187 })?;
188
189 Ok(None)
191 }
192
193 async fn handle_media(&self, peer: PeerId, data: Bytes) -> TransportResult<Option<Bytes>> {
195 trace!(peer = ?peer, size = data.len(), "Processing WebRTC media");
196
197 let packet = RtpPacket::from_bytes(&data).map_err(|e| {
199 TransportError::Internal(format!("Failed to deserialize RTP packet: {}", e))
200 })?;
201
202 trace!(
203 peer = ?peer,
204 stream_type = ?packet.stream_type,
205 seq = packet.sequence_number,
206 "Received media packet"
207 );
208
209 {
211 let mut sessions = self.sessions.write().await;
212 let session = sessions.entry(peer).or_default();
213 session.messages_received += 1;
214 session.last_activity = Some(std::time::Instant::now());
215 }
216
217 match self
219 .media_tx
220 .try_send(WebRtcIncoming::Media { peer, packet })
221 {
222 Ok(()) => {}
223 Err(mpsc::error::TrySendError::Full(_)) => {
224 warn!(peer = ?peer, "Media channel full, dropping packet");
225 }
226 Err(mpsc::error::TrySendError::Closed(_)) => {
227 return Err(TransportError::Shutdown);
228 }
229 }
230
231 Ok(None)
233 }
234
235 async fn handle_data(&self, peer: PeerId, data: Bytes) -> TransportResult<Option<Bytes>> {
237 trace!(peer = ?peer, size = data.len(), "Processing WebRTC data");
238
239 if data.len() < 4 {
241 return Err(TransportError::Internal(
242 "Data channel message too short".into(),
243 ));
244 }
245
246 let channel_id = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
247 let payload = data.slice(4..);
248
249 debug!(
250 peer = ?peer,
251 channel_id = channel_id,
252 payload_size = payload.len(),
253 "Received data channel message"
254 );
255
256 {
258 let mut sessions = self.sessions.write().await;
259 let session = sessions.entry(peer).or_default();
260 session.messages_received += 1;
261 session.last_activity = Some(std::time::Instant::now());
262 if !session.data_channels.contains(&channel_id) {
263 session.data_channels.push(channel_id);
264 }
265 }
266
267 self.data_tx
269 .send(WebRtcIncoming::Data {
270 peer,
271 channel_id,
272 data: payload,
273 })
274 .await
275 .map_err(|e| {
276 TransportError::Internal(format!("Failed to send to data channel: {}", e))
277 })?;
278
279 Ok(None)
281 }
282
283 pub async fn session_count(&self) -> usize {
285 self.sessions.read().await.len()
286 }
287
288 pub async fn get_session_stats(&self, peer: &PeerId) -> Option<(u64, Vec<u32>)> {
290 let sessions = self.sessions.read().await;
291 sessions
292 .get(peer)
293 .map(|s| (s.messages_received, s.data_channels.clone()))
294 }
295
296 pub async fn remove_session(&self, peer: &PeerId) {
298 let mut sessions = self.sessions.write().await;
299 if sessions.remove(peer).is_some() {
300 debug!(peer = ?peer, "Removed WebRTC session");
301 }
302 }
303}
304
305#[async_trait]
306impl ProtocolHandler for WebRtcProtocolHandler {
307 fn stream_types(&self) -> &[StreamType] {
308 StreamType::webrtc_types()
309 }
310
311 async fn handle_stream(
312 &self,
313 peer: PeerId,
314 stream_type: StreamType,
315 data: Bytes,
316 ) -> TransportResult<Option<Bytes>> {
317 if *self.shutdown.read().await {
319 return Err(TransportError::Shutdown);
320 }
321
322 match stream_type {
323 StreamType::WebRtcSignal => self.handle_signal(peer, data).await,
324 StreamType::WebRtcMedia => self.handle_media(peer, data).await,
325 StreamType::WebRtcData => self.handle_data(peer, data).await,
326 _ => {
327 error!(stream_type = %stream_type, "Unexpected stream type in WebRTC handler");
328 Err(TransportError::Internal(format!(
329 "Unknown stream type: {}",
330 stream_type
331 )))
332 }
333 }
334 }
335
336 async fn handle_datagram(
337 &self,
338 peer: PeerId,
339 stream_type: StreamType,
340 data: Bytes,
341 ) -> TransportResult<()> {
342 if stream_type == StreamType::WebRtcMedia {
344 trace!(peer = ?peer, size = data.len(), "Received media datagram");
345
346 if let Ok(packet) = RtpPacket::from_bytes(&data) {
348 let _ = self
349 .media_tx
350 .try_send(WebRtcIncoming::Media { peer, packet });
351 }
352 }
353 Ok(())
354 }
355
356 async fn shutdown(&self) -> TransportResult<()> {
357 debug!("Shutting down WebRTC protocol handler");
358
359 let mut shutdown = self.shutdown.write().await;
360 *shutdown = true;
361
362 self.sessions.write().await.clear();
364
365 Ok(())
366 }
367
368 fn name(&self) -> &str {
369 "WebRtcProtocolHandler"
370 }
371}
372
373pub struct WebRtcProtocolHandlerBuilder {
375 config: WebRtcHandlerConfig,
376}
377
378impl WebRtcProtocolHandlerBuilder {
379 pub fn new() -> Self {
381 Self {
382 config: WebRtcHandlerConfig::default(),
383 }
384 }
385
386 pub fn signal_buffer_size(mut self, size: usize) -> Self {
388 self.config.signal_buffer_size = size;
389 self
390 }
391
392 pub fn media_buffer_size(mut self, size: usize) -> Self {
394 self.config.media_buffer_size = size;
395 self
396 }
397
398 pub fn data_buffer_size(mut self, size: usize) -> Self {
400 self.config.data_buffer_size = size;
401 self
402 }
403
404 pub fn build(
406 self,
407 ) -> (
408 WebRtcProtocolHandler,
409 mpsc::Receiver<WebRtcIncoming>,
410 mpsc::Receiver<WebRtcIncoming>,
411 mpsc::Receiver<WebRtcIncoming>,
412 ) {
413 WebRtcProtocolHandler::new(self.config)
414 }
415}
416
417impl Default for WebRtcProtocolHandlerBuilder {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423#[cfg(test)]
424#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
425mod tests {
426 use super::*;
427
428 #[tokio::test]
429 async fn test_handler_stream_types() {
430 let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
431
432 let types = handler.stream_types();
433 assert!(types.contains(&StreamType::WebRtcSignal));
434 assert!(types.contains(&StreamType::WebRtcMedia));
435 assert!(types.contains(&StreamType::WebRtcData));
436 assert_eq!(types.len(), 3);
437 }
438
439 #[tokio::test]
440 async fn test_handler_name() {
441 let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
442 assert_eq!(handler.name(), "WebRtcProtocolHandler");
443 }
444
445 #[tokio::test]
446 async fn test_handle_signal_message() {
447 let (handler, mut signal_rx, _, _) = WebRtcProtocolHandler::with_defaults();
448
449 let peer = PeerId::from([1u8; 32]);
450 let message = SignalingMessage::Offer {
451 session_id: "test-session".to_string(),
452 sdp: "v=0\r\no=- 123 1 IN IP4 127.0.0.1\r\n".to_string(),
453 quic_endpoint: None,
454 };
455
456 let data = Bytes::from(serde_json::to_vec(&message).unwrap());
457
458 let result = handler
459 .handle_stream(peer, StreamType::WebRtcSignal, data)
460 .await;
461 assert!(result.is_ok());
462
463 let received = signal_rx.try_recv();
465 assert!(received.is_ok());
466
467 if let WebRtcIncoming::Signal {
468 peer: p,
469 message: m,
470 } = received.unwrap()
471 {
472 assert_eq!(p, peer);
473 assert_eq!(m.session_id(), "test-session");
474 } else {
475 panic!("Expected Signal message");
476 }
477 }
478
479 #[tokio::test]
480 async fn test_handle_media_packet() {
481 let (handler, _, mut media_rx, _) = WebRtcProtocolHandler::with_defaults();
482
483 let peer = PeerId::from([2u8; 32]);
484 let packet = RtpPacket::new(
485 96, 1000, 12345, 0xDEADBEEF, vec![1, 2, 3, 4],
490 crate::quic_bridge::StreamType::Audio,
491 )
492 .unwrap();
493
494 let data = Bytes::from(packet.to_bytes().unwrap());
495
496 let result = handler
497 .handle_stream(peer, StreamType::WebRtcMedia, data)
498 .await;
499 assert!(result.is_ok());
500
501 let received = media_rx.try_recv();
503 assert!(received.is_ok());
504
505 if let WebRtcIncoming::Media {
506 peer: p,
507 packet: pkt,
508 } = received.unwrap()
509 {
510 assert_eq!(p, peer);
511 assert_eq!(pkt.sequence_number, 1000);
512 } else {
513 panic!("Expected Media message");
514 }
515 }
516
517 #[tokio::test]
518 async fn test_handle_data_channel() {
519 let (handler, _, _, mut data_rx) = WebRtcProtocolHandler::with_defaults();
520
521 let peer = PeerId::from([3u8; 32]);
522
523 let channel_id: u32 = 42;
525 let payload = b"hello world";
526 let mut data = channel_id.to_be_bytes().to_vec();
527 data.extend_from_slice(payload);
528
529 let result = handler
530 .handle_stream(peer, StreamType::WebRtcData, Bytes::from(data))
531 .await;
532 assert!(result.is_ok());
533
534 let received = data_rx.try_recv();
536 assert!(received.is_ok());
537
538 if let WebRtcIncoming::Data {
539 peer: p,
540 channel_id: ch,
541 data: d,
542 } = received.unwrap()
543 {
544 assert_eq!(p, peer);
545 assert_eq!(ch, 42);
546 assert_eq!(&d[..], payload);
547 } else {
548 panic!("Expected Data message");
549 }
550 }
551
552 #[tokio::test]
553 async fn test_data_channel_too_short() {
554 let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
555
556 let peer = PeerId::from([4u8; 32]);
557 let data = Bytes::from_static(&[1, 2, 3]); let result = handler
560 .handle_stream(peer, StreamType::WebRtcData, data)
561 .await;
562
563 assert!(result.is_err());
564 }
565
566 #[tokio::test]
567 async fn test_session_tracking() {
568 let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
569
570 let peer = PeerId::from([5u8; 32]);
571
572 assert_eq!(handler.session_count().await, 0);
574
575 let mut data = 1u32.to_be_bytes().to_vec();
577 data.extend_from_slice(b"test");
578
579 let _ = handler
580 .handle_stream(peer, StreamType::WebRtcData, Bytes::from(data))
581 .await;
582
583 assert_eq!(handler.session_count().await, 1);
585
586 let stats = handler.get_session_stats(&peer).await;
587 assert!(stats.is_some());
588 let (msgs, channels) = stats.unwrap();
589 assert_eq!(msgs, 1);
590 assert!(channels.contains(&1));
591
592 handler.remove_session(&peer).await;
594 assert_eq!(handler.session_count().await, 0);
595 }
596
597 #[tokio::test]
598 async fn test_shutdown() {
599 let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
600
601 let result = handler.shutdown().await;
603 assert!(result.is_ok());
604
605 let peer = PeerId::from([6u8; 32]);
607 let result = handler
608 .handle_stream(peer, StreamType::WebRtcSignal, Bytes::new())
609 .await;
610 assert!(result.is_err());
611 }
612
613 #[tokio::test]
614 async fn test_builder() {
615 let (handler, _, _, _) = WebRtcProtocolHandlerBuilder::new()
616 .signal_buffer_size(128)
617 .media_buffer_size(512)
618 .data_buffer_size(256)
619 .build();
620
621 assert_eq!(handler.name(), "WebRtcProtocolHandler");
622 }
623
624 #[tokio::test]
625 async fn test_invalid_signal_message() {
626 let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
627
628 let peer = PeerId::from([7u8; 32]);
629 let data = Bytes::from_static(b"not valid json");
630
631 let result = handler
632 .handle_stream(peer, StreamType::WebRtcSignal, data)
633 .await;
634 assert!(result.is_err());
635 }
636
637 #[tokio::test]
638 async fn test_unexpected_stream_type() {
639 let (handler, _, _, _) = WebRtcProtocolHandler::with_defaults();
640
641 let peer = PeerId::from([8u8; 32]);
642
643 let result = handler
645 .handle_stream(peer, StreamType::Membership, Bytes::new())
646 .await;
647 assert!(result.is_err());
648 }
649}
650
651pub mod stream_routing {
655 use crate::link_transport::StreamType;
656
657 pub const AUDIO_PAYLOAD_TYPE_RANGE: (u8, u8) = (96, 127);
659
660 pub const VIDEO_PAYLOAD_TYPE_RANGE: (u8, u8) = (96, 127);
662
663 pub const RTCP_PAYLOAD_TYPE_RANGE: (u8, u8) = (200, 211);
665
666 pub fn is_rtp(payload_type: u8) -> bool {
676 payload_type < 128 || (96..=127).contains(&payload_type)
677 }
678
679 pub fn is_rtcp(payload_type: u8) -> bool {
689 (200..=211).contains(&payload_type)
690 }
691
692 pub fn is_audio_codec(payload_type: u8) -> bool {
702 matches!(
705 payload_type,
706 0 | 1
707 | 3
708 | 4
709 | 5
710 | 6
711 | 7
712 | 8
713 | 9
714 | 10
715 | 11
716 | 12
717 | 13
718 | 14
719 | 15
720 | 16
721 | 17
722 | 18
723 | 19
724 | 25
725 | 97
726 )
727 }
728
729 pub fn is_video_codec(payload_type: u8) -> bool {
739 matches!(
741 payload_type,
742 26 | 32 | 33 | 34 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105
743 )
744 }
745
746 pub fn route_to_stream(payload_type: u8) -> StreamType {
756 if is_rtcp(payload_type) {
757 StreamType::RtcpFeedback
758 } else if is_audio_codec(payload_type) {
759 StreamType::Audio
760 } else if is_video_codec(payload_type) {
761 StreamType::Video
762 } else {
763 StreamType::Video
765 }
766 }
767
768 pub fn rtcp_feedback_types() -> Vec<u8> {
774 (200..=211).collect()
775 }
776
777 #[cfg(test)]
778 mod routing_tests {
779 use super::*;
780
781 #[test]
782 fn test_is_rtp() {
783 assert!(is_rtp(0));
784 assert!(is_rtp(96));
785 assert!(is_rtp(127));
786 assert!(!is_rtp(200));
787 }
788
789 #[test]
790 fn test_is_rtcp() {
791 assert!(is_rtcp(200));
792 assert!(is_rtcp(205));
793 assert!(is_rtcp(211));
794 assert!(!is_rtcp(199));
795 assert!(!is_rtcp(212));
796 }
797
798 #[test]
799 fn test_is_audio_codec() {
800 assert!(is_audio_codec(0)); assert!(is_audio_codec(8)); assert!(is_audio_codec(97)); assert!(!is_audio_codec(26)); }
805
806 #[test]
807 fn test_is_video_codec() {
808 assert!(is_video_codec(26)); assert!(is_video_codec(32)); assert!(is_video_codec(96)); assert!(!is_video_codec(0)); }
813
814 #[test]
815 fn test_route_to_stream_audio() {
816 let stream = route_to_stream(0); assert_eq!(stream, StreamType::Audio);
818 }
819
820 #[test]
821 fn test_route_to_stream_video() {
822 let stream = route_to_stream(26); assert_eq!(stream, StreamType::Video);
824 }
825
826 #[test]
827 fn test_route_to_stream_rtcp() {
828 let stream = route_to_stream(200); assert_eq!(stream, StreamType::RtcpFeedback);
830 }
831
832 #[test]
833 fn test_rtcp_feedback_types() {
834 let types = rtcp_feedback_types();
835 assert_eq!(types.len(), 12);
836 assert_eq!(types[0], 200);
837 assert_eq!(types[11], 211);
838 }
839 }
840}
841
842impl WebRtcProtocolHandler {
843 pub fn route_packet_to_stream(payload: &[u8]) -> crate::link_transport::StreamType {
853 if payload.is_empty() {
854 return crate::link_transport::StreamType::Data;
855 }
856
857 if payload[1] >= 200 {
860 return crate::link_transport::StreamType::RtcpFeedback;
861 }
862
863 let pt = payload[1] & 0x7F;
865 stream_routing::route_to_stream(pt)
866 }
867
868 pub fn stream_media_type(stream_type: crate::link_transport::StreamType) -> &'static str {
878 match stream_type {
879 crate::link_transport::StreamType::Audio => "Audio RTP",
880 crate::link_transport::StreamType::Video => "Video RTP",
881 crate::link_transport::StreamType::Screen => "Screen Share RTP",
882 crate::link_transport::StreamType::RtcpFeedback => "RTCP Feedback",
883 crate::link_transport::StreamType::Data => "Data Channel",
884 }
885 }
886}
887
888#[cfg(test)]
889mod routing_integration_tests {
890 use super::*;
891 use crate::link_transport::StreamType;
892
893 #[test]
894 fn test_route_packet_audio_rtp() {
895 let payload = vec![0x80, 0x00, 0x00, 0x01];
897 let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
898 assert_eq!(stream, StreamType::Audio);
899 }
900
901 #[test]
902 fn test_route_packet_video_rtp() {
903 let payload = vec![0x80, 0x1A, 0x00, 0x01];
905 let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
906 assert_eq!(stream, StreamType::Video);
907 }
908
909 #[test]
910 fn test_route_packet_rtcp() {
911 let payload = vec![0x80, 0xC8, 0x00, 0x01];
913 let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
914 assert_eq!(stream, StreamType::RtcpFeedback);
915 }
916
917 #[test]
918 fn test_route_packet_empty() {
919 let payload: Vec<u8> = vec![];
920 let stream = WebRtcProtocolHandler::route_packet_to_stream(&payload);
921 assert_eq!(stream, StreamType::Data);
922 }
923
924 #[test]
925 fn test_stream_media_type_descriptions() {
926 assert_eq!(
927 WebRtcProtocolHandler::stream_media_type(StreamType::Audio),
928 "Audio RTP"
929 );
930 assert_eq!(
931 WebRtcProtocolHandler::stream_media_type(StreamType::Video),
932 "Video RTP"
933 );
934 assert_eq!(
935 WebRtcProtocolHandler::stream_media_type(StreamType::Screen),
936 "Screen Share RTP"
937 );
938 assert_eq!(
939 WebRtcProtocolHandler::stream_media_type(StreamType::RtcpFeedback),
940 "RTCP Feedback"
941 );
942 assert_eq!(
943 WebRtcProtocolHandler::stream_media_type(StreamType::Data),
944 "Data Channel"
945 );
946 }
947}