1use super::{
2 client_channel::*,
3 log_prefixes,
4 webrtc::{webrtc_action_with_timeout, Options},
5};
6use crate::gen::google;
7use crate::gen::proto::rpc::v1::{
8 auth_service_client::AuthServiceClient, AuthenticateRequest, Credentials,
9};
10use crate::gen::proto::rpc::webrtc::v1::{
11 call_response::Stage, call_update_request::Update,
12 signaling_service_client::SignalingServiceClient, CallUpdateRequest,
13 OptionalWebRtcConfigRequest, OptionalWebRtcConfigResponse,
14};
15use crate::gen::proto::rpc::webrtc::v1::{
16 CallRequest, IceCandidate, Metadata, RequestHeaders, Strings,
17};
18use crate::rpc::webrtc;
19use ::http::header::HeaderName;
20use ::http::{
21 uri::{Authority, Parts, PathAndQuery, Scheme},
22 HeaderValue, Version,
23};
24use ::viam_mdns::{discover, Response};
25use ::webrtc::ice_transport::{
26 ice_candidate::{RTCIceCandidate, RTCIceCandidateInit},
27 ice_connection_state::RTCIceConnectionState,
28};
29use ::webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
30use anyhow::{Context, Result};
31use core::fmt;
32use futures::stream::FuturesUnordered;
33use futures_util::{pin_mut, stream::StreamExt};
34use local_ip_address::list_afinet_netifas;
35use std::{
36 collections::HashMap,
37 net::{IpAddr, Ipv4Addr},
38 sync::{
39 atomic::{AtomicBool, Ordering},
40 Arc, Mutex, RwLock,
41 },
42 task::{Context as TaskContext, Poll},
43 time::{Duration, Instant},
44};
45use tokio::sync::{mpsc, watch};
46use tonic::codegen::BoxFuture;
47use tonic::transport::{Body, Channel, Uri};
48use tonic::{body::BoxBody, transport::ClientTlsConfig};
49use tower::{Service, ServiceBuilder};
50use tower_http::auth::AddAuthorization;
51use tower_http::auth::AddAuthorizationLayer;
52use tower_http::set_header::{SetRequestHeader, SetRequestHeaderLayer};
53
54const STATUS_CODE_OK: i32 = 0;
56const STATUS_CODE_UNKNOWN: i32 = 2;
57const STATUS_CODE_RESOURCE_EXHAUSTED: i32 = 8;
58
59pub const VIAM_MDNS_SERVICE_NAME: &'static str = "_rpc._tcp.local";
60
61type SecretType = String;
62
63#[derive(Clone)]
64pub enum ViamChannel {
67 Direct(Channel),
68 DirectPreAuthorized(AddAuthorization<SetRequestHeader<Channel, HeaderValue>>),
69 WebRTC(Arc<WebRTCClientChannel>),
70}
71
72#[derive(Debug, Clone)]
73pub struct RPCCredentials {
74 entity: Option<String>,
75 credentials: Credentials,
76}
77
78impl RPCCredentials {
79 pub fn new(entity: Option<String>, r#type: SecretType, payload: String) -> Self {
80 Self {
81 credentials: Credentials { r#type, payload },
82 entity,
83 }
84 }
85}
86
87impl ViamChannel {
88 async fn create_resp(
89 channel: &mut Arc<WebRTCClientChannel>,
90 stream: crate::gen::proto::rpc::webrtc::v1::Stream,
91 request: http::Request<BoxBody>,
92 response: http::response::Builder,
93 ) -> http::Response<Body> {
94 let (parts, body) = request.into_parts();
95 let mut status_code = STATUS_CODE_OK;
96 let stream_id = stream.id;
97 let metadata = Some(metadata_from_parts(&parts));
98 let headers = RequestHeaders {
99 method: parts
100 .uri
101 .path_and_query()
102 .map(PathAndQuery::to_string)
103 .unwrap_or_default(),
104 metadata,
105 timeout: None,
106 };
107
108 if let Err(e) = channel.write_headers(&stream, headers).await {
109 log::error!("error writing headers: {e}");
110 channel.close_stream_with_recv_error(stream_id, e);
111 status_code = STATUS_CODE_UNKNOWN;
112 }
113
114 let data = hyper::body::to_bytes(body).await.unwrap().to_vec();
115 if let Err(e) = channel.write_message(Some(stream), data).await {
116 log::error!("error sending message: {e}");
117 channel.close_stream_with_recv_error(stream_id, e);
118 status_code = STATUS_CODE_UNKNOWN;
119 };
120
121 let body = match channel.resp_body_from_stream(stream_id) {
122 Ok(body) => body,
123 Err(e) => {
124 log::error!("error receiving response from stream: {e}");
125 channel.close_stream_with_recv_error(stream_id, e);
126 status_code = STATUS_CODE_UNKNOWN;
127 Body::empty()
128 }
129 };
130
131 let response = if status_code != STATUS_CODE_OK {
132 response.header("grpc-status", &status_code.to_string())
133 } else {
134 response
135 };
136
137 response.body(body).unwrap()
138 }
139}
140
141impl Service<http::Request<BoxBody>> for ViamChannel {
142 type Response = http::Response<Body>;
143 type Error = tonic::transport::Error;
144 type Future = BoxFuture<Self::Response, Self::Error>;
145
146 fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
147 match self {
148 Self::Direct(channel) => channel.poll_ready(cx),
149 Self::DirectPreAuthorized(channel) => channel.poll_ready(cx),
150 Self::WebRTC(_channel) => Poll::Ready(Ok(())),
151 }
152 }
153
154 fn call(&mut self, request: http::Request<BoxBody>) -> Self::Future {
155 match self {
156 Self::Direct(channel) => Box::pin(channel.call(request)),
157 Self::DirectPreAuthorized(channel) => Box::pin(channel.call(request)),
158 Self::WebRTC(channel) => {
159 let mut channel = channel.clone();
160 let fut = async move {
161 let response = http::response::Response::builder()
162 .header("content-type", "application/grpc")
164 .version(Version::HTTP_2);
165
166 match channel.new_stream() {
167 Err(e) => {
168 log::error!("{e}");
169 let response = response
170 .header("grpc-status", &STATUS_CODE_RESOURCE_EXHAUSTED.to_string())
171 .body(Body::default())
172 .unwrap();
173
174 Ok(response)
175 }
176 Ok(stream) => {
177 Ok(Self::create_resp(&mut channel, stream, request, response).await)
178 }
179 }
180 };
181 Box::pin(fut)
182 }
183 }
184 }
185}
186
187#[derive(Debug)]
189pub struct DialOptions {
190 credentials: Option<RPCCredentials>,
191 webrtc_options: Option<Options>,
192 uri: Option<Parts>,
193 disable_mdns: bool,
194 allow_downgrade: bool,
195 insecure: bool,
196}
197#[derive(Clone)]
198pub struct WantsCredentials(());
199#[derive(Clone)]
200pub struct WantsUri(());
201#[derive(Clone)]
202pub struct WithCredentials(());
203#[derive(Clone)]
204pub struct WithoutCredentials(());
205
206pub trait AuthMethod {}
207impl AuthMethod for WithCredentials {}
208impl AuthMethod for WithoutCredentials {}
209#[allow(dead_code)]
211pub struct DialBuilder<T> {
212 state: T,
213 config: DialOptions,
214}
215
216impl<T> fmt::Debug for DialBuilder<T> {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 f.debug_struct("Dial")
219 .field("State", &format_args!("{}", &std::any::type_name::<T>()))
220 .field("Opt", &format_args!("{:?}", self.config))
221 .finish()
222 }
223}
224
225impl DialOptions {
226 pub fn builder() -> DialBuilder<WantsUri> {
228 DialBuilder {
229 state: WantsUri(()),
230 config: DialOptions {
231 credentials: None,
232 uri: None,
233 allow_downgrade: false,
234 disable_mdns: false,
235 insecure: false,
236 webrtc_options: None,
237 },
238 }
239 }
240}
241
242impl DialBuilder<WantsUri> {
243 pub fn uri(self, uri: &str) -> DialBuilder<WantsCredentials> {
245 let uri_parts = uri_parts_with_defaults(uri);
246 DialBuilder {
247 state: WantsCredentials(()),
248 config: DialOptions {
249 credentials: None,
250 uri: Some(uri_parts),
251 allow_downgrade: false,
252 disable_mdns: false,
253 insecure: false,
254 webrtc_options: None,
255 },
256 }
257 }
258}
259impl DialBuilder<WantsCredentials> {
260 pub fn without_credentials(self) -> DialBuilder<WithoutCredentials> {
262 DialBuilder {
263 state: WithoutCredentials(()),
264 config: DialOptions {
265 credentials: None,
266 uri: self.config.uri,
267 allow_downgrade: false,
268 disable_mdns: false,
269 insecure: false,
270 webrtc_options: None,
271 },
272 }
273 }
274 pub fn with_credentials(self, creds: RPCCredentials) -> DialBuilder<WithCredentials> {
276 DialBuilder {
277 state: WithCredentials(()),
278 config: DialOptions {
279 credentials: Some(creds),
280 uri: self.config.uri,
281 allow_downgrade: false,
282 disable_mdns: false,
283 insecure: false,
284 webrtc_options: None,
285 },
286 }
287 }
288}
289
290impl<T: AuthMethod> DialBuilder<T> {
291 pub fn insecure(mut self) -> Self {
293 self.config.insecure = true;
294 self
295 }
296 pub fn allow_downgrade(mut self) -> Self {
298 self.config.allow_downgrade = true;
299 self
300 }
301 pub fn disable_mdns(mut self) -> Self {
303 self.config.disable_mdns = true;
304 self
305 }
306
307 pub fn disable_webrtc(mut self) -> Self {
311 let webrtc_options = Options::default().disable_webrtc();
312 self.config.webrtc_options = Some(webrtc_options);
313 self
314 }
315
316 async fn get_addr_from_interface(
317 iface: (&str, Vec<&IpAddr>),
318 candidates: &Vec<String>,
319 ) -> Option<String> {
320 let addresses: Vec<Ipv4Addr> = iface
321 .1
322 .iter()
323 .filter_map(|ip| match ip {
324 IpAddr::V4(v4) => Some(*v4),
325 IpAddr::V6(_) => None,
326 })
327 .collect();
328
329 let mut resp: Option<Response> = None;
330 for ipv4 in addresses {
331 for candidate in candidates {
332 let discovery = discover::interface_with_loopback(
333 VIAM_MDNS_SERVICE_NAME,
334 Duration::from_millis(250),
335 ipv4,
336 )
337 .ok()?;
338 let stream = discovery.listen();
339 pin_mut!(stream);
340 while let Some(Ok(response)) = stream.next().await {
341 if let Some(hostname) = response.hostname() {
342 let local_agnostic_candidate = candidate.as_str().split("viam").next()?;
350 if hostname.contains(local_agnostic_candidate) {
351 resp = Some(response);
352 break;
353 }
354 }
355 if resp.is_some() {
356 break;
357 }
358 }
359 }
360 }
361
362 let resp = resp?;
363 let mut has_grpc = false;
364 let mut has_webrtc = false;
365 for field in resp.txt_records() {
366 has_grpc = has_grpc || field.contains("grpc");
367 has_webrtc = has_webrtc || field.contains("webrtc");
368 }
369
370 let ip_addr = match resp.ip_addr() {
371 Some(std::net::IpAddr::V4(ip_v4)) => Some(ip_v4),
372 Some(std::net::IpAddr::V6(_)) | None => None,
373 };
374
375 if !(has_grpc || has_webrtc) || ip_addr.is_none() {
376 return None;
377 }
378 let mut local_addr = ip_addr?.to_string();
379 local_addr.push(':');
380 local_addr.push_str(&resp.port()?.to_string());
381 Some(local_addr)
382 }
383
384 fn duplicate_uri(&self) -> Option<Parts> {
385 match &self.config.uri {
386 None => None,
387 Some(uri) => duplicate_uri(uri),
388 }
389 }
390
391 async fn get_mdns_uri(&self) -> Option<Parts> {
392 log::debug!("{}", log_prefixes::MDNS_QUERY_ATTEMPT);
393 if self.config.disable_mdns {
394 return None;
395 }
396
397 let mut uri = self.duplicate_uri()?;
398 let candidate = uri.authority.clone()?.to_string();
399
400 let candidates: Vec<String> = vec![candidate.replace('.', "-"), candidate];
401
402 let ifaces = list_afinet_netifas().ok()?;
403
404 let ifaces: HashMap<&str, Vec<&IpAddr>> =
405 ifaces.iter().fold(HashMap::new(), |mut map, (k, v)| {
406 map.entry(k).or_default().push(v);
407 map
408 });
409
410 let mut iface_futures = FuturesUnordered::new();
411 for iface in ifaces {
412 iface_futures.push(Self::get_addr_from_interface(iface, &candidates));
413 }
414
415 let mut local_addr: Option<String> = None;
416 while let Some(maybe_addr) = iface_futures.next().await {
417 if maybe_addr.is_some() {
418 local_addr = maybe_addr;
419 break;
420 }
421 }
422 let local_addr = match local_addr {
423 None => {
424 log::debug!("Unable to connect via mDNS");
425 return None;
426 }
427 Some(addr) => {
428 log::debug!("{}: {addr}", log_prefixes::MDNS_ADDRESS_FOUND);
429 addr
430 }
431 };
432
433 let auth = local_addr.parse::<Authority>().ok()?;
434 uri.authority = Some(auth);
435 uri.scheme = Some(Scheme::HTTP);
436
437 Some(uri)
438 }
439
440 async fn create_channel(
441 allow_downgrade: bool,
442 domain: &str,
443 uri: Uri,
444 for_mdns: bool,
445 ) -> Result<Channel> {
446 let mut chan = Channel::builder(uri.clone());
447 if for_mdns {
448 let tls_config = ClientTlsConfig::new().domain_name(domain);
449 chan = chan.tls_config(tls_config)?;
450 }
451 let chan = match chan
452 .connect()
453 .await
454 .with_context(|| format!("Connecting to {:?}", uri.clone()))
455 {
456 Ok(c) => c,
457 Err(e) => {
458 if allow_downgrade {
459 let mut uri_parts = uri.clone().into_parts();
460 uri_parts.scheme = Some(Scheme::HTTP);
461 let uri = Uri::from_parts(uri_parts)?;
462 Channel::builder(uri).connect().await?
463 } else {
464 return Err(anyhow::anyhow!(e));
465 }
466 }
467 };
468 Ok(chan)
469 }
470}
471
472impl DialBuilder<WithoutCredentials> {
473 fn clone(&self) -> Self {
474 DialBuilder {
475 state: WithoutCredentials(()),
476 config: DialOptions {
477 credentials: None,
478 webrtc_options: self.config.webrtc_options.clone(),
479 uri: self.duplicate_uri(),
480 disable_mdns: self.config.disable_mdns,
481 allow_downgrade: self.config.allow_downgrade,
482 insecure: self.config.insecure,
483 },
484 }
485 }
486
487 async fn connect_inner(
489 self,
490 mdns_uri: Option<Parts>,
491 mut original_uri_parts: Parts,
492 ) -> Result<ViamChannel> {
493 let webrtc_options = self.config.webrtc_options;
494 let disable_webrtc = match &webrtc_options {
495 Some(options) => options.disable_webrtc,
496 None => false,
497 };
498 if self.config.insecure {
499 original_uri_parts.scheme = Some(Scheme::HTTP);
500 }
501 let original_uri = Uri::from_parts(original_uri_parts)?;
502 let uri2 = original_uri.clone();
503 let uri = infer_remote_uri_from_authority(original_uri);
504 let domain = uri2.authority().to_owned().unwrap().as_str();
505
506 let mdns_uri = mdns_uri.and_then(|p| Uri::from_parts(p).ok());
507 let attempting_mdns = mdns_uri.is_some();
508 if attempting_mdns {
509 log::debug!("Attempting to connect via mDNS");
510 } else {
511 log::debug!("Attempting to connect");
512 }
513
514 let channel = match mdns_uri {
515 Some(uri) => Self::create_channel(self.config.allow_downgrade, domain, uri, true).await,
516 None => Err(anyhow::anyhow!("")),
519 };
520
521 let channel = match channel {
522 Ok(c) => {
523 log::debug!("Connected via mDNS");
524 c
525 }
526 Err(e) => {
527 if attempting_mdns {
528 log::debug!(
529 "Unable to connect via mDNS; falling back to robot URI. Error: {e}"
530 );
531 }
532 Self::create_channel(self.config.allow_downgrade, domain, uri.clone(), false)
533 .await?
534 }
535 };
536 let intercepted_channel = ServiceBuilder::new()
539 .layer(AddAuthorizationLayer::basic(
540 "fake username",
541 "fake password",
542 ))
543 .layer(SetRequestHeaderLayer::overriding(
544 HeaderName::from_static("rpc-host"),
545 HeaderValue::from_str(domain)?,
546 ))
547 .service(channel.clone());
548
549 if disable_webrtc {
550 log::debug!("{}", log_prefixes::DIALED_GRPC);
551 Ok(ViamChannel::Direct(channel.clone()))
552 } else {
553 match maybe_connect_via_webrtc(uri, intercepted_channel.clone(), webrtc_options).await {
554 Ok(webrtc_channel) => Ok(ViamChannel::WebRTC(webrtc_channel)),
555 Err(e) => {
556 log::error!("error connecting via webrtc: {e}. Attempting to connect directly");
557 log::debug!("{}", log_prefixes::DIALED_GRPC);
558 Ok(ViamChannel::Direct(channel.clone()))
559 }
560 }
561 }
562 }
563
564 async fn connect_mdns(self, original_uri: Parts) -> Result<ViamChannel> {
565 let mdns_uri =
566 webrtc::action_with_timeout(self.get_mdns_uri(), Duration::from_millis(1500))
567 .await
568 .ok()
569 .flatten()
570 .ok_or(anyhow::anyhow!(
571 "Unable to establish connection via mDNS; uri not found"
572 ))?;
573
574 self.connect_inner(Some(mdns_uri), original_uri).await
575 }
576
577 pub async fn connect(self) -> Result<ViamChannel> {
578 log::debug!("{}", log_prefixes::DIAL_ATTEMPT);
579 let original_uri = self.duplicate_uri().ok_or(anyhow::anyhow!(
580 "Attempting to connect but there was no uri"
581 ))?;
582 let original_uri2 = duplicate_uri(&original_uri).ok_or(anyhow::anyhow!(
583 "Attempting to connect but there was no uri"
584 ))?;
585 tokio::pin! {
592 let with_mdns = self.clone().connect_mdns(original_uri);
593 let without_mdns = self.connect_inner(None, original_uri2);
594 }
595 let mut with_mdns_err: Option<anyhow::Error> = None;
596 let mut without_mdns_err: Option<anyhow::Error> = None;
597 while with_mdns_err.is_none() || without_mdns_err.is_none() {
598 tokio::select! {
599 with_mdns = &mut with_mdns, if with_mdns_err.is_none() => {
600 match with_mdns {
601 Ok(chan) => return Ok(chan),
602 Err(e) => {
603 log::debug!("Error connecting with mdns: {e}");
604 with_mdns_err = Some(e);
605 }
606 }
607 }
608 without_mdns = &mut without_mdns, if without_mdns_err.is_none() => {
609 match without_mdns {
610 Ok(chan) => return Ok(chan),
611 Err(e) => {
612 log::debug!("Error connecting without mdns: {e}");
613 without_mdns_err = Some(e);
614 }
615 }
616 }
617 }
618 }
619 Err(anyhow::anyhow!(
620 "Unable to connect with or without mdns.
621 with_mdns err: {with_mdns_err:?}
622 without_mdns err: {without_mdns_err:?}"
623 ))
624 }
625}
626
627async fn get_auth_token(
628 channel: &mut Channel,
629 creds: Credentials,
630 entity: String,
631) -> Result<String> {
632 let mut auth_service = AuthServiceClient::new(channel);
633 let req = AuthenticateRequest {
634 entity,
635 credentials: Some(creds),
636 };
637
638 let rsp = auth_service.authenticate(req).await?;
639 Ok(rsp.into_inner().access_token)
640}
641
642impl DialBuilder<WithCredentials> {
643 fn clone(&self) -> Self {
644 DialBuilder {
645 state: WithCredentials(()),
646 config: DialOptions {
647 credentials: self.config.credentials.clone(),
648 webrtc_options: self.config.webrtc_options.clone(),
649 uri: self.duplicate_uri(),
650 disable_mdns: self.config.disable_mdns,
651 allow_downgrade: self.config.allow_downgrade,
652 insecure: self.config.insecure,
653 },
654 }
655 }
656
657 async fn connect_inner(
658 self,
659 mdns_uri: Option<Parts>,
660 mut original_uri_parts: Parts,
661 ) -> Result<ViamChannel> {
662 let is_insecure = self.config.insecure;
663
664 let webrtc_options = self.config.webrtc_options;
665 let disable_webrtc = match &webrtc_options {
666 Some(options) => options.disable_webrtc,
667 None => false,
668 };
669
670 if is_insecure {
671 original_uri_parts.scheme = Some(Scheme::HTTP);
672 }
673
674 let original_uri = Uri::from_parts(original_uri_parts)?;
675
676 let domain = original_uri.authority().unwrap().to_string();
677 let uri_for_auth = infer_remote_uri_from_authority(original_uri.clone());
678
679 let mdns_uri = mdns_uri.and_then(|p| Uri::from_parts(p).ok());
680 let attempting_mdns = mdns_uri.is_some();
681
682 let allow_downgrade = self.config.allow_downgrade;
683 if attempting_mdns {
684 log::debug!("Attempting to connect via mDNS");
685 } else {
686 log::debug!("Attempting to connect");
687 }
688 let channel = match mdns_uri {
689 Some(uri) => Self::create_channel(allow_downgrade, &domain, uri, true).await,
690 None => Err(anyhow::anyhow!("")),
693 };
694 let real_channel = match channel {
695 Ok(c) => {
696 log::debug!("Connected via mDNS");
697 c
698 }
699 Err(e) => {
700 if attempting_mdns {
701 log::debug!(
702 "Unable to connect via mDNS; falling back to robot URI. Error: {e}"
703 );
704 }
705 Self::create_channel(allow_downgrade, &domain, uri_for_auth, false).await?
706 }
707 };
708
709 log::debug!("{}", log_prefixes::ACQUIRING_AUTH_TOKEN);
710 let token = get_auth_token(
711 &mut real_channel.clone(),
712 self.config
713 .credentials
714 .as_ref()
715 .unwrap()
716 .credentials
717 .clone(),
718 self.config
719 .credentials
720 .unwrap()
721 .entity
722 .unwrap_or_else(|| domain.clone()),
723 )
724 .await?;
725 log::debug!("{}", log_prefixes::ACQUIRED_AUTH_TOKEN);
726
727 let channel = ServiceBuilder::new()
728 .layer(AddAuthorizationLayer::bearer(&token))
729 .layer(SetRequestHeaderLayer::overriding(
730 HeaderName::from_static("rpc-host"),
731 HeaderValue::from_str(domain.as_str())?,
732 ))
733 .service(real_channel);
734
735 if disable_webrtc {
736 log::debug!("Connected via gRPC");
737 Ok(ViamChannel::DirectPreAuthorized(channel))
738 } else {
739 match maybe_connect_via_webrtc(original_uri, channel.clone(), webrtc_options).await {
740 Ok(webrtc_channel) => Ok(ViamChannel::WebRTC(webrtc_channel)),
741 Err(e) => {
742 log::error!(
743 "Unable to establish webrtc connection due to error: [{e}]. Attempting direct connection."
744 );
745 log::debug!("Connected via gRPC");
746 Ok(ViamChannel::DirectPreAuthorized(channel))
747 }
748 }
749 }
750 }
751
752 async fn connect_mdns(self, original_uri: Parts) -> Result<ViamChannel> {
753 let mdns_uri =
758 webrtc::action_with_timeout(self.get_mdns_uri(), Duration::from_millis(1500))
759 .await
760 .ok()
761 .flatten()
762 .ok_or(anyhow::anyhow!(
763 "Unable to establish connection via mDNS; uri not found"
764 ))?;
765
766 self.connect_inner(Some(mdns_uri), original_uri).await
767 }
768
769 pub async fn connect(self) -> Result<ViamChannel> {
771 log::debug!("{}", log_prefixes::DIAL_ATTEMPT);
772 let original_uri = self.duplicate_uri().ok_or(anyhow::anyhow!(
773 "Attempting to connect but there was no uri"
774 ))?;
775 let original_uri2 = duplicate_uri(&original_uri).ok_or(anyhow::anyhow!(
776 "Attempting to connect but there was no uri"
777 ))?;
778
779 tokio::pin! {
786 let with_mdns = self.clone().connect_mdns(original_uri);
787 let without_mdns = self.connect_inner(None, original_uri2);
788 }
789 let mut with_mdns_err: Option<anyhow::Error> = None;
790 let mut without_mdns_err: Option<anyhow::Error> = None;
791 while with_mdns_err.is_none() || without_mdns_err.is_none() {
792 tokio::select! {
793 with_mdns = &mut with_mdns, if with_mdns_err.is_none() => {
794 match with_mdns {
795 Ok(chan) => return Ok(chan),
796 Err(e) => {
797 log::debug!("Error connecting with mdns: {e}");
798 with_mdns_err = Some(e);
799 }
800 }
801 }
802 without_mdns = &mut without_mdns, if without_mdns_err.is_none() => {
803 match without_mdns {
804 Ok(chan) => return Ok(chan),
805 Err(e) => {
806 log::debug!("Error connecting without mdns: {e}");
807 without_mdns_err = Some(e);
808 }
809 }
810 }
811 }
812 }
813 Err(anyhow::anyhow!(
814 "Unable to connect with or without mdns.
815 with_mdns err: {with_mdns_err:?}
816 without_mdns err: {without_mdns_err:?}"
817 ))
818 }
819}
820
821async fn send_done_or_error_update(
822 update: CallUpdateRequest,
823 channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
824) {
825 let mut signaling_client = SignalingServiceClient::new(channel.clone());
826
827 if let Err(e) = signaling_client
828 .call_update(update)
829 .await
830 .map_err(anyhow::Error::from)
831 .map(|_| ())
832 {
833 log::error!("Error sending done or error update: {e}")
834 }
835}
836
837async fn send_error_once(
838 sent_error: Arc<AtomicBool>,
839 uuid: &String,
840 err: &anyhow::Error,
841 channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
842) {
843 if sent_error.load(Ordering::Acquire) {
844 return;
845 }
846
847 let err = google::rpc::Status {
848 code: google::rpc::Code::Unknown.into(),
849 message: err.to_string(),
850 details: Vec::new(),
851 };
852 sent_error.store(true, Ordering::Release);
853 let update_request = CallUpdateRequest {
854 uuid: uuid.to_string(),
855 update: Some(Update::Error(err)),
856 };
857
858 send_done_or_error_update(update_request, channel).await
859}
860
861async fn send_done_once(
862 sent_done: Arc<AtomicBool>,
863 uuid: &String,
864 channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
865) {
866 if sent_done.load(Ordering::Acquire) {
867 return;
868 }
869 sent_done.store(true, Ordering::Release);
870 let update_request = CallUpdateRequest {
871 uuid: uuid.to_string(),
872 update: Some(Update::Done(true)),
873 };
874
875 send_done_or_error_update(update_request, channel).await
876}
877
878#[derive(Default)]
879struct CallerUpdateStats {
880 count: u128,
881 total_duration: Duration,
882 max_duration: Duration,
883}
884
885impl fmt::Display for CallerUpdateStats {
886 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
887 let average_duration = &self.total_duration.as_millis() / &self.count;
888 writeln!(
889 f,
890 "Caller update statistics: num_updates: {}, average_duration: {}ms, max_duration: {}ms",
891 &self.count,
892 average_duration,
893 &self.max_duration.as_millis()
894 )?;
895 Ok(())
896 }
897}
898
899async fn maybe_connect_via_webrtc(
900 uri: Uri,
901 channel: AddAuthorization<SetRequestHeader<Channel, HeaderValue>>,
902 webrtc_options: Option<Options>,
903) -> Result<Arc<WebRTCClientChannel>> {
904 let webrtc_options = webrtc_options.unwrap_or_else(|| Options::infer_from_uri(uri.clone()));
905 let mut signaling_client = SignalingServiceClient::new(channel.clone());
906 let response = match signaling_client
907 .optional_web_rtc_config(OptionalWebRtcConfigRequest::default())
908 .await
909 {
910 Ok(resp) => resp,
911 Err(e) => {
912 if e.code() == tonic::Code::Unimplemented {
913 tonic::Response::new(OptionalWebRtcConfigResponse::default())
914 } else {
915 return Err(anyhow::anyhow!(e));
916 }
917 }
918 };
919
920 let optional_config = response.into_inner().config;
921 let config = webrtc::extend_webrtc_config(webrtc_options.config, optional_config);
922
923 let (peer_connection, data_channel) =
924 webrtc::new_peer_connection_for_client(config, webrtc_options.disable_trickle_ice).await?;
925
926 let sent_done_or_error = Arc::new(AtomicBool::new(false));
927 let uuid_lock = Arc::new(RwLock::new("".to_string()));
928 let uuid_for_ice_gathering_thread = uuid_lock.clone();
929
930 let (is_open_s, mut is_open_r) = mpsc::channel(1);
934 let on_open_is_open = is_open_s.clone();
935
936 data_channel.on_open(Box::new(move || {
937 let _ = on_open_is_open.try_send(None); Box::pin(async move {})
939 }));
940
941 let exchange_done = Arc::new(AtomicBool::new(false));
942 let (remote_description_set_s, remote_description_set_r) = watch::channel(None);
943 let ice_done = Arc::new(tokio::sync::Notify::new());
944 let ice_done2 = ice_done.clone();
945 let caller_update_stats = Arc::new(Mutex::new(CallerUpdateStats::default()));
946
947 if !webrtc_options.disable_trickle_ice {
948 let offer = peer_connection.create_offer(None).await?;
949 let channel2 = channel.clone();
950 let uuid_lock2 = uuid_lock.clone();
951 let sent_done_or_error2 = sent_done_or_error.clone();
952
953 let exchange_done = exchange_done.clone();
954
955 let on_local_ice_candidate_failure = is_open_s.clone();
956
957 let caller_update_stats = caller_update_stats.clone();
958 let caller_update_stats2 = caller_update_stats.clone();
959 peer_connection.on_ice_connection_state_change(Box::new(
960 move |state: RTCIceConnectionState| {
961 let caller_update_stats = caller_update_stats.clone();
962 Box::pin(async move {
963 if state == RTCIceConnectionState::Completed {
964 let caller_update_stats_inner = caller_update_stats.lock().unwrap();
965 log::debug!("{}", caller_update_stats_inner);
966 }
967 })
968 },
969 ));
970 peer_connection.on_ice_candidate(Box::new(
971 move |ice_candidate: Option<RTCIceCandidate>| {
972 if exchange_done.load(Ordering::Acquire) {
973 return Box::pin(async move {});
974 }
975 let channel = channel2.clone();
976 let sent_done_or_error = sent_done_or_error2.clone();
977 let ice_done = ice_done.clone();
978 let uuid_lock = uuid_lock2.clone();
979 let on_local_ice_candidate_failure = on_local_ice_candidate_failure.clone();
980 let mut remote_description_set_r = remote_description_set_r.clone();
981 let caller_update_stats = caller_update_stats2.clone();
982 Box::pin(async move {
983 if remote_description_set_r.borrow().is_none() {
987 match webrtc_action_with_timeout(remote_description_set_r.changed()).await {
988 Ok(Err(e)) => {
989 let _ = on_local_ice_candidate_failure.try_send(Some(Box::new(
990 anyhow::anyhow!(
991 "remote description watch channel is closed with error {e}"
992 ),
993 )));
994 }
995 Err(_) => {
996 log::info!(
997 "timed out on_ice_candidate; remote description was never set"
998 );
999 let _ = on_local_ice_candidate_failure.try_send(Some(Box::new(
1000 anyhow::anyhow!("timed out waiting for remote description"),
1001 )));
1002 }
1003 _ => (),
1004 }
1005 }
1006
1007 let uuid = uuid_lock.read().unwrap().to_string();
1008 if uuid.is_empty() {
1020 log::debug!(
1021 "UUID never updated. This is likely because we never received a response \
1022 from the signaling client. This happens occasionally with parallel dialing \
1023 and isn't concerning provided connection still occurs."
1024 );
1025 return;
1026 }
1027 let mut signaling_client = SignalingServiceClient::new(channel.clone());
1028 match ice_candidate {
1029 Some(ice_candidate) => {
1030 log::debug!("Gathered local candidate of {ice_candidate}");
1031 if sent_done_or_error.load(Ordering::Acquire) {
1032 return;
1033 }
1034 let proto_candidate = ice_candidate_to_proto(ice_candidate).await;
1035 match proto_candidate {
1036 Ok(proto_candidate) => {
1037 let update_request = CallUpdateRequest {
1038 uuid: uuid.clone(),
1039 update: Some(Update::Candidate(proto_candidate)),
1040 };
1041 let call_update_start = Instant::now();
1042 if let Err(e) = webrtc_action_with_timeout(
1043 signaling_client.call_update(update_request),
1044 )
1045 .await
1046 .and_then(|resp| resp.map_err(anyhow::Error::from))
1047 {
1048 log::error!("Error sending ice candidate: {e}");
1049 let _ = on_local_ice_candidate_failure.try_send(Some(
1050 Box::new(anyhow::anyhow!(
1051 "Error sending ice candidate: {e}"
1052 )),
1053 ));
1054 }
1055 let mut caller_update_stats_inner =
1056 caller_update_stats.lock().unwrap();
1057 caller_update_stats_inner.count += 1;
1058 let call_update_duration = call_update_start.elapsed();
1059 if call_update_duration > caller_update_stats_inner.max_duration
1060 {
1061 caller_update_stats_inner.max_duration =
1062 call_update_duration;
1063 }
1064 caller_update_stats_inner.total_duration +=
1065 call_update_duration;
1066 }
1067 Err(e) => log::error!("Error parsing ice candidate: {e}"),
1068 }
1069 }
1070 None => {
1071 ice_done.notify_one();
1073 send_done_once(sent_done_or_error, &uuid, channel.clone()).await;
1074 }
1075 }
1076 })
1077 },
1078 ));
1079
1080 peer_connection.set_local_description(offer).await?;
1081 }
1082
1083 let local_description = peer_connection.local_description().await.unwrap();
1084
1085 log::debug!(
1087 "{}\n{}",
1088 log_prefixes::START_LOCAL_SESSION_DESCRIPTION,
1089 local_description.sdp
1090 );
1091 log::debug!("{}", log_prefixes::END_LOCAL_SESSION_DESCRIPTION);
1092
1093 let sdp = encode_sdp(local_description)?;
1094 let call_request = CallRequest {
1095 sdp,
1096 disable_trickle: webrtc_options.disable_trickle_ice,
1097 };
1098
1099 let client_channel = WebRTCClientChannel::new(peer_connection, data_channel).await;
1100 let client_channel_for_ice_gathering_thread = Arc::downgrade(&client_channel);
1101 let mut signaling_client = SignalingServiceClient::new(channel.clone());
1102 let mut call_client = signaling_client.call(call_request).await?.into_inner();
1103
1104 let channel2 = channel.clone();
1105 let sent_done_or_error2 = sent_done_or_error.clone();
1106 tokio::spawn(async move {
1107 let uuid = uuid_for_ice_gathering_thread;
1108 let client_channel = client_channel_for_ice_gathering_thread;
1109 let init_received = AtomicBool::new(false);
1110 let sent_done = sent_done_or_error2;
1111
1112 loop {
1113 let response = match webrtc_action_with_timeout(call_client.message())
1114 .await
1115 .and_then(|resp| resp.map_err(anyhow::Error::from))
1116 {
1117 Ok(cr) => match cr {
1118 Some(cr) => cr,
1119 None => {
1120 let _ = webrtc_action_with_timeout(ice_done2.notified()).await;
1123 let uuid = uuid.read().unwrap().to_string();
1124 send_done_once(sent_done.clone(), &uuid, channel2.clone()).await;
1125 break;
1126 }
1127 },
1128 Err(e) => {
1129 log::error!("Error processing call response: {e}");
1130 let _ = is_open_s.try_send(Some(Box::new(e)));
1131 break;
1132 }
1133 };
1134
1135 match response.stage {
1136 Some(Stage::Init(init)) => {
1137 if init_received.load(Ordering::Acquire) {
1138 let uuid = uuid.read().unwrap().to_string();
1139 let e = anyhow::anyhow!("Init received more than once");
1140 send_error_once(sent_done.clone(), &uuid, &e, channel2.clone()).await;
1141 let _ = is_open_s.try_send(Some(Box::new(e)));
1142 break;
1143 }
1144 init_received.store(true, Ordering::Release);
1145 {
1146 let mut uuid_s = uuid.write().unwrap();
1147 uuid_s.clone_from(&response.uuid);
1148 }
1149
1150 let answer = match decode_sdp(init.sdp) {
1151 Ok(a) => a,
1152 Err(e) => {
1153 send_error_once(
1154 sent_done.clone(),
1155 &response.uuid,
1156 &e,
1157 channel2.clone(),
1158 )
1159 .await;
1160 let _ = is_open_s.try_send(Some(Box::new(e)));
1161 break;
1162 }
1163 };
1164 {
1165 let cc = match client_channel.upgrade() {
1166 Some(cc) => cc,
1167 None => {
1168 break;
1169 }
1170 };
1171 if let Err(e) = cc
1172 .base_channel
1173 .peer_connection
1174 .set_remote_description(answer)
1175 .await
1176 {
1177 let e = anyhow::Error::from(e);
1178 send_error_once(
1179 sent_done.clone(),
1180 &response.uuid,
1181 &e,
1182 channel2.clone(),
1183 )
1184 .await;
1185 let _ = is_open_s.try_send(Some(Box::new(e)));
1186 break;
1187 }
1188 }
1189 let _ = remote_description_set_s.send_replace(Some(()));
1190 if webrtc_options.disable_trickle_ice {
1191 send_done_once(sent_done.clone(), &response.uuid, channel2.clone()).await;
1192 break;
1193 }
1194 }
1195
1196 Some(Stage::Update(update)) => {
1197 let uuid_s = uuid.read().unwrap().to_string();
1198 if !init_received.load(Ordering::Acquire) {
1199 let e = anyhow::anyhow!("Got update before init stage");
1200 send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone()).await;
1201 let _ = is_open_s.try_send(Some(Box::new(e)));
1202 break;
1203 }
1204
1205 if response.uuid != *uuid.read().unwrap() {
1206 let e = anyhow::anyhow!(
1207 "uuid mismatch: have {}, want {}",
1208 response.uuid,
1209 uuid_s,
1210 );
1211 send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone()).await;
1212 let _ = is_open_s.try_send(Some(Box::new(e)));
1213 break;
1214 }
1215 match ice_candidate_from_proto(update.candidate) {
1216 Ok(candidate) => {
1217 let client_channel = match client_channel.upgrade() {
1218 Some(cc) => cc,
1219 None => {
1220 break;
1221 }
1222 };
1223 log::debug!("Received remote ICE candidate of {candidate:#?}");
1224 if let Err(e) = client_channel
1225 .base_channel
1226 .peer_connection
1227 .add_ice_candidate(candidate)
1228 .await
1229 {
1230 let e = anyhow::Error::from(e);
1231 send_error_once(sent_done.clone(), &uuid_s, &e, channel2.clone())
1232 .await;
1233 let _ = is_open_s.try_send(Some(Box::new(e)));
1234 break;
1235 }
1236 }
1237 Err(e) => log::error!("Error parsing ice candidate: {e}"),
1238 }
1239 }
1240 None => continue,
1241 }
1242 }
1243 });
1244
1245 let is_open = webrtc_action_with_timeout(is_open_r.recv()).await;
1249 match is_open {
1250 Ok(is_open) => {
1251 if let Some(Some(e)) = is_open {
1252 return Err(anyhow::anyhow!("Couldn't connect to peer with error {e}"));
1253 }
1254 }
1255 Err(_) => {
1256 return Err(anyhow::anyhow!("Timed out opening data channel."));
1257 }
1258 }
1259
1260 exchange_done.store(true, Ordering::Release);
1261 let uuid = uuid_lock.read().unwrap().to_string();
1262 send_done_once(sent_done_or_error, &uuid, channel.clone()).await;
1263 Ok(client_channel)
1264}
1265
1266async fn ice_candidate_to_proto(ice_candidate: RTCIceCandidate) -> Result<IceCandidate> {
1267 let ice_candidate = ice_candidate.to_json()?;
1268 Ok(IceCandidate {
1269 candidate: ice_candidate.candidate,
1270 sdp_mid: ice_candidate.sdp_mid,
1271 sdpm_line_index: ice_candidate.sdp_mline_index.map(u32::from),
1272 username_fragment: ice_candidate.username_fragment,
1273 })
1274}
1275
1276fn ice_candidate_from_proto(proto: Option<IceCandidate>) -> Result<RTCIceCandidateInit> {
1277 match proto {
1278 Some(proto) => {
1279 let proto_sdpm: usize = proto.sdpm_line_index().try_into()?;
1280 let sdp_mline_index: Option<u16> = proto_sdpm.try_into().ok();
1281
1282 Ok(RTCIceCandidateInit {
1283 candidate: proto.candidate.clone(),
1284 sdp_mid: Some(proto.sdp_mid().to_string()),
1285 sdp_mline_index,
1286 username_fragment: Some(proto.username_fragment().to_string()),
1287 })
1288 }
1289 None => Err(anyhow::anyhow!("No ice candidate provided")),
1290 }
1291}
1292
1293fn decode_sdp(sdp: String) -> Result<RTCSessionDescription> {
1294 let sdp = String::from_utf8(base64::decode(sdp)?)?;
1295 Ok(serde_json::from_str::<RTCSessionDescription>(&sdp)?)
1296}
1297
1298fn encode_sdp(sdp: RTCSessionDescription) -> Result<String> {
1299 let sdp = serde_json::to_vec(&sdp)?;
1300 Ok(base64::encode(sdp))
1301}
1302
1303fn infer_remote_uri_from_authority(uri: Uri) -> Uri {
1304 let authority = uri.authority().map(Authority::as_str).unwrap_or_default();
1305 let is_local_connection = authority.contains(".local.viam.cloud")
1306 || authority.contains("localhost")
1307 || authority.contains("0.0.0.0");
1308
1309 if !is_local_connection {
1310 if let Some((new_uri, _)) = Options::infer_signaling_server_address(&uri) {
1311 return Uri::from_parts(uri_parts_with_defaults(&new_uri)).unwrap_or(uri);
1312 }
1313 }
1314 uri
1315}
1316
1317fn duplicate_uri(parts: &Parts) -> Option<Parts> {
1318 let uri = Uri::builder()
1319 .authority(parts.authority.clone()?)
1320 .path_and_query(parts.path_and_query.clone()?)
1321 .scheme(parts.scheme.clone()?);
1322 Some(uri.build().ok()?.into_parts())
1323}
1324
1325fn uri_parts_with_defaults(uri: &str) -> Parts {
1326 let mut uri_parts = uri.parse::<Uri>().unwrap().into_parts();
1327 uri_parts.scheme = Some(Scheme::HTTPS);
1328 uri_parts.path_and_query = Some(PathAndQuery::from_static(""));
1329 uri_parts
1330}
1331
1332fn metadata_from_parts(parts: &http::request::Parts) -> Metadata {
1333 let mut md = HashMap::new();
1334 for (k, v) in parts.headers.iter() {
1335 let k = k.to_string();
1336 let v = Strings {
1337 values: vec![HeaderValue::to_str(v).unwrap().to_string()],
1338 };
1339 md.insert(k, v);
1340 }
1341 Metadata { md }
1342}