1use crate::message::{ParseError, SipMessage};
2use std::net::SocketAddr;
3use std::sync::Arc;
4use thiserror::Error;
5use tokio::net::UdpSocket;
6use tokio::sync::mpsc;
7
8#[derive(Debug, Error)]
9pub enum TransportError {
10 #[error("IO error: {0}")]
11 Io(#[from] std::io::Error),
12 #[error("Parse error: {0}")]
13 Parse(#[from] ParseError),
14 #[error("Transport not started")]
15 NotStarted,
16 #[error("Send failed: {0}")]
17 SendFailed(String),
18}
19
20#[derive(Debug, Clone)]
21pub struct IncomingMessage {
22 pub message: SipMessage,
23 pub source: SocketAddr,
24}
25
26pub struct SipTransport {
27 socket: Arc<UdpSocket>,
28 local_addr: SocketAddr,
29}
30
31impl SipTransport {
32 pub async fn bind(addr: &str) -> Result<Self, TransportError> {
34 let socket = UdpSocket::bind(addr).await?;
35 let local_addr = socket.local_addr()?;
36 Ok(Self {
37 socket: Arc::new(socket),
38 local_addr,
39 })
40 }
41
42 pub async fn bind_addr(addr: SocketAddr) -> Result<Self, TransportError> {
44 let socket = UdpSocket::bind(addr).await?;
45 let local_addr = socket.local_addr()?;
46 Ok(Self {
47 socket: Arc::new(socket),
48 local_addr,
49 })
50 }
51
52 pub fn local_addr(&self) -> SocketAddr {
54 self.local_addr
55 }
56
57 pub async fn send_to(
59 &self,
60 message: &SipMessage,
61 addr: SocketAddr,
62 ) -> Result<usize, TransportError> {
63 let data = message.to_bytes();
64 let sent = self.socket.send_to(&data, addr).await?;
65 tracing::debug!("Sent {} bytes to {}", sent, addr);
66 Ok(sent)
67 }
68
69 pub async fn send_raw(
71 &self,
72 data: &[u8],
73 addr: SocketAddr,
74 ) -> Result<usize, TransportError> {
75 let sent = self.socket.send_to(data, addr).await?;
76 Ok(sent)
77 }
78
79 pub async fn recv(&self) -> Result<IncomingMessage, TransportError> {
81 let mut buf = vec![0u8; 65535];
82 let (len, source) = self.socket.recv_from(&mut buf).await?;
83 let data = String::from_utf8_lossy(&buf[..len]);
84 let message = SipMessage::parse(&data)?;
85
86 Ok(IncomingMessage { message, source })
87 }
88
89 pub fn start_receiving(
91 &self,
92 buffer_size: usize,
93 ) -> (mpsc::Receiver<IncomingMessage>, mpsc::Sender<()>) {
94 let (tx, rx) = mpsc::channel(buffer_size);
95 let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
96 let socket = self.socket.clone();
97
98 tokio::spawn(async move {
99 let mut buf = vec![0u8; 65535];
100 loop {
101 tokio::select! {
102 result = socket.recv_from(&mut buf) => {
103 match result {
104 Ok((len, source)) => {
105 let data = String::from_utf8_lossy(&buf[..len]);
106 match SipMessage::parse(&data) {
107 Ok(message) => {
108 if tx
109 .send(IncomingMessage {
110 message,
111 source,
112 })
113 .await
114 .is_err()
115 {
116 break; }
118 }
119 Err(e) => {
120 tracing::warn!("Failed to parse SIP message from {}: {}", source, e);
121 }
122 }
123 }
124 Err(e) => {
125 tracing::error!("UDP receive error: {}", e);
126 break;
127 }
128 }
129 }
130 _ = stop_rx.recv() => {
131 break;
132 }
133 }
134 }
135 });
136
137 (rx, stop_tx)
138 }
139
140 pub fn socket(&self) -> &Arc<UdpSocket> {
142 &self.socket
143 }
144}
145
146pub fn parse_sip_uri(uri: &str) -> Option<(String, u16)> {
148 let uri = uri.strip_prefix("sip:").or_else(|| uri.strip_prefix("sips:"))?;
149
150 let host_part = if let Some(at_pos) = uri.find('@') {
152 &uri[at_pos + 1..]
153 } else {
154 uri
155 };
156
157 let host_part = host_part.split(';').next()?;
159
160 if host_part.starts_with('[') {
162 let end_bracket = host_part.find(']')?;
164 let host = &host_part[1..end_bracket];
165 let after = &host_part[end_bracket + 1..];
166 let port = if let Some(port_str) = after.strip_prefix(':') {
167 port_str.parse().ok()?
168 } else {
169 5060
170 };
171 Some((host.to_string(), port))
172 } else if let Some((host, port_str)) = host_part.rsplit_once(':') {
173 if host.contains(':') {
175 Some((host_part.to_string(), 5060))
177 } else {
178 let port: u16 = port_str.parse().ok()?;
179 Some((host.to_string(), port))
180 }
181 } else {
182 Some((host_part.to_string(), 5060))
183 }
184}
185
186pub fn resolve_sip_uri(uri: &str) -> Option<SocketAddr> {
188 let (host, port) = parse_sip_uri(uri)?;
189 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
191 Some(SocketAddr::new(ip, port))
192 } else {
193 None
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_parse_sip_uri() {
205 assert_eq!(
206 parse_sip_uri("sip:bob@192.168.1.100:5060"),
207 Some(("192.168.1.100".to_string(), 5060))
208 );
209 assert_eq!(
210 parse_sip_uri("sip:bob@biloxi.com"),
211 Some(("biloxi.com".to_string(), 5060))
212 );
213 assert_eq!(
214 parse_sip_uri("sip:alice@10.0.0.1:5080"),
215 Some(("10.0.0.1".to_string(), 5080))
216 );
217 assert_eq!(
218 parse_sip_uri("sip:registrar.example.com"),
219 Some(("registrar.example.com".to_string(), 5060))
220 );
221 }
222
223 #[test]
224 fn test_parse_sip_uri_with_params() {
225 assert_eq!(
226 parse_sip_uri("sip:bob@192.168.1.100:5060;transport=udp"),
227 Some(("192.168.1.100".to_string(), 5060))
228 );
229 }
230
231 #[test]
232 fn test_parse_sip_uri_invalid() {
233 assert_eq!(parse_sip_uri("http://example.com"), None);
234 assert_eq!(parse_sip_uri("not-a-uri"), None);
235 }
236
237 #[test]
238 fn test_resolve_sip_uri_ip() {
239 let addr = resolve_sip_uri("sip:bob@192.168.1.100:5060").unwrap();
240 assert_eq!(addr.ip().to_string(), "192.168.1.100");
241 assert_eq!(addr.port(), 5060);
242 }
243
244 #[test]
245 fn test_resolve_sip_uri_hostname() {
246 assert!(resolve_sip_uri("sip:bob@biloxi.com").is_none());
248 }
249
250 #[tokio::test]
251 async fn test_transport_bind() {
252 let transport = SipTransport::bind("127.0.0.1:0").await.unwrap();
253 let addr = transport.local_addr();
254 assert_eq!(addr.ip().to_string(), "127.0.0.1");
255 assert!(addr.port() > 0);
256 }
257
258 #[tokio::test]
259 async fn test_transport_send_receive() {
260 let t1 = SipTransport::bind("127.0.0.1:0").await.unwrap();
261 let t2 = SipTransport::bind("127.0.0.1:0").await.unwrap();
262
263 let request = crate::message::RequestBuilder::new(
264 crate::message::SipMethod::Register,
265 "sip:registrar.example.com",
266 )
267 .header(
268 HeaderName::Via,
269 format!(
270 "SIP/2.0/UDP {};branch=z9hG4bKtest123",
271 t1.local_addr()
272 ),
273 )
274 .header(HeaderName::From, "<sip:alice@example.com>;tag=abc")
275 .header(HeaderName::To, "<sip:alice@example.com>")
276 .header(HeaderName::CallId, "test-transport-call")
277 .header(HeaderName::CSeq, "1 REGISTER")
278 .build();
279
280 t1.send_to(&request, t2.local_addr()).await.unwrap();
282
283 let incoming = t2.recv().await.unwrap();
285 assert!(incoming.message.is_request());
286 assert_eq!(incoming.source, t1.local_addr());
287 assert_eq!(incoming.message.call_id().unwrap(), "test-transport-call");
288 }
289
290 #[tokio::test]
291 async fn test_transport_channel_receive() {
292 let t1 = SipTransport::bind("127.0.0.1:0").await.unwrap();
293 let t2 = SipTransport::bind("127.0.0.1:0").await.unwrap();
294
295 let (mut rx, _stop_tx) = t2.start_receiving(16);
296
297 let request = crate::message::RequestBuilder::new(
298 crate::message::SipMethod::Options,
299 "sip:bob@example.com",
300 )
301 .header(
302 HeaderName::Via,
303 format!("SIP/2.0/UDP {};branch=z9hG4bKchan", t1.local_addr()),
304 )
305 .header(HeaderName::From, "<sip:alice@example.com>;tag=ch1")
306 .header(HeaderName::To, "<sip:bob@example.com>")
307 .header(HeaderName::CallId, "channel-test")
308 .header(HeaderName::CSeq, "1 OPTIONS")
309 .build();
310
311 t1.send_to(&request, t2.local_addr()).await.unwrap();
312
313 let incoming = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
314 .await
315 .unwrap()
316 .unwrap();
317
318 assert_eq!(incoming.message.call_id().unwrap(), "channel-test");
319 }
320
321 use crate::header::HeaderName;
322}