1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7#[non_exhaustive]
8pub enum SipViaError {
9 Empty,
11 InvalidFormat(String),
13}
14
15impl fmt::Display for SipViaError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 Self::Empty => write!(f, "Via header is empty"),
19 Self::InvalidFormat(msg) => write!(f, "Invalid Via format: {}", msg),
20 }
21 }
22}
23
24impl std::error::Error for SipViaError {}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28#[non_exhaustive]
29pub struct SipViaEntry {
30 protocol_name: String,
31 protocol_version: String,
32 transport: String,
33 host: String,
34 port: Option<u16>,
35 params: Vec<(String, Option<String>)>,
36 rport: Option<Option<u16>>,
37}
38
39impl SipViaEntry {
40 pub fn protocol(&self) -> &str {
42 &self.protocol_name
43 }
44
45 pub fn version(&self) -> &str {
47 &self.protocol_version
48 }
49
50 pub fn transport(&self) -> &str {
52 &self.transport
53 }
54
55 pub fn host(&self) -> &str {
57 &self.host
58 }
59
60 pub fn port(&self) -> Option<u16> {
62 self.port
63 }
64
65 pub fn params(&self) -> &[(String, Option<String>)] {
67 &self.params
68 }
69
70 pub fn param(&self, key: &str) -> Option<Option<&str>> {
72 let key_lower = key.to_ascii_lowercase();
73 self.params
74 .iter()
75 .find(|(k, _)| k == &key_lower)
76 .map(|(_, v)| v.as_deref())
77 }
78
79 pub fn branch(&self) -> Option<&str> {
81 self.param("branch")
82 .flatten()
83 }
84
85 pub fn received(&self) -> Option<&str> {
87 self.param("received")
88 .flatten()
89 }
90
91 pub fn rport(&self) -> Option<Option<u16>> {
100 self.rport
101 }
102
103 fn parse(entry: &str) -> Result<Self, SipViaError> {
104 let trimmed = entry.trim();
105 if trimmed.is_empty() {
106 return Err(SipViaError::InvalidFormat("empty Via entry".to_string()));
107 }
108
109 let (main_part, params_part) = if let Some(semi_idx) = trimmed.find(';') {
111 (&trimmed[..semi_idx], Some(&trimmed[semi_idx + 1..]))
112 } else {
113 (trimmed, None)
114 };
115
116 let parts: Vec<&str> = main_part
118 .split_whitespace()
119 .collect();
120 if parts.len() != 2 {
121 return Err(SipViaError::InvalidFormat(format!(
122 "expected 'protocol/version/transport host[:port]', got '{}'",
123 main_part
124 )));
125 }
126
127 let sent_protocol = parts[0];
128 let sent_by = parts[1];
129
130 let protocol_parts: Vec<&str> = sent_protocol
132 .split('/')
133 .collect();
134 if protocol_parts.len() != 3 {
135 return Err(SipViaError::InvalidFormat(format!(
136 "expected 'protocol/version/transport', got '{}'",
137 sent_protocol
138 )));
139 }
140
141 let protocol_name = protocol_parts[0].to_string();
142 let protocol_version = protocol_parts[1].to_string();
143 let transport = protocol_parts[2].to_string();
144
145 let (host, port) = parse_host_port(sent_by)?;
148
149 let mut params = Vec::new();
151 if let Some(params_str) = params_part {
152 for param in params_str.split(';') {
153 let param = param.trim();
154 if param.is_empty() {
155 continue;
156 }
157
158 if let Some(eq_idx) = param.find('=') {
159 let key = param[..eq_idx]
160 .trim()
161 .to_ascii_lowercase();
162 let value = param[eq_idx + 1..]
163 .trim()
164 .to_string();
165 params.push((key, Some(value)));
166 } else {
167 params.push((param.to_ascii_lowercase(), None));
169 }
170 }
171 }
172
173 let rport = params
174 .iter()
175 .find(|(k, _)| k == "rport")
176 .map(|(_, v)| match v {
177 None => Ok(None),
178 Some(s) => s
179 .parse::<u16>()
180 .map(Some)
181 .map_err(|_| SipViaError::InvalidFormat(format!("invalid rport value: {s}"))),
182 })
183 .transpose()?;
184
185 Ok(Self {
186 protocol_name,
187 protocol_version,
188 transport,
189 host,
190 port,
191 params,
192 rport,
193 })
194 }
195}
196
197impl fmt::Display for SipViaEntry {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 write!(
200 f,
201 "{}/{}/{}",
202 self.protocol_name, self.protocol_version, self.transport
203 )?;
204
205 if self
207 .host
208 .contains(':')
209 && !self
210 .host
211 .starts_with('[')
212 {
213 write!(f, " [{}]", self.host)?;
214 } else {
215 write!(f, " {}", self.host)?;
216 }
217
218 if let Some(port) = self.port {
219 write!(f, ":{}", port)?;
220 }
221
222 for (key, value) in &self.params {
223 if let Some(val) = value {
224 write!(f, ";{}={}", key, val)?;
225 } else {
226 write!(f, ";{}", key)?;
227 }
228 }
229
230 Ok(())
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
236#[non_exhaustive]
237pub struct SipVia {
238 entries: Vec<SipViaEntry>,
239}
240
241impl SipVia {
242 pub fn parse(raw: &str) -> Result<Self, SipViaError> {
244 let raw = raw.trim();
245 if raw.is_empty() {
246 return Err(SipViaError::Empty);
247 }
248
249 let parts = crate::split_comma_entries(raw);
250 let mut entries = Vec::new();
251
252 for part in parts {
253 entries.push(SipViaEntry::parse(part)?);
254 }
255
256 if entries.is_empty() {
257 return Err(SipViaError::Empty);
258 }
259
260 Ok(Self { entries })
261 }
262
263 pub fn entries(&self) -> &[SipViaEntry] {
265 &self.entries
266 }
267
268 pub fn into_entries(self) -> Vec<SipViaEntry> {
270 self.entries
271 }
272
273 pub fn len(&self) -> usize {
275 self.entries
276 .len()
277 }
278
279 pub fn is_empty(&self) -> bool {
281 self.entries
282 .is_empty()
283 }
284}
285
286impl fmt::Display for SipVia {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 crate::fmt_joined(f, &self.entries, ", ")
289 }
290}
291
292impl_from_str_via_parse!(SipVia, SipViaError);
293
294impl IntoIterator for SipVia {
295 type Item = SipViaEntry;
296 type IntoIter = std::vec::IntoIter<SipViaEntry>;
297
298 fn into_iter(self) -> Self::IntoIter {
299 self.entries
300 .into_iter()
301 }
302}
303
304impl<'a> IntoIterator for &'a SipVia {
305 type Item = &'a SipViaEntry;
306 type IntoIter = std::slice::Iter<'a, SipViaEntry>;
307
308 fn into_iter(self) -> Self::IntoIter {
309 self.entries
310 .iter()
311 }
312}
313
314fn parse_host_port(sent_by: &str) -> Result<(String, Option<u16>), SipViaError> {
315 if sent_by.starts_with('[') {
317 if let Some(close_bracket) = sent_by.find(']') {
319 let host = sent_by[1..close_bracket].to_string();
320 let remainder = &sent_by[close_bracket + 1..];
321
322 if remainder.is_empty() {
323 return Ok((host, None));
324 }
325
326 if let Some(port_str) = remainder.strip_prefix(':') {
327 let port = port_str
328 .parse::<u16>()
329 .map_err(|_| {
330 SipViaError::InvalidFormat(format!("invalid port: {}", port_str))
331 })?;
332 return Ok((host, Some(port)));
333 }
334
335 return Err(SipViaError::InvalidFormat(format!(
336 "unexpected characters after IPv6 address: {}",
337 remainder
338 )));
339 } else {
340 return Err(SipViaError::InvalidFormat(
341 "unclosed IPv6 bracket".to_string(),
342 ));
343 }
344 }
345
346 if let Some(colon_idx) = sent_by.rfind(':') {
349 let host = sent_by[..colon_idx].to_string();
350 let port_str = &sent_by[colon_idx + 1..];
351
352 if host.contains(':') {
354 return Ok((sent_by.to_string(), None));
356 }
357
358 let port = port_str
359 .parse::<u16>()
360 .map_err(|_| SipViaError::InvalidFormat(format!("invalid port: {}", port_str)))?;
361 Ok((host, Some(port)))
362 } else {
363 Ok((sent_by.to_string(), None))
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_single_via() {
373 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060").unwrap();
374 assert_eq!(via.len(), 1);
375
376 let entry = &via.entries()[0];
377 assert_eq!(entry.protocol(), "SIP");
378 assert_eq!(entry.version(), "2.0");
379 assert_eq!(entry.transport(), "UDP");
380 assert_eq!(entry.host(), "198.51.100.1");
381 assert_eq!(entry.port(), Some(5060));
382 assert!(entry
383 .params()
384 .is_empty());
385 }
386
387 #[test]
388 fn test_multiple_vias() {
389 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
390 assert_eq!(via.len(), 2);
391
392 let entry1 = &via.entries()[0];
393 assert_eq!(entry1.host(), "198.51.100.1");
394 assert_eq!(entry1.port(), Some(5060));
395 assert_eq!(entry1.transport(), "UDP");
396
397 let entry2 = &via.entries()[1];
398 assert_eq!(entry2.host(), "203.0.113.5");
399 assert_eq!(entry2.port(), None);
400 assert_eq!(entry2.transport(), "TCP");
401 }
402
403 #[test]
404 fn test_via_with_params() {
405 let via = SipVia::parse(
406 "SIP/2.0/UDP 198.51.100.1:5060;branch=z9hG4bKnashds8;received=203.0.113.10;rport=5061",
407 )
408 .unwrap();
409
410 let entry = &via.entries()[0];
411 assert_eq!(entry.branch(), Some("z9hG4bKnashds8"));
412 assert_eq!(entry.received(), Some("203.0.113.10"));
413 assert_eq!(entry.rport(), Some(Some(5061)));
414 }
415
416 #[test]
417 fn test_via_with_rport_no_value() {
418 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060;rport").unwrap();
419
420 let entry = &via.entries()[0];
421 assert_eq!(entry.rport(), Some(None));
422 }
423
424 #[test]
425 fn test_via_without_rport() {
426 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060").unwrap();
427
428 let entry = &via.entries()[0];
429 assert_eq!(entry.rport(), None);
430 }
431
432 #[test]
433 fn test_via_ipv6() {
434 let via = SipVia::parse("SIP/2.0/UDP [2001:db8::1]:5060").unwrap();
435
436 let entry = &via.entries()[0];
437 assert_eq!(entry.host(), "2001:db8::1");
438 assert_eq!(entry.port(), Some(5060));
439 }
440
441 #[test]
442 fn test_via_ipv6_no_port() {
443 let via = SipVia::parse("SIP/2.0/UDP [2001:db8::1]").unwrap();
444
445 let entry = &via.entries()[0];
446 assert_eq!(entry.host(), "2001:db8::1");
447 assert_eq!(entry.port(), None);
448 }
449
450 #[test]
451 fn test_via_hostname() {
452 let via = SipVia::parse("SIP/2.0/TLS example.com:5061").unwrap();
453
454 let entry = &via.entries()[0];
455 assert_eq!(entry.host(), "example.com");
456 assert_eq!(entry.port(), Some(5061));
457 assert_eq!(entry.transport(), "TLS");
458 }
459
460 #[test]
461 fn test_empty_via() {
462 let result = SipVia::parse("");
463 assert!(matches!(result, Err(SipViaError::Empty)));
464 }
465
466 #[test]
467 fn test_empty_via_whitespace() {
468 let result = SipVia::parse(" ");
469 assert!(matches!(result, Err(SipViaError::Empty)));
470 }
471
472 #[test]
473 fn test_invalid_format() {
474 let result = SipVia::parse("invalid");
475 assert!(matches!(result, Err(SipViaError::InvalidFormat(_))));
476 }
477
478 #[test]
479 fn test_rport_invalid_value_is_error() {
480 let result = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060;rport=garbage");
481 assert!(result.is_err());
482 }
483
484 #[test]
485 fn test_display_roundtrip() {
486 let original =
487 "SIP/2.0/UDP 198.51.100.1:5060;branch=z9hG4bKnashds8;received=203.0.113.10;rport";
488 let via = SipVia::parse(original).unwrap();
489 let displayed = via.to_string();
490
491 let reparsed = SipVia::parse(&displayed).unwrap();
492 assert_eq!(via, reparsed);
493 }
494
495 #[test]
496 fn test_display_multiple_vias() {
497 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
498 let displayed = via.to_string();
499 assert!(displayed.contains("198.51.100.1"));
500 assert!(displayed.contains("203.0.113.5"));
501 }
502
503 #[test]
504 fn test_into_iterator() {
505 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
506
507 let mut count = 0;
508 for entry in &via {
509 assert!(entry.host() == "198.51.100.1" || entry.host() == "203.0.113.5");
510 count += 1;
511 }
512 assert_eq!(count, 2);
513
514 let entries: Vec<_> = via
515 .into_iter()
516 .collect();
517 assert_eq!(entries.len(), 2);
518 }
519
520 #[test]
521 fn test_into_entries() {
522 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060, SIP/2.0/TCP 203.0.113.5").unwrap();
523 let entries = via.into_entries();
524 assert_eq!(entries.len(), 2);
525 assert_eq!(entries[0].host(), "198.51.100.1");
526 assert_eq!(entries[1].host(), "203.0.113.5");
527 }
528
529 #[test]
530 fn test_from_str() {
531 let via: SipVia = "SIP/2.0/UDP 198.51.100.1:5060"
532 .parse()
533 .unwrap();
534 assert_eq!(via.len(), 1);
535 }
536
537 #[test]
538 fn test_param_case_insensitive() {
539 let via = SipVia::parse("SIP/2.0/UDP 198.51.100.1:5060;Branch=test").unwrap();
540 let entry = &via.entries()[0];
541 assert_eq!(entry.param("branch"), Some(Some("test")));
542 assert_eq!(entry.param("BRANCH"), Some(Some("test")));
543 }
544
545 #[test]
546 fn test_display_ipv6() {
547 let via = SipVia::parse("SIP/2.0/UDP [2001:db8::1]:5060").unwrap();
548 let displayed = via.to_string();
549 assert!(displayed.contains("[2001:db8::1]"));
550 }
551}