viam_rust_utils/rpc/
dial.rs

1use super::{
2    client_channel::*,
3    log_prefixes,
4    webrtc::{webrtc_action_with_timeout, Options},
5};
6use crate::gen::google;
7use crate::gen::proto::rpc::v1::{
8    auth_service_client::AuthServiceClient, AuthenticateRequest, Credentials,
9};
10use crate::gen::proto::rpc::webrtc::v1::{
11    call_response::Stage, call_update_request::Update,
12    signaling_service_client::SignalingServiceClient, CallUpdateRequest,
13    OptionalWebRtcConfigRequest, OptionalWebRtcConfigResponse,
14};
15use crate::gen::proto::rpc::webrtc::v1::{
16    CallRequest, IceCandidate, Metadata, RequestHeaders, Strings,
17};
18use crate::rpc::webrtc;
19use ::http::header::HeaderName;
20use ::http::{
21    uri::{Authority, Parts, PathAndQuery, Scheme},
22    HeaderValue, Version,
23};
24use ::viam_mdns::{discover, Response};
25use ::webrtc::ice_transport::{
26    ice_candidate::{RTCIceCandidate, RTCIceCandidateInit},
27    ice_connection_state::RTCIceConnectionState,
28};
29use ::webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
30use anyhow::{Context, Result};
31use core::fmt;
32use futures::stream::FuturesUnordered;
33use futures_util::{pin_mut, stream::StreamExt};
34use local_ip_address::list_afinet_netifas;
35use std::{
36    collections::HashMap,
37    net::{IpAddr, Ipv4Addr},
38    sync::{
39        atomic::{AtomicBool, Ordering},
40        Arc, Mutex, RwLock,
41    },
42    task::{Context as TaskContext, Poll},
43    time::{Duration, Instant},
44};
45use tokio::sync::{mpsc, watch};
46use tonic::codegen::BoxFuture;
47use tonic::transport::{Body, Channel, Uri};
48use tonic::{body::BoxBody, transport::ClientTlsConfig};
49use tower::{Service, ServiceBuilder};
50use tower_http::auth::AddAuthorization;
51use tower_http::auth::AddAuthorizationLayer;
52use tower_http::set_header::{SetRequestHeader, SetRequestHeaderLayer};
53
54// gRPC status codes
55const STATUS_CODE_OK: i32 = 0;
56const STATUS_CODE_UNKNOWN: i32 = 2;
57const STATUS_CODE_RESOURCE_EXHAUSTED: i32 = 8;
58
59pub const VIAM_MDNS_SERVICE_NAME: &'static str = "_rpc._tcp.local";
60
61type SecretType = String;
62
63#[derive(Clone)]
64/// A communication channel to a given uri. The channel is either a direct tonic channel,
65/// or a webRTC channel.
66pub enum ViamChannel {
67    Direct(Channel),
68    DirectPreAuthorized(AddAuthorization<SetRequestHeader<Channel, HeaderValue>>),
69    WebRTC(Arc<WebRTCClientChannel>),
70}
71
72#[derive(Debug, Clone)]
73pub struct RPCCredentials {
74    entity: Option<String>,
75    credentials: Credentials,
76}
77
78impl RPCCredentials {
79    pub fn new(entity: Option<String>, r#type: SecretType, payload: String) -> Self {
80        Self {
81            credentials: Credentials { r#type, payload },
82            entity,
83        }
84    }
85}
86
87impl ViamChannel {
88    async fn create_resp(
89        channel: &mut Arc<WebRTCClientChannel>,
90        stream: crate::gen::proto::rpc::webrtc::v1::Stream,
91        request: http::Request<BoxBody>,
92        response: http::response::Builder,
93    ) -> http::Response<Body> {
94        let (parts, body) = request.into_parts();
95        let mut status_code = STATUS_CODE_OK;
96        let stream_id = stream.id;
97        let metadata = Some(metadata_from_parts(&parts));
98        let headers = RequestHeaders {
99            method: parts
100                .uri
101                .path_and_query()
102                .map(PathAndQuery::to_string)
103                .unwrap_or_default(),
104            metadata,
105            timeout: None,
106        };
107
108        if let Err(e) = channel.write_headers(&stream, headers).await {
109            log::error!("error writing headers: {e}");
110            channel.close_stream_with_recv_error(stream_id, e);
111            status_code = STATUS_CODE_UNKNOWN;
112        }
113
114        let data = hyper::body::to_bytes(body).await.unwrap().to_vec();
115        if let Err(e) = channel.write_message(Some(stream), data).await {
116            log::error!("error sending message: {e}");
117            channel.close_stream_with_recv_error(stream_id, e);
118            status_code = STATUS_CODE_UNKNOWN;
119        };
120
121        let body = match channel.resp_body_from_stream(stream_id) {
122            Ok(body) => body,
123            Err(e) => {
124                log::error!("error receiving response from stream: {e}");
125                channel.close_stream_with_recv_error(stream_id, e);
126                status_code = STATUS_CODE_UNKNOWN;
127                Body::empty()
128            }
129        };
130
131        let response = if status_code != STATUS_CODE_OK {
132            response.header("grpc-status", &status_code.to_string())
133        } else {
134            response
135        };
136
137        response.body(body).unwrap()
138    }
139}
140
141impl Service<http::Request<BoxBody>> for ViamChannel {
142    type Response = http::Response<Body>;
143    type Error = tonic::transport::Error;
144    type Future = BoxFuture<Self::Response, Self::Error>;
145
146    fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
147        match self {
148            Self::Direct(channel) => channel.poll_ready(cx),
149            Self::DirectPreAuthorized(channel) => channel.poll_ready(cx),
150            Self::WebRTC(_channel) => Poll::Ready(Ok(())),
151        }
152    }
153
154    fn call(&mut self, request: http::Request<BoxBody>) -> Self::Future {
155        match self {
156            Self::Direct(channel) => Box::pin(channel.call(request)),
157            Self::DirectPreAuthorized(channel) => Box::pin(channel.call(request)),
158            Self::WebRTC(channel) => {
159                let mut channel = channel.clone();
160                let fut = async move {
161                    let response = http::response::Response::builder()
162                        // standardized gRPC headers.
163                        .header("content-type", "application/grpc")
164                        .version(Version::HTTP_2);
165
166                    match channel.new_stream() {
167                        Err(e) => {
168                            log::error!("{e}");
169                            let response = response
170                                .header("grpc-status", &STATUS_CODE_RESOURCE_EXHAUSTED.to_string())
171                                .body(Body::default())
172                                .unwrap();
173
174                            Ok(response)
175                        }
176                        Ok(stream) => {
177                            Ok(Self::create_resp(&mut channel, stream, request, response).await)
178                        }
179                    }
180                };
181                Box::pin(fut)
182            }
183        }
184    }
185}
186
187/// Options for modifying the connection parameters
188#[derive(Debug)]
189pub struct DialOptions {
190    credentials: Option<RPCCredentials>,
191    webrtc_options: Option<Options>,
192    uri: Option<Parts>,
193    disable_mdns: bool,
194    allow_downgrade: bool,
195    insecure: bool,
196}
197#[derive(Clone)]
198pub struct WantsCredentials(());
199#[derive(Clone)]
200pub struct WantsUri(());
201#[derive(Clone)]
202pub struct WithCredentials(());
203#[derive(Clone)]
204pub struct WithoutCredentials(());
205
206pub trait AuthMethod {}
207impl AuthMethod for WithCredentials {}
208impl AuthMethod for WithoutCredentials {}
209/// A DialBuilder allows us to set options before establishing a connection to a server
210#[allow(dead_code)]
211pub struct DialBuilder<T> {
212    state: T,
213    config: DialOptions,
214}
215
216impl<T> fmt::Debug for DialBuilder<T> {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        f.debug_struct("Dial")
219            .field("State", &format_args!("{}", &std::any::type_name::<T>()))
220            .field("Opt", &format_args!("{:?}", self.config))
221            .finish()
222    }
223}
224
225impl DialOptions {
226    /// Creates a new DialBuilder
227    pub fn builder() -> DialBuilder<WantsUri> {
228        DialBuilder {
229            state: WantsUri(()),
230            config: DialOptions {
231                credentials: None,
232                uri: None,
233                allow_downgrade: false,
234                disable_mdns: false,
235                insecure: false,
236                webrtc_options: None,
237            },
238        }
239    }
240}
241
242impl DialBuilder<WantsUri> {
243    /// Sets the uri to connect to
244    pub fn uri(self, uri: &str) -> DialBuilder<WantsCredentials> {
245        let uri_parts = uri_parts_with_defaults(uri);
246        DialBuilder {
247            state: WantsCredentials(()),
248            config: DialOptions {
249                credentials: None,
250                uri: Some(uri_parts),
251                allow_downgrade: false,
252                disable_mdns: false,
253                insecure: false,
254                webrtc_options: None,
255            },
256        }
257    }
258}
259impl DialBuilder<WantsCredentials> {
260    /// Tells connecting logic to not expect/require credentials
261    pub fn without_credentials(self) -> DialBuilder<WithoutCredentials> {
262        DialBuilder {
263            state: WithoutCredentials(()),
264            config: DialOptions {
265                credentials: None,
266                uri: self.config.uri,
267                allow_downgrade: false,
268                disable_mdns: false,
269                insecure: false,
270                webrtc_options: None,
271            },
272        }
273    }
274    /// Sets credentials to use when connecting
275    pub fn with_credentials(self, creds: RPCCredentials) -> DialBuilder<WithCredentials> {
276        DialBuilder {
277            state: WithCredentials(()),
278            config: DialOptions {
279                credentials: Some(creds),
280                uri: self.config.uri,
281                allow_downgrade: false,
282                disable_mdns: false,
283                insecure: false,
284                webrtc_options: None,
285            },
286        }
287    }
288}
289
290impl<T: AuthMethod> DialBuilder<T> {
291    /// Attempts to connect insecurely with scheme of HTTP as a default
292    pub fn insecure(mut self) -> Self {
293        self.config.insecure = true;
294        self
295    }
296    /// Allows for downgrading and attempting to connect via HTTP if HTTPS fails
297    pub fn allow_downgrade(mut self) -> Self {
298        self.config.allow_downgrade = true;
299        self
300    }
301    /// Disables connection via mDNS
302    pub fn disable_mdns(mut self) -> Self {
303        self.config.disable_mdns = true;
304        self
305    }
306
307    /// Overrides any default connection behavior, forcing direct connection. Note that
308    /// the connection itself will fail if it is between a client and server on separate
309    /// networks and not over webRTC
310    pub fn disable_webrtc(mut self) -> Self {
311        let webrtc_options = Options::default().disable_webrtc();
312        self.config.webrtc_options = Some(webrtc_options);
313        self
314    }
315
316    async fn get_addr_from_interface(
317        iface: (&str, Vec<&IpAddr>),
318        candidates: &Vec<String>,
319    ) -> Option<String> {
320        let addresses: Vec<Ipv4Addr> = iface
321            .1
322            .iter()
323            .filter_map(|ip| match ip {
324                IpAddr::V4(v4) => Some(*v4),
325                IpAddr::V6(_) => None,
326            })
327            .collect();
328
329        let mut resp: Option<Response> = None;
330        for ipv4 in addresses {
331            for candidate in candidates {
332                let discovery = discover::interface_with_loopback(
333                    VIAM_MDNS_SERVICE_NAME,
334                    Duration::from_millis(250),
335                    ipv4,
336                )
337                .ok()?;
338                let stream = discovery.listen();
339                pin_mut!(stream);
340                while let Some(Ok(response)) = stream.next().await {
341                    if let Some(hostname) = response.hostname() {
342                        // Machine uris come in local ("my-cool-robot.abcdefg.local.viam.cloud")
343                        // and non-local ("my-cool-robot.abcdefg.viam.cloud") forms. Sometimes
344                        // (namely with micro-rdk), our mdns query can only see one (the local) version.
345                        // However, users are typically passing the non-local version. By splitting at
346                        // "viam" and taking the only the first value, we can still search for
347                        // candidates based on the actual "my-cool-robot" name without being opinionated
348                        // on whether the candidate is locally named or not.
349                        let local_agnostic_candidate = candidate.as_str().split("viam").next()?;
350                        if hostname.contains(local_agnostic_candidate) {
351                            resp = Some(response);
352                            break;
353                        }
354                    }
355                    if resp.is_some() {
356                        break;
357                    }
358                }
359            }
360        }
361
362        let resp = resp?;
363        let mut has_grpc = false;
364        let mut has_webrtc = false;
365        for field in resp.txt_records() {
366            has_grpc = has_grpc || field.contains("grpc");
367            has_webrtc = has_webrtc || field.contains("webrtc");
368        }
369
370        let ip_addr = match resp.ip_addr() {
371            Some(std::net::IpAddr::V4(ip_v4)) => Some(ip_v4),
372            Some(std::net::IpAddr::V6(_)) | None => None,
373        };
374
375        if !(has_grpc || has_webrtc) || ip_addr.is_none() {
376            return None;
377        }
378        let mut local_addr = ip_addr?.to_string();
379        local_addr.push(':');
380        local_addr.push_str(&resp.port()?.to_string());
381        Some(local_addr)
382    }
383
384    fn duplicate_uri(&self) -> Option<Parts> {
385        match &self.config.uri {
386            None => None,
387            Some(uri) => duplicate_uri(uri),
388        }
389    }
390
391    async fn get_mdns_uri(&self) -> Option<Parts> {
392        log::debug!("{}", log_prefixes::MDNS_QUERY_ATTEMPT);
393        if self.config.disable_mdns {
394            return None;
395        }
396
397        let mut uri = self.duplicate_uri()?;
398        let candidate = uri.authority.clone()?.to_string();
399
400        let candidates: Vec<String> = vec![candidate.replace('.', "-"), candidate];
401
402        let ifaces = list_afinet_netifas().ok()?;
403
404        let ifaces: HashMap<&str, Vec<&IpAddr>> =
405            ifaces.iter().fold(HashMap::new(), |mut map, (k, v)| {
406                map.entry(k).or_default().push(v);
407                map
408            });
409
410        let mut iface_futures = FuturesUnordered::new();
411        for iface in ifaces {
412            iface_futures.push(Self::get_addr_from_interface(iface, &candidates));
413        }
414
415        let mut local_addr: Option<String> = None;
416        while let Some(maybe_addr) = iface_futures.next().await {
417            if maybe_addr.is_some() {
418                local_addr = maybe_addr;
419                break;
420            }
421        }
422        let local_addr = match local_addr {
423            None => {
424                log::debug!("Unable to connect via mDNS");
425                return None;
426            }
427            Some(addr) => {
428                log::debug!("{}: {addr}", log_prefixes::MDNS_ADDRESS_FOUND);
429                addr
430            }
431        };
432
433        let auth = local_addr.parse::<Authority>().ok()?;
434        uri.authority = Some(auth);
435        uri.scheme = Some(Scheme::HTTP);
436
437        Some(uri)
438    }
439
440    async fn create_channel(
441        allow_downgrade: bool,
442        domain: &str,
443        uri: Uri,
444        for_mdns: bool,
445    ) -> Result<Channel> {
446        let mut chan = Channel::builder(uri.clone());
447        if for_mdns {
448            let tls_config = ClientTlsConfig::new().domain_name(domain);
449            chan = chan.tls_config(tls_config)?;
450        }
451        let chan = match chan
452            .connect()
453            .await
454            .with_context(|| format!("Connecting to {:?}", uri.clone()))
455        {
456            Ok(c) => c,
457            Err(e) => {
458                if allow_downgrade {
459                    let mut uri_parts = uri.clone().into_parts();
460                    uri_parts.scheme = Some(Scheme::HTTP);
461                    let uri = Uri::from_parts(uri_parts)?;
462                    Channel::builder(uri).connect().await?
463                } else {
464                    return Err(anyhow::anyhow!(e));
465                }
466            }
467        };
468        Ok(chan)
469    }
470}
471
472impl DialBuilder<WithoutCredentials> {
473    fn clone(&self) -> Self {
474        DialBuilder {
475            state: WithoutCredentials(()),
476            config: DialOptions {
477                credentials: None,
478                webrtc_options: self.config.webrtc_options.clone(),
479                uri: self.duplicate_uri(),
480                disable_mdns: self.config.disable_mdns,
481                allow_downgrade: self.config.allow_downgrade,
482                insecure: self.config.insecure,
483            },
484        }
485    }
486
487    /// attempts to establish a connection without credentials to the DialBuilder's given uri
488    async fn connect_inner(
489        self,
490        mdns_uri: Option<Parts>,
491        mut original_uri_parts: Parts,
492    ) -> Result<ViamChannel> {
493        let webrtc_options = self.config.webrtc_options;
494        let disable_webrtc = match &webrtc_options {
495            Some(options) => options.disable_webrtc,
496            None => false,
497        };
498        if self.config.insecure {
499            original_uri_parts.scheme = Some(Scheme::HTTP);
500        }
501        let original_uri = Uri::from_parts(original_uri_parts)?;
502        let uri2 = original_uri.clone();
503        let uri = infer_remote_uri_from_authority(original_uri);
504        let domain = uri2.authority().to_owned().unwrap().as_str();
505
506        let mdns_uri = mdns_uri.and_then(|p| Uri::from_parts(p).ok());
507        let attempting_mdns = mdns_uri.is_some();
508        if attempting_mdns {
509            log::debug!("Attempting to connect via mDNS");
510        } else {
511            log::debug!("Attempting to connect");
512        }
513
514        let channel = match mdns_uri {
515            Some(uri) => Self::create_channel(self.config.allow_downgrade, domain, uri, true).await,
516            // not actually an error necessarily, but we want to ensure that a channel is still
517            // created with the default uri
518            None => Err(anyhow::anyhow!("")),
519        };
520
521        let channel = match channel {
522            Ok(c) => {
523                log::debug!("Connected via mDNS");
524                c
525            }
526            Err(e) => {
527                if attempting_mdns {
528                    log::debug!(
529                        "Unable to connect via mDNS; falling back to robot URI. Error: {e}"
530                    );
531                }
532                Self::create_channel(self.config.allow_downgrade, domain, uri.clone(), false)
533                    .await?
534            }
535        };
536        // TODO (RSDK-517) make maybe_connect_via_webrtc take a more generic type so we don't
537        // need to add these dummy layers.
538        let intercepted_channel = ServiceBuilder::new()
539            .layer(AddAuthorizationLayer::basic(
540                "fake username",
541                "fake password",
542            ))
543            .layer(SetRequestHeaderLayer::overriding(
544                HeaderName::from_static("rpc-host"),
545                HeaderValue::from_str(domain)?,
546            ))
547            .service(channel.clone());
548
549        if disable_webrtc {
550            log::debug!("{}", log_prefixes::DIALED_GRPC);
551            Ok(ViamChannel::Direct(channel.clone()))
552        } else {
553            match maybe_connect_via_webrtc(uri, intercepted_channel.clone(), webrtc_options).await {
554                Ok(webrtc_channel) => Ok(ViamChannel::WebRTC(webrtc_channel)),
555                Err(e) => {
556                    log::error!("error connecting via webrtc: {e}. Attempting to connect directly");
557                    log::debug!("{}", log_prefixes::DIALED_GRPC);
558                    Ok(ViamChannel::Direct(channel.clone()))
559                }
560            }
561        }
562    }
563
564    async fn connect_mdns(self, original_uri: Parts) -> Result<ViamChannel> {
565        let mdns_uri =
566            webrtc::action_with_timeout(self.get_mdns_uri(), Duration::from_millis(1500))
567                .await
568                .ok()
569                .flatten()
570                .ok_or(anyhow::anyhow!(
571                    "Unable to establish connection via mDNS; uri not found"
572                ))?;
573
574        self.connect_inner(Some(mdns_uri), original_uri).await
575    }
576
577    pub async fn connect(self) -> Result<ViamChannel> {
578        log::debug!("{}", log_prefixes::DIAL_ATTEMPT);
579        let original_uri = self.duplicate_uri().ok_or(anyhow::anyhow!(
580            "Attempting to connect but there was no uri"
581        ))?;
582        let original_uri2 = duplicate_uri(&original_uri).ok_or(anyhow::anyhow!(
583            "Attempting to connect but there was no uri"
584        ))?;
585        // We want to short circuit and return the first `Ok` result from our connection
586        // attempts, which `tokio::select!` does great. Buuuuut, we don't want to
587        // abandon the `Err` results, and we want to provide comprehensive logging for
588        // debugging purposes. Hence the loop and pinning. The pinning lets us reference
589        // the same future multiple times, while the loop lets us immediately return on the
590        // first `Ok` result while still seeing and logging any error results.
591        tokio::pin! {
592            let with_mdns = self.clone().connect_mdns(original_uri);
593            let without_mdns = self.connect_inner(None, original_uri2);
594        }
595        let mut with_mdns_err: Option<anyhow::Error> = None;
596        let mut without_mdns_err: Option<anyhow::Error> = None;
597        while with_mdns_err.is_none() || without_mdns_err.is_none() {
598            tokio::select! {
599                with_mdns = &mut with_mdns, if with_mdns_err.is_none() => {
600                    match with_mdns {
601                        Ok(chan) => return Ok(chan),
602                        Err(e) => {
603                            log::debug!("Error connecting with mdns: {e}");
604                            with_mdns_err = Some(e);
605                        }
606                    }
607                }
608                without_mdns = &mut without_mdns, if without_mdns_err.is_none() => {
609                    match without_mdns {
610                        Ok(chan) => return Ok(chan),
611                        Err(e) => {
612                            log::debug!("Error connecting without mdns: {e}");
613                            without_mdns_err = Some(e);
614                        }
615                    }
616                }
617            }
618        }
619        Err(anyhow::anyhow!(
620            "Unable to connect with or without mdns.
621                    with_mdns err: {with_mdns_err:?}
622                    without_mdns err: {without_mdns_err:?}"
623        ))
624    }
625}
626
627async fn get_auth_token(
628    channel: &mut Channel,
629    creds: Credentials,
630    entity: String,
631) -> Result<String> {
632    let mut auth_service = AuthServiceClient::new(channel);
633    let req = AuthenticateRequest {
634        entity,
635        credentials: Some(creds),
636    };
637
638    let rsp = auth_service.authenticate(req).await?;
639    Ok(rsp.into_inner().access_token)
640}
641
642impl DialBuilder<WithCredentials> {
643    fn clone(&self) -> Self {
644        DialBuilder {
645            state: WithCredentials(()),
646            config: DialOptions {
647                credentials: self.config.credentials.clone(),
648                webrtc_options: self.config.webrtc_options.clone(),
649                uri: self.duplicate_uri(),
650                disable_mdns: self.config.disable_mdns,
651                allow_downgrade: self.config.allow_downgrade,
652                insecure: self.config.insecure,
653            },
654        }
655    }
656
657    async fn connect_inner(
658        self,
659        mdns_uri: Option<Parts>,
660        mut original_uri_parts: Parts,
661    ) -> Result<ViamChannel> {
662        let is_insecure = self.config.insecure;
663
664        let webrtc_options = self.config.webrtc_options;
665        let disable_webrtc = match &webrtc_options {
666            Some(options) => options.disable_webrtc,
667            None => false,
668        };
669
670        if is_insecure {
671            original_uri_parts.scheme = Some(Scheme::HTTP);
672        }
673
674        let original_uri = Uri::from_parts(original_uri_parts)?;
675
676        let domain = original_uri.authority().unwrap().to_string();
677        let uri_for_auth = infer_remote_uri_from_authority(original_uri.clone());
678
679        let mdns_uri = mdns_uri.and_then(|p| Uri::from_parts(p).ok());
680        let attempting_mdns = mdns_uri.is_some();
681
682        let allow_downgrade = self.config.allow_downgrade;
683        if attempting_mdns {
684            log::debug!("Attempting to connect via mDNS");
685        } else {
686            log::debug!("Attempting to connect");
687        }
688        let channel = match mdns_uri {
689            Some(uri) => Self::create_channel(allow_downgrade, &domain, uri, true).await,
690            // not actually an error necessarily, but we want to ensure that a channel is still
691            // created with the default uri
692            None => Err(anyhow::anyhow!("")),
693        };
694        let real_channel = match channel {
695            Ok(c) => {
696                log::debug!("Connected via mDNS");
697                c
698            }
699            Err(e) => {
700                if attempting_mdns {
701                    log::debug!(
702                        "Unable to connect via mDNS; falling back to robot URI. Error: {e}"
703                    );
704                }
705                Self::create_channel(allow_downgrade, &domain, uri_for_auth, false).await?
706            }
707        };
708
709        log::debug!("{}", log_prefixes::ACQUIRING_AUTH_TOKEN);
710        let token = get_auth_token(
711            &mut real_channel.clone(),
712            self.config
713                .credentials
714                .as_ref()
715                .unwrap()
716                .credentials
717                .clone(),
718            self.config
719                .credentials
720                .unwrap()
721                .entity
722                .unwrap_or_else(|| domain.clone()),
723        )
724        .await?;
725        log::debug!("{}", log_prefixes::ACQUIRED_AUTH_TOKEN);
726
727        let channel = ServiceBuilder::new()
728            .layer(AddAuthorizationLayer::bearer(&token))
729            .layer(SetRequestHeaderLayer::overriding(
730                HeaderName::from_static("rpc-host"),
731                HeaderValue::from_str(domain.as_str())?,
732            ))
733            .service(real_channel);
734
735        if disable_webrtc {
736            log::debug!("Connected via gRPC");
737            Ok(ViamChannel::DirectPreAuthorized(channel))
738        } else {
739            match maybe_connect_via_webrtc(original_uri, channel.clone(), webrtc_options).await {
740                Ok(webrtc_channel) => Ok(ViamChannel::WebRTC(webrtc_channel)),
741                Err(e) => {
742                    log::error!(
743                    "Unable to establish webrtc connection due to error: [{e}]. Attempting direct connection."
744                );
745                    log::debug!("Connected via gRPC");
746                    Ok(ViamChannel::DirectPreAuthorized(channel))
747                }
748            }
749        }
750    }
751
752    async fn connect_mdns(self, original_uri: Parts) -> Result<ViamChannel> {
753        // NOTE(benjirewis): Use a duration of 1500ms for getting the mDNS URI. I've anecdotally
754        // seen times as great as 922ms to fetch a non-loopback mDNS URI. With an
755        // interface_with_loopback query interval of 250ms, 1500ms here should give us time for ~6
756        // queries.
757        let mdns_uri =
758            webrtc::action_with_timeout(self.get_mdns_uri(), Duration::from_millis(1500))
759                .await
760                .ok()
761                .flatten()
762                .ok_or(anyhow::anyhow!(
763                    "Unable to establish connection via mDNS; uri not found"
764                ))?;
765
766        self.connect_inner(Some(mdns_uri), original_uri).await
767    }
768
769    /// attempts to establish a connection with credentials to the DialBuilder's given uri
770    pub async fn connect(self) -> Result<ViamChannel> {
771        log::debug!("{}", log_prefixes::DIAL_ATTEMPT);
772        let original_uri = self.duplicate_uri().ok_or(anyhow::anyhow!(
773            "Attempting to connect but there was no uri"
774        ))?;
775        let original_uri2 = duplicate_uri(&original_uri).ok_or(anyhow::anyhow!(
776            "Attempting to connect but there was no uri"
777        ))?;
778
779        // We want to short circuit and return the first `Ok` result from our connection
780        // attempts, which `tokio::select!` does great. Buuuuut, we don't want to
781        // abandon the `Err` results, and we want to provide comprehensive logging for
782        // debugging purposes. Hence the loop and pinning. The pinning lets us reference
783        // the same future multiple times, while the loop lets us immediately return on the
784        // first `Ok` result while still seeing and logging any error results.
785        tokio::pin! {
786            let with_mdns = self.clone().connect_mdns(original_uri);
787            let without_mdns = self.connect_inner(None, original_uri2);
788        }
789        let mut with_mdns_err: Option<anyhow::Error> = None;
790        let mut without_mdns_err: Option<anyhow::Error> = None;
791        while with_mdns_err.is_none() || without_mdns_err.is_none() {
792            tokio::select! {
793                with_mdns = &mut with_mdns, if with_mdns_err.is_none() => {
794                    match with_mdns {
795                        Ok(chan) => return Ok(chan),
796                        Err(e) => {
797                            log::debug!("Error connecting with mdns: {e}");
798                            with_mdns_err = Some(e);
799                        }
800                    }
801                }
802                without_mdns = &mut without_mdns, if without_mdns_err.is_none() => {
803                    match without_mdns {
804                        Ok(chan) => return Ok(chan),
805                        Err(e) => {
806                            log::debug!("Error connecting without mdns: {e}");
807                            without_mdns_err = Some(e);
808                        }
809                    }
810                }
811            }
812        }
813        Err(anyhow::anyhow!(
814            "Unable to connect with or without mdns.
815                    with_mdns err: {with_mdns_err:?}
816                    without_mdns err: {without_mdns_err:?}"
817        ))
818    }
819}
820
821async fn send_done_or_error_update(
822    update: CallUpdateRequest,
823    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
824) {
825    let mut signaling_client = SignalingServiceClient::new(channel.clone());
826
827    if let Err(e) = signaling_client
828        .call_update(update)
829        .await
830        .map_err(anyhow::Error::from)
831        .map(|_| ())
832    {
833        log::error!("Error sending done or error update: {e}")
834    }
835}
836
837async fn send_error_once(
838    sent_error: Arc<AtomicBool>,
839    uuid: &String,
840    err: &anyhow::Error,
841    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
842) {
843    if sent_error.load(Ordering::Acquire) {
844        return;
845    }
846
847    let err = google::rpc::Status {
848        code: google::rpc::Code::Unknown.into(),
849        message: err.to_string(),
850        details: Vec::new(),
851    };
852    sent_error.store(true, Ordering::Release);
853    let update_request = CallUpdateRequest {
854        uuid: uuid.to_string(),
855        update: Some(Update::Error(err)),
856    };
857
858    send_done_or_error_update(update_request, channel).await
859}
860
861async fn send_done_once(
862    sent_done: Arc<AtomicBool>,
863    uuid: &String,
864    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
865) {
866    if sent_done.load(Ordering::Acquire) {
867        return;
868    }
869    sent_done.store(true, Ordering::Release);
870    let update_request = CallUpdateRequest {
871        uuid: uuid.to_string(),
872        update: Some(Update::Done(true)),
873    };
874
875    send_done_or_error_update(update_request, channel).await
876}
877
878#[derive(Default)]
879struct CallerUpdateStats {
880    count: u128,
881    total_duration: Duration,
882    max_duration: Duration,
883}
884
885impl fmt::Display for CallerUpdateStats {
886    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
887        let average_duration = &self.total_duration.as_millis() / &self.count;
888        writeln!(
889            f,
890            "Caller update statistics: num_updates: {}, average_duration: {}ms, max_duration: {}ms",
891            &self.count,
892            average_duration,
893            &self.max_duration.as_millis()
894        )?;
895        Ok(())
896    }
897}
898
899async fn maybe_connect_via_webrtc(
900    uri: Uri,
901    channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
902    webrtc_options: Option<Options>,
903) -> Result<Arc<WebRTCClientChannel>> {
904    let webrtc_options = webrtc_options.unwrap_or_else(|| Options::infer_from_uri(uri.clone()));
905    let mut signaling_client = SignalingServiceClient::new(channel.clone());
906    let response = match signaling_client
907        .optional_web_rtc_config(OptionalWebRtcConfigRequest::default())
908        .await
909    {
910        Ok(resp) => resp,
911        Err(e) => {
912            if e.code() == tonic::Code::Unimplemented {
913                tonic::Response::new(OptionalWebRtcConfigResponse::default())
914            } else {
915                return Err(anyhow::anyhow!(e));
916            }
917        }
918    };
919
920    let optional_config = response.into_inner().config;
921    let config = webrtc::extend_webrtc_config(webrtc_options.config, optional_config);
922
923    let (peer_connection, data_channel) =
924        webrtc::new_peer_connection_for_client(config, webrtc_options.disable_trickle_ice).await?;
925
926    let sent_done_or_error = Arc::new(AtomicBool::new(false));
927    let uuid_lock = Arc::new(RwLock::new("".to_string()));
928    let uuid_for_ice_gathering_thread = uuid_lock.clone();
929
930    // Using an mpsc channel to report unrecoverable errors during Signaling, so we
931    // don't have to wait until the timeout expires before giving up on this attempt.
932    // The size of the channel is set to 1 since any error (or success) should terminate the function
933    let (is_open_s, mut is_open_r) = mpsc::channel(1);
934    let on_open_is_open = is_open_s.clone();
935
936    data_channel.on_open(Box::new(move || {
937        let _ = on_open_is_open.try_send(None); // ignore sending errors, either an error (or success) was already sent or the operation will succeed
938        Box::pin(async move {})
939    }));
940
941    let exchange_done = Arc::new(AtomicBool::new(false));
942    let (remote_description_set_s, remote_description_set_r) = watch::channel(None);
943    let ice_done = Arc::new(tokio::sync::Notify::new());
944    let ice_done2 = ice_done.clone();
945    let caller_update_stats = Arc::new(Mutex::new(CallerUpdateStats::default()));
946
947    if !webrtc_options.disable_trickle_ice {
948        let offer = peer_connection.create_offer(None).await?;
949        let channel2 = channel.clone();
950        let uuid_lock2 = uuid_lock.clone();
951        let sent_done_or_error2 = sent_done_or_error.clone();
952
953        let exchange_done = exchange_done.clone();
954
955        let on_local_ice_candidate_failure = is_open_s.clone();
956
957        let caller_update_stats = caller_update_stats.clone();
958        let caller_update_stats2 = caller_update_stats.clone();
959        peer_connection.on_ice_connection_state_change(Box::new(
960            move |state: RTCIceConnectionState| {
961                let caller_update_stats = caller_update_stats.clone();
962                Box::pin(async move {
963                    if state == RTCIceConnectionState::Completed {
964                        let caller_update_stats_inner = caller_update_stats.lock().unwrap();
965                        log::debug!("{}", caller_update_stats_inner);
966                    }
967                })
968            },
969        ));
970        peer_connection.on_ice_candidate(Box::new(
971            move |ice_candidate: Option<RTCIceCandidate>| {
972                if exchange_done.load(Ordering::Acquire) {
973                    return Box::pin(async move {});
974                }
975                let channel = channel2.clone();
976                let sent_done_or_error = sent_done_or_error2.clone();
977                let ice_done = ice_done.clone();
978                let uuid_lock = uuid_lock2.clone();
979                let on_local_ice_candidate_failure = on_local_ice_candidate_failure.clone();
980                let mut remote_description_set_r = remote_description_set_r.clone();
981                let caller_update_stats = caller_update_stats2.clone();
982                Box::pin(async move {
983                    // If the value in the watch channel has not been set yet, we wait until it does.
984                    // Afterwards Some(()) should be visible to all watcher and any watcher waiting  will
985                    // return
986                    if remote_description_set_r.borrow().is_none() {
987                        match webrtc_action_with_timeout(remote_description_set_r.changed()).await {
988                            Ok(Err(e)) => {
989                                let _ = on_local_ice_candidate_failure.try_send(Some(Box::new(
990                                    anyhow::anyhow!(
991                                        "remote description watch channel is closed with error {e}"
992                                    ),
993                                )));
994                            }
995                            Err(_) => {
996                                log::info!(
997                                    "timed out on_ice_candidate; remote description was never set"
998                                );
999                                let _ = on_local_ice_candidate_failure.try_send(Some(Box::new(
1000                                    anyhow::anyhow!("timed out waiting for remote description"),
1001                                )));
1002                            }
1003                            _ => (),
1004                        }
1005                    }
1006
1007                    let uuid = uuid_lock.read().unwrap().to_string();
1008                    // Note(ethan): for reasons that aren't entirely clear to me, parallel dialing
1009                    // occasionally causes us to not receive a signaling client response when
1010                    // trying to establish a connection. This results in noisy error messages that
1011                    // fortunately are harmless (this problem seems to only ever affect one branch
1012                    // of the parallel dial, so we still end up with a successful connection).
1013                    // By checking if the `uuid` is empty, we can tell if we're in such a case and
1014                    // exit out before it results in logging noisy error messages.
1015                    //
1016                    // It would be lovely to understand this problem better, but given that it's
1017                    // not actually causing performance failures it's probably not worth the effort
1018                    // at this time.
1019                    if uuid.is_empty() {
1020                        log::debug!(
1021                            "UUID never updated. This is likely because we never received a response \
1022                            from the signaling client. This happens occasionally with parallel dialing \
1023                            and isn't concerning provided connection still occurs."
1024                        );
1025                        return;
1026                    }
1027                    let mut signaling_client = SignalingServiceClient::new(channel.clone());
1028                    match ice_candidate {
1029                        Some(ice_candidate) => {
1030                            log::debug!("Gathered local candidate of {ice_candidate}");
1031                            if sent_done_or_error.load(Ordering::Acquire) {
1032                                return;
1033                            }
1034                            let proto_candidate = ice_candidate_to_proto(ice_candidate).await;
1035                            match proto_candidate {
1036                                Ok(proto_candidate) => {
1037                                    let update_request = CallUpdateRequest {
1038                                        uuid: uuid.clone(),
1039                                        update: Some(Update::Candidate(proto_candidate)),
1040                                    };
1041                                    let call_update_start = Instant::now();
1042                                    if let Err(e) = webrtc_action_with_timeout(
1043                                        signaling_client.call_update(update_request),
1044                                    )
1045                                    .await
1046                                    .and_then(|resp| resp.map_err(anyhow::Error::from))
1047                                    {
1048                                        log::error!("Error sending ice candidate: {e}");
1049                                        let _ = on_local_ice_candidate_failure.try_send(Some(
1050                                            Box::new(anyhow::anyhow!(
1051                                                "Error sending ice candidate: {e}"
1052                                            )),
1053                                        ));
1054                                    }
1055                                    let mut caller_update_stats_inner =
1056                                        caller_update_stats.lock().unwrap();
1057                                    caller_update_stats_inner.count += 1;
1058                                    let call_update_duration = call_update_start.elapsed();
1059                                    if call_update_duration > caller_update_stats_inner.max_duration
1060                                    {
1061                                        caller_update_stats_inner.max_duration =
1062                                            call_update_duration;
1063                                    }
1064                                    caller_update_stats_inner.total_duration +=
1065                                        call_update_duration;
1066                                }
1067                                Err(e) => log::error!("Error parsing ice candidate: {e}"),
1068                            }
1069                        }
1070                        None => {
1071                            // will only be executed once when gathering is finished
1072                            ice_done.notify_one();
1073                            send_done_once(sent_done_or_error, &uuid, channel.clone()).await;
1074                        }
1075                    }
1076                })
1077            },
1078        ));
1079
1080        peer_connection.set_local_description(offer).await?;
1081    }
1082
1083    let local_description = peer_connection.local_description().await.unwrap();
1084
1085    // Local SD will be multi-line, so use two log messages to indicate start, SD and end.
1086    log::debug!(
1087        "{}\n{}",
1088        log_prefixes::START_LOCAL_SESSION_DESCRIPTION,
1089        local_description.sdp
1090    );
1091    log::debug!("{}", log_prefixes::END_LOCAL_SESSION_DESCRIPTION);
1092
1093    let sdp = encode_sdp(local_description)?;
1094    let call_request = CallRequest {
1095        sdp,
1096        disable_trickle: webrtc_options.disable_trickle_ice,
1097    };
1098
1099    let client_channel = WebRTCClientChannel::new(peer_connection, data_channel).await;
1100    let client_channel_for_ice_gathering_thread = Arc::downgrade(&client_channel);
1101    let mut signaling_client = SignalingServiceClient::new(channel.clone());
1102    let mut call_client = signaling_client.call(call_request).await?.into_inner();
1103
1104    let channel2 = channel.clone();
1105    let sent_done_or_error2 = sent_done_or_error.clone();
1106    tokio::spawn(async move {
1107        let uuid = uuid_for_ice_gathering_thread;
1108        let client_channel = client_channel_for_ice_gathering_thread;
1109        let init_received = AtomicBool::new(false);
1110        let sent_done = sent_done_or_error2;
1111
1112        loop {
1113            let response = match webrtc_action_with_timeout(call_client.message())
1114                .await
1115                .and_then(|resp| resp.map_err(anyhow::Error::from))
1116            {
1117                Ok(cr) => match cr {
1118                    Some(cr) => cr,
1119                    None => {
1120                        // want to delay sending done until we either are actually done, or
1121                        // we hit a timeout
1122                        let _ = webrtc_action_with_timeout(ice_done2.notified()).await;
1123                        let uuid = uuid.read().unwrap().to_string();
1124                        send_done_once(sent_done.clone(), &uuid, channel2.clone()).await;
1125                        break;
1126                    }
1127                },
1128                Err(e) => {
1129                    log::error!("Error processing call response: {e}");
1130                    let _ = is_open_s.try_send(Some(Box::new(e)));
1131                    break;
1132                }
1133            };
1134
1135            match response.stage {
1136                Some(Stage::Init(init)) => {
1137                    if init_received.load(Ordering::Acquire) {
1138                        let uuid = uuid.read().unwrap().to_string();
1139                        let e = anyhow::anyhow!("Init received more than once");
1140                        send_error_once(sent_done.clone(), &uuid, &e, channel2.clone()).await;
1141                        let _ = is_open_s.try_send(Some(Box::new(e)));
1142                        break;
1143                    }
1144                    init_received.store(true, Ordering::Release);
1145                    {
1146                        let mut uuid_s = uuid.write().unwrap();
1147                        uuid_s.clone_from(&response.uuid);
1148                    }
1149
1150                    let answer = match decode_sdp(init.sdp) {
1151                        Ok(a) => a,
1152                        Err(e) => {
1153                            send_error_once(
1154                                sent_done.clone(),
1155                                &response.uuid,
1156                                &e,
1157                                channel2.clone(),
1158                            )
1159                            .await;
1160                            let _ = is_open_s.try_send(Some(Box::new(e)));
1161                            break;
1162                        }
1163                    };
1164                    {
1165                        let cc = match client_channel.upgrade() {
1166                            Some(cc) => cc,
1167                            None => {
1168                                break;
1169                            }
1170                        };
1171                        if let Err(e) = cc
1172                            .base_channel
1173                            .peer_connection
1174                            .set_remote_description(answer)
1175                            .await
1176                        {
1177                            let e = anyhow::Error::from(e);
1178                            send_error_once(
1179                                sent_done.clone(),
1180                                &response.uuid,
1181                                &e,
1182                                channel2.clone(),
1183                            )
1184                            .await;
1185                            let _ = is_open_s.try_send(Some(Box::new(e)));
1186                            break;
1187                        }
1188                    }
1189                    let _ = remote_description_set_s.send_replace(Some(()));
1190                    if webrtc_options.disable_trickle_ice {
1191                        send_done_once(sent_done.clone(), &response.uuid, channel2.clone()).await;
1192                        break;
1193                    }
1194                }
1195
1196                Some(Stage::Update(update)) => {
1197                    let uuid_s = uuid.read().unwrap().to_string();
1198                    if !init_received.load(Ordering::Acquire) {
1199                        let e = anyhow::anyhow!("Got update before init stage");
1200                        send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone()).await;
1201                        let _ = is_open_s.try_send(Some(Box::new(e)));
1202                        break;
1203                    }
1204
1205                    if response.uuid != *uuid.read().unwrap() {
1206                        let e = anyhow::anyhow!(
1207                            "uuid mismatch: have {}, want {}",
1208                            response.uuid,
1209                            uuid_s,
1210                        );
1211                        send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone()).await;
1212                        let _ = is_open_s.try_send(Some(Box::new(e)));
1213                        break;
1214                    }
1215                    match ice_candidate_from_proto(update.candidate) {
1216                        Ok(candidate) => {
1217                            let client_channel = match client_channel.upgrade() {
1218                                Some(cc) => cc,
1219                                None => {
1220                                    break;
1221                                }
1222                            };
1223                            log::debug!("Received remote ICE candidate of {candidate:#?}");
1224                            if let Err(e) = client_channel
1225                                .base_channel
1226                                .peer_connection
1227                                .add_ice_candidate(candidate)
1228                                .await
1229                            {
1230                                let e = anyhow::Error::from(e);
1231                                send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone())
1232                                    .await;
1233                                let _ = is_open_s.try_send(Some(Box::new(e)));
1234                                break;
1235                            }
1236                        }
1237                        Err(e) => log::error!("Error parsing ice candidate: {e}"),
1238                    }
1239                }
1240                None => continue,
1241            }
1242        }
1243    });
1244
1245    // TODO (GOUT-11): create separate authorization if external_auth_addr and/or creds.Type is `Some`
1246
1247    // Delay returning the client channel until data channel is open, so we don't lose messages
1248    let is_open = webrtc_action_with_timeout(is_open_r.recv()).await;
1249    match is_open {
1250        Ok(is_open) => {
1251            if let Some(Some(e)) = is_open {
1252                return Err(anyhow::anyhow!("Couldn't connect to peer with error {e}"));
1253            }
1254        }
1255        Err(_) => {
1256            return Err(anyhow::anyhow!("Timed out opening data channel."));
1257        }
1258    }
1259
1260    exchange_done.store(true, Ordering::Release);
1261    let uuid = uuid_lock.read().unwrap().to_string();
1262    send_done_once(sent_done_or_error, &uuid, channel.clone()).await;
1263    Ok(client_channel)
1264}
1265
1266async fn ice_candidate_to_proto(ice_candidate: RTCIceCandidate) -> Result<IceCandidate> {
1267    let ice_candidate = ice_candidate.to_json()?;
1268    Ok(IceCandidate {
1269        candidate: ice_candidate.candidate,
1270        sdp_mid: ice_candidate.sdp_mid,
1271        sdpm_line_index: ice_candidate.sdp_mline_index.map(u32::from),
1272        username_fragment: ice_candidate.username_fragment,
1273    })
1274}
1275
1276fn ice_candidate_from_proto(proto: Option<IceCandidate>) -> Result<RTCIceCandidateInit> {
1277    match proto {
1278        Some(proto) => {
1279            let proto_sdpm: usize = proto.sdpm_line_index().try_into()?;
1280            let sdp_mline_index: Option<u16> = proto_sdpm.try_into().ok();
1281
1282            Ok(RTCIceCandidateInit {
1283                candidate: proto.candidate.clone(),
1284                sdp_mid: Some(proto.sdp_mid().to_string()),
1285                sdp_mline_index,
1286                username_fragment: Some(proto.username_fragment().to_string()),
1287            })
1288        }
1289        None => Err(anyhow::anyhow!("No ice candidate provided")),
1290    }
1291}
1292
1293fn decode_sdp(sdp: String) -> Result<RTCSessionDescription> {
1294    let sdp = String::from_utf8(base64::decode(sdp)?)?;
1295    Ok(serde_json::from_str::<RTCSessionDescription>(&sdp)?)
1296}
1297
1298fn encode_sdp(sdp: RTCSessionDescription) -> Result<String> {
1299    let sdp = serde_json::to_vec(&sdp)?;
1300    Ok(base64::encode(sdp))
1301}
1302
1303fn infer_remote_uri_from_authority(uri: Uri) -> Uri {
1304    let authority = uri.authority().map(Authority::as_str).unwrap_or_default();
1305    let is_local_connection = authority.contains(".local.viam.cloud")
1306        || authority.contains("localhost")
1307        || authority.contains("0.0.0.0");
1308
1309    if !is_local_connection {
1310        if let Some((new_uri, _)) = Options::infer_signaling_server_address(&uri) {
1311            return Uri::from_parts(uri_parts_with_defaults(&new_uri)).unwrap_or(uri);
1312        }
1313    }
1314    uri
1315}
1316
1317fn duplicate_uri(parts: &Parts) -> Option<Parts> {
1318    let uri = Uri::builder()
1319        .authority(parts.authority.clone()?)
1320        .path_and_query(parts.path_and_query.clone()?)
1321        .scheme(parts.scheme.clone()?);
1322    Some(uri.build().ok()?.into_parts())
1323}
1324
1325fn uri_parts_with_defaults(uri: &str) -> Parts {
1326    let mut uri_parts = uri.parse::<Uri>().unwrap().into_parts();
1327    uri_parts.scheme = Some(Scheme::HTTPS);
1328    uri_parts.path_and_query = Some(PathAndQuery::from_static(""));
1329    uri_parts
1330}
1331
1332fn metadata_from_parts(parts: &http::request::Parts) -> Metadata {
1333    let mut md = HashMap::new();
1334    for (k, v) in parts.headers.iter() {
1335        let k = k.to_string();
1336        let v = Strings {
1337            values: vec![HeaderValue::to_str(v).unwrap().to_string()],
1338        };
1339        md.insert(k, v);
1340    }
1341    Metadata { md }
1342}