Skip to main content

spvirit_client/
search.rs

1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use dns_lookup::lookup_host;
7use get_if_addrs::{IfAddr, get_if_addrs};
8use socket2::{Domain, Protocol, Socket, Type};
9use tokio::io::AsyncWriteExt;
10use tokio::net::UdpSocket;
11use tracing::debug;
12
13use crate::auth::{default_authnz_host, default_authnz_user};
14use crate::transport::read_packet;
15use crate::types::{PvGetError, PvGetOptions};
16use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
17use spvirit_codec::spvirit_encode::{
18    encode_client_connection_validation, encode_search_request, ip_to_bytes,
19    socket_addr_from_pva_bytes,
20};
21
22#[derive(Clone, Copy, Debug)]
23pub struct SearchTarget {
24    pub target: IpAddr,
25    pub bind: IpAddr,
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub struct DiscoveredServer {
30    pub guid: [u8; 12],
31    pub tcp_addr: SocketAddr,
32}
33
34pub fn parse_addr_list(env: &str) -> Vec<IpAddr> {
35    env.split(|c| c == ',' || c == ' ' || c == '\t')
36        .filter(|s| !s.trim().is_empty())
37        .filter_map(|s| parse_search_target_ip(s.trim()))
38        .collect()
39}
40
41fn parse_search_target_ip(token: &str) -> Option<IpAddr> {
42    if token.is_empty() {
43        return None;
44    }
45
46    if let Ok(ip) = token.parse::<IpAddr>() {
47        return Some(ip);
48    }
49    if let Ok(sock) = token.parse::<SocketAddr>() {
50        return Some(sock.ip());
51    }
52
53    // Accept host:port where host may be a name or an IP literal.
54    // For IPv6 bracket notation [::1]:port, SocketAddr::parse above already handles it.
55    if let Some((host, port_str)) = token.rsplit_once(':') {
56        if !host.is_empty()
57            && !port_str.is_empty()
58            && port_str.chars().all(|c| c.is_ascii_digit())
59            && !host.contains(']')
60        {
61            if let Ok(ip) = host.parse::<IpAddr>() {
62                return Some(ip);
63            }
64            if let Ok(addrs) = lookup_host(host) {
65                // Prefer IPv4 for backward compat, fall back to first IPv6
66                let addrs: Vec<IpAddr> = addrs.collect();
67                if let Some(ip) = addrs
68                    .iter()
69                    .find(|ip| ip.is_ipv4())
70                    .copied()
71                    .or_else(|| addrs.into_iter().next())
72                {
73                    return Some(ip);
74                }
75            }
76        }
77    }
78
79    if let Ok(addrs) = lookup_host(token) {
80        // Prefer IPv4, fall back to first IPv6
81        let addrs: Vec<IpAddr> = addrs.collect();
82        if let Some(ip) = addrs
83            .iter()
84            .find(|ip| ip.is_ipv4())
85            .copied()
86            .or_else(|| addrs.into_iter().next())
87        {
88            return Some(ip);
89        }
90    }
91
92    None
93}
94
95/// Return a default unspecified bind address matching the target's address family.
96fn unspecified_for(ip: IpAddr) -> IpAddr {
97    match ip {
98        IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
99        IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
100    }
101}
102
103pub fn build_search_targets(
104    search_addr: Option<IpAddr>,
105    bind_addr: Option<IpAddr>,
106) -> Vec<SearchTarget> {
107    // Explicit --search-addr overrides everything (single target).
108    if let Some(ip) = search_addr {
109        return vec![SearchTarget {
110            target: ip,
111            bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
112        }];
113    }
114
115    let mut targets = Vec::new();
116    let mut seen = HashSet::new();
117
118    // Addresses from EPICS_PVA_ADDR_LIST.
119    if let Ok(env) = std::env::var("EPICS_PVA_ADDR_LIST") {
120        for ip in parse_addr_list(&env) {
121            if seen.insert(ip) {
122                targets.push(SearchTarget {
123                    target: ip,
124                    bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
125                });
126            }
127        }
128    }
129
130    // Merge auto-discovered broadcast addresses unless explicitly disabled.
131    // This matches EPICS Base behaviour: ADDR_LIST + auto-broadcast combined.
132    if is_auto_addr_list_enabled() {
133        for t in build_auto_broadcast_targets() {
134            if seen.insert(t.target) {
135                targets.push(SearchTarget {
136                    target: t.target,
137                    bind: bind_addr.unwrap_or(t.bind),
138                });
139            }
140        }
141    }
142
143    targets
144}
145
146pub fn is_auto_addr_list_enabled() -> bool {
147    match std::env::var("EPICS_PVA_AUTO_ADDR_LIST") {
148        Ok(v) => {
149            let v = v.trim().to_ascii_uppercase();
150            v == "YES" || v == "Y" || v == "1" || v == "TRUE"
151        }
152        Err(_) => true,
153    }
154}
155
156fn ipv4_is_link_local(ip: Ipv4Addr) -> bool {
157    let octets = ip.octets();
158    octets[0] == 169 && octets[1] == 254
159}
160
161fn choose_default_bind_v4() -> Option<Ipv4Addr> {
162    let ifaces = get_if_addrs().ok()?;
163    for iface in ifaces {
164        if let IfAddr::V4(v4) = iface.addr {
165            let ip = v4.ip;
166            if ip.is_loopback() || ipv4_is_link_local(ip) {
167                continue;
168            }
169            return Some(ip);
170        }
171    }
172    None
173}
174
175fn choose_default_bind_v6() -> Option<Ipv6Addr> {
176    let ifaces = get_if_addrs().ok()?;
177    for iface in ifaces {
178        if let IfAddr::V6(v6) = iface.addr {
179            let ip = v6.ip;
180            if ip.is_loopback() {
181                continue;
182            }
183            // Skip link-local (fe80::/10) — not routable without scope id
184            let segs = ip.segments();
185            if segs[0] & 0xffc0 == 0xfe80 {
186                continue;
187            }
188            return Some(ip);
189        }
190    }
191    None
192}
193
194fn broadcast_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
195    let ip_u = u32::from(ip);
196    let mask_u = u32::from(netmask);
197    Ipv4Addr::from(ip_u | !mask_u)
198}
199
200fn discovery_target_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
201    let limited_broadcast = Ipv4Addr::new(255, 255, 255, 255);
202    if netmask == Ipv4Addr::new(255, 255, 255, 255) || netmask.is_unspecified() {
203        return limited_broadcast;
204    }
205    let directed = broadcast_for(ip, netmask);
206    if directed == ip {
207        limited_broadcast
208    } else {
209        directed
210    }
211}
212
213pub fn build_auto_broadcast_targets() -> Vec<SearchTarget> {
214    let mut targets = Vec::new();
215    let mut fallback_targets = Vec::new();
216    let mut fallback_seen = HashSet::new();
217    let mut added_v4_multicast = false;
218    let mut added_v6_multicast = false;
219    let ifaces = match get_if_addrs() {
220        Ok(v) => v,
221        Err(_) => return targets,
222    };
223    for iface in &ifaces {
224        if let IfAddr::V4(v4) = &iface.addr {
225            let ip = v4.ip;
226            if ip.is_loopback() || ipv4_is_link_local(ip) {
227                continue;
228            }
229            let bcast = discovery_target_for(ip, v4.netmask);
230            targets.push(SearchTarget {
231                target: IpAddr::V4(bcast),
232                bind: IpAddr::V4(ip),
233            });
234            // Also send to IPv4 multicast group (matching PVXS behaviour).
235            // Docker overlay networks may block broadcast but allow multicast.
236            targets.push(SearchTarget {
237                target: IpAddr::V4(PVA_MULTICAST_V4),
238                bind: IpAddr::V4(ip),
239            });
240            if fallback_seen.insert(IpAddr::V4(bcast)) {
241                fallback_targets.push(SearchTarget {
242                    target: IpAddr::V4(bcast),
243                    bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
244                });
245            }
246            if !added_v4_multicast {
247                added_v4_multicast = true;
248                fallback_targets.push(SearchTarget {
249                    target: IpAddr::V4(PVA_MULTICAST_V4),
250                    bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
251                });
252            }
253        }
254    }
255    // Add IPv6 multicast targets for each non-loopback, non-link-local v6 iface.
256    for iface in &ifaces {
257        if let IfAddr::V6(v6) = &iface.addr {
258            let ip = v6.ip;
259            if ip.is_loopback() {
260                continue;
261            }
262            let segs = ip.segments();
263            if segs[0] & 0xffc0 == 0xfe80 {
264                continue; // skip link-local
265            }
266            let multicast_target = IpAddr::V6(PVA_MULTICAST_V6);
267            targets.push(SearchTarget {
268                target: multicast_target,
269                bind: IpAddr::V6(ip),
270            });
271            if !added_v6_multicast {
272                added_v6_multicast = true;
273                fallback_targets.push(SearchTarget {
274                    target: multicast_target,
275                    bind: IpAddr::V6(Ipv6Addr::UNSPECIFIED),
276                });
277            }
278        }
279    }
280    targets.extend(fallback_targets);
281    targets
282}
283
284/// PVA multicast group (IPv4).
285const PVA_MULTICAST_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 128);
286
287/// PVA multicast group (IPv6 link-local, ff02::42:1).
288const PVA_MULTICAST_V6: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0x42, 1);
289
290/// Best-effort join the PVA multicast group appropriate for the bind address.
291fn join_multicast_any(socket: &std::net::UdpSocket, bind: IpAddr) {
292    match bind {
293        IpAddr::V4(iface) => {
294            let _ = socket.join_multicast_v4(&PVA_MULTICAST_V4, &iface);
295        }
296        IpAddr::V6(_) => {
297            // interface index 0 = OS picks the default interface
298            let _ = socket.join_multicast_v6(&PVA_MULTICAST_V6, 0);
299        }
300    }
301}
302
303fn decode_search_response_addr(addr: [u8; 16], port: u16, src: SocketAddr) -> SocketAddr {
304    socket_addr_from_pva_bytes(addr, port)
305        .filter(|a| !a.ip().is_unspecified())
306        .unwrap_or_else(|| SocketAddr::new(src.ip(), port))
307}
308
309fn normalize_discovered_servers(items: Vec<DiscoveredServer>) -> Vec<DiscoveredServer> {
310    let mut seen = HashSet::new();
311    let mut out = Vec::new();
312    for item in items {
313        if seen.insert((item.guid, item.tcp_addr)) {
314            out.push(item);
315        }
316    }
317    out.sort_by(|a, b| a.tcp_addr.to_string().cmp(&b.tcp_addr.to_string()));
318    out
319}
320
321/// Create a UDP socket with SO_REUSEADDR set (matching PVXS behaviour),
322/// allowing multiple processes to share the search port.
323///
324/// On Windows SO_REUSEADDR has different (unsafe) semantics — it allows
325/// a second socket to steal an actively-used port — so we only enable it
326/// on Unix where it merely permits rebinding during TIME_WAIT.
327fn bind_udp_reuse(addr: SocketAddr) -> std::io::Result<std::net::UdpSocket> {
328    let domain = if addr.is_ipv4() {
329        Domain::IPV4
330    } else {
331        Domain::IPV6
332    };
333    let sock = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
334    #[cfg(unix)]
335    sock.set_reuse_address(true)?;
336    sock.set_nonblocking(true)?;
337    sock.bind(&addr.into())?;
338    Ok(sock.into())
339}
340
341pub async fn search_pv(
342    pv_name: &str,
343    udp_port: u16,
344    timeout_dur: Duration,
345    targets: &[SearchTarget],
346    debug_enabled: bool,
347) -> Result<SocketAddr, PvGetError> {
348    if targets.is_empty() {
349        return Err(PvGetError::Search("no search targets"));
350    }
351
352    let now = std::time::SystemTime::now()
353        .duration_since(std::time::UNIX_EPOCH)
354        .unwrap_or_default();
355    let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
356    let cid = seq ^ 0x9E37_79B9;
357
358    let mut last_io_error: Option<std::io::Error> = None;
359    let deadline = tokio::time::Instant::now() + timeout_dur;
360
361    // Group targets by bind address so we can share a socket per bind.
362    let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
363    for t in targets {
364        if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
365            group.1.push(t.target);
366        } else {
367            bind_groups.push((t.bind, vec![t.target]));
368        }
369    }
370
371    // Open sockets and send to all targets first, then collect responses.
372    // Store (socket, message, destinations) for retransmission.
373    let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
374
375    for (bind_ip, group_targets) in &bind_groups {
376        let bind_addr = SocketAddr::new(*bind_ip, udp_port);
377        let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
378            Ok(sock) => (sock, bind_addr),
379            Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
380                let fallback = SocketAddr::new(*bind_ip, 0);
381                match bind_udp_reuse(fallback) {
382                    Ok(sock) => {
383                        let actual = sock.local_addr().unwrap_or(fallback);
384                        if debug_enabled {
385                            debug!(
386                                "pva search bind={} failed (in use), fallback bind={}",
387                                bind_addr, actual
388                            );
389                        }
390                        (sock, actual)
391                    }
392                    Err(fallback_err) => {
393                        if debug_enabled {
394                            debug!(
395                                "pva search skipping bind={} step=bind-fallback kind={:?} err={}",
396                                bind_addr,
397                                fallback_err.kind(),
398                                fallback_err
399                            );
400                        }
401                        last_io_error = Some(fallback_err);
402                        continue;
403                    }
404                }
405            }
406            Err(err) => {
407                if debug_enabled {
408                    debug!(
409                        "pva search skipping bind={} step=bind kind={:?} err={}",
410                        bind_addr,
411                        err.kind(),
412                        err
413                    );
414                }
415                last_io_error = Some(err);
416                continue;
417            }
418        };
419        if let Err(err) = std_sock.set_broadcast(true) {
420            if debug_enabled {
421                debug!(
422                    "pva search skipping bind={} step=set_broadcast kind={:?} err={}",
423                    bind_addr,
424                    err.kind(),
425                    err
426                );
427            }
428            last_io_error = Some(err);
429            continue;
430        }
431
432        join_multicast_any(&std_sock, *bind_ip);
433
434        let reply_addr = ip_to_bytes(*bind_ip);
435        let reply_port = match std_sock.local_addr() {
436            Ok(addr) => addr.port(),
437            Err(err) => {
438                if debug_enabled {
439                    debug!(
440                        "pva search skipping bind={} step=local_addr kind={:?} err={}",
441                        bind_addr,
442                        err.kind(),
443                        err
444                    );
445                }
446                last_io_error = Some(err);
447                continue;
448            }
449        };
450        let requests = [(cid, pv_name)];
451        let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &requests, 2, false);
452
453        let socket = match UdpSocket::from_std(std_sock) {
454            Ok(socket) => socket,
455            Err(err) => {
456                if debug_enabled {
457                    debug!(
458                        "pva search skipping bind={} step=from_std kind={:?} err={}",
459                        bind_addr,
460                        err.kind(),
461                        err
462                    );
463                }
464                last_io_error = Some(err);
465                continue;
466            }
467        };
468
469        let dests: Vec<SocketAddr> = group_targets
470            .iter()
471            .map(|ip| SocketAddr::new(*ip, udp_port))
472            .collect();
473
474        // Send to every target in this bind group immediately.
475        for dest in &dests {
476            if debug_enabled {
477                debug!(
478                    "pva search bind={} target={} server_port={} reply_port={}",
479                    actual_bind_addr,
480                    dest.ip(),
481                    udp_port,
482                    reply_port
483                );
484                debug!("pva search seq={} cid={}", seq, cid);
485                debug!("pva search send {} bytes to {}", msg.len(), dest);
486            }
487            if let Err(err) = socket.send_to(&msg, dest).await {
488                if debug_enabled {
489                    debug!(
490                        "pva search send_to target={} kind={:?} err={}",
491                        dest,
492                        err.kind(),
493                        err
494                    );
495                }
496                last_io_error = Some(err);
497            }
498        }
499
500        socket_info.push((Arc::new(socket), msg, dests));
501    }
502
503    if socket_info.is_empty() {
504        if let Some(err) = last_io_error {
505            return Err(PvGetError::Io(err));
506        }
507        return Err(PvGetError::Timeout("search response"));
508    }
509
510    // Spawn a receiver task per socket that forwards packets into a shared channel.
511    let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
512    for (sock, _, _) in &socket_info {
513        let sock = Arc::clone(sock);
514        let tx = tx.clone();
515        tokio::spawn(async move {
516            loop {
517                let mut buf = vec![0u8; 2048];
518                match sock.recv_from(&mut buf).await {
519                    Ok((len, src)) => {
520                        buf.truncate(len);
521                        if tx.send((buf, src)).await.is_err() {
522                            break;
523                        }
524                    }
525                    Err(_) => break,
526                }
527            }
528        });
529    }
530    drop(tx); // Only spawned tasks hold senders; channel closes when they exit.
531
532    // Retransmit schedule: exponential backoff from start.
533    let retransmit_offsets = [100u64, 500, 1000, 2000];
534    let start = tokio::time::Instant::now();
535    let mut next_retransmit = 0usize;
536
537    loop {
538        // Compute the next wake-up: either the next retransmit or the deadline.
539        let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
540            start + Duration::from_millis(retransmit_offsets[next_retransmit])
541        } else {
542            deadline
543        };
544        let wake_at = next_retransmit_at.min(deadline);
545
546        tokio::select! {
547            recv = rx.recv() => {
548                let Some((buf, src)) = recv else { break };
549                let mut pkt = PvaPacket::new(&buf);
550                let cmd = pkt
551                    .decode_payload()
552                    .ok_or(PvGetError::Search("failed to decode search response"))?;
553                if let PvaPacketCommand::SearchResponse(payload) = cmd {
554                    if debug_enabled {
555                        debug!(
556                            "pva search response found={} cids={:?} addr={:?} port={}",
557                            payload.found, payload.cids, payload.addr, payload.port
558                        );
559                    }
560                    if payload.seq != seq {
561                        continue;
562                    }
563                    if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
564                        continue;
565                    }
566                    if !payload.found {
567                        continue;
568                    }
569                    if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
570                        continue;
571                    }
572
573                    let addr = decode_search_response_addr(payload.addr, payload.port, src);
574                    if debug_enabled {
575                        debug!("pva search response from {}", addr);
576                    }
577                    return Ok(addr);
578                }
579            }
580            _ = tokio::time::sleep_until(wake_at) => {
581                if tokio::time::Instant::now() >= deadline {
582                    break;
583                }
584                // Retransmit to all targets on all sockets.
585                if next_retransmit < retransmit_offsets.len() {
586                    if debug_enabled {
587                        debug!("pva search retransmit round {}", next_retransmit + 1);
588                    }
589                    for (sock, msg, dests) in &socket_info {
590                        for dest in dests {
591                            let _ = sock.send_to(msg, dest).await;
592                        }
593                    }
594                    next_retransmit += 1;
595                }
596            }
597        }
598    }
599
600    Err(PvGetError::Timeout("search response"))
601}
602
603pub fn default_bind_ip() -> Option<IpAddr> {
604    choose_default_bind_v4()
605        .map(IpAddr::V4)
606        .or_else(|| choose_default_bind_v6().map(IpAddr::V6))
607}
608
609/// Parse `EPICS_PVA_NAME_SERVERS` value into socket addresses.
610/// Accepts space/comma separated entries: `host:port`, `ip`, `hostname`
611/// (port defaults to 5075).
612pub fn parse_name_servers(env_val: &str) -> Vec<SocketAddr> {
613    let mut out = Vec::new();
614    for token in env_val.split(|c| c == ',' || c == ' ' || c == '\t') {
615        let token = token.trim();
616        if token.is_empty() {
617            continue;
618        }
619        if let Ok(addr) = token.parse::<SocketAddr>() {
620            out.push(addr);
621            continue;
622        }
623        if let Ok(ip) = token.parse::<IpAddr>() {
624            out.push(SocketAddr::new(ip, 5075));
625            continue;
626        }
627        use std::net::ToSocketAddrs;
628        if let Ok(mut addrs) = token.to_socket_addrs() {
629            if let Some(addr) = addrs.next() {
630                out.push(addr);
631                continue;
632            }
633        }
634        let with_port = format!("{}:5075", token);
635        if let Ok(mut addrs) = with_port.to_socket_addrs() {
636            if let Some(addr) = addrs.next() {
637                out.push(addr);
638            }
639        }
640    }
641    out
642}
643
644/// Build a minimal PVA ConnectionValidation response for name server search.
645fn encode_search_validation(version: u8, is_be: bool) -> Vec<u8> {
646    let user = default_authnz_user();
647    let host = default_authnz_host();
648    encode_client_connection_validation(87_040, 32_767, 0, "ca", &user, &host, version, is_be)
649}
650
651/// Search for a PV via a TCP connection to a PVA name server.
652///
653/// Connects to the name server, performs the PVA handshake, sends a search
654/// request over TCP, and returns the server address from the search response.
655pub async fn search_pv_tcp(
656    pv_name: &str,
657    name_server: SocketAddr,
658    timeout_dur: Duration,
659    debug_enabled: bool,
660) -> Result<SocketAddr, PvGetError> {
661    let deadline = tokio::time::Instant::now() + timeout_dur;
662
663    let mut stream = tokio::time::timeout(timeout_dur, tokio::net::TcpStream::connect(name_server))
664        .await
665        .map_err(|_| PvGetError::Timeout("name server connect"))??;
666
667    let mut version = 2u8;
668    let mut is_be = false;
669
670    // Read SET_BYTE_ORDER + ConnectionValidation from name server.
671    for _ in 0..2 {
672        let now = tokio::time::Instant::now();
673        if now >= deadline {
674            return Err(PvGetError::Timeout("name server handshake"));
675        }
676        let remaining = deadline - now;
677        if let Ok(bytes) = read_packet(&mut stream, remaining).await {
678            let mut pkt = PvaPacket::new(&bytes);
679            if let Some(cmd) = pkt.decode_payload() {
680                match cmd {
681                    PvaPacketCommand::Control(payload) => {
682                        if payload.command == 2 {
683                            is_be = pkt.header.flags.is_msb;
684                        }
685                    }
686                    PvaPacketCommand::ConnectionValidation(_) => {
687                        version = pkt.header.version;
688                        is_be = pkt.header.flags.is_msb;
689                    }
690                    _ => {}
691                }
692            }
693        }
694    }
695
696    let validation = encode_search_validation(version, is_be);
697    stream.write_all(&validation).await?;
698
699    // Wait for ConnectionValidated.
700    loop {
701        let now = tokio::time::Instant::now();
702        if now >= deadline {
703            return Err(PvGetError::Timeout("name server validated"));
704        }
705        let remaining = deadline - now;
706        let bytes = read_packet(&mut stream, remaining).await?;
707        let mut pkt = PvaPacket::new(&bytes);
708        if let Some(cmd) = pkt.decode_payload() {
709            if matches!(cmd, PvaPacketCommand::ConnectionValidated(_)) {
710                break;
711            }
712        }
713    }
714
715    // Send search request over TCP.
716    let now_ts = std::time::SystemTime::now()
717        .duration_since(std::time::UNIX_EPOCH)
718        .unwrap_or_default();
719    let seq = (now_ts.as_nanos() as u32).wrapping_add(std::process::id());
720    let cid = seq ^ 0x9E37_79B9;
721    let requests = [(cid, pv_name)];
722    let msg = encode_search_request(seq, 0x80, 0, [0u8; 16], &requests, version, is_be);
723    stream.write_all(&msg).await?;
724
725    if debug_enabled {
726        debug!(
727            "pva tcp search sent to name_server={} pv={}",
728            name_server, pv_name
729        );
730    }
731
732    // Read search response.
733    loop {
734        let now = tokio::time::Instant::now();
735        if now >= deadline {
736            return Err(PvGetError::Timeout("name server search response"));
737        }
738        let remaining = deadline - now;
739        let bytes = read_packet(&mut stream, remaining).await?;
740        let mut pkt = PvaPacket::new(&bytes);
741        if let Some(cmd) = pkt.decode_payload() {
742            if let PvaPacketCommand::SearchResponse(payload) = cmd {
743                if !payload.found {
744                    continue;
745                }
746                if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
747                    continue;
748                }
749                let addr = decode_search_response_addr(payload.addr, payload.port, name_server);
750                if debug_enabled {
751                    debug!(
752                        "pva tcp search response from name_server={}: {}",
753                        name_server, addr
754                    );
755                }
756                return Ok(addr);
757            }
758        }
759    }
760}
761
762/// Resolve the PVA server for a PV using name servers (TCP) and/or UDP search.
763///
764/// - If `opts.server_addr` is set, returns it directly.
765/// - Tries each name server from `opts.name_servers` and `EPICS_PVA_NAME_SERVERS`
766///   via TCP search.
767/// - Falls back to UDP search using `build_search_targets()`.
768pub async fn resolve_pv_server(opts: &PvGetOptions) -> Result<SocketAddr, PvGetError> {
769    if let Some(addr) = opts.server_addr {
770        return Ok(addr);
771    }
772
773    let mut name_servers = opts.name_servers.clone();
774    if let Ok(env) = std::env::var("EPICS_PVA_NAME_SERVERS") {
775        name_servers.extend(parse_name_servers(&env));
776    }
777
778    let no_broadcast = opts.no_broadcast;
779
780    // Fail fast when no search strategy is available.
781    if no_broadcast && name_servers.is_empty() {
782        return Err(PvGetError::Search(
783            "no search strategy: specify --name-server or --server when using --no-broadcast",
784        ));
785    }
786
787    // Launch all search strategies concurrently — TCP name servers + UDP broadcast.
788    // Return the first successful result.
789    let targets = build_search_targets(opts.search_addr, opts.bind_addr);
790
791    let pv = opts.pv_name.clone();
792    let timeout_dur = opts.timeout;
793    let debug_enabled = opts.debug;
794    let udp_port = opts.udp_port;
795
796    let mut set = tokio::task::JoinSet::new();
797
798    for ns in name_servers {
799        let pv = pv.clone();
800        set.spawn(async move {
801            let addr = search_pv_tcp(&pv, ns, timeout_dur, debug_enabled).await?;
802            Ok::<SocketAddr, PvGetError>(addr)
803        });
804    }
805
806    if !no_broadcast {
807        let pv = pv.clone();
808        let targets = targets.clone();
809        set.spawn(async move {
810            let addr = search_pv(&pv, udp_port, timeout_dur, &targets, debug_enabled).await?;
811            Ok(addr)
812        });
813    }
814
815    let mut last_err = None;
816    while let Some(result) = set.join_next().await {
817        match result {
818            Ok(Ok(addr)) => {
819                set.abort_all();
820                return Ok(addr);
821            }
822            Ok(Err(e)) => {
823                if debug_enabled {
824                    debug!("pva search strategy failed: {}", e);
825                }
826                last_err = Some(e);
827            }
828            Err(join_err) => {
829                if debug_enabled {
830                    debug!("pva search task panicked: {}", join_err);
831                }
832            }
833        }
834    }
835
836    Err(last_err.unwrap_or(PvGetError::Timeout("search response")))
837}
838
839pub async fn discover_servers(
840    udp_port: u16,
841    timeout_dur: Duration,
842    targets: &[SearchTarget],
843    debug_enabled: bool,
844) -> Result<Vec<DiscoveredServer>, PvGetError> {
845    if targets.is_empty() {
846        return Err(PvGetError::Search("no search targets"));
847    }
848
849    let now = std::time::SystemTime::now()
850        .duration_since(std::time::UNIX_EPOCH)
851        .unwrap_or_default();
852    let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
853
854    let mut found: Vec<DiscoveredServer> = Vec::new();
855    let mut last_io_error: Option<std::io::Error> = None;
856    let deadline = tokio::time::Instant::now() + timeout_dur;
857
858    // Group targets by bind address so we can share a socket per bind.
859    let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
860    for t in targets {
861        if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
862            group.1.push(t.target);
863        } else {
864            bind_groups.push((t.bind, vec![t.target]));
865        }
866    }
867
868    // Open sockets and send to all targets first, then collect responses.
869    // Store (socket, message, destinations) for retransmission.
870    let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
871
872    for (bind_ip, group_targets) in &bind_groups {
873        let bind_addr = SocketAddr::new(*bind_ip, udp_port);
874        let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
875            Ok(sock) => (sock, bind_addr),
876            Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
877                let fallback = SocketAddr::new(*bind_ip, 0);
878                match bind_udp_reuse(fallback) {
879                    Ok(sock) => {
880                        let actual = sock.local_addr().unwrap_or(fallback);
881                        if debug_enabled {
882                            debug!(
883                                "pva discover bind={} failed (in use), fallback bind={}",
884                                bind_addr, actual
885                            );
886                        }
887                        (sock, actual)
888                    }
889                    Err(fallback_err) => {
890                        if debug_enabled {
891                            debug!(
892                                "pva discover skipping bind={} step=bind-fallback kind={:?} err={}",
893                                bind_addr,
894                                fallback_err.kind(),
895                                fallback_err
896                            );
897                        }
898                        last_io_error = Some(fallback_err);
899                        continue;
900                    }
901                }
902            }
903            Err(err) => {
904                if debug_enabled {
905                    debug!(
906                        "pva discover skipping bind={} step=bind kind={:?} err={}",
907                        bind_addr,
908                        err.kind(),
909                        err
910                    );
911                }
912                last_io_error = Some(err);
913                continue;
914            }
915        };
916        if let Err(err) = std_sock.set_broadcast(true) {
917            if debug_enabled {
918                debug!(
919                    "pva discover skipping bind={} step=set_broadcast kind={:?} err={}",
920                    bind_addr,
921                    err.kind(),
922                    err
923                );
924            }
925            last_io_error = Some(err);
926            continue;
927        }
928
929        join_multicast_any(&std_sock, *bind_ip);
930
931        let reply_addr = ip_to_bytes(*bind_ip);
932        let reply_port = match std_sock.local_addr() {
933            Ok(addr) => addr.port(),
934            Err(err) => {
935                if debug_enabled {
936                    debug!(
937                        "pva discover skipping bind={} step=local_addr kind={:?} err={}",
938                        bind_addr,
939                        err.kind(),
940                        err
941                    );
942                }
943                last_io_error = Some(err);
944                continue;
945            }
946        };
947        let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &[], 2, false);
948
949        let socket = match UdpSocket::from_std(std_sock) {
950            Ok(socket) => socket,
951            Err(err) => {
952                if debug_enabled {
953                    debug!(
954                        "pva discover skipping bind={} step=from_std kind={:?} err={}",
955                        bind_addr,
956                        err.kind(),
957                        err
958                    );
959                }
960                last_io_error = Some(err);
961                continue;
962            }
963        };
964
965        let dests: Vec<SocketAddr> = group_targets
966            .iter()
967            .map(|ip| SocketAddr::new(*ip, udp_port))
968            .collect();
969
970        // Send to every target in this bind group immediately.
971        for dest in &dests {
972            if debug_enabled {
973                debug!(
974                    "pva discover bind={} target={} server_port={} reply_port={} seq={}",
975                    actual_bind_addr,
976                    dest.ip(),
977                    udp_port,
978                    reply_port,
979                    seq
980                );
981            }
982            if let Err(err) = socket.send_to(&msg, dest).await {
983                if debug_enabled {
984                    debug!(
985                        "pva discover send_to target={} kind={:?} err={}",
986                        dest,
987                        err.kind(),
988                        err
989                    );
990                }
991                last_io_error = Some(err);
992            }
993        }
994
995        socket_info.push((Arc::new(socket), msg, dests));
996    }
997
998    if socket_info.is_empty() {
999        if let Some(err) = last_io_error {
1000            return Err(PvGetError::Io(err));
1001        }
1002        return Err(PvGetError::Search("no search targets"));
1003    }
1004
1005    // Spawn a receiver task per socket that forwards packets into a shared channel.
1006    let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
1007    for (sock, _, _) in &socket_info {
1008        let sock = Arc::clone(sock);
1009        let tx = tx.clone();
1010        tokio::spawn(async move {
1011            loop {
1012                let mut buf = vec![0u8; 2048];
1013                match sock.recv_from(&mut buf).await {
1014                    Ok((len, src)) => {
1015                        buf.truncate(len);
1016                        if tx.send((buf, src)).await.is_err() {
1017                            break;
1018                        }
1019                    }
1020                    Err(_) => break,
1021                }
1022            }
1023        });
1024    }
1025    drop(tx); // Only spawned tasks hold senders; channel closes when they exit.
1026
1027    // Retransmit schedule: exponential backoff from start.
1028    let retransmit_offsets = [100u64, 500, 1000, 2000];
1029    let start = tokio::time::Instant::now();
1030    let mut next_retransmit = 0usize;
1031
1032    loop {
1033        // Compute the next wake-up: either the next retransmit or the deadline.
1034        let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
1035            start + Duration::from_millis(retransmit_offsets[next_retransmit])
1036        } else {
1037            deadline
1038        };
1039        let wake_at = next_retransmit_at.min(deadline);
1040
1041        tokio::select! {
1042            recv = rx.recv() => {
1043                let Some((buf, src)) = recv else { break };
1044                let mut pkt = PvaPacket::new(&buf);
1045                let Some(cmd) = pkt.decode_payload() else {
1046                    continue;
1047                };
1048                if let PvaPacketCommand::SearchResponse(payload) = cmd {
1049                    if payload.seq != seq {
1050                        continue;
1051                    }
1052                    if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
1053                        continue;
1054                    }
1055                    let tcp_addr = decode_search_response_addr(payload.addr, payload.port, src);
1056                    found.push(DiscoveredServer {
1057                        guid: payload.guid,
1058                        tcp_addr,
1059                    });
1060                }
1061            }
1062            _ = tokio::time::sleep_until(wake_at) => {
1063                if tokio::time::Instant::now() >= deadline {
1064                    break;
1065                }
1066                // Retransmit to all targets on all sockets.
1067                if next_retransmit < retransmit_offsets.len() {
1068                    if debug_enabled {
1069                        debug!("pva discover retransmit round {}", next_retransmit + 1);
1070                    }
1071                    for (sock, msg, dests) in &socket_info {
1072                        for dest in dests {
1073                            let _ = sock.send_to(msg, dest).await;
1074                        }
1075                    }
1076                    next_retransmit += 1;
1077                }
1078            }
1079        }
1080    }
1081
1082    Ok(normalize_discovered_servers(found))
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088    use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
1089
1090    #[test]
1091    fn encode_decode_search_request_roundtrip() {
1092        let seq = 1234;
1093        let cid = 42;
1094        let port = 5076;
1095        let pv_name = "TEST:PV";
1096        let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 20)));
1097        let requests = [(cid, pv_name)];
1098        let msg = encode_search_request(seq, 0x81, port, reply_addr, &requests, 2, false);
1099        let mut pkt = PvaPacket::new(&msg);
1100        let cmd = pkt.decode_payload().expect("decoded");
1101        match cmd {
1102            PvaPacketCommand::Search(payload) => {
1103                assert_eq!(payload.seq, seq);
1104                assert_eq!(payload.mask, 0x81);
1105                assert_eq!(payload.addr, reply_addr);
1106                assert_eq!(payload.port, port);
1107                assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1108                assert_eq!(payload.pv_requests.len(), 1);
1109                assert_eq!(payload.pv_requests[0].0, cid);
1110                assert_eq!(payload.pv_requests[0].1, pv_name.to_string());
1111            }
1112            other => panic!("unexpected decode: {:?}", other),
1113        }
1114    }
1115
1116    #[test]
1117    fn encode_decode_server_discovery_request_roundtrip() {
1118        let seq = 4321;
1119        let port = 5076;
1120        let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(10, 20, 30, 40)));
1121        let msg = encode_search_request(seq, 0x81, port, reply_addr, &[], 2, false);
1122        let mut pkt = PvaPacket::new(&msg);
1123        let cmd = pkt.decode_payload().expect("decoded");
1124        match cmd {
1125            PvaPacketCommand::Search(payload) => {
1126                assert_eq!(payload.seq, seq);
1127                assert_eq!(payload.pv_requests.len(), 0);
1128                assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1129            }
1130            other => panic!("unexpected decode: {:?}", other),
1131        }
1132    }
1133
1134    #[test]
1135    fn normalize_discovered_servers_deduplicates_by_guid_and_addr() {
1136        let guid = [1u8; 12];
1137        let s1 = DiscoveredServer {
1138            guid,
1139            tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1140        };
1141        let s2 = DiscoveredServer {
1142            guid,
1143            tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1144        };
1145        let s3 = DiscoveredServer {
1146            guid: [2u8; 12],
1147            tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1148        };
1149        let normalized = normalize_discovered_servers(vec![s1, s2, s3]);
1150        assert_eq!(normalized.len(), 2);
1151    }
1152
1153    #[test]
1154    fn parse_addr_list_accepts_ip_and_ip_port() {
1155        let items = parse_addr_list("192.168.1.10 10.0.0.1:5076");
1156        assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10))));
1157        assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
1158    }
1159
1160    #[test]
1161    fn discovery_target_falls_back_to_limited_broadcast_for_invalid_netmask() {
1162        let ip = Ipv4Addr::new(130, 246, 90, 92);
1163        assert_eq!(
1164            discovery_target_for(ip, Ipv4Addr::new(255, 255, 255, 255)),
1165            Ipv4Addr::new(255, 255, 255, 255)
1166        );
1167        assert_eq!(
1168            discovery_target_for(ip, Ipv4Addr::new(0, 0, 0, 0)),
1169            Ipv4Addr::new(255, 255, 255, 255)
1170        );
1171    }
1172
1173    #[test]
1174    fn discovery_target_uses_directed_broadcast_for_normal_subnet() {
1175        let ip = Ipv4Addr::new(192, 168, 56, 1);
1176        let netmask = Ipv4Addr::new(255, 255, 255, 0);
1177        assert_eq!(
1178            discovery_target_for(ip, netmask),
1179            Ipv4Addr::new(192, 168, 56, 255)
1180        );
1181    }
1182
1183    #[test]
1184    fn parse_name_servers_ip_with_port() {
1185        let addrs = parse_name_servers("192.168.1.10:5075");
1186        assert_eq!(
1187            addrs,
1188            vec!["192.168.1.10:5075".parse::<SocketAddr>().unwrap()]
1189        );
1190    }
1191
1192    #[test]
1193    fn parse_name_servers_ip_without_port_defaults_to_5075() {
1194        let addrs = parse_name_servers("10.0.0.1");
1195        assert_eq!(
1196            addrs,
1197            vec![SocketAddr::new(
1198                IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
1199                5075
1200            )]
1201        );
1202    }
1203
1204    #[test]
1205    fn parse_name_servers_multiple_comma_separated() {
1206        let addrs = parse_name_servers("10.0.0.1:5075,10.0.0.2:9876");
1207        assert_eq!(addrs.len(), 2);
1208        assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1209        assert_eq!(addrs[1], "10.0.0.2:9876".parse::<SocketAddr>().unwrap());
1210    }
1211
1212    #[test]
1213    fn parse_name_servers_multiple_space_separated() {
1214        let addrs = parse_name_servers("10.0.0.1 10.0.0.2:5075");
1215        assert_eq!(addrs.len(), 2);
1216        assert_eq!(
1217            addrs[0],
1218            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5075)
1219        );
1220        assert_eq!(addrs[1], "10.0.0.2:5075".parse::<SocketAddr>().unwrap());
1221    }
1222
1223    #[test]
1224    fn parse_name_servers_empty_string() {
1225        let addrs = parse_name_servers("");
1226        assert!(addrs.is_empty());
1227    }
1228
1229    #[test]
1230    fn parse_name_servers_whitespace_only() {
1231        let addrs = parse_name_servers("  \t  ");
1232        assert!(addrs.is_empty());
1233    }
1234
1235    #[test]
1236    fn parse_name_servers_mixed_separators() {
1237        let addrs = parse_name_servers("10.0.0.1:5075, 10.0.0.2  ,  10.0.0.3:9999");
1238        assert_eq!(addrs.len(), 3);
1239        assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1240        assert_eq!(
1241            addrs[1],
1242            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 5075)
1243        );
1244        assert_eq!(addrs[2], "10.0.0.3:9999".parse::<SocketAddr>().unwrap());
1245    }
1246
1247    #[test]
1248    fn parse_name_servers_ipv6_with_port() {
1249        let addrs = parse_name_servers("[::1]:5075");
1250        assert_eq!(
1251            addrs,
1252            vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]
1253        );
1254    }
1255
1256    #[test]
1257    fn parse_name_servers_ipv6_without_port() {
1258        let addrs = parse_name_servers("::1");
1259        assert_eq!(
1260            addrs,
1261            vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]
1262        );
1263    }
1264
1265    #[test]
1266    fn decode_search_response_addr_falls_back_to_udp_source_when_unspecified() {
1267        let src: SocketAddr = "192.168.1.20:5076".parse().unwrap();
1268        let decoded = decode_search_response_addr([0u8; 16], 5075, src);
1269        assert_eq!(decoded, "192.168.1.20:5075".parse().unwrap());
1270    }
1271}