1#![warn(missing_docs)] #[macro_use]
8extern crate tracing;
9
10mod metrics;
11mod proxy;
12mod raw;
13mod retry;
14mod worker_registry;
15mod workflow_handle;
16
17pub use crate::{
18 proxy::HttpConnectProxyOptions,
19 retry::{CallType, RETRYABLE_ERROR_CODES, RetryClient},
20};
21pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
22pub use raw::{CloudService, HealthService, OperatorService, TestService, WorkflowService};
23pub use squads_temporal_sdk_core_protos::temporal::api::{
24 enums::v1::ArchivalState,
25 filter::v1::{StartTimeFilter, StatusFilter, WorkflowExecutionFilter, WorkflowTypeFilter},
26 workflowservice::v1::{
27 list_closed_workflow_executions_request::Filters as ListClosedFilters,
28 list_open_workflow_executions_request::Filters as ListOpenFilters,
29 },
30};
31pub use tonic;
32pub use worker_registry::{Slot, SlotManager, SlotProvider, WorkerKey};
33pub use workflow_handle::{
34 GetWorkflowResultOpts, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
35};
36
37use crate::{
38 metrics::{GrpcMetricSvc, MetricsContext},
39 raw::{AttachMetricLabels, sealed::RawClientLike},
40 sealed::WfHandleClient,
41 workflow_handle::UntypedWorkflowHandle,
42};
43use backoff::{ExponentialBackoff, SystemClock, exponential};
44use http::{Uri, uri::InvalidUri};
45use parking_lot::RwLock;
46use std::{
47 collections::HashMap,
48 fmt::{Debug, Formatter},
49 ops::{Deref, DerefMut},
50 str::FromStr,
51 sync::{Arc, OnceLock},
52 time::{Duration, Instant},
53};
54use squads_temporal_sdk_core_api::telemetry::metrics::TemporalMeter;
55use squads_temporal_sdk_core_protos::{
56 TaskToken,
57 coresdk::IntoPayloadsExt,
58 grpc::health::v1::health_client::HealthClient,
59 temporal::api::{
60 cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
61 common,
62 common::v1::{Header, Payload, Payloads, RetryPolicy, WorkflowExecution, WorkflowType},
63 enums::v1::{TaskQueueKind, WorkflowIdConflictPolicy, WorkflowIdReusePolicy},
64 operatorservice::v1::operator_service_client::OperatorServiceClient,
65 query::v1::WorkflowQuery,
66 replication::v1::ClusterReplicationConfig,
67 taskqueue::v1::TaskQueue,
68 testservice::v1::test_service_client::TestServiceClient,
69 update,
70 workflowservice::v1::{workflow_service_client::WorkflowServiceClient, *},
71 },
72};
73use tonic::{
74 Code,
75 body::Body,
76 client::GrpcService,
77 codegen::InterceptedService,
78 metadata::{MetadataKey, MetadataMap, MetadataValue},
79 service::Interceptor,
80 transport::{Certificate, Channel, Endpoint, Identity},
81};
82use tower::ServiceBuilder;
83use url::Url;
84use uuid::Uuid;
85
86static CLIENT_NAME_HEADER_KEY: &str = "client-name";
87static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
88static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
89
90pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
92
93const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
95const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
96
97type Result<T, E = tonic::Status> = std::result::Result<T, E>;
98
99#[derive(Clone, Debug, derive_builder::Builder)]
101#[non_exhaustive]
102pub struct ClientOptions {
103 #[builder(setter(into))]
105 pub target_url: Url,
106
107 #[builder(setter(into))]
110 pub client_name: String,
111
112 #[builder(setter(into))]
115 pub client_version: String,
116
117 #[builder(default)]
119 pub identity: String,
120
121 #[builder(setter(strip_option), default)]
125 pub tls_cfg: Option<TlsConfig>,
126
127 #[builder(default)]
129 pub retry_config: RetryConfig,
130
131 #[builder(default)]
137 pub override_origin: Option<Uri>,
138
139 #[builder(default = "Some(ClientKeepAliveConfig::default())")]
141 pub keep_alive: Option<ClientKeepAliveConfig>,
142
143 #[builder(default)]
145 pub headers: Option<HashMap<String, String>>,
146
147 #[builder(default)]
150 pub api_key: Option<String>,
151
152 #[builder(default)]
154 pub http_connect_proxy: Option<HttpConnectProxyOptions>,
155
156 #[builder(default)]
158 pub disable_error_code_metric_tags: bool,
159
160 #[builder(default)]
162 pub skip_get_system_info: bool,
163}
164
165#[derive(Clone, Debug, Default)]
167pub struct TlsConfig {
168 pub server_root_ca_cert: Option<Vec<u8>>,
172 pub domain: Option<String>,
175 pub client_tls_config: Option<ClientTlsConfig>,
177}
178
179#[derive(Clone)]
181pub struct ClientTlsConfig {
182 pub client_cert: Vec<u8>,
184 pub client_private_key: Vec<u8>,
186}
187
188#[derive(Clone, Debug)]
190pub struct ClientKeepAliveConfig {
191 pub interval: Duration,
193 pub timeout: Duration,
195}
196
197impl Default for ClientKeepAliveConfig {
198 fn default() -> Self {
199 Self {
200 interval: Duration::from_secs(30),
201 timeout: Duration::from_secs(15),
202 }
203 }
204}
205
206#[derive(Clone, Debug, PartialEq)]
208pub struct RetryConfig {
209 pub initial_interval: Duration,
211 pub randomization_factor: f64,
214 pub multiplier: f64,
216 pub max_interval: Duration,
218 pub max_elapsed_time: Option<Duration>,
221 pub max_retries: usize,
223}
224
225impl Default for RetryConfig {
226 fn default() -> Self {
227 Self {
228 initial_interval: Duration::from_millis(100), randomization_factor: 0.2, multiplier: 1.7, max_interval: Duration::from_secs(5), max_elapsed_time: Some(Duration::from_secs(10)), max_retries: 10,
234 }
235 }
236}
237
238impl RetryConfig {
239 pub(crate) const fn task_poll_retry_policy() -> Self {
240 Self {
241 initial_interval: Duration::from_millis(200),
242 randomization_factor: 0.2,
243 multiplier: 2.0,
244 max_interval: Duration::from_secs(10),
245 max_elapsed_time: None,
246 max_retries: 0,
247 }
248 }
249
250 pub(crate) const fn throttle_retry_policy() -> Self {
251 Self {
252 initial_interval: Duration::from_secs(1),
253 randomization_factor: 0.2,
254 multiplier: 2.0,
255 max_interval: Duration::from_secs(10),
256 max_elapsed_time: None,
257 max_retries: 0,
258 }
259 }
260
261 pub const fn no_retries() -> Self {
263 Self {
264 initial_interval: Duration::from_secs(0),
265 randomization_factor: 0.0,
266 multiplier: 1.0,
267 max_interval: Duration::from_secs(0),
268 max_elapsed_time: None,
269 max_retries: 1,
270 }
271 }
272
273 pub(crate) fn into_exp_backoff<C>(self, clock: C) -> exponential::ExponentialBackoff<C> {
274 exponential::ExponentialBackoff {
275 current_interval: self.initial_interval,
276 initial_interval: self.initial_interval,
277 randomization_factor: self.randomization_factor,
278 multiplier: self.multiplier,
279 max_interval: self.max_interval,
280 max_elapsed_time: self.max_elapsed_time,
281 clock,
282 start_time: Instant::now(),
283 }
284 }
285}
286
287impl From<RetryConfig> for ExponentialBackoff {
288 fn from(c: RetryConfig) -> Self {
289 c.into_exp_backoff(SystemClock::default())
290 }
291}
292
293#[derive(Copy, Clone, Debug)]
296pub struct IsWorkerTaskLongPoll;
297
298#[derive(Copy, Clone, Debug)]
302pub struct NoRetryOnMatching {
303 pub predicate: fn(&tonic::Status) -> bool,
305}
306
307impl Debug for ClientTlsConfig {
308 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
310 write!(f, "ClientTlsConfig(..)")
311 }
312}
313
314#[derive(thiserror::Error, Debug)]
316pub enum ClientInitError {
317 #[error("Invalid URI: {0:?}")]
319 InvalidUri(#[from] InvalidUri),
320 #[error("Server connection error: {0:?}")]
322 TonicTransportError(#[from] tonic::transport::Error),
323 #[error("`get_system_info` call error after connection: {0:?}")]
326 SystemInfoCallError(tonic::Status),
327}
328
329#[derive(Clone, Debug)]
332pub struct ConfiguredClient<C> {
333 client: C,
334 options: Arc<ClientOptions>,
335 headers: Arc<RwLock<ClientHeaders>>,
336 capabilities: Option<get_system_info_response::Capabilities>,
338 workers: Arc<SlotManager>,
339}
340
341impl<C> ConfiguredClient<C> {
342 pub fn set_headers(&self, headers: HashMap<String, String>) {
344 self.headers.write().user_headers = headers;
345 }
346
347 pub fn set_api_key(&self, api_key: Option<String>) {
349 self.headers.write().api_key = api_key;
350 }
351
352 pub fn options(&self) -> &ClientOptions {
354 &self.options
355 }
356
357 pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
360 self.capabilities.as_ref()
361 }
362
363 pub fn workers(&self) -> Arc<SlotManager> {
365 self.workers.clone()
366 }
367}
368
369#[derive(Debug)]
370struct ClientHeaders {
371 user_headers: HashMap<String, String>,
372 api_key: Option<String>,
373}
374
375impl ClientHeaders {
376 fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
377 for (key, val) in self.user_headers.iter() {
378 if !metadata.contains_key(key) {
380 if let (Ok(key), Ok(val)) = (MetadataKey::from_str(key), val.parse()) {
382 metadata.insert(key, val);
383 }
384 }
385 }
386 if let Some(api_key) = &self.api_key {
387 if !metadata.contains_key("authorization")
389 && let Ok(val) = format!("Bearer {api_key}").parse()
390 {
391 metadata.insert("authorization", val);
392 }
393 }
394 }
395}
396
397impl<C> Deref for ConfiguredClient<C> {
399 type Target = C;
400
401 fn deref(&self) -> &Self::Target {
402 &self.client
403 }
404}
405
406impl<C> DerefMut for ConfiguredClient<C> {
407 fn deref_mut(&mut self) -> &mut Self::Target {
408 &mut self.client
409 }
410}
411
412impl ClientOptions {
413 pub async fn connect(
416 &self,
417 namespace: impl Into<String>,
418 metrics_meter: Option<TemporalMeter>,
419 ) -> Result<RetryClient<Client>, ClientInitError> {
420 let client = self.connect_no_namespace(metrics_meter).await?.into_inner();
421 let client = Client::new(client, namespace.into());
422 let retry_client = RetryClient::new(client, self.retry_config.clone());
423 Ok(retry_client)
424 }
425
426 pub async fn connect_no_namespace(
431 &self,
432 metrics_meter: Option<TemporalMeter>,
433 ) -> Result<RetryClient<ConfiguredClient<TemporalServiceClientWithMetrics>>, ClientInitError>
434 {
435 let channel = Channel::from_shared(self.target_url.to_string())?;
436 let channel = self.add_tls_to_channel(channel).await?;
437 let channel = if let Some(keep_alive) = self.keep_alive.as_ref() {
438 channel
439 .keep_alive_while_idle(true)
440 .http2_keep_alive_interval(keep_alive.interval)
441 .keep_alive_timeout(keep_alive.timeout)
442 } else {
443 channel
444 };
445 let channel = if let Some(origin) = self.override_origin.clone() {
446 channel.origin(origin)
447 } else {
448 channel
449 };
450 let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() {
452 proxy.connect_endpoint(&channel).await?
453 } else {
454 channel.connect().await?
455 };
456 let service = ServiceBuilder::new()
457 .layer_fn(move |channel| GrpcMetricSvc {
458 inner: channel,
459 metrics: metrics_meter.clone().map(MetricsContext::new),
460 disable_errcode_label: self.disable_error_code_metric_tags,
461 })
462 .service(channel);
463 let headers = Arc::new(RwLock::new(ClientHeaders {
464 user_headers: self.headers.clone().unwrap_or_default(),
465 api_key: self.api_key.clone(),
466 }));
467 let interceptor = ServiceCallInterceptor {
468 opts: self.clone(),
469 headers: headers.clone(),
470 };
471 let svc = InterceptedService::new(service, interceptor);
472
473 let mut client = ConfiguredClient {
474 headers,
475 client: TemporalServiceClient::new(svc),
476 options: Arc::new(self.clone()),
477 capabilities: None,
478 workers: Arc::new(SlotManager::new()),
479 };
480 if !self.skip_get_system_info {
481 match client
482 .get_system_info(GetSystemInfoRequest::default())
483 .await
484 {
485 Ok(sysinfo) => {
486 client.capabilities = sysinfo.into_inner().capabilities;
487 }
488 Err(status) => match status.code() {
489 Code::Unimplemented => {}
490 _ => return Err(ClientInitError::SystemInfoCallError(status)),
491 },
492 };
493 }
494 Ok(RetryClient::new(client, self.retry_config.clone()))
495 }
496
497 async fn add_tls_to_channel(&self, mut channel: Endpoint) -> Result<Endpoint, ClientInitError> {
500 if let Some(tls_cfg) = &self.tls_cfg {
501 let mut tls = tonic::transport::ClientTlsConfig::new().with_native_roots();
502
503 if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
504 let server_root_ca_cert = Certificate::from_pem(root_cert);
505 tls = tls.ca_certificate(server_root_ca_cert);
506 }
507
508 if let Some(domain) = &tls_cfg.domain {
509 tls = tls.domain_name(domain);
510
511 let uri: Uri = format!("https://{domain}").parse()?;
516 channel = channel.origin(uri);
517 }
518
519 if let Some(client_opts) = &tls_cfg.client_tls_config {
520 let client_identity =
521 Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
522 tls = tls.identity(client_identity);
523 }
524
525 return channel.tls_config(tls).map_err(Into::into);
526 }
527 Ok(channel)
528 }
529}
530
531#[derive(Clone)]
533pub struct ServiceCallInterceptor {
534 opts: ClientOptions,
535 headers: Arc<RwLock<ClientHeaders>>,
537}
538
539impl Interceptor for ServiceCallInterceptor {
540 fn call(
543 &mut self,
544 mut request: tonic::Request<()>,
545 ) -> Result<tonic::Request<()>, tonic::Status> {
546 let metadata = request.metadata_mut();
547 if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
548 metadata.insert(
549 CLIENT_NAME_HEADER_KEY,
550 self.opts
551 .client_name
552 .parse()
553 .unwrap_or_else(|_| MetadataValue::from_static("")),
554 );
555 }
556 if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
557 metadata.insert(
558 CLIENT_VERSION_HEADER_KEY,
559 self.opts
560 .client_version
561 .parse()
562 .unwrap_or_else(|_| MetadataValue::from_static("")),
563 );
564 }
565 self.headers.read().apply_to_metadata(metadata);
566 request.set_default_timeout(OTHER_CALL_TIMEOUT);
567
568 Ok(request)
569 }
570}
571
572#[derive(Debug, Clone)]
574pub struct TemporalServiceClient<T> {
575 svc: T,
576 workflow_svc_client: OnceLock<WorkflowServiceClient<T>>,
577 operator_svc_client: OnceLock<OperatorServiceClient<T>>,
578 cloud_svc_client: OnceLock<CloudServiceClient<T>>,
579 test_svc_client: OnceLock<TestServiceClient<T>>,
580 health_svc_client: OnceLock<HealthClient<T>>,
581}
582
583fn get_decode_max_size() -> usize {
586 static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
587 *_DECODE_MAX_SIZE.get_or_init(|| {
588 std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
589 .ok()
590 .and_then(|s| s.parse().ok())
591 .unwrap_or(128 * 1024 * 1024)
592 })
593}
594
595impl<T> TemporalServiceClient<T>
596where
597 T: Clone,
598 T: GrpcService<Body> + Send + Clone + 'static,
599 T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
600 T::Error: Into<tonic::codegen::StdError>,
601 <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
602{
603 fn new(svc: T) -> Self {
604 Self {
605 svc,
606 workflow_svc_client: OnceLock::new(),
607 operator_svc_client: OnceLock::new(),
608 cloud_svc_client: OnceLock::new(),
609 test_svc_client: OnceLock::new(),
610 health_svc_client: OnceLock::new(),
611 }
612 }
613 pub fn workflow_svc(&self) -> &WorkflowServiceClient<T> {
615 self.workflow_svc_client.get_or_init(|| {
616 WorkflowServiceClient::new(self.svc.clone())
617 .max_decoding_message_size(get_decode_max_size())
618 })
619 }
620 pub fn operator_svc(&self) -> &OperatorServiceClient<T> {
622 self.operator_svc_client.get_or_init(|| {
623 OperatorServiceClient::new(self.svc.clone())
624 .max_decoding_message_size(get_decode_max_size())
625 })
626 }
627 pub fn cloud_svc(&self) -> &CloudServiceClient<T> {
629 self.cloud_svc_client.get_or_init(|| {
630 CloudServiceClient::new(self.svc.clone())
631 .max_decoding_message_size(get_decode_max_size())
632 })
633 }
634 pub fn test_svc(&self) -> &TestServiceClient<T> {
636 self.test_svc_client.get_or_init(|| {
637 TestServiceClient::new(self.svc.clone())
638 .max_decoding_message_size(get_decode_max_size())
639 })
640 }
641 pub fn health_svc(&self) -> &HealthClient<T> {
643 self.health_svc_client.get_or_init(|| {
644 HealthClient::new(self.svc.clone()).max_decoding_message_size(get_decode_max_size())
645 })
646 }
647 pub fn workflow_svc_mut(&mut self) -> &mut WorkflowServiceClient<T> {
649 let _ = self.workflow_svc();
650 self.workflow_svc_client.get_mut().unwrap()
651 }
652 pub fn operator_svc_mut(&mut self) -> &mut OperatorServiceClient<T> {
654 let _ = self.operator_svc();
655 self.operator_svc_client.get_mut().unwrap()
656 }
657 pub fn cloud_svc_mut(&mut self) -> &mut CloudServiceClient<T> {
659 let _ = self.cloud_svc();
660 self.cloud_svc_client.get_mut().unwrap()
661 }
662 pub fn test_svc_mut(&mut self) -> &mut TestServiceClient<T> {
664 let _ = self.test_svc();
665 self.test_svc_client.get_mut().unwrap()
666 }
667 pub fn health_svc_mut(&mut self) -> &mut HealthClient<T> {
669 let _ = self.health_svc();
670 self.health_svc_client.get_mut().unwrap()
671 }
672}
673
674pub type WorkflowServiceClientWithMetrics = WorkflowServiceClient<InterceptedMetricsSvc>;
676pub type OperatorServiceClientWithMetrics = OperatorServiceClient<InterceptedMetricsSvc>;
678pub type TestServiceClientWithMetrics = TestServiceClient<InterceptedMetricsSvc>;
680pub type TemporalServiceClientWithMetrics = TemporalServiceClient<InterceptedMetricsSvc>;
682type InterceptedMetricsSvc = InterceptedService<GrpcMetricSvc, ServiceCallInterceptor>;
683
684#[derive(Debug, Clone)]
686pub struct Client {
687 inner: ConfiguredClient<TemporalServiceClientWithMetrics>,
689 namespace: String,
691}
692
693impl Client {
694 pub fn new(
696 client: ConfiguredClient<TemporalServiceClientWithMetrics>,
697 namespace: String,
698 ) -> Self {
699 Client {
700 inner: client,
701 namespace,
702 }
703 }
704
705 pub fn raw_retry_client(&self) -> RetryClient<WorkflowServiceClientWithMetrics> {
711 RetryClient::new(
712 self.raw_client().clone(),
713 self.inner.options.retry_config.clone(),
714 )
715 }
716
717 pub fn raw_client(&self) -> &WorkflowServiceClientWithMetrics {
722 self.inner.workflow_svc()
723 }
724
725 pub fn options(&self) -> &ClientOptions {
727 &self.inner.options
728 }
729
730 pub fn options_mut(&mut self) -> &mut ClientOptions {
732 Arc::make_mut(&mut self.inner.options)
733 }
734
735 pub fn inner(&self) -> &ConfiguredClient<TemporalServiceClientWithMetrics> {
737 &self.inner
738 }
739
740 pub fn into_inner(self) -> ConfiguredClient<TemporalServiceClientWithMetrics> {
742 self.inner
743 }
744}
745
746impl NamespacedClient for Client {
747 fn namespace(&self) -> &str {
748 &self.namespace
749 }
750
751 fn get_identity(&self) -> &str {
752 &self.inner.options.identity
753 }
754}
755
756#[derive(Clone)]
758pub enum Namespace {
759 Name(String),
761 Id(String),
763}
764
765impl Namespace {
766 pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
768 let (namespace, id) = match self {
769 Namespace::Name(n) => (n, "".to_owned()),
770 Namespace::Id(n) => ("".to_owned(), n),
771 };
772 DescribeNamespaceRequest { namespace, id }
773 }
774}
775
776pub const DEFAULT_WORKFLOW_EXECUTION_RETENTION_PERIOD: Duration =
778 Duration::from_secs(60 * 60 * 24 * 3);
779
780#[derive(Clone, derive_builder::Builder)]
782pub struct RegisterNamespaceOptions {
783 #[builder(setter(into))]
785 pub namespace: String,
786 #[builder(setter(into))]
788 pub description: String,
789 #[builder(setter(into), default)]
791 pub owner_email: String,
792 #[builder(default = "DEFAULT_WORKFLOW_EXECUTION_RETENTION_PERIOD")]
794 pub workflow_execution_retention_period: Duration,
795 #[builder(setter(strip_option, custom), default)]
797 pub clusters: Vec<ClusterReplicationConfig>,
798 #[builder(setter(into), default)]
800 pub active_cluster_name: String,
801 #[builder(default)]
803 pub data: HashMap<String, String>,
804 #[builder(setter(into), default)]
806 pub security_token: String,
807 #[builder(default)]
809 pub is_global_namespace: bool,
810 #[builder(setter(into), default = "ArchivalState::Unspecified")]
812 pub history_archival_state: ArchivalState,
813 #[builder(setter(into), default)]
815 pub history_archival_uri: String,
816 #[builder(setter(into), default = "ArchivalState::Unspecified")]
818 pub visibility_archival_state: ArchivalState,
819 #[builder(setter(into), default)]
821 pub visibility_archival_uri: String,
822}
823
824impl RegisterNamespaceOptions {
825 pub fn builder() -> RegisterNamespaceOptionsBuilder {
827 Default::default()
828 }
829}
830
831impl From<RegisterNamespaceOptions> for RegisterNamespaceRequest {
832 fn from(val: RegisterNamespaceOptions) -> Self {
833 RegisterNamespaceRequest {
834 namespace: val.namespace,
835 description: val.description,
836 owner_email: val.owner_email,
837 workflow_execution_retention_period: val
838 .workflow_execution_retention_period
839 .try_into()
840 .ok(),
841 clusters: val.clusters,
842 active_cluster_name: val.active_cluster_name,
843 data: val.data,
844 security_token: val.security_token,
845 is_global_namespace: val.is_global_namespace,
846 history_archival_state: val.history_archival_state as i32,
847 history_archival_uri: val.history_archival_uri,
848 visibility_archival_state: val.visibility_archival_state as i32,
849 visibility_archival_uri: val.visibility_archival_uri,
850 }
851 }
852}
853
854impl RegisterNamespaceOptionsBuilder {
855 pub fn cluster_names(&mut self, clusters: Vec<String>) {
858 self.clusters = Some(
859 clusters
860 .into_iter()
861 .map(|s| ClusterReplicationConfig { cluster_name: s })
862 .collect(),
863 );
864 }
865}
866
867#[derive(Clone, derive_builder::Builder)]
869pub struct SignalWithStartOptions {
870 #[builder(setter(strip_option), default)]
872 pub input: Option<Payloads>,
873 #[builder(setter(into))]
875 pub task_queue: String,
876 #[builder(setter(into))]
878 pub workflow_id: String,
879 #[builder(setter(into))]
881 pub workflow_type: String,
882 #[builder(setter(strip_option), default)]
883 pub request_id: Option<String>,
885 #[builder(setter(into))]
887 pub signal_name: String,
888 #[builder(default)]
890 pub signal_input: Option<Payloads>,
891 #[builder(setter(strip_option), default)]
892 pub signal_header: Option<Header>,
894}
895
896impl SignalWithStartOptions {
897 pub fn builder() -> SignalWithStartOptionsBuilder {
899 Default::default()
900 }
901}
902
903#[async_trait::async_trait]
906pub trait WorkflowClientTrait: NamespacedClient {
907 async fn start_workflow(
909 &self,
910 input: Vec<Payload>,
911 task_queue: String,
912 workflow_id: String,
913 workflow_type: String,
914 request_id: Option<String>,
915 options: WorkflowOptions,
916 ) -> Result<StartWorkflowExecutionResponse>;
917
918 async fn reset_sticky_task_queue(
921 &self,
922 workflow_id: String,
923 run_id: String,
924 ) -> Result<ResetStickyTaskQueueResponse>;
925
926 async fn complete_activity_task(
930 &self,
931 task_token: TaskToken,
932 result: Option<Payloads>,
933 ) -> Result<RespondActivityTaskCompletedResponse>;
934
935 async fn record_activity_heartbeat(
940 &self,
941 task_token: TaskToken,
942 details: Option<Payloads>,
943 ) -> Result<RecordActivityTaskHeartbeatResponse>;
944
945 async fn cancel_activity_task(
949 &self,
950 task_token: TaskToken,
951 details: Option<Payloads>,
952 ) -> Result<RespondActivityTaskCanceledResponse>;
953
954 async fn signal_workflow_execution(
956 &self,
957 workflow_id: String,
958 run_id: String,
959 signal_name: String,
960 payloads: Option<Payloads>,
961 request_id: Option<String>,
962 ) -> Result<SignalWorkflowExecutionResponse>;
963
964 #[allow(clippy::too_many_arguments)]
967 async fn signal_with_start_workflow_execution(
968 &self,
969 options: SignalWithStartOptions,
970 workflow_options: WorkflowOptions,
971 ) -> Result<SignalWithStartWorkflowExecutionResponse>;
972
973 async fn query_workflow_execution(
975 &self,
976 workflow_id: String,
977 run_id: String,
978 query: WorkflowQuery,
979 ) -> Result<QueryWorkflowResponse>;
980
981 async fn describe_workflow_execution(
983 &self,
984 workflow_id: String,
985 run_id: Option<String>,
986 ) -> Result<DescribeWorkflowExecutionResponse>;
987
988 async fn get_workflow_execution_history(
990 &self,
991 workflow_id: String,
992 run_id: Option<String>,
993 page_token: Vec<u8>,
994 ) -> Result<GetWorkflowExecutionHistoryResponse>;
995
996 async fn cancel_workflow_execution(
998 &self,
999 workflow_id: String,
1000 run_id: Option<String>,
1001 reason: String,
1002 request_id: Option<String>,
1003 ) -> Result<RequestCancelWorkflowExecutionResponse>;
1004
1005 async fn terminate_workflow_execution(
1007 &self,
1008 workflow_id: String,
1009 run_id: Option<String>,
1010 ) -> Result<TerminateWorkflowExecutionResponse>;
1011
1012 async fn register_namespace(
1014 &self,
1015 options: RegisterNamespaceOptions,
1016 ) -> Result<RegisterNamespaceResponse>;
1017
1018 async fn list_namespaces(&self) -> Result<ListNamespacesResponse>;
1020
1021 async fn describe_namespace(&self, namespace: Namespace) -> Result<DescribeNamespaceResponse>;
1023
1024 async fn list_open_workflow_executions(
1026 &self,
1027 max_page_size: i32,
1028 next_page_token: Vec<u8>,
1029 start_time_filter: Option<StartTimeFilter>,
1030 filters: Option<ListOpenFilters>,
1031 ) -> Result<ListOpenWorkflowExecutionsResponse>;
1032
1033 async fn list_closed_workflow_executions(
1035 &self,
1036 max_page_size: i32,
1037 next_page_token: Vec<u8>,
1038 start_time_filter: Option<StartTimeFilter>,
1039 filters: Option<ListClosedFilters>,
1040 ) -> Result<ListClosedWorkflowExecutionsResponse>;
1041
1042 async fn list_workflow_executions(
1044 &self,
1045 page_size: i32,
1046 next_page_token: Vec<u8>,
1047 query: String,
1048 ) -> Result<ListWorkflowExecutionsResponse>;
1049
1050 async fn list_archived_workflow_executions(
1052 &self,
1053 page_size: i32,
1054 next_page_token: Vec<u8>,
1055 query: String,
1056 ) -> Result<ListArchivedWorkflowExecutionsResponse>;
1057
1058 async fn get_search_attributes(&self) -> Result<GetSearchAttributesResponse>;
1060
1061 async fn update_workflow_execution(
1063 &self,
1064 workflow_id: String,
1065 run_id: String,
1066 name: String,
1067 wait_policy: update::v1::WaitPolicy,
1068 args: Option<Payloads>,
1069 ) -> Result<UpdateWorkflowExecutionResponse>;
1070}
1071
1072pub trait NamespacedClient {
1074 fn namespace(&self) -> &str;
1076 fn get_identity(&self) -> &str;
1078}
1079
1080#[derive(Debug, Clone, Default)]
1082pub struct WorkflowOptions {
1083 pub id_reuse_policy: WorkflowIdReusePolicy,
1085
1086 pub id_conflict_policy: WorkflowIdConflictPolicy,
1089
1090 pub execution_timeout: Option<Duration>,
1093
1094 pub run_timeout: Option<Duration>,
1096
1097 pub task_timeout: Option<Duration>,
1099
1100 pub cron_schedule: Option<String>,
1102
1103 pub search_attributes: Option<HashMap<String, Payload>>,
1105
1106 pub enable_eager_workflow_start: bool,
1109
1110 pub retry_policy: Option<RetryPolicy>,
1112
1113 pub links: Vec<common::v1::Link>,
1115
1116 pub completion_callbacks: Vec<common::v1::Callback>,
1119
1120 pub priority: Option<Priority>,
1122}
1123
1124#[derive(Debug, Clone, Default, PartialEq, Eq)]
1144pub struct Priority {
1145 pub priority_key: u32,
1156}
1157
1158impl From<Priority> for common::v1::Priority {
1159 fn from(priority: Priority) -> Self {
1160 common::v1::Priority {
1161 priority_key: priority.priority_key as i32,
1162 }
1163 }
1164}
1165
1166impl From<common::v1::Priority> for Priority {
1167 fn from(priority: common::v1::Priority) -> Self {
1168 Self {
1169 priority_key: priority.priority_key as u32,
1170 }
1171 }
1172}
1173
1174#[async_trait::async_trait]
1175impl<T> WorkflowClientTrait for T
1176where
1177 T: RawClientLike + NamespacedClient + Clone + Send + Sync + 'static,
1178 <Self as RawClientLike>::SvcType: GrpcService<Body> + Send + Clone + 'static,
1179 <<Self as RawClientLike>::SvcType as GrpcService<Body>>::ResponseBody:
1180 tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
1181 <<Self as RawClientLike>::SvcType as GrpcService<Body>>::Error:
1182 Into<tonic::codegen::StdError>,
1183 <<Self as RawClientLike>::SvcType as GrpcService<Body>>::Future: Send,
1184 <<<Self as RawClientLike>::SvcType as GrpcService<Body>>::ResponseBody
1185 as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
1186{
1187 async fn start_workflow(
1188 &self,
1189 input: Vec<Payload>,
1190 task_queue: String,
1191 workflow_id: String,
1192 workflow_type: String,
1193 request_id: Option<String>,
1194 options: WorkflowOptions,
1195 ) -> Result<StartWorkflowExecutionResponse> {
1196 Ok(self
1197 .clone()
1198 .start_workflow_execution(StartWorkflowExecutionRequest {
1199 namespace: self.namespace().to_owned(),
1200 input: input.into_payloads(),
1201 workflow_id,
1202 workflow_type: Some(WorkflowType {
1203 name: workflow_type,
1204 }),
1205 task_queue: Some(TaskQueue {
1206 name: task_queue,
1207 kind: TaskQueueKind::Unspecified as i32,
1208 normal_name: "".to_string(),
1209 }),
1210 request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()),
1211 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1212 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1213 workflow_execution_timeout: options
1214 .execution_timeout
1215 .and_then(|d| d.try_into().ok()),
1216 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1217 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1218 search_attributes: options.search_attributes.map(|d| d.into()),
1219 cron_schedule: options.cron_schedule.unwrap_or_default(),
1220 request_eager_execution: options.enable_eager_workflow_start,
1221 retry_policy: options.retry_policy,
1222 links: options.links,
1223 completion_callbacks: options.completion_callbacks,
1224 priority: options.priority.map(Into::into),
1225 ..Default::default()
1226 })
1227 .await?
1228 .into_inner())
1229 }
1230
1231 async fn reset_sticky_task_queue(
1232 &self,
1233 workflow_id: String,
1234 run_id: String,
1235 ) -> Result<ResetStickyTaskQueueResponse> {
1236 let request = ResetStickyTaskQueueRequest {
1237 namespace: self.namespace().to_owned(),
1238 execution: Some(WorkflowExecution {
1239 workflow_id,
1240 run_id,
1241 }),
1242 };
1243 Ok(
1244 WorkflowService::reset_sticky_task_queue(&mut self.clone(), request)
1245 .await?
1246 .into_inner(),
1247 )
1248 }
1249
1250 async fn complete_activity_task(
1251 &self,
1252 task_token: TaskToken,
1253 result: Option<Payloads>,
1254 ) -> Result<RespondActivityTaskCompletedResponse> {
1255 Ok(self.clone().respond_activity_task_completed(
1256 RespondActivityTaskCompletedRequest {
1257 task_token: task_token.0,
1258 result,
1259 identity: self.get_identity().to_owned(),
1260 namespace: self.namespace().to_owned(),
1261 ..Default::default()
1262 },
1263 )
1264 .await?
1265 .into_inner())
1266 }
1267
1268 async fn record_activity_heartbeat(
1269 &self,
1270 task_token: TaskToken,
1271 details: Option<Payloads>,
1272 ) -> Result<RecordActivityTaskHeartbeatResponse> {
1273 Ok(self.clone().record_activity_task_heartbeat(
1274 RecordActivityTaskHeartbeatRequest {
1275 task_token: task_token.0,
1276 details,
1277 identity: self.get_identity().to_owned(),
1278 namespace: self.namespace().to_owned(),
1279 },
1280 )
1281 .await?
1282 .into_inner())
1283 }
1284
1285 async fn cancel_activity_task(
1286 &self,
1287 task_token: TaskToken,
1288 details: Option<Payloads>,
1289 ) -> Result<RespondActivityTaskCanceledResponse> {
1290 Ok(self.clone().respond_activity_task_canceled(
1291 RespondActivityTaskCanceledRequest {
1292 task_token: task_token.0,
1293 details,
1294 identity: self.get_identity().to_owned(),
1295 namespace: self.namespace().to_owned(),
1296 ..Default::default()
1297 },
1298 )
1299 .await?
1300 .into_inner())
1301 }
1302
1303 async fn signal_workflow_execution(
1304 &self,
1305 workflow_id: String,
1306 run_id: String,
1307 signal_name: String,
1308 payloads: Option<Payloads>,
1309 request_id: Option<String>,
1310 ) -> Result<SignalWorkflowExecutionResponse> {
1311 Ok(WorkflowService::signal_workflow_execution(&mut self.clone(),
1312 SignalWorkflowExecutionRequest {
1313 namespace: self.namespace().to_owned(),
1314 workflow_execution: Some(WorkflowExecution {
1315 workflow_id,
1316 run_id,
1317 }),
1318 signal_name,
1319 input: payloads,
1320 identity: self.get_identity().to_owned(),
1321 request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()),
1322 ..Default::default()
1323 },
1324 )
1325 .await?
1326 .into_inner())
1327 }
1328
1329 async fn signal_with_start_workflow_execution(
1330 &self,
1331 options: SignalWithStartOptions,
1332 workflow_options: WorkflowOptions,
1333 ) -> Result<SignalWithStartWorkflowExecutionResponse> {
1334 Ok(WorkflowService::signal_with_start_workflow_execution(&mut self.clone(),
1335 SignalWithStartWorkflowExecutionRequest {
1336 namespace: self.namespace().to_owned(),
1337 workflow_id: options.workflow_id,
1338 workflow_type: Some(WorkflowType {
1339 name: options.workflow_type,
1340 }),
1341 task_queue: Some(TaskQueue {
1342 name: options.task_queue,
1343 kind: TaskQueueKind::Normal as i32,
1344 normal_name: "".to_string(),
1345 }),
1346 input: options.input,
1347 signal_name: options.signal_name,
1348 signal_input: options.signal_input,
1349 identity: self.get_identity().to_owned(),
1350 request_id: options
1351 .request_id
1352 .unwrap_or_else(|| Uuid::new_v4().to_string()),
1353 workflow_id_reuse_policy: workflow_options.id_reuse_policy as i32,
1354 workflow_id_conflict_policy: workflow_options.id_conflict_policy as i32,
1355 workflow_execution_timeout: workflow_options
1356 .execution_timeout
1357 .and_then(|d| d.try_into().ok()),
1358 workflow_run_timeout: workflow_options.run_timeout.and_then(|d| d.try_into().ok()),
1359 workflow_task_timeout: workflow_options
1360 .task_timeout
1361 .and_then(|d| d.try_into().ok()),
1362 search_attributes: workflow_options.search_attributes.map(|d| d.into()),
1363 cron_schedule: workflow_options.cron_schedule.unwrap_or_default(),
1364 header: options.signal_header,
1365 ..Default::default()
1366 },
1367 )
1368 .await?
1369 .into_inner())
1370 }
1371
1372 async fn query_workflow_execution(
1373 &self,
1374 workflow_id: String,
1375 run_id: String,
1376 query: WorkflowQuery,
1377 ) -> Result<QueryWorkflowResponse> {
1378 Ok(self.clone().query_workflow(
1379 QueryWorkflowRequest {
1380 namespace: self.namespace().to_owned(),
1381 execution: Some(WorkflowExecution {
1382 workflow_id,
1383 run_id,
1384 }),
1385 query: Some(query),
1386 query_reject_condition: 1,
1387 },
1388 )
1389 .await?
1390 .into_inner())
1391 }
1392
1393 async fn describe_workflow_execution(
1394 &self,
1395 workflow_id: String,
1396 run_id: Option<String>,
1397 ) -> Result<DescribeWorkflowExecutionResponse> {
1398 Ok(WorkflowService::describe_workflow_execution(&mut self.clone(),
1399 DescribeWorkflowExecutionRequest {
1400 namespace: self.namespace().to_owned(),
1401 execution: Some(WorkflowExecution {
1402 workflow_id,
1403 run_id: run_id.unwrap_or_default(),
1404 }),
1405 },
1406 )
1407 .await?
1408 .into_inner())
1409 }
1410
1411 async fn get_workflow_execution_history(
1412 &self,
1413 workflow_id: String,
1414 run_id: Option<String>,
1415 page_token: Vec<u8>,
1416 ) -> Result<GetWorkflowExecutionHistoryResponse> {
1417 Ok(WorkflowService::get_workflow_execution_history(&mut self.clone(),
1418 GetWorkflowExecutionHistoryRequest {
1419 namespace: self.namespace().to_owned(),
1420 execution: Some(WorkflowExecution {
1421 workflow_id,
1422 run_id: run_id.unwrap_or_default(),
1423 }),
1424 next_page_token: page_token,
1425 ..Default::default()
1426 },
1427 )
1428 .await?
1429 .into_inner())
1430 }
1431
1432 async fn cancel_workflow_execution(
1433 &self,
1434 workflow_id: String,
1435 run_id: Option<String>,
1436 reason: String,
1437 request_id: Option<String>,
1438 ) -> Result<RequestCancelWorkflowExecutionResponse> {
1439 Ok(self.clone().request_cancel_workflow_execution(
1440 RequestCancelWorkflowExecutionRequest {
1441 namespace: self.namespace().to_owned(),
1442 workflow_execution: Some(WorkflowExecution {
1443 workflow_id,
1444 run_id: run_id.unwrap_or_default(),
1445 }),
1446 identity: self.get_identity().to_owned(),
1447 request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()),
1448 first_execution_run_id: "".to_string(),
1449 reason,
1450 links: vec![],
1451 },
1452 )
1453 .await?
1454 .into_inner())
1455 }
1456
1457 async fn terminate_workflow_execution(
1458 &self,
1459 workflow_id: String,
1460 run_id: Option<String>,
1461 ) -> Result<TerminateWorkflowExecutionResponse> {
1462 Ok(WorkflowService::terminate_workflow_execution(&mut self.clone(),
1463 TerminateWorkflowExecutionRequest {
1464 namespace: self.namespace().to_owned(),
1465 workflow_execution: Some(WorkflowExecution {
1466 workflow_id,
1467 run_id: run_id.unwrap_or_default(),
1468 }),
1469 reason: "".to_string(),
1470 details: None,
1471 identity: self.get_identity().to_owned(),
1472 first_execution_run_id: "".to_string(),
1473 links: vec![],
1474 },
1475 )
1476 .await?
1477 .into_inner())
1478 }
1479
1480 async fn register_namespace(
1481 &self,
1482 options: RegisterNamespaceOptions,
1483 ) -> Result<RegisterNamespaceResponse> {
1484 let req = Into::<RegisterNamespaceRequest>::into(options);
1485 Ok(
1486 WorkflowService::register_namespace(&mut self.clone(),req)
1487 .await?
1488 .into_inner(),
1489 )
1490 }
1491
1492 async fn list_namespaces(&self) -> Result<ListNamespacesResponse> {
1493 Ok(WorkflowService::list_namespaces(&mut self.clone(),
1494 ListNamespacesRequest::default(),
1495 )
1496 .await?
1497 .into_inner())
1498 }
1499
1500 async fn describe_namespace(&self, namespace: Namespace) -> Result<DescribeNamespaceResponse> {
1501 Ok(WorkflowService::describe_namespace(&mut self.clone(),
1502 namespace.into_describe_namespace_request(),
1503 )
1504 .await?
1505 .into_inner())
1506 }
1507
1508 async fn list_open_workflow_executions(
1509 &self,
1510 maximum_page_size: i32,
1511 next_page_token: Vec<u8>,
1512 start_time_filter: Option<StartTimeFilter>,
1513 filters: Option<ListOpenFilters>,
1514 ) -> Result<ListOpenWorkflowExecutionsResponse> {
1515 Ok(WorkflowService::list_open_workflow_executions(&mut self.clone(),
1516 ListOpenWorkflowExecutionsRequest {
1517 namespace: self.namespace().to_owned(),
1518 maximum_page_size,
1519 next_page_token,
1520 start_time_filter,
1521 filters,
1522 },
1523 )
1524 .await?
1525 .into_inner())
1526 }
1527
1528 async fn list_closed_workflow_executions(
1529 &self,
1530 maximum_page_size: i32,
1531 next_page_token: Vec<u8>,
1532 start_time_filter: Option<StartTimeFilter>,
1533 filters: Option<ListClosedFilters>,
1534 ) -> Result<ListClosedWorkflowExecutionsResponse> {
1535 Ok(WorkflowService::list_closed_workflow_executions(&mut self.clone(),
1536 ListClosedWorkflowExecutionsRequest {
1537 namespace: self.namespace().to_owned(),
1538 maximum_page_size,
1539 next_page_token,
1540 start_time_filter,
1541 filters,
1542 },
1543 )
1544 .await?
1545 .into_inner())
1546 }
1547
1548 async fn list_workflow_executions(
1549 &self,
1550 page_size: i32,
1551 next_page_token: Vec<u8>,
1552 query: String,
1553 ) -> Result<ListWorkflowExecutionsResponse> {
1554 Ok(WorkflowService::list_workflow_executions(&mut self.clone(),
1555 ListWorkflowExecutionsRequest {
1556 namespace: self.namespace().to_owned(),
1557 page_size,
1558 next_page_token,
1559 query,
1560 },
1561 )
1562 .await?
1563 .into_inner())
1564 }
1565
1566 async fn list_archived_workflow_executions(
1567 &self,
1568 page_size: i32,
1569 next_page_token: Vec<u8>,
1570 query: String,
1571 ) -> Result<ListArchivedWorkflowExecutionsResponse> {
1572 Ok(WorkflowService::list_archived_workflow_executions(&mut self.clone(),
1573 ListArchivedWorkflowExecutionsRequest {
1574 namespace: self.namespace().to_owned(),
1575 page_size,
1576 next_page_token,
1577 query,
1578 },
1579 )
1580 .await?
1581 .into_inner())
1582 }
1583
1584 async fn get_search_attributes(&self) -> Result<GetSearchAttributesResponse> {
1585 Ok(WorkflowService::get_search_attributes(&mut self.clone(),
1586 GetSearchAttributesRequest {},
1587 )
1588 .await?
1589 .into_inner())
1590 }
1591
1592 async fn update_workflow_execution(
1593 &self,
1594 workflow_id: String,
1595 run_id: String,
1596 name: String,
1597 wait_policy: update::v1::WaitPolicy,
1598 args: Option<Payloads>,
1599 ) -> Result<UpdateWorkflowExecutionResponse> {
1600 Ok(WorkflowService::update_workflow_execution(&mut self.clone(),
1601 UpdateWorkflowExecutionRequest {
1602 namespace: self.namespace().to_owned(),
1603 workflow_execution: Some(WorkflowExecution {
1604 workflow_id,
1605 run_id,
1606 }),
1607 wait_policy: Some(wait_policy),
1608 request: Some(update::v1::Request {
1609 meta: Some(update::v1::Meta {
1610 update_id: "".into(),
1611 identity: self.get_identity().to_owned(),
1612 }),
1613 input: Some(update::v1::Input {
1614 header: None,
1615 name,
1616 args,
1617 }),
1618 }),
1619 ..Default::default()
1620 },
1621 )
1622 .await?
1623 .into_inner())
1624 }
1625}
1626
1627mod sealed {
1628 use crate::{InterceptedMetricsSvc, RawClientLike, WorkflowClientTrait};
1629
1630 pub trait WfHandleClient:
1631 WorkflowClientTrait + RawClientLike<SvcType = InterceptedMetricsSvc>
1632 {
1633 }
1634
1635 impl<T> WfHandleClient for T where
1636 T: WorkflowClientTrait + RawClientLike<SvcType = InterceptedMetricsSvc>
1637 {
1638 }
1639}
1640
1641pub trait WfClientExt: WfHandleClient + Sized + Clone {
1643 fn get_untyped_workflow_handle(
1646 &self,
1647 workflow_id: impl Into<String>,
1648 run_id: impl Into<String>,
1649 ) -> UntypedWorkflowHandle<Self> {
1650 let rid = run_id.into();
1651 UntypedWorkflowHandle::new(
1652 self.clone(),
1653 WorkflowExecutionInfo {
1654 namespace: self.namespace().to_string(),
1655 workflow_id: workflow_id.into(),
1656 run_id: if rid.is_empty() { None } else { Some(rid) },
1657 },
1658 )
1659 }
1660}
1661
1662impl<T> WfClientExt for T where T: WfHandleClient + Clone + Sized {}
1663
1664trait RequestExt {
1665 fn set_default_timeout(&mut self, duration: Duration);
1667}
1668impl<T> RequestExt for tonic::Request<T> {
1669 fn set_default_timeout(&mut self, duration: Duration) {
1670 if !self.metadata().contains_key("grpc-timeout") {
1671 self.set_timeout(duration)
1672 }
1673 }
1674}
1675
1676macro_rules! dbg_panic {
1677 ($($arg:tt)*) => {
1678 use tracing::error;
1679 error!($($arg)*);
1680 debug_assert!(false, $($arg)*);
1681 };
1682}
1683pub(crate) use dbg_panic;
1684
1685#[cfg(test)]
1686mod tests {
1687 use super::*;
1688 use tonic::metadata::Ascii;
1689
1690 #[test]
1691 fn applies_headers() {
1692 let opts = ClientOptionsBuilder::default()
1693 .identity("enchicat".to_string())
1694 .target_url(Url::parse("https://smolkitty").unwrap())
1695 .client_name("cute-kitty".to_string())
1696 .client_version("0.1.0".to_string())
1697 .build()
1698 .unwrap();
1699
1700 let headers = Arc::new(RwLock::new(ClientHeaders {
1702 user_headers: HashMap::new(),
1703 api_key: Some("my-api-key".to_owned()),
1704 }));
1705 headers
1706 .clone()
1707 .write()
1708 .user_headers
1709 .insert("my-meta-key".to_owned(), "my-meta-val".to_owned());
1710 let mut interceptor = ServiceCallInterceptor {
1711 opts,
1712 headers: headers.clone(),
1713 };
1714
1715 let req = interceptor.call(tonic::Request::new(())).unwrap();
1717 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1718 assert_eq!(
1719 req.metadata().get("authorization").unwrap(),
1720 "Bearer my-api-key"
1721 );
1722
1723 let mut req = tonic::Request::new(());
1725 req.metadata_mut()
1726 .insert("my-meta-key", "my-meta-val2".parse().unwrap());
1727 req.metadata_mut()
1728 .insert("authorization", "my-api-key2".parse().unwrap());
1729 let req = interceptor.call(req).unwrap();
1730 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1731 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1732
1733 headers
1735 .clone()
1736 .write()
1737 .user_headers
1738 .insert("authorization".to_owned(), "my-api-key3".to_owned());
1739 let req = interceptor.call(tonic::Request::new(())).unwrap();
1740 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1741 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1742
1743 headers.clone().write().user_headers.clear();
1745 headers.clone().write().api_key.take();
1746 let req = interceptor.call(tonic::Request::new(())).unwrap();
1747 assert!(!req.metadata().contains_key("my-meta-key"));
1748 assert!(!req.metadata().contains_key("authorization"));
1749
1750 let mut req = tonic::Request::new(());
1752 req.metadata_mut()
1753 .insert("grpc-timeout", "1S".parse().unwrap());
1754 let req = interceptor.call(req).unwrap();
1755 assert_eq!(
1756 req.metadata().get("grpc-timeout").unwrap(),
1757 "1S".parse::<MetadataValue<Ascii>>().unwrap()
1758 );
1759 }
1760
1761 #[test]
1762 fn keep_alive_defaults() {
1763 let mut builder = ClientOptionsBuilder::default();
1764 builder
1765 .identity("enchicat".to_string())
1766 .target_url(Url::parse("https://smolkitty").unwrap())
1767 .client_name("cute-kitty".to_string())
1768 .client_version("0.1.0".to_string());
1769 let opts = builder.build().unwrap();
1771 assert_eq!(
1772 opts.keep_alive.clone().unwrap().interval,
1773 ClientKeepAliveConfig::default().interval
1774 );
1775 assert_eq!(
1776 opts.keep_alive.clone().unwrap().timeout,
1777 ClientKeepAliveConfig::default().timeout
1778 );
1779 let opts = builder.keep_alive(None).build().unwrap();
1781 assert!(opts.keep_alive.is_none());
1782 }
1783}