Skip to main content

viam_rust_utils/rpc/
dial.rs

1use super::{
2    client_channel::*,
3    log_prefixes,
4    webrtc::{webrtc_action_with_timeout, Options},
5};
6use crate::gen::google;
7use crate::gen::proto::rpc::v1::{
8    auth_service_client::AuthServiceClient, AuthenticateRequest, Credentials,
9};
10use crate::gen::proto::rpc::webrtc::v1::{
11    call_response::Stage, call_update_request::Update,
12    signaling_service_client::SignalingServiceClient, CallUpdateRequest,
13    OptionalWebRtcConfigRequest, OptionalWebRtcConfigResponse,
14};
15use crate::gen::proto::rpc::webrtc::v1::{
16    CallRequest, IceCandidate, Metadata, RequestHeaders, Strings,
17};
18use crate::rpc::webrtc;
19use ::http::header::HeaderName;
20use ::http::{
21    uri::{Authority, Parts, PathAndQuery, Scheme},
22    HeaderValue, Version,
23};
24use ::viam_mdns::{discover, RecordKind, Response};
25use ::webrtc::ice_transport::{
26    ice_candidate::{RTCIceCandidate, RTCIceCandidateInit},
27    ice_connection_state::RTCIceConnectionState,
28};
29use ::webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
30use anyhow::{Context, Result};
31use core::fmt;
32use futures::stream::FuturesUnordered;
33use futures_util::{pin_mut, stream::StreamExt};
34use local_ip_address::list_afinet_netifas;
35use std::{
36    collections::HashMap,
37    net::{IpAddr, Ipv4Addr},
38    sync::{
39        atomic::{AtomicBool, Ordering},
40        Arc, Mutex, RwLock,
41    },
42    task::{Context as TaskContext, Poll},
43    time::{Duration, Instant},
44};
45use tokio::sync::{mpsc, watch};
46use tonic::body::BoxBody;
47use tonic::codegen::BoxFuture;
48use tonic::transport::{Body, Channel, ClientTlsConfig, Uri};
49
50use tower::{Service, ServiceBuilder};
51use tower_http::auth::AddAuthorization;
52use tower_http::auth::AddAuthorizationLayer;
53use tower_http::set_header::{SetRequestHeader, SetRequestHeaderLayer};
54
55// gRPC status codes
56const STATUS_CODE_OK: i32 = 0;
57const STATUS_CODE_UNKNOWN: i32 = 2;
58const STATUS_CODE_RESOURCE_EXHAUSTED: i32 = 8;
59
60pub const VIAM_MDNS_SERVICE_NAME: &'static str = "_rpc._tcp.local";
61
62type SecretType = String;
63
64#[derive(Clone)]
65/// A communication channel to a given uri. The channel is either a direct tonic channel,
66/// or a webRTC channel.
67pub enum ViamChannel {
68    Direct(Channel),
69    DirectPreAuthorized(AddAuthorization<SetRequestHeader<Channel, HeaderValue>>),
70    WebRTC(Arc<WebRTCClientChannel>),
71}
72
73#[derive(Debug, Clone)]
74pub struct RPCCredentials {
75    entity: Option<String>,
76    credentials: Credentials,
77}
78
79impl RPCCredentials {
80    pub fn new(entity: Option<String>, r#type: SecretType, payload: String) -> Self {
81        Self {
82            credentials: Credentials { r#type, payload },
83            entity,
84        }
85    }
86}
87
88impl ViamChannel {
89    async fn create_resp(
90        channel: &mut Arc<WebRTCClientChannel>,
91        stream: crate::gen::proto::rpc::webrtc::v1::Stream,
92        request: http::Request<BoxBody>,
93        response: http::response::Builder,
94    ) -> http::Response<Body> {
95        let (parts, body) = request.into_parts();
96        let mut status_code = STATUS_CODE_OK;
97        let stream_id = stream.id;
98        let metadata = Some(metadata_from_parts(&parts));
99        let headers = RequestHeaders {
100            method: parts
101                .uri
102                .path_and_query()
103                .map(PathAndQuery::to_string)
104                .unwrap_or_default(),
105            metadata,
106            timeout: None,
107        };
108
109        if let Err(e) = channel.write_headers(&stream, headers).await {
110            log::error!("error writing headers: {e}");
111            channel.close_stream_with_recv_error(stream_id, e);
112            status_code = STATUS_CODE_UNKNOWN;
113        }
114
115        let data = hyper::body::to_bytes(body).await.unwrap().to_vec();
116        if let Err(e) = channel.write_message(Some(stream), data).await {
117            log::error!("error sending message: {e}");
118            channel.close_stream_with_recv_error(stream_id, e);
119            status_code = STATUS_CODE_UNKNOWN;
120        };
121
122        let body = match channel.resp_body_from_stream(stream_id) {
123            Ok(body) => body,
124            Err(e) => {
125                log::error!("error receiving response from stream: {e}");
126                channel.close_stream_with_recv_error(stream_id, e);
127                status_code = STATUS_CODE_UNKNOWN;
128                Body::empty()
129            }
130        };
131
132        let response = if status_code != STATUS_CODE_OK {
133            response.header("grpc-status", &status_code.to_string())
134        } else {
135            response
136        };
137
138        response.body(body).unwrap()
139    }
140}
141
142impl Service<http::Request<BoxBody>> for ViamChannel {
143    type Response = http::Response<Body>;
144    type Error = tonic::transport::Error;
145    type Future = BoxFuture<Self::Response, Self::Error>;
146
147    fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
148        match self {
149            Self::Direct(channel) => channel.poll_ready(cx),
150            Self::DirectPreAuthorized(channel) => channel.poll_ready(cx),
151            Self::WebRTC(_channel) => Poll::Ready(Ok(())),
152        }
153    }
154
155    fn call(&mut self, request: http::Request<BoxBody>) -> Self::Future {
156        match self {
157            Self::Direct(channel) => Box::pin(channel.call(request)),
158            Self::DirectPreAuthorized(channel) => Box::pin(channel.call(request)),
159            Self::WebRTC(channel) => {
160                let mut channel = channel.clone();
161                let fut = async move {
162                    let response = http::response::Response::builder()
163                        // standardized gRPC headers.
164                        .header("content-type", "application/grpc")
165                        .version(Version::HTTP_2);
166
167                    match channel.new_stream() {
168                        Err(e) => {
169                            log::error!("{e}");
170                            let response = response
171                                .header("grpc-status", &STATUS_CODE_RESOURCE_EXHAUSTED.to_string())
172                                .body(Body::default())
173                                .unwrap();
174
175                            Ok(response)
176                        }
177                        Ok(stream) => {
178                            Ok(Self::create_resp(&mut channel, stream, request, response).await)
179                        }
180                    }
181                };
182                Box::pin(fut)
183            }
184        }
185    }
186}
187
188/// Options for modifying the connection parameters
189#[derive(Debug)]
190pub struct DialOptions {
191    credentials: Option<RPCCredentials>,
192    webrtc_options: Option<Options>,
193    uri: Option<Parts>,
194    disable_mdns: bool,
195    allow_downgrade: bool,
196    insecure: bool,
197    signaling_server_override: Option<String>,
198}
199#[derive(Clone)]
200pub struct WantsCredentials(());
201#[derive(Clone)]
202pub struct WantsUri(());
203#[derive(Clone)]
204pub struct WithCredentials(());
205#[derive(Clone)]
206pub struct WithoutCredentials(());
207
208pub trait AuthMethod {}
209impl AuthMethod for WithCredentials {}
210impl AuthMethod for WithoutCredentials {}
211/// A DialBuilder allows us to set options before establishing a connection to a server
212#[allow(dead_code)]
213pub struct DialBuilder<T> {
214    state: T,
215    config: DialOptions,
216}
217
218impl<T> fmt::Debug for DialBuilder<T> {
219    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220        f.debug_struct("Dial")
221            .field("State", &format_args!("{}", &std::any::type_name::<T>()))
222            .field("Opt", &format_args!("{:?}", self.config))
223            .finish()
224    }
225}
226
227impl DialOptions {
228    /// Creates a new DialBuilder
229    pub fn builder() -> DialBuilder<WantsUri> {
230        DialBuilder {
231            state: WantsUri(()),
232            config: DialOptions {
233                credentials: None,
234                uri: None,
235                allow_downgrade: false,
236                disable_mdns: false,
237                insecure: false,
238                webrtc_options: None,
239                signaling_server_override: None,
240            },
241        }
242    }
243}
244
245impl DialBuilder<WantsUri> {
246    /// Sets the uri to connect to
247    pub fn uri(self, uri: &str) -> DialBuilder<WantsCredentials> {
248        let uri_parts = uri_parts_with_defaults(uri);
249        DialBuilder {
250            state: WantsCredentials(()),
251            config: DialOptions {
252                credentials: None,
253                uri: Some(uri_parts),
254                allow_downgrade: false,
255                disable_mdns: false,
256                insecure: false,
257                webrtc_options: None,
258                signaling_server_override: None,
259            },
260        }
261    }
262}
263impl DialBuilder<WantsCredentials> {
264    /// Tells connecting logic to not expect/require credentials
265    pub fn without_credentials(self) -> DialBuilder<WithoutCredentials> {
266        DialBuilder {
267            state: WithoutCredentials(()),
268            config: DialOptions {
269                credentials: None,
270                uri: self.config.uri,
271                allow_downgrade: false,
272                disable_mdns: false,
273                insecure: false,
274                webrtc_options: None,
275                signaling_server_override: None,
276            },
277        }
278    }
279    /// Sets credentials to use when connecting
280    pub fn with_credentials(self, creds: RPCCredentials) -> DialBuilder<WithCredentials> {
281        DialBuilder {
282            state: WithCredentials(()),
283            config: DialOptions {
284                credentials: Some(creds),
285                uri: self.config.uri,
286                allow_downgrade: false,
287                disable_mdns: false,
288                insecure: false,
289                webrtc_options: None,
290                signaling_server_override: None,
291            },
292        }
293    }
294}
295
296impl<T: AuthMethod> DialBuilder<T> {
297    /// Attempts to connect insecurely with scheme of HTTP as a default
298    pub fn insecure(mut self) -> Self {
299        self.config.insecure = true;
300        self
301    }
302    /// Allows for downgrading and attempting to connect via HTTP if HTTPS fails
303    pub fn allow_downgrade(mut self) -> Self {
304        self.config.allow_downgrade = true;
305        self
306    }
307    /// Disables connection via mDNS
308    pub fn disable_mdns(mut self) -> Self {
309        self.config.disable_mdns = true;
310        self
311    }
312
313    /// Overrides any default connection behavior, forcing direct connection. Note that
314    /// the connection itself will fail if it is between a client and server on separate
315    /// networks and not over webRTC
316    pub fn disable_webrtc(mut self) -> Self {
317        let webrtc_options = Options::default().disable_webrtc();
318        self.config.webrtc_options = Some(webrtc_options);
319        self
320    }
321
322    /// Forces ICE transport policy to relay-only so only TURN candidates are used.
323    /// Useful for testing relay connectivity through a TURN server.
324    pub fn force_relay(mut self) -> Self {
325        self.config
326            .webrtc_options
327            .get_or_insert_with(Options::default)
328            .force_relay = true;
329        self
330    }
331
332    /// Strips TURN servers from the ICE config so only host and server-reflexive
333    /// candidates are used. Useful for testing direct connectivity without relay fallback.
334    pub fn force_p2p(mut self) -> Self {
335        self.config
336            .webrtc_options
337            .get_or_insert_with(Options::default)
338            .force_p2p = true;
339        self
340    }
341
342    /// Filters the signaling server's TURN list to only the server whose parsed URI
343    /// matches (compared by scheme, host, port, and transport — defaulting transport
344    /// to UDP if unspecified). Example: "turn:turn.viam.com:443"
345    pub fn turn_uri(mut self, uri: String) -> Self {
346        self.config
347            .webrtc_options
348            .get_or_insert_with(Options::default)
349            .turn_uri = Some(uri);
350        self
351    }
352
353    /// Overrides the signaling server address used for WebRTC negotiation.
354    pub fn signaling_server(mut self, address: String) -> Self {
355        self.config.signaling_server_override = Some(address);
356        self
357    }
358
359    async fn get_addr_from_interface(
360        iface: (&str, Vec<&IpAddr>),
361        candidates: &Vec<String>,
362        local_ipv4s: &std::collections::HashSet<Ipv4Addr>,
363    ) -> Option<String> {
364        let addresses: Vec<Ipv4Addr> = iface
365            .1
366            .iter()
367            .filter_map(|ip| match ip {
368                IpAddr::V4(v4) => Some(*v4),
369                IpAddr::V6(_) => None,
370            })
371            .collect();
372
373        let mut resp: Option<Response> = None;
374        for ipv4 in addresses {
375            for candidate in candidates {
376                let discovery = match discover::interface_with_loopback(
377                    VIAM_MDNS_SERVICE_NAME,
378                    Duration::from_millis(250),
379                    ipv4,
380                ) {
381                    Ok(d) => d,
382                    Err(e) => {
383                        log::debug!("mDNS socket error on {ipv4}: {e}");
384                        continue;
385                    }
386                };
387                let stream = discovery.listen();
388                pin_mut!(stream);
389                while let Some(Ok(response)) = stream.next().await {
390                    if let Some(hostname) = response.hostname() {
391                        // Machine uris come in local ("my-cool-robot.abcdefg.local.viam.cloud")
392                        // and non-local ("my-cool-robot.abcdefg.viam.cloud") forms. Sometimes
393                        // (namely with micro-rdk), our mdns query can only see one (the local) version.
394                        // However, users are typically passing the non-local version. By splitting at
395                        // "viam" and taking the only the first value, we can still search for
396                        // candidates based on the actual "my-cool-robot" name without being opinionated
397                        // on whether the candidate is locally named or not.
398                        let local_agnostic_candidate = candidate.as_str().split("viam").next()?;
399                        log::debug!(
400                            "mDNS response on {ipv4}: hostname={hostname:?}, candidate={candidate:?}, local_agnostic={local_agnostic_candidate:?}, matches={}",
401                            hostname.contains(local_agnostic_candidate)
402                        );
403                        if hostname.contains(local_agnostic_candidate) {
404                            resp = Some(response);
405                            break;
406                        }
407                    } else {
408                        log::debug!(
409                            "mDNS response on {ipv4}: no hostname (no PTR record); answers={:?}",
410                            response.answers
411                        );
412                    }
413                    if resp.is_some() {
414                        break;
415                    }
416                }
417            }
418        }
419
420        let resp = resp?;
421        let mut has_grpc = false;
422        let mut has_webrtc = false;
423        for field in resp.txt_records() {
424            has_grpc = has_grpc || field.contains("grpc");
425            has_webrtc = has_webrtc || field.contains("webrtc");
426        }
427
428        // Log all records in the response for diagnostics.
429        log::debug!(
430            "mDNS matched response records: {:?}",
431            resp.records().collect::<Vec<_>>()
432        );
433
434        // Select the best IP from the mDNS response using a three-tier preference:
435        //
436        // 1. Non-loopback IP that is currently assigned to one of our own network
437        //    interfaces.  This handles the same-machine case (client and robot on
438        //    the same host) while avoiding stale IPs that were valid when
439        //    viam-server started but are now unreachable (e.g. a WiFi address after
440        //    WiFi was disconnected).
441        //
442        // 2. Any IP (including loopback) that is currently assigned to one of our
443        //    interfaces.  This catches 127.0.0.1 when offline on the same machine:
444        //    127.0.0.1 is excluded from tier 1 by the !is_loopback() guard, but it
445        //    is always in local_ipv4s, so it is correctly preferred here over a
446        //    stale non-loopback address that is no longer assigned.
447        //
448        // 3. Last resort: any advertised IPv4, for the common case of connecting to
449        //    a robot on a separate machine (its IP will never appear in local_ipv4s).
450        let ip_addr = resp
451            .records()
452            .filter_map(|r| match r.kind {
453                RecordKind::A(addr) if !addr.is_loopback() && local_ipv4s.contains(&addr) => {
454                    Some(addr)
455                }
456                _ => None,
457            })
458            .next()
459            .or_else(|| {
460                resp.records()
461                    .find_map(|r| match r.kind {
462                        RecordKind::A(addr) if local_ipv4s.contains(&addr) => Some(addr),
463                        _ => None,
464                    })
465                    .or_else(|| {
466                        resp.records().find_map(|r| match r.kind {
467                            RecordKind::A(addr) => Some(addr),
468                            _ => None,
469                        })
470                    })
471            });
472
473        if !(has_grpc || has_webrtc) || ip_addr.is_none() {
474            return None;
475        }
476        let mut local_addr = ip_addr?.to_string();
477        local_addr.push(':');
478        local_addr.push_str(&resp.port()?.to_string());
479        log::debug!("mDNS resolved address: {local_addr}");
480        Some(local_addr)
481    }
482
483    fn duplicate_uri(&self) -> Option<Parts> {
484        match &self.config.uri {
485            None => None,
486            Some(uri) => duplicate_uri(uri),
487        }
488    }
489
490    async fn get_mdns_uri(&self) -> Option<Parts> {
491        log::debug!("{}", log_prefixes::MDNS_QUERY_ATTEMPT);
492        if self.config.disable_mdns {
493            return None;
494        }
495
496        let mut uri = self.duplicate_uri()?;
497        let candidate = uri.authority.clone()?.to_string();
498
499        let candidates: Vec<String> = vec![candidate.replace('.', "-"), candidate];
500
501        let ifaces = list_afinet_netifas().ok()?;
502
503        // Collect all local IPv4 addresses for use in get_addr_from_interface, which prefers
504        // mDNS response IPs that are currently assigned to one of our own interfaces.
505        // viam-server may advertise a stale IP (e.g. a WiFi address from when it started,
506        // now unreachable because WiFi is off); filtering by current interface addresses
507        // avoids connecting to an unreachable address.
508        let local_ipv4s: std::collections::HashSet<Ipv4Addr> = ifaces
509            .iter()
510            .filter_map(|(_, ip)| match ip {
511                IpAddr::V4(v4) => Some(*v4),
512                _ => None,
513            })
514            .collect();
515
516        let ifaces: HashMap<&str, Vec<&IpAddr>> =
517            ifaces.iter().fold(HashMap::new(), |mut map, (k, v)| {
518                map.entry(k).or_default().push(v);
519                map
520            });
521
522        let mut iface_futures = FuturesUnordered::new();
523        for iface in ifaces {
524            iface_futures.push(Self::get_addr_from_interface(
525                iface,
526                &candidates,
527                &local_ipv4s,
528            ));
529        }
530
531        let mut local_addr: Option<String> = None;
532        while let Some(maybe_addr) = iface_futures.next().await {
533            if maybe_addr.is_some() {
534                local_addr = maybe_addr;
535                break;
536            }
537        }
538        let local_addr = match local_addr {
539            None => {
540                log::debug!("Unable to connect via mDNS");
541                return None;
542            }
543            Some(addr) => {
544                log::debug!("{}: {addr}", log_prefixes::MDNS_ADDRESS_FOUND);
545                addr
546            }
547        };
548
549        let auth = local_addr.parse::<Authority>().ok()?;
550        uri.authority = Some(auth);
551
552        Some(uri)
553    }
554
555    async fn create_channel(
556        allow_downgrade: bool,
557        domain: &str,
558        uri: Uri,
559        for_mdns: bool,
560    ) -> Result<Channel> {
561        if for_mdns {
562            let host = uri.host().unwrap_or("");
563            // viam-server serves TLS gRPC on the port it advertises via mDNS, including on
564            // loopback.  `domain` is the robot's canonical hostname (e.g.
565            // `my-robot.abcdefg.viam.cloud`) used for SNI and SAN verification.
566            log::debug!("mDNS create_channel: connecting to {host} with TLS");
567            let tls_config = ClientTlsConfig::new().domain_name(domain);
568            let mut parts = uri.into_parts();
569            parts.scheme = Some(Scheme::HTTPS);
570            let uri = Uri::from_parts(parts)?;
571            return Channel::builder(uri.clone())
572                .tls_config(tls_config)?
573                .connect()
574                .await
575                .with_context(|| format!("Connecting to {:?}", uri));
576        }
577
578        let chan = match Channel::builder(uri.clone())
579            .connect()
580            .await
581            .with_context(|| format!("Connecting to {:?}", uri.clone()))
582        {
583            Ok(c) => c,
584            Err(e) => {
585                if allow_downgrade {
586                    let mut uri_parts = uri.clone().into_parts();
587                    uri_parts.scheme = Some(Scheme::HTTP);
588                    let uri = Uri::from_parts(uri_parts)?;
589                    Channel::builder(uri).connect().await?
590                } else {
591                    return Err(anyhow::anyhow!(e));
592                }
593            }
594        };
595        Ok(chan)
596    }
597}
598
599impl DialBuilder<WithoutCredentials> {
600    fn clone(&self) -> Self {
601        DialBuilder {
602            state: WithoutCredentials(()),
603            config: DialOptions {
604                credentials: None,
605                webrtc_options: self.config.webrtc_options.clone(),
606                uri: self.duplicate_uri(),
607                disable_mdns: self.config.disable_mdns,
608                allow_downgrade: self.config.allow_downgrade,
609                insecure: self.config.insecure,
610                signaling_server_override: self.config.signaling_server_override.clone(),
611            },
612        }
613    }
614
615    /// attempts to establish a connection without credentials to the DialBuilder's given uri
616    async fn connect_inner(
617        self,
618        mdns_uri: Option<Parts>,
619        mut original_uri_parts: Parts,
620    ) -> Result<ViamChannel> {
621        let webrtc_options = self.config.webrtc_options;
622        let disable_webrtc = match &webrtc_options {
623            Some(options) => options.disable_webrtc,
624            None => false,
625        };
626        if self.config.insecure {
627            original_uri_parts.scheme = Some(Scheme::HTTP);
628        }
629        let original_uri = Uri::from_parts(original_uri_parts)?;
630        let uri2 = original_uri.clone();
631        let uri = infer_remote_uri_from_authority(
632            original_uri,
633            self.config.signaling_server_override.as_deref(),
634        );
635        let domain = uri2.authority().to_owned().unwrap().as_str();
636
637        let mdns_uri = mdns_uri.and_then(|p| Uri::from_parts(p).ok());
638        let attempting_mdns = mdns_uri.is_some();
639        if attempting_mdns {
640            log::debug!("Attempting to connect via mDNS");
641        } else {
642            log::debug!("Attempting to connect");
643        }
644
645        let channel = match mdns_uri {
646            Some(uri) => Self::create_channel(self.config.allow_downgrade, domain, uri, true).await,
647            // not actually an error necessarily, but we want to ensure that a channel is still
648            // created with the default uri
649            None => Err(anyhow::anyhow!("")),
650        };
651
652        let channel = match channel {
653            Ok(c) => {
654                log::debug!("Connected via mDNS");
655                c
656            }
657            Err(e) => {
658                if attempting_mdns {
659                    // mDNS found the robot but the connection failed — don't fall through to
660                    // the remote/signaling URI here.  The parallel `without_mdns` branch in
661                    // `connect()` is already handling that case.  Returning an error lets the
662                    // select! loop surface whichever branch succeeds, and avoids a spurious
663                    // connection attempt to app.viam.com when the device is offline.
664                    log::debug!("Unable to connect via mDNS. Error: {e:#}");
665                    return Err(e);
666                }
667                Self::create_channel(self.config.allow_downgrade, domain, uri.clone(), false)
668                    .await?
669            }
670        };
671
672        // TODO (RSDK-517) make maybe_connect_via_webrtc take a more generic type so we don't
673        // need to add these dummy layers.
674        let intercepted_channel = ServiceBuilder::new()
675            .layer(AddAuthorizationLayer::basic(
676                "fake username",
677                "fake password",
678            ))
679            .layer(SetRequestHeaderLayer::overriding(
680                HeaderName::from_static("rpc-host"),
681                HeaderValue::from_str(domain)?,
682            ))
683            .service(channel.clone());
684
685        // TODO (RSDK-14026): support WebRTC over mDNS for offline connections (e.g. video streaming).
686        if disable_webrtc || attempting_mdns {
687            log::debug!("{}", log_prefixes::DIALED_GRPC);
688            Ok(ViamChannel::Direct(channel.clone()))
689        } else {
690            match maybe_connect_via_webrtc(uri, intercepted_channel.clone(), webrtc_options).await {
691                Ok(webrtc_channel) => Ok(ViamChannel::WebRTC(webrtc_channel)),
692                Err(e) => {
693                    log::error!("error connecting via webrtc: {e}. Attempting to connect directly");
694                    log::debug!("{}", log_prefixes::DIALED_GRPC);
695                    Ok(ViamChannel::Direct(channel.clone()))
696                }
697            }
698        }
699    }
700
701    async fn connect_mdns(self, original_uri: Parts) -> Result<ViamChannel> {
702        let mdns_uri =
703            webrtc::action_with_timeout(self.get_mdns_uri(), Duration::from_millis(1500))
704                .await
705                .ok()
706                .flatten()
707                .ok_or(anyhow::anyhow!(
708                    "Unable to establish connection via mDNS; uri not found"
709                ))?;
710
711        self.connect_inner(Some(mdns_uri), original_uri).await
712    }
713
714    pub async fn connect(self) -> Result<ViamChannel> {
715        log::debug!("{}", log_prefixes::DIAL_ATTEMPT);
716        let original_uri = self.duplicate_uri().ok_or(anyhow::anyhow!(
717            "Attempting to connect but there was no uri"
718        ))?;
719        let original_uri2 = duplicate_uri(&original_uri).ok_or(anyhow::anyhow!(
720            "Attempting to connect but there was no uri"
721        ))?;
722
723        let skip_mdns = self.config.disable_mdns;
724
725        // We want to short circuit and return the first `Ok` result from our connection
726        // attempts, which `tokio::select!` does great. Buuuuut, we don't want to
727        // abandon the `Err` results, and we want to provide comprehensive logging for
728        // debugging purposes. Hence the loop and pinning. The pinning lets us reference
729        // the same future multiple times, while the loop lets us immediately return on the
730        // first `Ok` result while still seeing and logging any error results.
731        //
732        // When mDNS is skipped (disable_mdns), with_mdns_err is pre-set so
733        // the select guard disables that branch and only the direct connection is attempted.
734        tokio::pin! {
735            let with_mdns = self.clone().connect_mdns(original_uri);
736            let without_mdns = self.connect_inner(None, original_uri2);
737        }
738        let mut with_mdns_err: Option<anyhow::Error> =
739            skip_mdns.then(|| anyhow::anyhow!("mDNS skipped"));
740        let mut without_mdns_err: Option<anyhow::Error> = None;
741        while with_mdns_err.is_none() || without_mdns_err.is_none() {
742            tokio::select! {
743                with_mdns = &mut with_mdns, if with_mdns_err.is_none() => {
744                    match with_mdns {
745                        Ok(chan) => return Ok(chan),
746                        Err(e) => {
747                            log::debug!("Error connecting with mdns: {e}");
748                            with_mdns_err = Some(e);
749                        }
750                    }
751                }
752                without_mdns = &mut without_mdns, if without_mdns_err.is_none() => {
753                    match without_mdns {
754                        Ok(chan) => return Ok(chan),
755                        Err(e) => {
756                            log::debug!("Error connecting without mdns: {e}");
757                            without_mdns_err = Some(e);
758                        }
759                    }
760                }
761            }
762        }
763        Err(anyhow::anyhow!(
764            "Unable to connect with or without mdns.
765                    with_mdns err: {with_mdns_err:?}
766                    without_mdns err: {without_mdns_err:?}"
767        ))
768    }
769}
770
771async fn get_auth_token(
772    channel: &mut Channel,
773    creds: Credentials,
774    entity: String,
775) -> Result<String> {
776    let mut auth_service = AuthServiceClient::new(channel);
777    let req = AuthenticateRequest {
778        entity,
779        credentials: Some(creds),
780    };
781
782    let rsp = auth_service.authenticate(req).await?;
783    Ok(rsp.into_inner().access_token)
784}
785
786impl DialBuilder<WithCredentials> {
787    fn clone(&self) -> Self {
788        DialBuilder {
789            state: WithCredentials(()),
790            config: DialOptions {
791                credentials: self.config.credentials.clone(),
792                webrtc_options: self.config.webrtc_options.clone(),
793                uri: self.duplicate_uri(),
794                disable_mdns: self.config.disable_mdns,
795                allow_downgrade: self.config.allow_downgrade,
796                insecure: self.config.insecure,
797                signaling_server_override: self.config.signaling_server_override.clone(),
798            },
799        }
800    }
801
802    async fn connect_inner(
803        self,
804        mdns_uri: Option<Parts>,
805        mut original_uri_parts: Parts,
806    ) -> Result<ViamChannel> {
807        let is_insecure = self.config.insecure;
808
809        let webrtc_options = self.config.webrtc_options;
810        let disable_webrtc = match &webrtc_options {
811            Some(options) => options.disable_webrtc,
812            None => false,
813        };
814
815        if is_insecure {
816            original_uri_parts.scheme = Some(Scheme::HTTP);
817        }
818
819        let original_uri = Uri::from_parts(original_uri_parts)?;
820
821        let domain = original_uri.authority().unwrap().to_string();
822        let uri_for_auth = infer_remote_uri_from_authority(
823            original_uri.clone(),
824            self.config.signaling_server_override.as_deref(),
825        );
826
827        let mdns_uri = mdns_uri.and_then(|p| Uri::from_parts(p).ok());
828        let attempting_mdns = mdns_uri.is_some();
829
830        let allow_downgrade = self.config.allow_downgrade;
831        if attempting_mdns {
832            log::debug!("Attempting to connect via mDNS");
833        } else {
834            log::debug!("Attempting to connect");
835        }
836        let channel = match mdns_uri {
837            Some(uri) => Self::create_channel(allow_downgrade, &domain, uri, true).await,
838            // not actually an error necessarily, but we want to ensure that a channel is still
839            // created with the default uri
840            None => Err(anyhow::anyhow!("")),
841        };
842        let real_channel = match channel {
843            Ok(c) => {
844                log::debug!("Connected via mDNS");
845                c
846            }
847            Err(e) => {
848                if attempting_mdns {
849                    // mDNS found the robot but the connection failed — don't fall through to
850                    // the remote/signaling URI here.  The parallel `without_mdns` branch in
851                    // `connect()` is already handling that case.  Returning an error lets the
852                    // select! loop surface whichever branch succeeds, and avoids a spurious
853                    // auth attempt against app.viam.com when the device is offline.
854                    log::debug!("Unable to connect via mDNS. Error: {e:#}");
855                    return Err(e);
856                }
857                Self::create_channel(allow_downgrade, &domain, uri_for_auth, false).await?
858            }
859        };
860
861        log::debug!("{}", log_prefixes::ACQUIRING_AUTH_TOKEN);
862        let token = get_auth_token(
863            &mut real_channel.clone(),
864            self.config
865                .credentials
866                .as_ref()
867                .unwrap()
868                .credentials
869                .clone(),
870            self.config
871                .credentials
872                .unwrap()
873                .entity
874                .unwrap_or_else(|| domain.clone()),
875        )
876        .await?;
877        log::debug!("{}", log_prefixes::ACQUIRED_AUTH_TOKEN);
878
879        let channel = ServiceBuilder::new()
880            .layer(AddAuthorizationLayer::bearer(&token))
881            .layer(SetRequestHeaderLayer::overriding(
882                HeaderName::from_static("rpc-host"),
883                HeaderValue::from_str(domain.as_str())?,
884            ))
885            .service(real_channel);
886
887        // TODO (RSDK-14026): support WebRTC over mDNS for offline connections (e.g. video streaming).
888        if disable_webrtc || attempting_mdns {
889            log::debug!("Connected via gRPC");
890            Ok(ViamChannel::DirectPreAuthorized(channel))
891        } else {
892            match maybe_connect_via_webrtc(original_uri, channel.clone(), webrtc_options).await {
893                Ok(webrtc_channel) => Ok(ViamChannel::WebRTC(webrtc_channel)),
894                Err(e) => {
895                    log::error!(
896                    "Unable to establish webrtc connection due to error: [{e}]. Attempting direct connection."
897                );
898                    log::debug!("Connected via gRPC");
899                    Ok(ViamChannel::DirectPreAuthorized(channel))
900                }
901            }
902        }
903    }
904
905    async fn connect_mdns(self, original_uri: Parts) -> Result<ViamChannel> {
906        // NOTE(benjirewis): Use a duration of 1500ms for getting the mDNS URI. I've anecdotally
907        // seen times as great as 922ms to fetch a non-loopback mDNS URI. With an
908        // interface_with_loopback query interval of 250ms, 1500ms here should give us time for ~6
909        // queries.
910        let mdns_uri =
911            webrtc::action_with_timeout(self.get_mdns_uri(), Duration::from_millis(1500))
912                .await
913                .ok()
914                .flatten()
915                .ok_or(anyhow::anyhow!(
916                    "Unable to establish connection via mDNS; uri not found"
917                ))?;
918
919        self.connect_inner(Some(mdns_uri), original_uri).await
920    }
921
922    /// attempts to establish a connection with credentials to the DialBuilder's given uri
923    pub async fn connect(self) -> Result<ViamChannel> {
924        log::debug!("{}", log_prefixes::DIAL_ATTEMPT);
925        let original_uri = self.duplicate_uri().ok_or(anyhow::anyhow!(
926            "Attempting to connect but there was no uri"
927        ))?;
928        let original_uri2 = duplicate_uri(&original_uri).ok_or(anyhow::anyhow!(
929            "Attempting to connect but there was no uri"
930        ))?;
931
932        let skip_mdns = self.config.disable_mdns;
933
934        // We want to short circuit and return the first `Ok` result from our connection
935        // attempts, which `tokio::select!` does great. Buuuuut, we don't want to
936        // abandon the `Err` results, and we want to provide comprehensive logging for
937        // debugging purposes. Hence the loop and pinning. The pinning lets us reference
938        // the same future multiple times, while the loop lets us immediately return on the
939        // first `Ok` result while still seeing and logging any error results.
940        //
941        // When mDNS is skipped (disable_mdns), with_mdns_err is pre-set so
942        // the select guard disables that branch and only the direct connection is attempted.
943        tokio::pin! {
944            let with_mdns = self.clone().connect_mdns(original_uri);
945            let without_mdns = self.connect_inner(None, original_uri2);
946        }
947        let mut with_mdns_err: Option<anyhow::Error> =
948            skip_mdns.then(|| anyhow::anyhow!("mDNS skipped"));
949        let mut without_mdns_err: Option<anyhow::Error> = None;
950        while with_mdns_err.is_none() || without_mdns_err.is_none() {
951            tokio::select! {
952                with_mdns = &mut with_mdns, if with_mdns_err.is_none() => {
953                    match with_mdns {
954                        Ok(chan) => return Ok(chan),
955                        Err(e) => {
956                            log::debug!("Error connecting with mdns: {e}");
957                            with_mdns_err = Some(e);
958                        }
959                    }
960                }
961                without_mdns = &mut without_mdns, if without_mdns_err.is_none() => {
962                    match without_mdns {
963                        Ok(chan) => return Ok(chan),
964                        Err(e) => {
965                            log::debug!("Error connecting without mdns: {e}");
966                            without_mdns_err = Some(e);
967                        }
968                    }
969                }
970            }
971        }
972        Err(anyhow::anyhow!(
973            "Unable to connect with or without mdns.
974                    with_mdns err: {with_mdns_err:?}
975                    without_mdns err: {without_mdns_err:?}"
976        ))
977    }
978}
979
980async fn send_done_or_error_update(
981    update: CallUpdateRequest,
982    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
983) {
984    let mut signaling_client = SignalingServiceClient::new(channel.clone());
985
986    if let Err(e) = signaling_client
987        .call_update(update)
988        .await
989        .map_err(anyhow::Error::from)
990        .map(|_| ())
991    {
992        log::error!("Error sending done or error update: {e}")
993    }
994}
995
996async fn send_error_once(
997    sent_error: Arc<AtomicBool>,
998    uuid: &String,
999    err: &anyhow::Error,
1000    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
1001) {
1002    if sent_error.load(Ordering::Acquire) {
1003        return;
1004    }
1005
1006    let err = google::rpc::Status {
1007        code: google::rpc::Code::Unknown.into(),
1008        message: err.to_string(),
1009        details: Vec::new(),
1010    };
1011    sent_error.store(true, Ordering::Release);
1012    let update_request = CallUpdateRequest {
1013        uuid: uuid.to_string(),
1014        update: Some(Update::Error(err)),
1015    };
1016
1017    send_done_or_error_update(update_request, channel).await
1018}
1019
1020async fn send_done_once(
1021    sent_done: Arc<AtomicBool>,
1022    uuid: &String,
1023    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
1024) {
1025    if sent_done.load(Ordering::Acquire) {
1026        return;
1027    }
1028    sent_done.store(true, Ordering::Release);
1029    let update_request = CallUpdateRequest {
1030        uuid: uuid.to_string(),
1031        update: Some(Update::Done(true)),
1032    };
1033
1034    send_done_or_error_update(update_request, channel).await
1035}
1036
1037#[derive(Default)]
1038struct CallerUpdateStats {
1039    count: u128,
1040    total_duration: Duration,
1041    max_duration: Duration,
1042}
1043
1044impl fmt::Display for CallerUpdateStats {
1045    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1046        let average_duration = &self.total_duration.as_millis() / &self.count;
1047        writeln!(
1048            f,
1049            "Caller update statistics: num_updates: {}, average_duration: {}ms, max_duration: {}ms",
1050            &self.count,
1051            average_duration,
1052            &self.max_duration.as_millis()
1053        )?;
1054        Ok(())
1055    }
1056}
1057
1058async fn maybe_connect_via_webrtc(
1059    uri: Uri,
1060    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
1061    webrtc_options: Option<Options>,
1062) -> Result<Arc<WebRTCClientChannel>> {
1063    let webrtc_options = webrtc_options.unwrap_or_else(|| Options::infer_from_uri(uri.clone()));
1064    let mut signaling_client = SignalingServiceClient::new(channel.clone());
1065    let response = match signaling_client
1066        .optional_web_rtc_config(OptionalWebRtcConfigRequest::default())
1067        .await
1068    {
1069        Ok(resp) => resp,
1070        Err(e) => {
1071            if e.code() == tonic::Code::Unimplemented {
1072                tonic::Response::new(OptionalWebRtcConfigResponse::default())
1073            } else {
1074                return Err(anyhow::anyhow!(e));
1075            }
1076        }
1077    };
1078
1079    let optional_config = response.into_inner().config;
1080
1081    if webrtc_options.force_relay && webrtc_options.force_p2p {
1082        log::warn!(
1083            "force_relay and force_p2p are both set; forceP2P strips TURN servers that forceRelay requires so the connection will fail");
1084    }
1085
1086    let (base_config, optional_config) = webrtc::apply_ice_policy(
1087        webrtc_options.config,
1088        optional_config,
1089        webrtc_options.force_relay,
1090        webrtc_options.force_p2p,
1091    );
1092
1093    if webrtc_options.force_relay {
1094        log::debug!("force relay enabled; using relay-only ICE transport policy");
1095    }
1096
1097    if webrtc_options.force_p2p {
1098        log::debug!(
1099            "force P2P enabled; stripping TURN servers and ignoring signaling server ICE config"
1100        );
1101    }
1102
1103    let mut config = webrtc::extend_webrtc_config(base_config, optional_config);
1104
1105    if webrtc_options.force_p2p && webrtc_options.turn_uri.is_some() {
1106        log::warn!("force_p2p is set alongside turn_uri; the TURN filter will have no effect since TURN servers were already stripped");
1107    }
1108    let turn_uri = webrtc_options.turn_uri.as_deref().and_then(|s| {
1109        let parsed = webrtc::TurnUri::parse(s);
1110        if parsed.is_none() {
1111            log::warn!("Failed to parse turn_uri, ignoring: {s:?}");
1112        }
1113        parsed
1114    });
1115    config = webrtc::apply_turn_options(config, turn_uri.as_ref());
1116    if let Some(ref uri) = turn_uri {
1117        log::debug!("TURN filter options set: turn_uri={uri:?}");
1118    }
1119
1120    let (peer_connection, data_channel) =
1121        webrtc::new_peer_connection_for_client(config, webrtc_options.disable_trickle_ice).await?;
1122
1123    let sent_done_or_error = Arc::new(AtomicBool::new(false));
1124    let uuid_lock = Arc::new(RwLock::new("".to_string()));
1125    let uuid_for_ice_gathering_thread = uuid_lock.clone();
1126
1127    // Using an mpsc channel to report unrecoverable errors during Signaling, so we
1128    // don't have to wait until the timeout expires before giving up on this attempt.
1129    // The size of the channel is set to 1 since any error (or success) should terminate the function
1130    let (is_open_s, mut is_open_r) = mpsc::channel(1);
1131    let on_open_is_open = is_open_s.clone();
1132
1133    data_channel.on_open(Box::new(move || {
1134        let _ = on_open_is_open.try_send(None); // ignore sending errors, either an error (or success) was already sent or the operation will succeed
1135        Box::pin(async move {})
1136    }));
1137
1138    let exchange_done = Arc::new(AtomicBool::new(false));
1139    let (remote_description_set_s, remote_description_set_r) = watch::channel(None);
1140    let ice_done = Arc::new(tokio::sync::Notify::new());
1141    let ice_done2 = ice_done.clone();
1142    let caller_update_stats = Arc::new(Mutex::new(CallerUpdateStats::default()));
1143
1144    if !webrtc_options.disable_trickle_ice {
1145        let offer = peer_connection.create_offer(None).await?;
1146        let channel2 = channel.clone();
1147        let uuid_lock2 = uuid_lock.clone();
1148        let sent_done_or_error2 = sent_done_or_error.clone();
1149
1150        let exchange_done = exchange_done.clone();
1151
1152        let on_local_ice_candidate_failure = is_open_s.clone();
1153
1154        let caller_update_stats = caller_update_stats.clone();
1155        let caller_update_stats2 = caller_update_stats.clone();
1156        peer_connection.on_ice_connection_state_change(Box::new(
1157            move |state: RTCIceConnectionState| {
1158                let caller_update_stats = caller_update_stats.clone();
1159                Box::pin(async move {
1160                    if state == RTCIceConnectionState::Completed {
1161                        let caller_update_stats_inner = caller_update_stats.lock().unwrap();
1162                        log::debug!("{}", caller_update_stats_inner);
1163                    }
1164                })
1165            },
1166        ));
1167        peer_connection.on_ice_candidate(Box::new(
1168            move |ice_candidate: Option<RTCIceCandidate>| {
1169                if exchange_done.load(Ordering::Acquire) {
1170                    return Box::pin(async move {});
1171                }
1172                let channel = channel2.clone();
1173                let sent_done_or_error = sent_done_or_error2.clone();
1174                let ice_done = ice_done.clone();
1175                let uuid_lock = uuid_lock2.clone();
1176                let on_local_ice_candidate_failure = on_local_ice_candidate_failure.clone();
1177                let mut remote_description_set_r = remote_description_set_r.clone();
1178                let caller_update_stats = caller_update_stats2.clone();
1179                Box::pin(async move {
1180                    // If the value in the watch channel has not been set yet, we wait until it does.
1181                    // Afterwards Some(()) should be visible to all watcher and any watcher waiting  will
1182                    // return
1183                    if remote_description_set_r.borrow().is_none() {
1184                        match webrtc_action_with_timeout(remote_description_set_r.changed()).await {
1185                            Ok(Err(e)) => {
1186                                let _ = on_local_ice_candidate_failure.try_send(Some(Box::new(
1187                                    anyhow::anyhow!(
1188                                        "remote description watch channel is closed with error {e}"
1189                                    ),
1190                                )));
1191                            }
1192                            Err(_) => {
1193                                log::info!(
1194                                    "timed out on_ice_candidate; remote description was never set"
1195                                );
1196                                let _ = on_local_ice_candidate_failure.try_send(Some(Box::new(
1197                                    anyhow::anyhow!("timed out waiting for remote description"),
1198                                )));
1199                            }
1200                            _ => (),
1201                        }
1202                    }
1203
1204                    let uuid = uuid_lock.read().unwrap().to_string();
1205                    // Note(ethan): for reasons that aren't entirely clear to me, parallel dialing
1206                    // occasionally causes us to not receive a signaling client response when
1207                    // trying to establish a connection. This results in noisy error messages that
1208                    // fortunately are harmless (this problem seems to only ever affect one branch
1209                    // of the parallel dial, so we still end up with a successful connection).
1210                    // By checking if the `uuid` is empty, we can tell if we're in such a case and
1211                    // exit out before it results in logging noisy error messages.
1212                    //
1213                    // It would be lovely to understand this problem better, but given that it's
1214                    // not actually causing performance failures it's probably not worth the effort
1215                    // at this time.
1216                    if uuid.is_empty() {
1217                        log::debug!(
1218                            "UUID never updated. This is likely because we never received a response \
1219                            from the signaling client. This happens occasionally with parallel dialing \
1220                            and isn't concerning provided connection still occurs."
1221                        );
1222                        return;
1223                    }
1224                    let mut signaling_client = SignalingServiceClient::new(channel.clone());
1225                    match ice_candidate {
1226                        Some(ice_candidate) => {
1227                            log::debug!("Gathered local candidate of {ice_candidate}");
1228                            if sent_done_or_error.load(Ordering::Acquire) {
1229                                return;
1230                            }
1231                            let proto_candidate = ice_candidate_to_proto(ice_candidate).await;
1232                            match proto_candidate {
1233                                Ok(proto_candidate) => {
1234                                    let update_request = CallUpdateRequest {
1235                                        uuid: uuid.clone(),
1236                                        update: Some(Update::Candidate(proto_candidate)),
1237                                    };
1238                                    let call_update_start = Instant::now();
1239                                    if let Err(e) = webrtc_action_with_timeout(
1240                                        signaling_client.call_update(update_request),
1241                                    )
1242                                    .await
1243                                    .and_then(|resp| resp.map_err(anyhow::Error::from))
1244                                    {
1245                                        log::error!("Error sending ice candidate: {e}");
1246                                        let _ = on_local_ice_candidate_failure.try_send(Some(
1247                                            Box::new(anyhow::anyhow!(
1248                                                "Error sending ice candidate: {e}"
1249                                            )),
1250                                        ));
1251                                    }
1252                                    let mut caller_update_stats_inner =
1253                                        caller_update_stats.lock().unwrap();
1254                                    caller_update_stats_inner.count += 1;
1255                                    let call_update_duration = call_update_start.elapsed();
1256                                    if call_update_duration > caller_update_stats_inner.max_duration
1257                                    {
1258                                        caller_update_stats_inner.max_duration =
1259                                            call_update_duration;
1260                                    }
1261                                    caller_update_stats_inner.total_duration +=
1262                                        call_update_duration;
1263                                }
1264                                Err(e) => log::error!("Error parsing ice candidate: {e}"),
1265                            }
1266                        }
1267                        None => {
1268                            // will only be executed once when gathering is finished
1269                            ice_done.notify_one();
1270                            send_done_once(sent_done_or_error, &uuid, channel.clone()).await;
1271                        }
1272                    }
1273                })
1274            },
1275        ));
1276
1277        peer_connection.set_local_description(offer).await?;
1278    }
1279
1280    let local_description = peer_connection.local_description().await.unwrap();
1281
1282    // Local SD will be multi-line, so use two log messages to indicate start, SD and end.
1283    log::debug!(
1284        "{}\n{}",
1285        log_prefixes::START_LOCAL_SESSION_DESCRIPTION,
1286        local_description.sdp
1287    );
1288    log::debug!("{}", log_prefixes::END_LOCAL_SESSION_DESCRIPTION);
1289
1290    let sdp = encode_sdp(local_description)?;
1291    let call_request = CallRequest {
1292        sdp,
1293        disable_trickle: webrtc_options.disable_trickle_ice,
1294    };
1295
1296    let client_channel = WebRTCClientChannel::new(peer_connection, data_channel).await;
1297    let client_channel_for_ice_gathering_thread = Arc::downgrade(&client_channel);
1298    let mut signaling_client = SignalingServiceClient::new(channel.clone());
1299    let mut call_client = signaling_client.call(call_request).await?.into_inner();
1300
1301    let channel2 = channel.clone();
1302    let sent_done_or_error2 = sent_done_or_error.clone();
1303    tokio::spawn(async move {
1304        let uuid = uuid_for_ice_gathering_thread;
1305        let client_channel = client_channel_for_ice_gathering_thread;
1306        let init_received = AtomicBool::new(false);
1307        let sent_done = sent_done_or_error2;
1308
1309        loop {
1310            let response = match webrtc_action_with_timeout(call_client.message())
1311                .await
1312                .and_then(|resp| resp.map_err(anyhow::Error::from))
1313            {
1314                Ok(cr) => match cr {
1315                    Some(cr) => cr,
1316                    None => {
1317                        // want to delay sending done until we either are actually done, or
1318                        // we hit a timeout
1319                        let _ = webrtc_action_with_timeout(ice_done2.notified()).await;
1320                        let uuid = uuid.read().unwrap().to_string();
1321                        send_done_once(sent_done.clone(), &uuid, channel2.clone()).await;
1322                        break;
1323                    }
1324                },
1325                Err(e) => {
1326                    log::error!("Error processing call response: {e}");
1327                    let _ = is_open_s.try_send(Some(Box::new(e)));
1328                    break;
1329                }
1330            };
1331
1332            match response.stage {
1333                Some(Stage::Init(init)) => {
1334                    if init_received.load(Ordering::Acquire) {
1335                        let uuid = uuid.read().unwrap().to_string();
1336                        let e = anyhow::anyhow!("Init received more than once");
1337                        send_error_once(sent_done.clone(), &uuid, &e, channel2.clone()).await;
1338                        let _ = is_open_s.try_send(Some(Box::new(e)));
1339                        break;
1340                    }
1341                    init_received.store(true, Ordering::Release);
1342                    {
1343                        let mut uuid_s = uuid.write().unwrap();
1344                        uuid_s.clone_from(&response.uuid);
1345                    }
1346
1347                    let answer = match decode_sdp(init.sdp) {
1348                        Ok(a) => a,
1349                        Err(e) => {
1350                            send_error_once(
1351                                sent_done.clone(),
1352                                &response.uuid,
1353                                &e,
1354                                channel2.clone(),
1355                            )
1356                            .await;
1357                            let _ = is_open_s.try_send(Some(Box::new(e)));
1358                            break;
1359                        }
1360                    };
1361                    {
1362                        let cc = match client_channel.upgrade() {
1363                            Some(cc) => cc,
1364                            None => {
1365                                break;
1366                            }
1367                        };
1368                        if let Err(e) = cc
1369                            .base_channel
1370                            .peer_connection
1371                            .set_remote_description(answer)
1372                            .await
1373                        {
1374                            let e = anyhow::Error::from(e);
1375                            send_error_once(
1376                                sent_done.clone(),
1377                                &response.uuid,
1378                                &e,
1379                                channel2.clone(),
1380                            )
1381                            .await;
1382                            let _ = is_open_s.try_send(Some(Box::new(e)));
1383                            break;
1384                        }
1385                    }
1386                    let _ = remote_description_set_s.send_replace(Some(()));
1387                    if webrtc_options.disable_trickle_ice {
1388                        send_done_once(sent_done.clone(), &response.uuid, channel2.clone()).await;
1389                        break;
1390                    }
1391                }
1392
1393                Some(Stage::Update(update)) => {
1394                    let uuid_s = uuid.read().unwrap().to_string();
1395                    if !init_received.load(Ordering::Acquire) {
1396                        let e = anyhow::anyhow!("Got update before init stage");
1397                        send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone()).await;
1398                        let _ = is_open_s.try_send(Some(Box::new(e)));
1399                        break;
1400                    }
1401
1402                    if response.uuid != *uuid.read().unwrap() {
1403                        let e = anyhow::anyhow!(
1404                            "uuid mismatch: have {}, want {}",
1405                            response.uuid,
1406                            uuid_s,
1407                        );
1408                        send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone()).await;
1409                        let _ = is_open_s.try_send(Some(Box::new(e)));
1410                        break;
1411                    }
1412                    match ice_candidate_from_proto(update.candidate) {
1413                        Ok(candidate) => {
1414                            let client_channel = match client_channel.upgrade() {
1415                                Some(cc) => cc,
1416                                None => {
1417                                    break;
1418                                }
1419                            };
1420                            log::debug!("Received remote ICE candidate of {candidate:#?}");
1421                            if let Err(e) = client_channel
1422                                .base_channel
1423                                .peer_connection
1424                                .add_ice_candidate(candidate)
1425                                .await
1426                            {
1427                                let e = anyhow::Error::from(e);
1428                                send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone())
1429                                    .await;
1430                                let _ = is_open_s.try_send(Some(Box::new(e)));
1431                                break;
1432                            }
1433                        }
1434                        Err(e) => log::error!("Error parsing ice candidate: {e}"),
1435                    }
1436                }
1437                None => continue,
1438            }
1439        }
1440    });
1441
1442    // TODO (GOUT-11): create separate authorization if external_auth_addr and/or creds.Type is `Some`
1443
1444    // Delay returning the client channel until data channel is open, so we don't lose messages
1445    let is_open = webrtc_action_with_timeout(is_open_r.recv()).await;
1446    match is_open {
1447        Ok(is_open) => {
1448            if let Some(Some(e)) = is_open {
1449                return Err(anyhow::anyhow!("Couldn't connect to peer with error {e}"));
1450            }
1451        }
1452        Err(_) => {
1453            return Err(anyhow::anyhow!("Timed out opening data channel."));
1454        }
1455    }
1456
1457    exchange_done.store(true, Ordering::Release);
1458    let uuid = uuid_lock.read().unwrap().to_string();
1459    send_done_once(sent_done_or_error, &uuid, channel.clone()).await;
1460    Ok(client_channel)
1461}
1462
1463async fn ice_candidate_to_proto(ice_candidate: RTCIceCandidate) -> Result<IceCandidate> {
1464    let ice_candidate = ice_candidate.to_json()?;
1465    Ok(IceCandidate {
1466        candidate: ice_candidate.candidate,
1467        sdp_mid: ice_candidate.sdp_mid,
1468        sdpm_line_index: ice_candidate.sdp_mline_index.map(u32::from),
1469        username_fragment: ice_candidate.username_fragment,
1470    })
1471}
1472
1473fn ice_candidate_from_proto(proto: Option<IceCandidate>) -> Result<RTCIceCandidateInit> {
1474    match proto {
1475        Some(proto) => {
1476            let proto_sdpm: usize = proto.sdpm_line_index().try_into()?;
1477            let sdp_mline_index: Option<u16> = proto_sdpm.try_into().ok();
1478
1479            Ok(RTCIceCandidateInit {
1480                candidate: proto.candidate.clone(),
1481                sdp_mid: Some(proto.sdp_mid().to_string()),
1482                sdp_mline_index,
1483                username_fragment: Some(proto.username_fragment().to_string()),
1484            })
1485        }
1486        None => Err(anyhow::anyhow!("No ice candidate provided")),
1487    }
1488}
1489
1490fn decode_sdp(sdp: String) -> Result<RTCSessionDescription> {
1491    let sdp = String::from_utf8(base64::decode(sdp)?)?;
1492    Ok(serde_json::from_str::<RTCSessionDescription>(&sdp)?)
1493}
1494
1495fn encode_sdp(sdp: RTCSessionDescription) -> Result<String> {
1496    let sdp = serde_json::to_vec(&sdp)?;
1497    Ok(base64::encode(sdp))
1498}
1499
1500fn infer_remote_uri_from_authority(uri: Uri, override_addr: Option<&str>) -> Uri {
1501    if let Some(addr) = override_addr {
1502        return Uri::from_parts(uri_parts_with_defaults(addr)).unwrap_or_else(|e| {
1503            log::warn!("Failed to parse signaling server override {addr:?}: {e}; falling back to original URI");
1504            uri
1505        });
1506    }
1507    let authority = uri.authority().map(Authority::as_str).unwrap_or_default();
1508    let is_local_connection = authority.contains(".local.viam.cloud")
1509        || authority.contains("localhost")
1510        || authority.contains("0.0.0.0")
1511        || authority.contains("127.0.0.1");
1512
1513    if !is_local_connection {
1514        if let Some((new_uri, _)) = Options::infer_signaling_server_address(&uri) {
1515            return Uri::from_parts(uri_parts_with_defaults(&new_uri)).unwrap_or(uri);
1516        }
1517    }
1518    uri
1519}
1520
1521fn duplicate_uri(parts: &Parts) -> Option<Parts> {
1522    let uri = Uri::builder()
1523        .authority(parts.authority.clone()?)
1524        .path_and_query(parts.path_and_query.clone()?)
1525        .scheme(parts.scheme.clone()?);
1526    Some(uri.build().ok()?.into_parts())
1527}
1528
1529fn uri_parts_with_defaults(uri: &str) -> Parts {
1530    let mut uri_parts = uri.parse::<Uri>().unwrap().into_parts();
1531    uri_parts.scheme = Some(Scheme::HTTPS);
1532    uri_parts.path_and_query = Some(PathAndQuery::from_static(""));
1533    uri_parts
1534}
1535
1536fn metadata_from_parts(parts: &http::request::Parts) -> Metadata {
1537    let mut md = HashMap::new();
1538    for (k, v) in parts.headers.iter() {
1539        let k = k.to_string();
1540        let v = Strings {
1541            values: vec![HeaderValue::to_str(v).unwrap().to_string()],
1542        };
1543        md.insert(k, v);
1544    }
1545    Metadata { md }
1546}