Skip to main content

sip_core/
transport.rs

1use crate::message::{ParseError, SipMessage};
2use std::net::SocketAddr;
3use std::sync::Arc;
4use thiserror::Error;
5use tokio::net::UdpSocket;
6use tokio::sync::mpsc;
7
8#[derive(Debug, Error)]
9pub enum TransportError {
10    #[error("IO error: {0}")]
11    Io(#[from] std::io::Error),
12    #[error("Parse error: {0}")]
13    Parse(#[from] ParseError),
14    #[error("Transport not started")]
15    NotStarted,
16    #[error("Send failed: {0}")]
17    SendFailed(String),
18}
19
20#[derive(Debug, Clone)]
21pub struct IncomingMessage {
22    pub message: SipMessage,
23    pub source: SocketAddr,
24}
25
26pub struct SipTransport {
27    socket: Arc<UdpSocket>,
28    local_addr: SocketAddr,
29}
30
31impl SipTransport {
32    /// Bind to a local address for UDP transport
33    pub async fn bind(addr: &str) -> Result<Self, TransportError> {
34        let socket = UdpSocket::bind(addr).await?;
35        let local_addr = socket.local_addr()?;
36        Ok(Self {
37            socket: Arc::new(socket),
38            local_addr,
39        })
40    }
41
42    /// Bind to a specific socket address
43    pub async fn bind_addr(addr: SocketAddr) -> Result<Self, TransportError> {
44        let socket = UdpSocket::bind(addr).await?;
45        let local_addr = socket.local_addr()?;
46        Ok(Self {
47            socket: Arc::new(socket),
48            local_addr,
49        })
50    }
51
52    /// Get the local address this transport is bound to
53    pub fn local_addr(&self) -> SocketAddr {
54        self.local_addr
55    }
56
57    /// Send a SIP message to a specific address
58    pub async fn send_to(
59        &self,
60        message: &SipMessage,
61        addr: SocketAddr,
62    ) -> Result<usize, TransportError> {
63        let data = message.to_bytes();
64        let sent = self.socket.send_to(&data, addr).await?;
65        tracing::debug!("Sent {} bytes to {}", sent, addr);
66        Ok(sent)
67    }
68
69    /// Send raw bytes to a specific address
70    pub async fn send_raw(
71        &self,
72        data: &[u8],
73        addr: SocketAddr,
74    ) -> Result<usize, TransportError> {
75        let sent = self.socket.send_to(data, addr).await?;
76        Ok(sent)
77    }
78
79    /// Receive a single SIP message
80    pub async fn recv(&self) -> Result<IncomingMessage, TransportError> {
81        let mut buf = vec![0u8; 65535];
82        let (len, source) = self.socket.recv_from(&mut buf).await?;
83        let data = String::from_utf8_lossy(&buf[..len]);
84        let message = SipMessage::parse(&data)?;
85
86        Ok(IncomingMessage { message, source })
87    }
88
89    /// Start receiving messages into a channel
90    pub fn start_receiving(
91        &self,
92        buffer_size: usize,
93    ) -> (mpsc::Receiver<IncomingMessage>, mpsc::Sender<()>) {
94        let (tx, rx) = mpsc::channel(buffer_size);
95        let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
96        let socket = self.socket.clone();
97
98        tokio::spawn(async move {
99            let mut buf = vec![0u8; 65535];
100            loop {
101                tokio::select! {
102                    result = socket.recv_from(&mut buf) => {
103                        match result {
104                            Ok((len, source)) => {
105                                let data = String::from_utf8_lossy(&buf[..len]);
106                                match SipMessage::parse(&data) {
107                                    Ok(message) => {
108                                        if tx
109                                            .send(IncomingMessage {
110                                                message,
111                                                source,
112                                            })
113                                            .await
114                                            .is_err()
115                                        {
116                                            break; // Channel closed
117                                        }
118                                    }
119                                    Err(e) => {
120                                        tracing::warn!("Failed to parse SIP message from {}: {}", source, e);
121                                    }
122                                }
123                            }
124                            Err(e) => {
125                                tracing::error!("UDP receive error: {}", e);
126                                break;
127                            }
128                        }
129                    }
130                    _ = stop_rx.recv() => {
131                        break;
132                    }
133                }
134            }
135        });
136
137        (rx, stop_tx)
138    }
139
140    /// Get the underlying socket reference (for advanced usage)
141    pub fn socket(&self) -> &Arc<UdpSocket> {
142        &self.socket
143    }
144}
145
146/// Parse a SIP URI into host and port
147pub fn parse_sip_uri(uri: &str) -> Option<(String, u16)> {
148    let uri = uri.strip_prefix("sip:").or_else(|| uri.strip_prefix("sips:"))?;
149
150    // Remove user@ part if present
151    let host_part = if let Some(at_pos) = uri.find('@') {
152        &uri[at_pos + 1..]
153    } else {
154        uri
155    };
156
157    // Remove any URI parameters
158    let host_part = host_part.split(';').next()?;
159
160    // Parse host:port (handle IPv6 bracket notation)
161    if host_part.starts_with('[') {
162        // IPv6: [host]:port or [host]
163        let end_bracket = host_part.find(']')?;
164        let host = &host_part[1..end_bracket];
165        let after = &host_part[end_bracket + 1..];
166        let port = if let Some(port_str) = after.strip_prefix(':') {
167            port_str.parse().ok()?
168        } else {
169            5060
170        };
171        Some((host.to_string(), port))
172    } else if let Some((host, port_str)) = host_part.rsplit_once(':') {
173        // Avoid splitting on IPv6 colons (unbracketed)
174        if host.contains(':') {
175            // Likely bare IPv6 without brackets
176            Some((host_part.to_string(), 5060))
177        } else {
178            let port: u16 = port_str.parse().ok()?;
179            Some((host.to_string(), port))
180        }
181    } else {
182        Some((host_part.to_string(), 5060))
183    }
184}
185
186/// Resolve a SIP URI to a socket address
187pub fn resolve_sip_uri(uri: &str) -> Option<SocketAddr> {
188    let (host, port) = parse_sip_uri(uri)?;
189    // For simplicity, try to parse as IP directly
190    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
191        Some(SocketAddr::new(ip, port))
192    } else {
193        // DNS resolution would happen here in production
194        // For now, return None for hostnames
195        None
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_parse_sip_uri() {
205        assert_eq!(
206            parse_sip_uri("sip:bob@192.168.1.100:5060"),
207            Some(("192.168.1.100".to_string(), 5060))
208        );
209        assert_eq!(
210            parse_sip_uri("sip:bob@biloxi.com"),
211            Some(("biloxi.com".to_string(), 5060))
212        );
213        assert_eq!(
214            parse_sip_uri("sip:alice@10.0.0.1:5080"),
215            Some(("10.0.0.1".to_string(), 5080))
216        );
217        assert_eq!(
218            parse_sip_uri("sip:registrar.example.com"),
219            Some(("registrar.example.com".to_string(), 5060))
220        );
221    }
222
223    #[test]
224    fn test_parse_sip_uri_with_params() {
225        assert_eq!(
226            parse_sip_uri("sip:bob@192.168.1.100:5060;transport=udp"),
227            Some(("192.168.1.100".to_string(), 5060))
228        );
229    }
230
231    #[test]
232    fn test_parse_sip_uri_invalid() {
233        assert_eq!(parse_sip_uri("http://example.com"), None);
234        assert_eq!(parse_sip_uri("not-a-uri"), None);
235    }
236
237    #[test]
238    fn test_resolve_sip_uri_ip() {
239        let addr = resolve_sip_uri("sip:bob@192.168.1.100:5060").unwrap();
240        assert_eq!(addr.ip().to_string(), "192.168.1.100");
241        assert_eq!(addr.port(), 5060);
242    }
243
244    #[test]
245    fn test_resolve_sip_uri_hostname() {
246        // Hostname resolution returns None in this simple implementation
247        assert!(resolve_sip_uri("sip:bob@biloxi.com").is_none());
248    }
249
250    #[tokio::test]
251    async fn test_transport_bind() {
252        let transport = SipTransport::bind("127.0.0.1:0").await.unwrap();
253        let addr = transport.local_addr();
254        assert_eq!(addr.ip().to_string(), "127.0.0.1");
255        assert!(addr.port() > 0);
256    }
257
258    #[tokio::test]
259    async fn test_transport_send_receive() {
260        let t1 = SipTransport::bind("127.0.0.1:0").await.unwrap();
261        let t2 = SipTransport::bind("127.0.0.1:0").await.unwrap();
262
263        let request = crate::message::RequestBuilder::new(
264            crate::message::SipMethod::Register,
265            "sip:registrar.example.com",
266        )
267        .header(
268            HeaderName::Via,
269            format!(
270                "SIP/2.0/UDP {};branch=z9hG4bKtest123",
271                t1.local_addr()
272            ),
273        )
274        .header(HeaderName::From, "<sip:alice@example.com>;tag=abc")
275        .header(HeaderName::To, "<sip:alice@example.com>")
276        .header(HeaderName::CallId, "test-transport-call")
277        .header(HeaderName::CSeq, "1 REGISTER")
278        .build();
279
280        // Send from t1 to t2
281        t1.send_to(&request, t2.local_addr()).await.unwrap();
282
283        // Receive on t2
284        let incoming = t2.recv().await.unwrap();
285        assert!(incoming.message.is_request());
286        assert_eq!(incoming.source, t1.local_addr());
287        assert_eq!(incoming.message.call_id().unwrap(), "test-transport-call");
288    }
289
290    #[tokio::test]
291    async fn test_transport_channel_receive() {
292        let t1 = SipTransport::bind("127.0.0.1:0").await.unwrap();
293        let t2 = SipTransport::bind("127.0.0.1:0").await.unwrap();
294
295        let (mut rx, _stop_tx) = t2.start_receiving(16);
296
297        let request = crate::message::RequestBuilder::new(
298            crate::message::SipMethod::Options,
299            "sip:bob@example.com",
300        )
301        .header(
302            HeaderName::Via,
303            format!("SIP/2.0/UDP {};branch=z9hG4bKchan", t1.local_addr()),
304        )
305        .header(HeaderName::From, "<sip:alice@example.com>;tag=ch1")
306        .header(HeaderName::To, "<sip:bob@example.com>")
307        .header(HeaderName::CallId, "channel-test")
308        .header(HeaderName::CSeq, "1 OPTIONS")
309        .build();
310
311        t1.send_to(&request, t2.local_addr()).await.unwrap();
312
313        let incoming = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
314            .await
315            .unwrap()
316            .unwrap();
317
318        assert_eq!(incoming.message.call_id().unwrap(), "channel-test");
319    }
320
321    use crate::header::HeaderName;
322}