1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use dns_lookup::lookup_host;
7use get_if_addrs::{IfAddr, get_if_addrs};
8use socket2::{Domain, Protocol, Socket, Type};
9use tokio::io::AsyncWriteExt;
10use tokio::net::UdpSocket;
11use tracing::debug;
12
13use crate::auth::{default_authnz_host, default_authnz_user};
14use crate::transport::read_packet;
15use crate::types::{PvGetError, PvGetOptions};
16use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
17use spvirit_codec::spvirit_encode::{
18 encode_client_connection_validation, encode_search_request, ip_to_bytes,
19 socket_addr_from_pva_bytes,
20};
21
22#[derive(Clone, Copy, Debug)]
23pub struct SearchTarget {
24 pub target: IpAddr,
25 pub bind: IpAddr,
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub struct DiscoveredServer {
30 pub guid: [u8; 12],
31 pub tcp_addr: SocketAddr,
32}
33
34pub fn parse_addr_list(env: &str) -> Vec<IpAddr> {
35 env.split(|c| c == ',' || c == ' ' || c == '\t')
36 .filter(|s| !s.trim().is_empty())
37 .filter_map(|s| parse_search_target_ip(s.trim()))
38 .collect()
39}
40
41fn parse_search_target_ip(token: &str) -> Option<IpAddr> {
42 if token.is_empty() {
43 return None;
44 }
45
46 if let Ok(ip) = token.parse::<IpAddr>() {
47 return Some(ip);
48 }
49 if let Ok(sock) = token.parse::<SocketAddr>() {
50 return Some(sock.ip());
51 }
52
53 if let Some((host, port_str)) = token.rsplit_once(':') {
56 if !host.is_empty()
57 && !port_str.is_empty()
58 && port_str.chars().all(|c| c.is_ascii_digit())
59 && !host.contains(']')
60 {
61 if let Ok(ip) = host.parse::<IpAddr>() {
62 return Some(ip);
63 }
64 if let Ok(addrs) = lookup_host(host) {
65 let addrs: Vec<IpAddr> = addrs.collect();
67 if let Some(ip) = addrs
68 .iter()
69 .find(|ip| ip.is_ipv4())
70 .copied()
71 .or_else(|| addrs.into_iter().next())
72 {
73 return Some(ip);
74 }
75 }
76 }
77 }
78
79 if let Ok(addrs) = lookup_host(token) {
80 let addrs: Vec<IpAddr> = addrs.collect();
82 if let Some(ip) = addrs
83 .iter()
84 .find(|ip| ip.is_ipv4())
85 .copied()
86 .or_else(|| addrs.into_iter().next())
87 {
88 return Some(ip);
89 }
90 }
91
92 None
93}
94
95fn unspecified_for(ip: IpAddr) -> IpAddr {
97 match ip {
98 IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
99 IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
100 }
101}
102
103pub fn build_search_targets(
104 search_addr: Option<IpAddr>,
105 bind_addr: Option<IpAddr>,
106) -> Vec<SearchTarget> {
107 if let Some(ip) = search_addr {
109 return vec![SearchTarget {
110 target: ip,
111 bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
112 }];
113 }
114
115 let mut targets = Vec::new();
116 let mut seen = HashSet::new();
117
118 if let Ok(env) = std::env::var("EPICS_PVA_ADDR_LIST") {
120 for ip in parse_addr_list(&env) {
121 if seen.insert(ip) {
122 targets.push(SearchTarget {
123 target: ip,
124 bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
125 });
126 }
127 }
128 }
129
130 if is_auto_addr_list_enabled() {
133 for t in build_auto_broadcast_targets() {
134 if seen.insert(t.target) {
135 targets.push(SearchTarget {
136 target: t.target,
137 bind: bind_addr.unwrap_or(t.bind),
138 });
139 }
140 }
141 }
142
143 targets
144}
145
146pub fn is_auto_addr_list_enabled() -> bool {
147 match std::env::var("EPICS_PVA_AUTO_ADDR_LIST") {
148 Ok(v) => {
149 let v = v.trim().to_ascii_uppercase();
150 v == "YES" || v == "Y" || v == "1" || v == "TRUE"
151 }
152 Err(_) => true,
153 }
154}
155
156fn ipv4_is_link_local(ip: Ipv4Addr) -> bool {
157 let octets = ip.octets();
158 octets[0] == 169 && octets[1] == 254
159}
160
161fn choose_default_bind_v4() -> Option<Ipv4Addr> {
162 let ifaces = get_if_addrs().ok()?;
163 for iface in ifaces {
164 if let IfAddr::V4(v4) = iface.addr {
165 let ip = v4.ip;
166 if ip.is_loopback() || ipv4_is_link_local(ip) {
167 continue;
168 }
169 return Some(ip);
170 }
171 }
172 None
173}
174
175fn choose_default_bind_v6() -> Option<Ipv6Addr> {
176 let ifaces = get_if_addrs().ok()?;
177 for iface in ifaces {
178 if let IfAddr::V6(v6) = iface.addr {
179 let ip = v6.ip;
180 if ip.is_loopback() {
181 continue;
182 }
183 let segs = ip.segments();
185 if segs[0] & 0xffc0 == 0xfe80 {
186 continue;
187 }
188 return Some(ip);
189 }
190 }
191 None
192}
193
194fn broadcast_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
195 let ip_u = u32::from(ip);
196 let mask_u = u32::from(netmask);
197 Ipv4Addr::from(ip_u | !mask_u)
198}
199
200fn discovery_target_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
201 let limited_broadcast = Ipv4Addr::new(255, 255, 255, 255);
202 if netmask == Ipv4Addr::new(255, 255, 255, 255) || netmask.is_unspecified() {
203 return limited_broadcast;
204 }
205 let directed = broadcast_for(ip, netmask);
206 if directed == ip {
207 limited_broadcast
208 } else {
209 directed
210 }
211}
212
213pub fn build_auto_broadcast_targets() -> Vec<SearchTarget> {
214 let mut targets = Vec::new();
215 let mut fallback_targets = Vec::new();
216 let mut fallback_seen = HashSet::new();
217 let mut added_v4_multicast = false;
218 let mut added_v6_multicast = false;
219 let ifaces = match get_if_addrs() {
220 Ok(v) => v,
221 Err(_) => return targets,
222 };
223 for iface in &ifaces {
224 if let IfAddr::V4(v4) = &iface.addr {
225 let ip = v4.ip;
226 if ip.is_loopback() || ipv4_is_link_local(ip) {
227 continue;
228 }
229 let bcast = discovery_target_for(ip, v4.netmask);
230 targets.push(SearchTarget {
231 target: IpAddr::V4(bcast),
232 bind: IpAddr::V4(ip),
233 });
234 targets.push(SearchTarget {
237 target: IpAddr::V4(PVA_MULTICAST_V4),
238 bind: IpAddr::V4(ip),
239 });
240 if fallback_seen.insert(IpAddr::V4(bcast)) {
241 fallback_targets.push(SearchTarget {
242 target: IpAddr::V4(bcast),
243 bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
244 });
245 }
246 if !added_v4_multicast {
247 added_v4_multicast = true;
248 fallback_targets.push(SearchTarget {
249 target: IpAddr::V4(PVA_MULTICAST_V4),
250 bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
251 });
252 }
253 }
254 }
255 for iface in &ifaces {
257 if let IfAddr::V6(v6) = &iface.addr {
258 let ip = v6.ip;
259 if ip.is_loopback() {
260 continue;
261 }
262 let segs = ip.segments();
263 if segs[0] & 0xffc0 == 0xfe80 {
264 continue; }
266 let multicast_target = IpAddr::V6(PVA_MULTICAST_V6);
267 targets.push(SearchTarget {
268 target: multicast_target,
269 bind: IpAddr::V6(ip),
270 });
271 if !added_v6_multicast {
272 added_v6_multicast = true;
273 fallback_targets.push(SearchTarget {
274 target: multicast_target,
275 bind: IpAddr::V6(Ipv6Addr::UNSPECIFIED),
276 });
277 }
278 }
279 }
280 targets.extend(fallback_targets);
281 targets
282}
283
284const PVA_MULTICAST_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 128);
286
287const PVA_MULTICAST_V6: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0x42, 1);
289
290fn join_multicast_any(socket: &std::net::UdpSocket, bind: IpAddr) {
292 match bind {
293 IpAddr::V4(iface) => {
294 let _ = socket.join_multicast_v4(&PVA_MULTICAST_V4, &iface);
295 }
296 IpAddr::V6(_) => {
297 let _ = socket.join_multicast_v6(&PVA_MULTICAST_V6, 0);
299 }
300 }
301}
302
303fn decode_search_response_addr(addr: [u8; 16], port: u16, src: SocketAddr) -> SocketAddr {
304 socket_addr_from_pva_bytes(addr, port)
305 .filter(|a| !a.ip().is_unspecified())
306 .unwrap_or_else(|| SocketAddr::new(src.ip(), port))
307}
308
309fn normalize_discovered_servers(items: Vec<DiscoveredServer>) -> Vec<DiscoveredServer> {
310 let mut seen = HashSet::new();
311 let mut out = Vec::new();
312 for item in items {
313 if seen.insert((item.guid, item.tcp_addr)) {
314 out.push(item);
315 }
316 }
317 out.sort_by(|a, b| a.tcp_addr.to_string().cmp(&b.tcp_addr.to_string()));
318 out
319}
320
321fn bind_udp_reuse(addr: SocketAddr) -> std::io::Result<std::net::UdpSocket> {
328 let domain = if addr.is_ipv4() {
329 Domain::IPV4
330 } else {
331 Domain::IPV6
332 };
333 let sock = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
334 #[cfg(unix)]
335 sock.set_reuse_address(true)?;
336 sock.set_nonblocking(true)?;
337 sock.bind(&addr.into())?;
338 Ok(sock.into())
339}
340
341pub async fn search_pv(
342 pv_name: &str,
343 udp_port: u16,
344 timeout_dur: Duration,
345 targets: &[SearchTarget],
346 debug_enabled: bool,
347) -> Result<SocketAddr, PvGetError> {
348 if targets.is_empty() {
349 return Err(PvGetError::Search("no search targets"));
350 }
351
352 let now = std::time::SystemTime::now()
353 .duration_since(std::time::UNIX_EPOCH)
354 .unwrap_or_default();
355 let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
356 let cid = seq ^ 0x9E37_79B9;
357
358 let mut last_io_error: Option<std::io::Error> = None;
359 let deadline = tokio::time::Instant::now() + timeout_dur;
360
361 let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
363 for t in targets {
364 if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
365 group.1.push(t.target);
366 } else {
367 bind_groups.push((t.bind, vec![t.target]));
368 }
369 }
370
371 let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
374
375 for (bind_ip, group_targets) in &bind_groups {
376 let bind_addr = SocketAddr::new(*bind_ip, udp_port);
377 let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
378 Ok(sock) => (sock, bind_addr),
379 Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
380 let fallback = SocketAddr::new(*bind_ip, 0);
381 match bind_udp_reuse(fallback) {
382 Ok(sock) => {
383 let actual = sock.local_addr().unwrap_or(fallback);
384 if debug_enabled {
385 debug!(
386 "pva search bind={} failed (in use), fallback bind={}",
387 bind_addr, actual
388 );
389 }
390 (sock, actual)
391 }
392 Err(fallback_err) => {
393 if debug_enabled {
394 debug!(
395 "pva search skipping bind={} step=bind-fallback kind={:?} err={}",
396 bind_addr,
397 fallback_err.kind(),
398 fallback_err
399 );
400 }
401 last_io_error = Some(fallback_err);
402 continue;
403 }
404 }
405 }
406 Err(err) => {
407 if debug_enabled {
408 debug!(
409 "pva search skipping bind={} step=bind kind={:?} err={}",
410 bind_addr,
411 err.kind(),
412 err
413 );
414 }
415 last_io_error = Some(err);
416 continue;
417 }
418 };
419 if let Err(err) = std_sock.set_broadcast(true) {
420 if debug_enabled {
421 debug!(
422 "pva search skipping bind={} step=set_broadcast kind={:?} err={}",
423 bind_addr,
424 err.kind(),
425 err
426 );
427 }
428 last_io_error = Some(err);
429 continue;
430 }
431
432 join_multicast_any(&std_sock, *bind_ip);
433
434 let reply_addr = ip_to_bytes(*bind_ip);
435 let reply_port = match std_sock.local_addr() {
436 Ok(addr) => addr.port(),
437 Err(err) => {
438 if debug_enabled {
439 debug!(
440 "pva search skipping bind={} step=local_addr kind={:?} err={}",
441 bind_addr,
442 err.kind(),
443 err
444 );
445 }
446 last_io_error = Some(err);
447 continue;
448 }
449 };
450 let requests = [(cid, pv_name)];
451 let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &requests, 2, false);
452
453 let socket = match UdpSocket::from_std(std_sock) {
454 Ok(socket) => socket,
455 Err(err) => {
456 if debug_enabled {
457 debug!(
458 "pva search skipping bind={} step=from_std kind={:?} err={}",
459 bind_addr,
460 err.kind(),
461 err
462 );
463 }
464 last_io_error = Some(err);
465 continue;
466 }
467 };
468
469 let dests: Vec<SocketAddr> = group_targets
470 .iter()
471 .map(|ip| SocketAddr::new(*ip, udp_port))
472 .collect();
473
474 for dest in &dests {
476 if debug_enabled {
477 debug!(
478 "pva search bind={} target={} server_port={} reply_port={}",
479 actual_bind_addr,
480 dest.ip(),
481 udp_port,
482 reply_port
483 );
484 debug!("pva search seq={} cid={}", seq, cid);
485 debug!("pva search send {} bytes to {}", msg.len(), dest);
486 }
487 if let Err(err) = socket.send_to(&msg, dest).await {
488 if debug_enabled {
489 debug!(
490 "pva search send_to target={} kind={:?} err={}",
491 dest,
492 err.kind(),
493 err
494 );
495 }
496 last_io_error = Some(err);
497 }
498 }
499
500 socket_info.push((Arc::new(socket), msg, dests));
501 }
502
503 if socket_info.is_empty() {
504 if let Some(err) = last_io_error {
505 return Err(PvGetError::Io(err));
506 }
507 return Err(PvGetError::Timeout("search response"));
508 }
509
510 let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
512 for (sock, _, _) in &socket_info {
513 let sock = Arc::clone(sock);
514 let tx = tx.clone();
515 tokio::spawn(async move {
516 loop {
517 let mut buf = vec![0u8; 2048];
518 match sock.recv_from(&mut buf).await {
519 Ok((len, src)) => {
520 buf.truncate(len);
521 if tx.send((buf, src)).await.is_err() {
522 break;
523 }
524 }
525 Err(_) => break,
526 }
527 }
528 });
529 }
530 drop(tx); let retransmit_offsets = [100u64, 500, 1000, 2000];
534 let start = tokio::time::Instant::now();
535 let mut next_retransmit = 0usize;
536
537 loop {
538 let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
540 start + Duration::from_millis(retransmit_offsets[next_retransmit])
541 } else {
542 deadline
543 };
544 let wake_at = next_retransmit_at.min(deadline);
545
546 tokio::select! {
547 recv = rx.recv() => {
548 let Some((buf, src)) = recv else { break };
549 let mut pkt = PvaPacket::new(&buf);
550 let cmd = pkt
551 .decode_payload()
552 .ok_or(PvGetError::Search("failed to decode search response"))?;
553 if let PvaPacketCommand::SearchResponse(payload) = cmd {
554 if debug_enabled {
555 debug!(
556 "pva search response found={} cids={:?} addr={:?} port={}",
557 payload.found, payload.cids, payload.addr, payload.port
558 );
559 }
560 if payload.seq != seq {
561 continue;
562 }
563 if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
564 continue;
565 }
566 if !payload.found {
567 continue;
568 }
569 if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
570 continue;
571 }
572
573 let addr = decode_search_response_addr(payload.addr, payload.port, src);
574 if debug_enabled {
575 debug!("pva search response from {}", addr);
576 }
577 return Ok(addr);
578 }
579 }
580 _ = tokio::time::sleep_until(wake_at) => {
581 if tokio::time::Instant::now() >= deadline {
582 break;
583 }
584 if next_retransmit < retransmit_offsets.len() {
586 if debug_enabled {
587 debug!("pva search retransmit round {}", next_retransmit + 1);
588 }
589 for (sock, msg, dests) in &socket_info {
590 for dest in dests {
591 let _ = sock.send_to(msg, dest).await;
592 }
593 }
594 next_retransmit += 1;
595 }
596 }
597 }
598 }
599
600 Err(PvGetError::Timeout("search response"))
601}
602
603pub fn default_bind_ip() -> Option<IpAddr> {
604 choose_default_bind_v4()
605 .map(IpAddr::V4)
606 .or_else(|| choose_default_bind_v6().map(IpAddr::V6))
607}
608
609pub fn parse_name_servers(env_val: &str) -> Vec<SocketAddr> {
613 let mut out = Vec::new();
614 for token in env_val.split(|c| c == ',' || c == ' ' || c == '\t') {
615 let token = token.trim();
616 if token.is_empty() {
617 continue;
618 }
619 if let Ok(addr) = token.parse::<SocketAddr>() {
620 out.push(addr);
621 continue;
622 }
623 if let Ok(ip) = token.parse::<IpAddr>() {
624 out.push(SocketAddr::new(ip, 5075));
625 continue;
626 }
627 use std::net::ToSocketAddrs;
628 if let Ok(mut addrs) = token.to_socket_addrs() {
629 if let Some(addr) = addrs.next() {
630 out.push(addr);
631 continue;
632 }
633 }
634 let with_port = format!("{}:5075", token);
635 if let Ok(mut addrs) = with_port.to_socket_addrs() {
636 if let Some(addr) = addrs.next() {
637 out.push(addr);
638 }
639 }
640 }
641 out
642}
643
644fn encode_search_validation(version: u8, is_be: bool) -> Vec<u8> {
646 let user = default_authnz_user();
647 let host = default_authnz_host();
648 encode_client_connection_validation(87_040, 32_767, 0, "ca", &user, &host, version, is_be)
649}
650
651pub async fn search_pv_tcp(
656 pv_name: &str,
657 name_server: SocketAddr,
658 timeout_dur: Duration,
659 debug_enabled: bool,
660) -> Result<SocketAddr, PvGetError> {
661 let deadline = tokio::time::Instant::now() + timeout_dur;
662
663 let mut stream = tokio::time::timeout(timeout_dur, tokio::net::TcpStream::connect(name_server))
664 .await
665 .map_err(|_| PvGetError::Timeout("name server connect"))??;
666
667 let mut version = 2u8;
668 let mut is_be = false;
669
670 for _ in 0..2 {
672 let now = tokio::time::Instant::now();
673 if now >= deadline {
674 return Err(PvGetError::Timeout("name server handshake"));
675 }
676 let remaining = deadline - now;
677 if let Ok(bytes) = read_packet(&mut stream, remaining).await {
678 let mut pkt = PvaPacket::new(&bytes);
679 if let Some(cmd) = pkt.decode_payload() {
680 match cmd {
681 PvaPacketCommand::Control(payload) => {
682 if payload.command == 2 {
683 is_be = pkt.header.flags.is_msb;
684 }
685 }
686 PvaPacketCommand::ConnectionValidation(_) => {
687 version = pkt.header.version;
688 is_be = pkt.header.flags.is_msb;
689 }
690 _ => {}
691 }
692 }
693 }
694 }
695
696 let validation = encode_search_validation(version, is_be);
697 stream.write_all(&validation).await?;
698
699 loop {
701 let now = tokio::time::Instant::now();
702 if now >= deadline {
703 return Err(PvGetError::Timeout("name server validated"));
704 }
705 let remaining = deadline - now;
706 let bytes = read_packet(&mut stream, remaining).await?;
707 let mut pkt = PvaPacket::new(&bytes);
708 if let Some(cmd) = pkt.decode_payload() {
709 if matches!(cmd, PvaPacketCommand::ConnectionValidated(_)) {
710 break;
711 }
712 }
713 }
714
715 let now_ts = std::time::SystemTime::now()
717 .duration_since(std::time::UNIX_EPOCH)
718 .unwrap_or_default();
719 let seq = (now_ts.as_nanos() as u32).wrapping_add(std::process::id());
720 let cid = seq ^ 0x9E37_79B9;
721 let requests = [(cid, pv_name)];
722 let msg = encode_search_request(seq, 0x80, 0, [0u8; 16], &requests, version, is_be);
723 stream.write_all(&msg).await?;
724
725 if debug_enabled {
726 debug!(
727 "pva tcp search sent to name_server={} pv={}",
728 name_server, pv_name
729 );
730 }
731
732 loop {
734 let now = tokio::time::Instant::now();
735 if now >= deadline {
736 return Err(PvGetError::Timeout("name server search response"));
737 }
738 let remaining = deadline - now;
739 let bytes = read_packet(&mut stream, remaining).await?;
740 let mut pkt = PvaPacket::new(&bytes);
741 if let Some(cmd) = pkt.decode_payload() {
742 if let PvaPacketCommand::SearchResponse(payload) = cmd {
743 if !payload.found {
744 continue;
745 }
746 if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
747 continue;
748 }
749 let addr = decode_search_response_addr(payload.addr, payload.port, name_server);
750 if debug_enabled {
751 debug!(
752 "pva tcp search response from name_server={}: {}",
753 name_server, addr
754 );
755 }
756 return Ok(addr);
757 }
758 }
759 }
760}
761
762pub async fn resolve_pv_server(opts: &PvGetOptions) -> Result<SocketAddr, PvGetError> {
769 if let Some(addr) = opts.server_addr {
770 return Ok(addr);
771 }
772
773 let mut name_servers = opts.name_servers.clone();
774 if let Ok(env) = std::env::var("EPICS_PVA_NAME_SERVERS") {
775 name_servers.extend(parse_name_servers(&env));
776 }
777
778 let no_broadcast = opts.no_broadcast;
779
780 if no_broadcast && name_servers.is_empty() {
782 return Err(PvGetError::Search(
783 "no search strategy: specify --name-server or --server when using --no-broadcast",
784 ));
785 }
786
787 let targets = build_search_targets(opts.search_addr, opts.bind_addr);
790
791 let pv = opts.pv_name.clone();
792 let timeout_dur = opts.timeout;
793 let debug_enabled = opts.debug;
794 let udp_port = opts.udp_port;
795
796 let mut set = tokio::task::JoinSet::new();
797
798 for ns in name_servers {
799 let pv = pv.clone();
800 set.spawn(async move {
801 let addr = search_pv_tcp(&pv, ns, timeout_dur, debug_enabled).await?;
802 Ok::<SocketAddr, PvGetError>(addr)
803 });
804 }
805
806 if !no_broadcast {
807 let pv = pv.clone();
808 let targets = targets.clone();
809 set.spawn(async move {
810 let addr = search_pv(&pv, udp_port, timeout_dur, &targets, debug_enabled).await?;
811 Ok(addr)
812 });
813 }
814
815 let mut last_err = None;
816 while let Some(result) = set.join_next().await {
817 match result {
818 Ok(Ok(addr)) => {
819 set.abort_all();
820 return Ok(addr);
821 }
822 Ok(Err(e)) => {
823 if debug_enabled {
824 debug!("pva search strategy failed: {}", e);
825 }
826 last_err = Some(e);
827 }
828 Err(join_err) => {
829 if debug_enabled {
830 debug!("pva search task panicked: {}", join_err);
831 }
832 }
833 }
834 }
835
836 Err(last_err.unwrap_or(PvGetError::Timeout("search response")))
837}
838
839pub async fn discover_servers(
840 udp_port: u16,
841 timeout_dur: Duration,
842 targets: &[SearchTarget],
843 debug_enabled: bool,
844) -> Result<Vec<DiscoveredServer>, PvGetError> {
845 if targets.is_empty() {
846 return Err(PvGetError::Search("no search targets"));
847 }
848
849 let now = std::time::SystemTime::now()
850 .duration_since(std::time::UNIX_EPOCH)
851 .unwrap_or_default();
852 let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
853
854 let mut found: Vec<DiscoveredServer> = Vec::new();
855 let mut last_io_error: Option<std::io::Error> = None;
856 let deadline = tokio::time::Instant::now() + timeout_dur;
857
858 let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
860 for t in targets {
861 if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
862 group.1.push(t.target);
863 } else {
864 bind_groups.push((t.bind, vec![t.target]));
865 }
866 }
867
868 let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
871
872 for (bind_ip, group_targets) in &bind_groups {
873 let bind_addr = SocketAddr::new(*bind_ip, udp_port);
874 let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
875 Ok(sock) => (sock, bind_addr),
876 Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
877 let fallback = SocketAddr::new(*bind_ip, 0);
878 match bind_udp_reuse(fallback) {
879 Ok(sock) => {
880 let actual = sock.local_addr().unwrap_or(fallback);
881 if debug_enabled {
882 debug!(
883 "pva discover bind={} failed (in use), fallback bind={}",
884 bind_addr, actual
885 );
886 }
887 (sock, actual)
888 }
889 Err(fallback_err) => {
890 if debug_enabled {
891 debug!(
892 "pva discover skipping bind={} step=bind-fallback kind={:?} err={}",
893 bind_addr,
894 fallback_err.kind(),
895 fallback_err
896 );
897 }
898 last_io_error = Some(fallback_err);
899 continue;
900 }
901 }
902 }
903 Err(err) => {
904 if debug_enabled {
905 debug!(
906 "pva discover skipping bind={} step=bind kind={:?} err={}",
907 bind_addr,
908 err.kind(),
909 err
910 );
911 }
912 last_io_error = Some(err);
913 continue;
914 }
915 };
916 if let Err(err) = std_sock.set_broadcast(true) {
917 if debug_enabled {
918 debug!(
919 "pva discover skipping bind={} step=set_broadcast kind={:?} err={}",
920 bind_addr,
921 err.kind(),
922 err
923 );
924 }
925 last_io_error = Some(err);
926 continue;
927 }
928
929 join_multicast_any(&std_sock, *bind_ip);
930
931 let reply_addr = ip_to_bytes(*bind_ip);
932 let reply_port = match std_sock.local_addr() {
933 Ok(addr) => addr.port(),
934 Err(err) => {
935 if debug_enabled {
936 debug!(
937 "pva discover skipping bind={} step=local_addr kind={:?} err={}",
938 bind_addr,
939 err.kind(),
940 err
941 );
942 }
943 last_io_error = Some(err);
944 continue;
945 }
946 };
947 let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &[], 2, false);
948
949 let socket = match UdpSocket::from_std(std_sock) {
950 Ok(socket) => socket,
951 Err(err) => {
952 if debug_enabled {
953 debug!(
954 "pva discover skipping bind={} step=from_std kind={:?} err={}",
955 bind_addr,
956 err.kind(),
957 err
958 );
959 }
960 last_io_error = Some(err);
961 continue;
962 }
963 };
964
965 let dests: Vec<SocketAddr> = group_targets
966 .iter()
967 .map(|ip| SocketAddr::new(*ip, udp_port))
968 .collect();
969
970 for dest in &dests {
972 if debug_enabled {
973 debug!(
974 "pva discover bind={} target={} server_port={} reply_port={} seq={}",
975 actual_bind_addr,
976 dest.ip(),
977 udp_port,
978 reply_port,
979 seq
980 );
981 }
982 if let Err(err) = socket.send_to(&msg, dest).await {
983 if debug_enabled {
984 debug!(
985 "pva discover send_to target={} kind={:?} err={}",
986 dest,
987 err.kind(),
988 err
989 );
990 }
991 last_io_error = Some(err);
992 }
993 }
994
995 socket_info.push((Arc::new(socket), msg, dests));
996 }
997
998 if socket_info.is_empty() {
999 if let Some(err) = last_io_error {
1000 return Err(PvGetError::Io(err));
1001 }
1002 return Err(PvGetError::Search("no search targets"));
1003 }
1004
1005 let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
1007 for (sock, _, _) in &socket_info {
1008 let sock = Arc::clone(sock);
1009 let tx = tx.clone();
1010 tokio::spawn(async move {
1011 loop {
1012 let mut buf = vec![0u8; 2048];
1013 match sock.recv_from(&mut buf).await {
1014 Ok((len, src)) => {
1015 buf.truncate(len);
1016 if tx.send((buf, src)).await.is_err() {
1017 break;
1018 }
1019 }
1020 Err(_) => break,
1021 }
1022 }
1023 });
1024 }
1025 drop(tx); let retransmit_offsets = [100u64, 500, 1000, 2000];
1029 let start = tokio::time::Instant::now();
1030 let mut next_retransmit = 0usize;
1031
1032 loop {
1033 let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
1035 start + Duration::from_millis(retransmit_offsets[next_retransmit])
1036 } else {
1037 deadline
1038 };
1039 let wake_at = next_retransmit_at.min(deadline);
1040
1041 tokio::select! {
1042 recv = rx.recv() => {
1043 let Some((buf, src)) = recv else { break };
1044 let mut pkt = PvaPacket::new(&buf);
1045 let Some(cmd) = pkt.decode_payload() else {
1046 continue;
1047 };
1048 if let PvaPacketCommand::SearchResponse(payload) = cmd {
1049 if payload.seq != seq {
1050 continue;
1051 }
1052 if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
1053 continue;
1054 }
1055 let tcp_addr = decode_search_response_addr(payload.addr, payload.port, src);
1056 found.push(DiscoveredServer {
1057 guid: payload.guid,
1058 tcp_addr,
1059 });
1060 }
1061 }
1062 _ = tokio::time::sleep_until(wake_at) => {
1063 if tokio::time::Instant::now() >= deadline {
1064 break;
1065 }
1066 if next_retransmit < retransmit_offsets.len() {
1068 if debug_enabled {
1069 debug!("pva discover retransmit round {}", next_retransmit + 1);
1070 }
1071 for (sock, msg, dests) in &socket_info {
1072 for dest in dests {
1073 let _ = sock.send_to(msg, dest).await;
1074 }
1075 }
1076 next_retransmit += 1;
1077 }
1078 }
1079 }
1080 }
1081
1082 Ok(normalize_discovered_servers(found))
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087 use super::*;
1088 use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
1089
1090 #[test]
1091 fn encode_decode_search_request_roundtrip() {
1092 let seq = 1234;
1093 let cid = 42;
1094 let port = 5076;
1095 let pv_name = "TEST:PV";
1096 let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 20)));
1097 let requests = [(cid, pv_name)];
1098 let msg = encode_search_request(seq, 0x81, port, reply_addr, &requests, 2, false);
1099 let mut pkt = PvaPacket::new(&msg);
1100 let cmd = pkt.decode_payload().expect("decoded");
1101 match cmd {
1102 PvaPacketCommand::Search(payload) => {
1103 assert_eq!(payload.seq, seq);
1104 assert_eq!(payload.mask, 0x81);
1105 assert_eq!(payload.addr, reply_addr);
1106 assert_eq!(payload.port, port);
1107 assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1108 assert_eq!(payload.pv_requests.len(), 1);
1109 assert_eq!(payload.pv_requests[0].0, cid);
1110 assert_eq!(payload.pv_requests[0].1, pv_name.to_string());
1111 }
1112 other => panic!("unexpected decode: {:?}", other),
1113 }
1114 }
1115
1116 #[test]
1117 fn encode_decode_server_discovery_request_roundtrip() {
1118 let seq = 4321;
1119 let port = 5076;
1120 let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(10, 20, 30, 40)));
1121 let msg = encode_search_request(seq, 0x81, port, reply_addr, &[], 2, false);
1122 let mut pkt = PvaPacket::new(&msg);
1123 let cmd = pkt.decode_payload().expect("decoded");
1124 match cmd {
1125 PvaPacketCommand::Search(payload) => {
1126 assert_eq!(payload.seq, seq);
1127 assert_eq!(payload.pv_requests.len(), 0);
1128 assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1129 }
1130 other => panic!("unexpected decode: {:?}", other),
1131 }
1132 }
1133
1134 #[test]
1135 fn normalize_discovered_servers_deduplicates_by_guid_and_addr() {
1136 let guid = [1u8; 12];
1137 let s1 = DiscoveredServer {
1138 guid,
1139 tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1140 };
1141 let s2 = DiscoveredServer {
1142 guid,
1143 tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1144 };
1145 let s3 = DiscoveredServer {
1146 guid: [2u8; 12],
1147 tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1148 };
1149 let normalized = normalize_discovered_servers(vec![s1, s2, s3]);
1150 assert_eq!(normalized.len(), 2);
1151 }
1152
1153 #[test]
1154 fn parse_addr_list_accepts_ip_and_ip_port() {
1155 let items = parse_addr_list("192.168.1.10 10.0.0.1:5076");
1156 assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10))));
1157 assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
1158 }
1159
1160 #[test]
1161 fn discovery_target_falls_back_to_limited_broadcast_for_invalid_netmask() {
1162 let ip = Ipv4Addr::new(130, 246, 90, 92);
1163 assert_eq!(
1164 discovery_target_for(ip, Ipv4Addr::new(255, 255, 255, 255)),
1165 Ipv4Addr::new(255, 255, 255, 255)
1166 );
1167 assert_eq!(
1168 discovery_target_for(ip, Ipv4Addr::new(0, 0, 0, 0)),
1169 Ipv4Addr::new(255, 255, 255, 255)
1170 );
1171 }
1172
1173 #[test]
1174 fn discovery_target_uses_directed_broadcast_for_normal_subnet() {
1175 let ip = Ipv4Addr::new(192, 168, 56, 1);
1176 let netmask = Ipv4Addr::new(255, 255, 255, 0);
1177 assert_eq!(
1178 discovery_target_for(ip, netmask),
1179 Ipv4Addr::new(192, 168, 56, 255)
1180 );
1181 }
1182
1183 #[test]
1184 fn parse_name_servers_ip_with_port() {
1185 let addrs = parse_name_servers("192.168.1.10:5075");
1186 assert_eq!(
1187 addrs,
1188 vec!["192.168.1.10:5075".parse::<SocketAddr>().unwrap()]
1189 );
1190 }
1191
1192 #[test]
1193 fn parse_name_servers_ip_without_port_defaults_to_5075() {
1194 let addrs = parse_name_servers("10.0.0.1");
1195 assert_eq!(
1196 addrs,
1197 vec![SocketAddr::new(
1198 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
1199 5075
1200 )]
1201 );
1202 }
1203
1204 #[test]
1205 fn parse_name_servers_multiple_comma_separated() {
1206 let addrs = parse_name_servers("10.0.0.1:5075,10.0.0.2:9876");
1207 assert_eq!(addrs.len(), 2);
1208 assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1209 assert_eq!(addrs[1], "10.0.0.2:9876".parse::<SocketAddr>().unwrap());
1210 }
1211
1212 #[test]
1213 fn parse_name_servers_multiple_space_separated() {
1214 let addrs = parse_name_servers("10.0.0.1 10.0.0.2:5075");
1215 assert_eq!(addrs.len(), 2);
1216 assert_eq!(
1217 addrs[0],
1218 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5075)
1219 );
1220 assert_eq!(addrs[1], "10.0.0.2:5075".parse::<SocketAddr>().unwrap());
1221 }
1222
1223 #[test]
1224 fn parse_name_servers_empty_string() {
1225 let addrs = parse_name_servers("");
1226 assert!(addrs.is_empty());
1227 }
1228
1229 #[test]
1230 fn parse_name_servers_whitespace_only() {
1231 let addrs = parse_name_servers(" \t ");
1232 assert!(addrs.is_empty());
1233 }
1234
1235 #[test]
1236 fn parse_name_servers_mixed_separators() {
1237 let addrs = parse_name_servers("10.0.0.1:5075, 10.0.0.2 , 10.0.0.3:9999");
1238 assert_eq!(addrs.len(), 3);
1239 assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1240 assert_eq!(
1241 addrs[1],
1242 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 5075)
1243 );
1244 assert_eq!(addrs[2], "10.0.0.3:9999".parse::<SocketAddr>().unwrap());
1245 }
1246
1247 #[test]
1248 fn parse_name_servers_ipv6_with_port() {
1249 let addrs = parse_name_servers("[::1]:5075");
1250 assert_eq!(
1251 addrs,
1252 vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]
1253 );
1254 }
1255
1256 #[test]
1257 fn parse_name_servers_ipv6_without_port() {
1258 let addrs = parse_name_servers("::1");
1259 assert_eq!(
1260 addrs,
1261 vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]
1262 );
1263 }
1264
1265 #[test]
1266 fn decode_search_response_addr_falls_back_to_udp_source_when_unspecified() {
1267 let src: SocketAddr = "192.168.1.20:5076".parse().unwrap();
1268 let decoded = decode_search_response_addr([0u8; 16], 5075, src);
1269 assert_eq!(decoded, "192.168.1.20:5075".parse().unwrap());
1270 }
1271}