saorsa_webrtc/
signaling.rs

1//! WebRTC signaling protocol
2//!
3//! Handles SDP exchange and ICE candidate gathering for WebRTC connections.
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::net::SocketAddr;
9use std::str::FromStr;
10use thiserror::Error;
11
12/// Signaling errors
13#[derive(Error, Debug)]
14pub enum SignalingError {
15    /// Invalid SDP
16    #[error("Invalid SDP: {0}")]
17    InvalidSdp(String),
18
19    /// Session not found
20    #[error("Session not found: {0}")]
21    SessionNotFound(String),
22
23    /// Transport error
24    #[error("Transport error: {0}")]
25    TransportError(String),
26}
27
28/// Signaling transport trait
29///
30/// Implement this for your specific transport (DHT, gossip, etc.)
31#[async_trait]
32pub trait SignalingTransport: Send + Sync {
33    /// Peer identifier type
34    type PeerId: Clone + Send + Sync + fmt::Debug + fmt::Display + FromStr;
35
36    /// Transport error type
37    type Error: std::error::Error + Send + Sync + 'static;
38
39    /// Send a signaling message
40    async fn send_message(
41        &self,
42        peer: &Self::PeerId,
43        message: SignalingMessage,
44    ) -> Result<(), Self::Error>;
45
46    /// Receive a signaling message
47    async fn receive_message(&self) -> Result<(Self::PeerId, SignalingMessage), Self::Error>;
48
49    /// Discover peer endpoint
50    async fn discover_peer_endpoint(
51        &self,
52        peer: &Self::PeerId,
53    ) -> Result<Option<SocketAddr>, Self::Error>;
54}
55
56/// Signaling message types
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58#[serde(tag = "type", rename_all = "lowercase")]
59pub enum SignalingMessage {
60    /// SDP offer
61    Offer {
62        /// Session ID
63        session_id: String,
64        /// SDP content
65        sdp: String,
66        /// Optional QUIC endpoint
67        quic_endpoint: Option<SocketAddr>,
68    },
69
70    /// SDP answer
71    Answer {
72        /// Session ID
73        session_id: String,
74        /// SDP content
75        sdp: String,
76        /// Optional QUIC endpoint
77        quic_endpoint: Option<SocketAddr>,
78    },
79
80    /// ICE candidate
81    IceCandidate {
82        /// Session ID
83        session_id: String,
84        /// Candidate string
85        candidate: String,
86        /// SDP mid
87        sdp_mid: Option<String>,
88        /// SDP mline index
89        sdp_mline_index: Option<u16>,
90    },
91
92    /// ICE gathering complete
93    IceComplete {
94        /// Session ID
95        session_id: String,
96    },
97
98    /// Close session
99    Bye {
100        /// Session ID
101        session_id: String,
102        /// Optional reason
103        reason: Option<String>,
104    },
105}
106
107impl SignalingMessage {
108    /// Get the session ID
109    #[must_use]
110    pub fn session_id(&self) -> &str {
111        match self {
112            Self::Offer { session_id, .. }
113            | Self::Answer { session_id, .. }
114            | Self::IceCandidate { session_id, .. }
115            | Self::IceComplete { session_id }
116            | Self::Bye { session_id, .. } => session_id,
117        }
118    }
119}
120
121/// Signaling handler
122pub struct SignalingHandler<T: SignalingTransport> {
123    transport: std::sync::Arc<T>,
124}
125
126impl<T: SignalingTransport> SignalingHandler<T> {
127    /// Create new signaling handler
128    #[must_use]
129    pub fn new(transport: std::sync::Arc<T>) -> Self {
130        Self { transport }
131    }
132
133    /// Send a signaling message to a peer
134    ///
135    /// # Errors
136    ///
137    /// Returns error if sending fails
138    pub async fn send_message(
139        &self,
140        peer: &T::PeerId,
141        message: SignalingMessage,
142    ) -> Result<(), T::Error> {
143        self.transport.send_message(peer, message).await
144    }
145
146    /// Receive a signaling message
147    ///
148    /// # Errors
149    ///
150    /// Returns error if receiving fails
151    pub async fn receive_message(&self) -> Result<(T::PeerId, SignalingMessage), T::Error> {
152        self.transport.receive_message().await
153    }
154
155    /// Discover endpoint for a peer
156    ///
157    /// # Errors
158    ///
159    /// Returns error if discovery fails
160    pub async fn discover_peer_endpoint(
161        &self,
162        peer: &T::PeerId,
163    ) -> Result<Option<std::net::SocketAddr>, T::Error> {
164        self.transport.discover_peer_endpoint(peer).await
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use async_trait::async_trait;
172    use std::collections::VecDeque;
173    use std::sync::{Arc, Mutex};
174
175    // Mock transport for testing
176    struct MockTransport {
177        messages: Mutex<VecDeque<(String, SignalingMessage)>>,
178    }
179
180    #[derive(Debug)]
181    struct MockError;
182
183    impl std::fmt::Display for MockError {
184        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185            write!(f, "Mock error")
186        }
187    }
188
189    impl std::error::Error for MockError {}
190
191    impl MockTransport {
192        fn new() -> Self {
193            Self {
194                messages: Mutex::new(VecDeque::new()),
195            }
196        }
197
198        fn add_message(&self, peer: String, message: SignalingMessage) {
199            self.messages.lock().unwrap().push_back((peer, message));
200        }
201    }
202
203    #[async_trait]
204    impl SignalingTransport for MockTransport {
205        type PeerId = String;
206        type Error = MockError;
207
208        async fn send_message(
209            &self,
210            peer: &String,
211            message: SignalingMessage,
212        ) -> Result<(), MockError> {
213            self.messages.lock().unwrap().push_back((peer.clone(), message));
214            Ok(())
215        }
216
217        async fn receive_message(&self) -> Result<(String, SignalingMessage), MockError> {
218            if let Some((peer, message)) = self.messages.lock().unwrap().pop_front() {
219                Ok((peer, message))
220            } else {
221                Err(MockError)
222            }
223        }
224
225        async fn discover_peer_endpoint(
226            &self,
227            _peer: &String,
228        ) -> Result<Option<std::net::SocketAddr>, MockError> {
229            Ok(Some("127.0.0.1:8080".parse().unwrap()))
230        }
231    }
232
233    #[tokio::test]
234    async fn test_signaling_handler_send_message() {
235        let transport = Arc::new(MockTransport::new());
236        let handler = SignalingHandler::new(transport.clone());
237
238        let message = SignalingMessage::Offer {
239            session_id: "test-session".to_string(),
240            sdp: "test-sdp".to_string(),
241            quic_endpoint: None,
242        };
243
244        let result = handler.send_message(&"peer1".to_string(), message.clone()).await;
245        assert!(result.is_ok());
246
247        // Check that message was queued
248        let received = transport.messages.lock().unwrap().pop_front();
249        assert_eq!(received, Some(("peer1".to_string(), message)));
250    }
251
252    #[tokio::test]
253    async fn test_signaling_handler_receive_message() {
254        let transport = Arc::new(MockTransport::new());
255        let handler = SignalingHandler::new(transport.clone());
256
257        let message = SignalingMessage::Answer {
258            session_id: "test-session".to_string(),
259            sdp: "test-sdp".to_string(),
260            quic_endpoint: None,
261        };
262
263        transport.add_message("peer1".to_string(), message.clone());
264
265        let result = handler.receive_message().await;
266        assert!(result.is_ok());
267        let (peer, received_message) = result.unwrap();
268        assert_eq!(peer, "peer1");
269        assert_eq!(received_message, message);
270    }
271
272    #[tokio::test]
273    async fn test_signaling_handler_discover_endpoint() {
274        let transport = Arc::new(MockTransport::new());
275        let handler = SignalingHandler::new(transport);
276
277        let result = handler.discover_peer_endpoint(&"peer1".to_string()).await;
278        assert!(result.is_ok());
279        assert_eq!(result.unwrap(), Some("127.0.0.1:8080".parse().unwrap()));
280    }
281}