1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::net::SocketAddr;
9use std::str::FromStr;
10use thiserror::Error;
11
12#[derive(Error, Debug)]
14pub enum SignalingError {
15 #[error("Invalid SDP: {0}")]
17 InvalidSdp(String),
18
19 #[error("Session not found: {0}")]
21 SessionNotFound(String),
22
23 #[error("Transport error: {0}")]
25 TransportError(String),
26}
27
28#[async_trait]
32pub trait SignalingTransport: Send + Sync {
33 type PeerId: Clone + Send + Sync + fmt::Debug + fmt::Display + FromStr;
35
36 type Error: std::error::Error + Send + Sync + 'static;
38
39 async fn send_message(
41 &self,
42 peer: &Self::PeerId,
43 message: SignalingMessage,
44 ) -> Result<(), Self::Error>;
45
46 async fn receive_message(&self) -> Result<(Self::PeerId, SignalingMessage), Self::Error>;
48
49 async fn discover_peer_endpoint(
51 &self,
52 peer: &Self::PeerId,
53 ) -> Result<Option<SocketAddr>, Self::Error>;
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58#[serde(tag = "type", rename_all = "lowercase")]
59pub enum SignalingMessage {
60 Offer {
62 session_id: String,
64 sdp: String,
66 quic_endpoint: Option<SocketAddr>,
68 },
69
70 Answer {
72 session_id: String,
74 sdp: String,
76 quic_endpoint: Option<SocketAddr>,
78 },
79
80 IceCandidate {
82 session_id: String,
84 candidate: String,
86 sdp_mid: Option<String>,
88 sdp_mline_index: Option<u16>,
90 },
91
92 IceComplete {
94 session_id: String,
96 },
97
98 Bye {
100 session_id: String,
102 reason: Option<String>,
104 },
105}
106
107impl SignalingMessage {
108 #[must_use]
110 pub fn session_id(&self) -> &str {
111 match self {
112 Self::Offer { session_id, .. }
113 | Self::Answer { session_id, .. }
114 | Self::IceCandidate { session_id, .. }
115 | Self::IceComplete { session_id }
116 | Self::Bye { session_id, .. } => session_id,
117 }
118 }
119}
120
121pub struct SignalingHandler<T: SignalingTransport> {
123 transport: std::sync::Arc<T>,
124}
125
126impl<T: SignalingTransport> SignalingHandler<T> {
127 #[must_use]
129 pub fn new(transport: std::sync::Arc<T>) -> Self {
130 Self { transport }
131 }
132
133 pub async fn send_message(
139 &self,
140 peer: &T::PeerId,
141 message: SignalingMessage,
142 ) -> Result<(), T::Error> {
143 self.transport.send_message(peer, message).await
144 }
145
146 pub async fn receive_message(&self) -> Result<(T::PeerId, SignalingMessage), T::Error> {
152 self.transport.receive_message().await
153 }
154
155 pub async fn discover_peer_endpoint(
161 &self,
162 peer: &T::PeerId,
163 ) -> Result<Option<std::net::SocketAddr>, T::Error> {
164 self.transport.discover_peer_endpoint(peer).await
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use async_trait::async_trait;
172 use std::collections::VecDeque;
173 use std::sync::{Arc, Mutex};
174
175 struct MockTransport {
177 messages: Mutex<VecDeque<(String, SignalingMessage)>>,
178 }
179
180 #[derive(Debug)]
181 struct MockError;
182
183 impl std::fmt::Display for MockError {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 write!(f, "Mock error")
186 }
187 }
188
189 impl std::error::Error for MockError {}
190
191 impl MockTransport {
192 fn new() -> Self {
193 Self {
194 messages: Mutex::new(VecDeque::new()),
195 }
196 }
197
198 fn add_message(&self, peer: String, message: SignalingMessage) {
199 self.messages.lock().unwrap().push_back((peer, message));
200 }
201 }
202
203 #[async_trait]
204 impl SignalingTransport for MockTransport {
205 type PeerId = String;
206 type Error = MockError;
207
208 async fn send_message(
209 &self,
210 peer: &String,
211 message: SignalingMessage,
212 ) -> Result<(), MockError> {
213 self.messages.lock().unwrap().push_back((peer.clone(), message));
214 Ok(())
215 }
216
217 async fn receive_message(&self) -> Result<(String, SignalingMessage), MockError> {
218 if let Some((peer, message)) = self.messages.lock().unwrap().pop_front() {
219 Ok((peer, message))
220 } else {
221 Err(MockError)
222 }
223 }
224
225 async fn discover_peer_endpoint(
226 &self,
227 _peer: &String,
228 ) -> Result<Option<std::net::SocketAddr>, MockError> {
229 Ok(Some("127.0.0.1:8080".parse().unwrap()))
230 }
231 }
232
233 #[tokio::test]
234 async fn test_signaling_handler_send_message() {
235 let transport = Arc::new(MockTransport::new());
236 let handler = SignalingHandler::new(transport.clone());
237
238 let message = SignalingMessage::Offer {
239 session_id: "test-session".to_string(),
240 sdp: "test-sdp".to_string(),
241 quic_endpoint: None,
242 };
243
244 let result = handler.send_message(&"peer1".to_string(), message.clone()).await;
245 assert!(result.is_ok());
246
247 let received = transport.messages.lock().unwrap().pop_front();
249 assert_eq!(received, Some(("peer1".to_string(), message)));
250 }
251
252 #[tokio::test]
253 async fn test_signaling_handler_receive_message() {
254 let transport = Arc::new(MockTransport::new());
255 let handler = SignalingHandler::new(transport.clone());
256
257 let message = SignalingMessage::Answer {
258 session_id: "test-session".to_string(),
259 sdp: "test-sdp".to_string(),
260 quic_endpoint: None,
261 };
262
263 transport.add_message("peer1".to_string(), message.clone());
264
265 let result = handler.receive_message().await;
266 assert!(result.is_ok());
267 let (peer, received_message) = result.unwrap();
268 assert_eq!(peer, "peer1");
269 assert_eq!(received_message, message);
270 }
271
272 #[tokio::test]
273 async fn test_signaling_handler_discover_endpoint() {
274 let transport = Arc::new(MockTransport::new());
275 let handler = SignalingHandler::new(transport);
276
277 let result = handler.discover_peer_endpoint(&"peer1".to_string()).await;
278 assert!(result.is_ok());
279 assert_eq!(result.unwrap(), Some("127.0.0.1:8080".parse().unwrap()));
280 }
281}