1#![warn(missing_docs)] #[macro_use]
8extern crate tracing;
9
10mod async_activity_handle;
11pub mod callback_based;
12mod dns;
13#[cfg(feature = "envconfig")]
15pub mod envconfig;
16pub mod errors;
17pub mod grpc;
18mod metrics;
19mod options_structs;
20#[doc(hidden)]
22pub mod proxy;
23mod replaceable;
24pub mod request_extensions;
25mod retry;
26pub mod schedules;
28pub mod worker;
29mod workflow_handle;
30
31pub use crate::{
32 proxy::HttpConnectProxyOptions,
33 retry::{CallType, RETRYABLE_ERROR_CODES},
34};
35pub use async_activity_handle::{
36 ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle,
37};
38
39pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
40pub use options_structs::*;
41pub use replaceable::SharedReplaceableClient;
42pub use retry::RetryOptions;
43pub use tonic;
44pub use workflow_handle::{
45 UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle,
46 WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
47 WorkflowHistory, WorkflowUpdateHandle,
48};
49
50use crate::{
51 grpc::{
52 AttachMetricLabels, CloudService, HealthService, OperatorService, TestService,
53 WorkflowService,
54 },
55 metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
56 request_extensions::RequestExt,
57 worker::ClientWorkerSet,
58};
59use errors::*;
60use futures_util::{stream, stream::Stream};
61use http::Uri;
62use parking_lot::RwLock;
63use std::{
64 collections::{HashMap, VecDeque},
65 fmt::Debug,
66 pin::Pin,
67 str::FromStr,
68 sync::{Arc, OnceLock},
69 task::{Context, Poll},
70 time::{Duration, SystemTime},
71};
72use temporalio_common::{
73 HasWorkflowDefinition,
74 data_converters::{
75 DataConverter, GenericPayloadConverter, PayloadConverter, SerializationContext,
76 SerializationContextData,
77 },
78 protos::{
79 coresdk::IntoPayloadsExt,
80 grpc::health::v1::health_client::HealthClient,
81 proto_ts_to_system_time,
82 temporal::api::{
83 cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
84 common::v1::{Memo, Payload, SearchAttributes, WorkflowType},
85 enums::v1::{TaskQueueKind, WorkflowExecutionStatus},
86 errordetails::v1::WorkflowExecutionAlreadyStartedFailure,
87 operatorservice::v1::operator_service_client::OperatorServiceClient,
88 sdk::v1::UserMetadata,
89 taskqueue::v1::TaskQueue,
90 testservice::v1::test_service_client::TestServiceClient,
91 workflow::v1 as workflow,
92 workflowservice::v1::{
93 count_workflow_executions_response, workflow_service_client::WorkflowServiceClient,
94 *,
95 },
96 },
97 utilities::decode_status_detail,
98 },
99};
100use tonic::{
101 Code, IntoRequest,
102 body::Body,
103 client::GrpcService,
104 codegen::InterceptedService,
105 metadata::{
106 AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
107 MetadataValue,
108 },
109 service::Interceptor,
110 transport::{Certificate, Channel, Endpoint, Identity},
111};
112use tower::ServiceBuilder;
113use uuid::Uuid;
114
115static CLIENT_NAME_HEADER_KEY: &str = "client-name";
116static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
117static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
118
119#[doc(hidden)]
120pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
122#[doc(hidden)]
123pub static ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT: &str = "short-circuit";
125
126const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
128const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
129const VERSION: &str = env!("CARGO_PKG_VERSION");
130
131#[derive(Clone)]
136pub struct Connection {
137 inner: Arc<ConnectionInner>,
138}
139
140#[derive(Clone)]
141struct ConnectionInner {
142 service: TemporalServiceClient,
143 retry_options: RetryOptions,
144 identity: String,
145 headers: Arc<RwLock<ClientHeaders>>,
146 client_name: String,
147 client_version: String,
148 capabilities: Option<get_system_info_response::Capabilities>,
150 workers: Arc<ClientWorkerSet>,
151 _dns_task: Option<Arc<dns::DnsReresolutionHandle>>,
152}
153
154impl Connection {
155 pub async fn connect(options: ConnectionOptions) -> Result<Self, ClientConnectError> {
157 let dns_lb_opts = dns::validate_and_get_dns_lb(&options)?.cloned();
158 let (service, dns_task) = if let Some(service_override) = options.service_override {
159 (
160 GrpcMetricSvc {
161 inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
162 metrics: options.metrics_meter.clone().map(MetricsContext::new),
163 disable_errcode_label: options.disable_error_code_metric_tags,
164 },
165 None,
166 )
167 } else if let Some(dns_opts) = &dns_lb_opts {
168 let (channel, sender) = dns::create_balanced_channel(&options).await?;
169 let handle = dns::spawn_dns_reresolution(
170 sender,
171 options.target.clone(),
172 options.tls_options.clone(),
173 options.keep_alive.clone(),
174 options.override_origin.clone(),
175 dns_opts.resolution_interval,
176 );
177 (
178 ServiceBuilder::new()
179 .layer_fn(move |channel| GrpcMetricSvc {
180 inner: ChannelOrGrpcOverride::Channel(channel),
181 metrics: options.metrics_meter.clone().map(MetricsContext::new),
182 disable_errcode_label: options.disable_error_code_metric_tags,
183 })
184 .service(channel),
185 Some(handle),
186 )
187 } else {
188 let channel = Channel::from_shared(options.target.to_string())?;
189 let channel = add_tls_to_channel(options.tls_options.as_ref(), channel).await?;
190 let channel = if let Some(keep_alive) = options.keep_alive.as_ref() {
191 channel
192 .keep_alive_while_idle(true)
193 .http2_keep_alive_interval(keep_alive.interval)
194 .keep_alive_timeout(keep_alive.timeout)
195 } else {
196 channel
197 };
198 let channel = if let Some(origin) = options.override_origin.clone() {
199 channel.origin(origin)
200 } else {
201 channel
202 };
203 let channel = if let Some(proxy) = options.http_connect_proxy.as_ref() {
205 proxy.connect_endpoint(&channel).await?
206 } else {
207 channel.connect().await?
208 };
209 (
210 ServiceBuilder::new()
211 .layer_fn(move |channel| GrpcMetricSvc {
212 inner: ChannelOrGrpcOverride::Channel(channel),
213 metrics: options.metrics_meter.clone().map(MetricsContext::new),
214 disable_errcode_label: options.disable_error_code_metric_tags,
215 })
216 .service(channel),
217 None,
218 )
219 };
220
221 let headers = Arc::new(RwLock::new(ClientHeaders {
222 user_headers: parse_ascii_headers(options.headers.clone().unwrap_or_default())?,
223 user_binary_headers: parse_binary_headers(
224 options.binary_headers.clone().unwrap_or_default(),
225 )?,
226 api_key: options.api_key.clone(),
227 }));
228 let interceptor = ServiceCallInterceptor {
229 client_name: options.client_name.clone(),
230 client_version: options.client_version.clone(),
231 headers: headers.clone(),
232 };
233 let svc = InterceptedService::new(service, interceptor);
234 let mut svc_client = TemporalServiceClient::new(svc);
235
236 let capabilities = if !options.skip_get_system_info {
237 match svc_client
238 .get_system_info(GetSystemInfoRequest::default().into_request())
239 .await
240 {
241 Ok(sysinfo) => sysinfo.into_inner().capabilities,
242 Err(status) => match status.code() {
243 Code::Unimplemented => None,
244 _ => return Err(ClientConnectError::SystemInfoCallError(status)),
245 },
246 }
247 } else {
248 None
249 };
250 Ok(Self {
251 inner: Arc::new(ConnectionInner {
252 service: svc_client,
253 retry_options: options.retry_options,
254 identity: options.identity,
255 headers,
256 client_name: options.client_name,
257 client_version: options.client_version,
258 capabilities,
259 workers: Arc::new(ClientWorkerSet::new()),
260 _dns_task: dns_task,
261 }),
262 })
263 }
264
265 pub fn set_api_key(&self, api_key: Option<String>) {
267 self.inner.headers.write().api_key = api_key;
268 }
269
270 pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
279 self.inner.headers.write().user_headers = parse_ascii_headers(headers)?;
280 Ok(())
281 }
282
283 pub fn set_binary_headers(
292 &self,
293 binary_headers: HashMap<String, Vec<u8>>,
294 ) -> Result<(), InvalidHeaderError> {
295 self.inner.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
296 Ok(())
297 }
298
299 pub fn client_name(&self) -> &str {
301 &self.inner.client_name
302 }
303
304 pub fn client_version(&self) -> &str {
306 &self.inner.client_version
307 }
308
309 pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
312 self.inner.capabilities.as_ref()
313 }
314
315 pub fn retry_options_mut(&mut self) -> &mut RetryOptions {
320 &mut Arc::make_mut(&mut self.inner).retry_options
321 }
322
323 pub fn identity(&self) -> &str {
325 &self.inner.identity
326 }
327
328 pub fn identity_mut(&mut self) -> &mut String {
333 &mut Arc::make_mut(&mut self.inner).identity
334 }
335
336 pub fn workers(&self) -> Arc<ClientWorkerSet> {
338 self.inner.workers.clone()
339 }
340
341 pub fn worker_grouping_key(&self) -> Uuid {
343 self.inner.workers.worker_grouping_key()
344 }
345
346 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
348 self.inner.service.workflow_service()
349 }
350
351 pub fn operator_service(&self) -> Box<dyn OperatorService> {
353 self.inner.service.operator_service()
354 }
355
356 pub fn cloud_service(&self) -> Box<dyn CloudService> {
358 self.inner.service.cloud_service()
359 }
360
361 pub fn test_service(&self) -> Box<dyn TestService> {
363 self.inner.service.test_service()
364 }
365
366 pub fn health_service(&self) -> Box<dyn HealthService> {
368 self.inner.service.health_service()
369 }
370}
371
372#[derive(Debug)]
373struct ClientHeaders {
374 user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
375 user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
376 api_key: Option<String>,
377}
378
379impl ClientHeaders {
380 fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
381 for (key, val) in self.user_headers.iter() {
382 if !metadata.contains_key(key) {
384 metadata.insert(key, val.clone());
385 }
386 }
387 for (key, val) in self.user_binary_headers.iter() {
388 if !metadata.contains_key(key) {
390 metadata.insert_bin(key, val.clone());
391 }
392 }
393 if let Some(api_key) = &self.api_key {
394 if !metadata.contains_key("authorization")
396 && let Ok(val) = format!("Bearer {api_key}").parse()
397 {
398 metadata.insert("authorization", val);
399 }
400 }
401 }
402}
403
404async fn add_tls_to_channel(
407 tls_options: Option<&TlsOptions>,
408 mut channel: Endpoint,
409) -> Result<Endpoint, ClientConnectError> {
410 if let Some(tls_cfg) = tls_options {
411 let mut tls = tonic::transport::ClientTlsConfig::new();
412
413 if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
414 let server_root_ca_cert = Certificate::from_pem(root_cert);
415 tls = tls.ca_certificate(server_root_ca_cert);
416 } else {
417 tls = tls.with_native_roots();
418 }
419
420 if let Some(domain) = &tls_cfg.domain {
421 tls = tls.domain_name(domain);
422
423 let uri: Uri = format!("https://{domain}").parse()?;
428 channel = channel.origin(uri);
429 }
430
431 if let Some(client_opts) = &tls_cfg.client_tls_options {
432 let client_identity =
433 Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
434 tls = tls.identity(client_identity);
435 }
436
437 return channel.tls_config(tls).map_err(Into::into);
438 }
439 Ok(channel)
440}
441
442fn parse_ascii_headers(
443 headers: HashMap<String, String>,
444) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
445 let mut parsed_headers = HashMap::with_capacity(headers.len());
446 for (k, v) in headers.into_iter() {
447 let key = match AsciiMetadataKey::from_str(&k) {
448 Ok(key) => key,
449 Err(err) => {
450 return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
451 key: k,
452 source: err,
453 });
454 }
455 };
456 let value = match MetadataValue::from_str(&v) {
457 Ok(value) => value,
458 Err(err) => {
459 return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
460 key: k,
461 value: v,
462 source: err,
463 });
464 }
465 };
466 parsed_headers.insert(key, value);
467 }
468
469 Ok(parsed_headers)
470}
471
472fn parse_binary_headers(
473 headers: HashMap<String, Vec<u8>>,
474) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
475 let mut parsed_headers = HashMap::with_capacity(headers.len());
476 for (k, v) in headers.into_iter() {
477 let key = match BinaryMetadataKey::from_str(&k) {
478 Ok(key) => key,
479 Err(err) => {
480 return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
481 key: k,
482 source: err,
483 });
484 }
485 };
486 let value = BinaryMetadataValue::from_bytes(&v);
487 parsed_headers.insert(key, value);
488 }
489
490 Ok(parsed_headers)
491}
492
493#[derive(Clone)]
495pub struct ServiceCallInterceptor {
496 client_name: String,
497 client_version: String,
498 headers: Arc<RwLock<ClientHeaders>>,
500}
501
502impl Interceptor for ServiceCallInterceptor {
503 fn call(
506 &mut self,
507 mut request: tonic::Request<()>,
508 ) -> Result<tonic::Request<()>, tonic::Status> {
509 let metadata = request.metadata_mut();
510 if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
511 metadata.insert(
512 CLIENT_NAME_HEADER_KEY,
513 self.client_name
514 .parse()
515 .unwrap_or_else(|_| MetadataValue::from_static("")),
516 );
517 }
518 if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
519 metadata.insert(
520 CLIENT_VERSION_HEADER_KEY,
521 self.client_version
522 .parse()
523 .unwrap_or_else(|_| MetadataValue::from_static("")),
524 );
525 }
526 self.headers.read().apply_to_metadata(metadata);
527 request.set_default_timeout(OTHER_CALL_TIMEOUT);
528
529 Ok(request)
530 }
531}
532
533#[derive(Clone)]
535pub struct TemporalServiceClient {
536 workflow_svc_client: Box<dyn WorkflowService>,
537 operator_svc_client: Box<dyn OperatorService>,
538 cloud_svc_client: Box<dyn CloudService>,
539 test_svc_client: Box<dyn TestService>,
540 health_svc_client: Box<dyn HealthService>,
541}
542
543fn get_decode_max_size() -> usize {
546 static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
547 *_DECODE_MAX_SIZE.get_or_init(|| {
548 std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
549 .ok()
550 .and_then(|s| s.parse().ok())
551 .unwrap_or(128 * 1024 * 1024)
552 })
553}
554
555impl TemporalServiceClient {
556 fn new<T>(svc: T) -> Self
557 where
558 T: GrpcService<Body> + Send + Sync + Clone + 'static,
559 T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
560 T::Error: Into<tonic::codegen::StdError>,
561 <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
562 <T as GrpcService<Body>>::Future: Send,
563 {
564 let workflow_svc_client = Box::new(
565 WorkflowServiceClient::new(svc.clone())
566 .max_decoding_message_size(get_decode_max_size()),
567 );
568 let operator_svc_client = Box::new(
569 OperatorServiceClient::new(svc.clone())
570 .max_decoding_message_size(get_decode_max_size()),
571 );
572 let cloud_svc_client = Box::new(
573 CloudServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
574 );
575 let test_svc_client = Box::new(
576 TestServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
577 );
578 let health_svc_client = Box::new(
579 HealthClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
580 );
581
582 Self {
583 workflow_svc_client,
584 operator_svc_client,
585 cloud_svc_client,
586 test_svc_client,
587 health_svc_client,
588 }
589 }
590
591 pub fn from_services(
594 workflow: Box<dyn WorkflowService>,
595 operator: Box<dyn OperatorService>,
596 cloud: Box<dyn CloudService>,
597 test: Box<dyn TestService>,
598 health: Box<dyn HealthService>,
599 ) -> Self {
600 Self {
601 workflow_svc_client: workflow,
602 operator_svc_client: operator,
603 cloud_svc_client: cloud,
604 test_svc_client: test,
605 health_svc_client: health,
606 }
607 }
608
609 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
611 self.workflow_svc_client.clone()
612 }
613 pub fn operator_service(&self) -> Box<dyn OperatorService> {
615 self.operator_svc_client.clone()
616 }
617 pub fn cloud_service(&self) -> Box<dyn CloudService> {
619 self.cloud_svc_client.clone()
620 }
621 pub fn test_service(&self) -> Box<dyn TestService> {
623 self.test_svc_client.clone()
624 }
625 pub fn health_service(&self) -> Box<dyn HealthService> {
627 self.health_svc_client.clone()
628 }
629}
630
631#[derive(Clone)]
634pub struct Client {
635 connection: Connection,
636 options: Arc<ClientOptions>,
637}
638
639impl Client {
640 pub fn new(connection: Connection, options: ClientOptions) -> Result<Self, ClientNewError> {
645 Ok(Client {
646 connection,
647 options: Arc::new(options),
648 })
649 }
650
651 pub fn options(&self) -> &ClientOptions {
653 &self.options
654 }
655
656 pub fn options_mut(&mut self) -> &mut ClientOptions {
661 Arc::make_mut(&mut self.options)
662 }
663
664 pub fn connection(&self) -> &Connection {
666 &self.connection
667 }
668
669 pub fn connection_mut(&mut self) -> &mut Connection {
671 &mut self.connection
672 }
673}
674
675impl Client {
679 pub async fn start_workflow<W>(
684 &self,
685 workflow: W,
686 input: W::Input,
687 options: WorkflowStartOptions,
688 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
689 where
690 W: HasWorkflowDefinition,
691 W::Input: Send,
692 {
693 WorkflowClientTrait::start_workflow(self, workflow, input, options).await
694 }
695
696 pub fn get_workflow_handle<W: HasWorkflowDefinition>(
700 &self,
701 workflow_id: impl Into<String>,
702 ) -> WorkflowHandle<Self, W> {
703 WorkflowClientTrait::get_workflow_handle(self, workflow_id)
704 }
705
706 pub fn list_workflows(
711 &self,
712 query: impl Into<String>,
713 opts: WorkflowListOptions,
714 ) -> ListWorkflowsStream {
715 WorkflowClientTrait::list_workflows(self, query, opts)
716 }
717
718 pub async fn count_workflows(
720 &self,
721 query: impl Into<String>,
722 opts: WorkflowCountOptions,
723 ) -> Result<WorkflowExecutionCount, ClientError> {
724 WorkflowClientTrait::count_workflows(self, query, opts).await
725 }
726
727 pub fn get_async_activity_handle(
731 &self,
732 identifier: ActivityIdentifier,
733 ) -> AsyncActivityHandle<Self> {
734 WorkflowClientTrait::get_async_activity_handle(self, identifier)
735 }
736}
737
738impl NamespacedClient for Client {
739 fn namespace(&self) -> String {
740 self.options.namespace.clone()
741 }
742
743 fn identity(&self) -> String {
744 self.connection.identity().to_owned()
745 }
746
747 fn data_converter(&self) -> &DataConverter {
748 &self.options.data_converter
749 }
750}
751
752#[derive(Clone)]
754pub enum Namespace {
755 Name(String),
757 Id(String),
759}
760
761impl Namespace {
762 pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
764 let (namespace, id) = match self {
765 Namespace::Name(n) => (n, "".to_owned()),
766 Namespace::Id(n) => ("".to_owned(), n),
767 };
768 DescribeNamespaceRequest { namespace, id }
769 }
770}
771
772pub(crate) trait WorkflowClientTrait: NamespacedClient {
775 fn start_workflow<W>(
777 &self,
778 workflow: W,
779 input: W::Input,
780 options: WorkflowStartOptions,
781 ) -> impl Future<Output = Result<WorkflowHandle<Self, W>, WorkflowStartError>>
782 where
783 Self: Sized,
784 W: HasWorkflowDefinition,
785 W::Input: Send;
786
787 fn get_workflow_handle<W: HasWorkflowDefinition>(
794 &self,
795 workflow_id: impl Into<String>,
796 ) -> WorkflowHandle<Self, W>
797 where
798 Self: Sized;
799
800 fn list_workflows(
804 &self,
805 query: impl Into<String>,
806 opts: WorkflowListOptions,
807 ) -> ListWorkflowsStream;
808
809 fn count_workflows(
811 &self,
812 query: impl Into<String>,
813 opts: WorkflowCountOptions,
814 ) -> impl Future<Output = Result<WorkflowExecutionCount, ClientError>>;
815
816 fn get_async_activity_handle(
820 &self,
821 identifier: ActivityIdentifier,
822 ) -> AsyncActivityHandle<Self>
823 where
824 Self: Sized;
825}
826
827pub trait NamespacedClient {
829 fn namespace(&self) -> String;
831 fn identity(&self) -> String;
833 fn data_converter(&self) -> &DataConverter {
836 static DEFAULT: OnceLock<DataConverter> = OnceLock::new();
837 DEFAULT.get_or_init(DataConverter::default)
838 }
839}
840
841#[derive(Debug, Clone)]
844pub struct WorkflowExecution {
845 raw: workflow::WorkflowExecutionInfo,
846}
847
848impl WorkflowExecution {
849 pub fn new(raw: workflow::WorkflowExecutionInfo) -> Self {
851 Self { raw }
852 }
853
854 pub fn id(&self) -> &str {
856 self.raw
857 .execution
858 .as_ref()
859 .map(|e| e.workflow_id.as_str())
860 .unwrap_or("")
861 }
862
863 pub fn run_id(&self) -> &str {
865 self.raw
866 .execution
867 .as_ref()
868 .map(|e| e.run_id.as_str())
869 .unwrap_or("")
870 }
871
872 pub fn workflow_type(&self) -> &str {
874 self.raw
875 .r#type
876 .as_ref()
877 .map(|t| t.name.as_str())
878 .unwrap_or("")
879 }
880
881 pub fn status(&self) -> WorkflowExecutionStatus {
883 self.raw.status()
884 }
885
886 pub fn start_time(&self) -> Option<SystemTime> {
888 self.raw
889 .start_time
890 .as_ref()
891 .and_then(proto_ts_to_system_time)
892 }
893
894 pub fn execution_time(&self) -> Option<SystemTime> {
896 self.raw
897 .execution_time
898 .as_ref()
899 .and_then(proto_ts_to_system_time)
900 }
901
902 pub fn close_time(&self) -> Option<SystemTime> {
904 self.raw
905 .close_time
906 .as_ref()
907 .and_then(proto_ts_to_system_time)
908 }
909
910 pub fn task_queue(&self) -> &str {
912 &self.raw.task_queue
913 }
914
915 pub fn history_length(&self) -> i64 {
917 self.raw.history_length
918 }
919
920 pub fn memo(&self) -> Option<&Memo> {
922 self.raw.memo.as_ref()
923 }
924
925 pub fn parent_id(&self) -> Option<&str> {
927 self.raw
928 .parent_execution
929 .as_ref()
930 .map(|e| e.workflow_id.as_str())
931 }
932
933 pub fn parent_run_id(&self) -> Option<&str> {
935 self.raw
936 .parent_execution
937 .as_ref()
938 .map(|e| e.run_id.as_str())
939 }
940
941 pub fn search_attributes(&self) -> Option<&SearchAttributes> {
943 self.raw.search_attributes.as_ref()
944 }
945
946 pub fn raw(&self) -> &workflow::WorkflowExecutionInfo {
948 &self.raw
949 }
950
951 pub fn into_raw(self) -> workflow::WorkflowExecutionInfo {
953 self.raw
954 }
955}
956
957impl From<workflow::WorkflowExecutionInfo> for WorkflowExecution {
958 fn from(raw: workflow::WorkflowExecutionInfo) -> Self {
959 Self::new(raw)
960 }
961}
962
963pub struct ListWorkflowsStream {
966 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
967}
968
969impl ListWorkflowsStream {
970 fn new(
971 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
972 ) -> Self {
973 Self { inner }
974 }
975}
976
977impl Stream for ListWorkflowsStream {
978 type Item = Result<WorkflowExecution, ClientError>;
979
980 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
981 self.inner.as_mut().poll_next(cx)
982 }
983}
984
985#[derive(Debug, Clone)]
990pub struct WorkflowExecutionCount {
991 count: usize,
992 groups: Vec<WorkflowCountAggregationGroup>,
993}
994
995impl WorkflowExecutionCount {
996 pub(crate) fn from_response(resp: CountWorkflowExecutionsResponse) -> Self {
997 Self {
998 count: resp.count as usize,
999 groups: resp
1000 .groups
1001 .into_iter()
1002 .map(WorkflowCountAggregationGroup::from_proto)
1003 .collect(),
1004 }
1005 }
1006
1007 pub fn count(&self) -> usize {
1010 self.count
1011 }
1012
1013 pub fn groups(&self) -> &[WorkflowCountAggregationGroup] {
1015 &self.groups
1016 }
1017}
1018
1019#[derive(Debug, Clone)]
1021pub struct WorkflowCountAggregationGroup {
1022 group_values: Vec<Payload>,
1023 count: usize,
1024}
1025
1026impl WorkflowCountAggregationGroup {
1027 fn from_proto(proto: count_workflow_executions_response::AggregationGroup) -> Self {
1028 Self {
1029 group_values: proto.group_values,
1030 count: proto.count as usize,
1031 }
1032 }
1033
1034 pub fn group_values(&self) -> &[Payload] {
1036 &self.group_values
1037 }
1038
1039 pub fn count(&self) -> usize {
1041 self.count
1042 }
1043}
1044
1045impl<T> WorkflowClientTrait for T
1046where
1047 T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static,
1048{
1049 async fn start_workflow<W>(
1050 &self,
1051 workflow: W,
1052 input: W::Input,
1053 options: WorkflowStartOptions,
1054 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
1055 where
1056 W: HasWorkflowDefinition,
1057 W::Input: Send,
1058 {
1059 let payloads = self
1060 .data_converter()
1061 .to_payloads(&SerializationContextData::Workflow, &input)
1062 .await?;
1063 let namespace = self.namespace();
1064 let workflow_id = options.workflow_id.clone();
1065 let task_queue_name = options.task_queue.clone();
1066
1067 let user_metadata = if options.static_summary.is_some() || options.static_details.is_some()
1068 {
1069 let payload_converter = PayloadConverter::default();
1070 let context = SerializationContext {
1071 data: &SerializationContextData::Workflow,
1072 converter: &payload_converter,
1073 };
1074 Some(UserMetadata {
1075 summary: options.static_summary.map(|s| {
1076 payload_converter
1077 .to_payload(&context, &s)
1078 .expect("String-to-JSON payload serialization is infallible")
1079 }),
1080 details: options.static_details.map(|s| {
1081 payload_converter
1082 .to_payload(&context, &s)
1083 .expect("String-to-JSON payload serialization is infallible")
1084 }),
1085 })
1086 } else {
1087 None
1088 };
1089
1090 let run_id = if let Some(start_signal) = options.start_signal {
1091 let res = WorkflowService::signal_with_start_workflow_execution(
1093 &mut self.clone(),
1094 SignalWithStartWorkflowExecutionRequest {
1095 namespace: namespace.clone(),
1096 workflow_id: workflow_id.clone(),
1097 workflow_type: Some(WorkflowType {
1098 name: workflow.name().to_string(),
1099 }),
1100 task_queue: Some(TaskQueue {
1101 name: task_queue_name,
1102 kind: TaskQueueKind::Normal as i32,
1103 normal_name: "".to_string(),
1104 }),
1105 input: payloads.into_payloads(),
1106 signal_name: start_signal.signal_name,
1107 signal_input: start_signal.input,
1108 identity: self.identity(),
1109 request_id: Uuid::new_v4().to_string(),
1110 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1111 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1112 workflow_execution_timeout: options
1113 .execution_timeout
1114 .and_then(|d| d.try_into().ok()),
1115 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1116 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1117 search_attributes: options.search_attributes.map(|d| d.into()),
1118 cron_schedule: options.cron_schedule.unwrap_or_default(),
1119 header: options.header.or(start_signal.header),
1120 user_metadata,
1121 ..Default::default()
1122 }
1123 .into_request(),
1124 )
1125 .await?
1126 .into_inner();
1127 res.run_id
1128 } else {
1129 let res = self
1131 .clone()
1132 .start_workflow_execution(
1133 StartWorkflowExecutionRequest {
1134 namespace: namespace.clone(),
1135 input: payloads.into_payloads(),
1136 workflow_id: workflow_id.clone(),
1137 workflow_type: Some(WorkflowType {
1138 name: workflow.name().to_string(),
1139 }),
1140 task_queue: Some(TaskQueue {
1141 name: task_queue_name,
1142 kind: TaskQueueKind::Unspecified as i32,
1143 normal_name: "".to_string(),
1144 }),
1145 request_id: Uuid::new_v4().to_string(),
1146 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1147 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1148 workflow_execution_timeout: options
1149 .execution_timeout
1150 .and_then(|d| d.try_into().ok()),
1151 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1152 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1153 search_attributes: options.search_attributes.map(|d| d.into()),
1154 cron_schedule: options.cron_schedule.unwrap_or_default(),
1155 request_eager_execution: options.enable_eager_workflow_start,
1156 retry_policy: options.retry_policy,
1157 links: options.links,
1158 completion_callbacks: options.completion_callbacks,
1159 priority: Some(options.priority.into()),
1160 header: options.header,
1161 user_metadata,
1162 ..Default::default()
1163 }
1164 .into_request(),
1165 )
1166 .await
1167 .map_err(|status| {
1168 if status.code() == Code::AlreadyExists {
1169 let run_id =
1170 decode_status_detail::<WorkflowExecutionAlreadyStartedFailure>(
1171 status.details(),
1172 )
1173 .map(|f| f.run_id);
1174 WorkflowStartError::AlreadyStarted {
1175 run_id,
1176 source: status,
1177 }
1178 } else {
1179 WorkflowStartError::Rpc(status)
1180 }
1181 })?
1182 .into_inner();
1183 res.run_id
1184 };
1185
1186 Ok(WorkflowHandle::new(
1187 self.clone(),
1188 WorkflowExecutionInfo {
1189 namespace,
1190 workflow_id,
1191 run_id: Some(run_id.clone()),
1192 first_execution_run_id: Some(run_id),
1193 },
1194 ))
1195 }
1196
1197 fn get_workflow_handle<W: HasWorkflowDefinition>(
1198 &self,
1199 workflow_id: impl Into<String>,
1200 ) -> WorkflowHandle<Self, W>
1201 where
1202 Self: Sized,
1203 {
1204 WorkflowHandle::new(
1205 self.clone(),
1206 WorkflowExecutionInfo {
1207 namespace: self.namespace(),
1208 workflow_id: workflow_id.into(),
1209 run_id: None,
1210 first_execution_run_id: None,
1211 },
1212 )
1213 }
1214
1215 fn list_workflows(
1216 &self,
1217 query: impl Into<String>,
1218 opts: WorkflowListOptions,
1219 ) -> ListWorkflowsStream {
1220 let client = self.clone();
1221 let namespace = self.namespace();
1222 let query = query.into();
1223 let limit = opts.limit;
1224
1225 let initial_state = (Vec::new(), VecDeque::new(), 0, false);
1227
1228 let stream = stream::unfold(
1229 initial_state,
1230 move |(next_page_token, mut buffer, mut yielded, exhausted)| {
1231 let mut client = client.clone();
1232 let namespace = namespace.clone();
1233 let query = query.clone();
1234
1235 async move {
1236 if let Some(l) = limit
1237 && yielded >= l
1238 {
1239 return None;
1240 }
1241
1242 if let Some(exec) = buffer.pop_front() {
1243 yielded += 1;
1244 return Some((Ok(exec), (next_page_token, buffer, yielded, exhausted)));
1245 }
1246
1247 if exhausted {
1248 return None;
1249 }
1250
1251 let response = WorkflowService::list_workflow_executions(
1252 &mut client,
1253 ListWorkflowExecutionsRequest {
1254 namespace,
1255 page_size: 0, next_page_token: next_page_token.clone(),
1257 query,
1258 }
1259 .into_request(),
1260 )
1261 .await;
1262
1263 match response {
1264 Ok(resp) => {
1265 let resp = resp.into_inner();
1266 let new_exhausted = resp.next_page_token.is_empty();
1267 let new_token = resp.next_page_token;
1268
1269 buffer = resp
1270 .executions
1271 .into_iter()
1272 .map(WorkflowExecution::from)
1273 .collect();
1274
1275 if let Some(exec) = buffer.pop_front() {
1276 yielded += 1;
1277 Some((Ok(exec), (new_token, buffer, yielded, new_exhausted)))
1278 } else {
1279 None
1280 }
1281 }
1282 Err(e) => Some((Err(e.into()), (next_page_token, buffer, yielded, true))),
1283 }
1284 }
1285 },
1286 );
1287
1288 ListWorkflowsStream::new(Box::pin(stream))
1289 }
1290
1291 async fn count_workflows(
1292 &self,
1293 query: impl Into<String>,
1294 _opts: WorkflowCountOptions,
1295 ) -> Result<WorkflowExecutionCount, ClientError> {
1296 let resp = WorkflowService::count_workflow_executions(
1297 &mut self.clone(),
1298 CountWorkflowExecutionsRequest {
1299 namespace: self.namespace(),
1300 query: query.into(),
1301 }
1302 .into_request(),
1303 )
1304 .await?
1305 .into_inner();
1306
1307 Ok(WorkflowExecutionCount::from_response(resp))
1308 }
1309
1310 fn get_async_activity_handle(&self, identifier: ActivityIdentifier) -> AsyncActivityHandle<Self>
1311 where
1312 Self: Sized,
1313 {
1314 AsyncActivityHandle::new(self.clone(), identifier)
1315 }
1316}
1317
1318macro_rules! dbg_panic {
1319 ($($arg:tt)*) => {
1320 use tracing::error;
1321 error!($($arg)*);
1322 debug_assert!(false, $($arg)*);
1323 };
1324}
1325pub(crate) use dbg_panic;
1326
1327#[cfg(test)]
1328mod tests {
1329 use super::*;
1330 use tonic::metadata::Ascii;
1331 use url::Url;
1332
1333 #[test]
1334 fn applies_headers() {
1335 let headers = Arc::new(RwLock::new(ClientHeaders {
1337 user_headers: HashMap::new(),
1338 user_binary_headers: HashMap::new(),
1339 api_key: Some("my-api-key".to_owned()),
1340 }));
1341 headers.clone().write().user_headers.insert(
1342 "my-meta-key".parse().unwrap(),
1343 "my-meta-val".parse().unwrap(),
1344 );
1345 headers.clone().write().user_binary_headers.insert(
1346 "my-bin-meta-key-bin".parse().unwrap(),
1347 vec![1, 2, 3].try_into().unwrap(),
1348 );
1349 let mut interceptor = ServiceCallInterceptor {
1350 client_name: "cute-kitty".to_string(),
1351 client_version: "0.1.0".to_string(),
1352 headers: headers.clone(),
1353 };
1354
1355 let req = interceptor.call(tonic::Request::new(())).unwrap();
1357 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1358 assert_eq!(
1359 req.metadata().get("authorization").unwrap(),
1360 "Bearer my-api-key"
1361 );
1362 assert_eq!(
1363 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1364 vec![1, 2, 3].as_slice()
1365 );
1366
1367 let mut req = tonic::Request::new(());
1369 req.metadata_mut()
1370 .insert("my-meta-key", "my-meta-val2".parse().unwrap());
1371 req.metadata_mut()
1372 .insert("authorization", "my-api-key2".parse().unwrap());
1373 req.metadata_mut()
1374 .insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
1375 let req = interceptor.call(req).unwrap();
1376 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1377 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1378 assert_eq!(
1379 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1380 vec![4, 5, 6].as_slice()
1381 );
1382
1383 headers.clone().write().user_headers.insert(
1385 "authorization".parse().unwrap(),
1386 "my-api-key3".parse().unwrap(),
1387 );
1388 let req = interceptor.call(tonic::Request::new(())).unwrap();
1389 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1390 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1391
1392 headers.clone().write().user_headers.clear();
1394 headers.clone().write().user_binary_headers.clear();
1395 headers.clone().write().api_key.take();
1396 let req = interceptor.call(tonic::Request::new(())).unwrap();
1397 assert!(!req.metadata().contains_key("my-meta-key"));
1398 assert!(!req.metadata().contains_key("authorization"));
1399 assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));
1400
1401 let mut req = tonic::Request::new(());
1403 req.metadata_mut()
1404 .insert("grpc-timeout", "1S".parse().unwrap());
1405 let req = interceptor.call(req).unwrap();
1406 assert_eq!(
1407 req.metadata().get("grpc-timeout").unwrap(),
1408 "1S".parse::<MetadataValue<Ascii>>().unwrap()
1409 );
1410 }
1411
1412 #[test]
1413 fn invalid_ascii_header_key() {
1414 let invalid_headers = {
1415 let mut h = HashMap::new();
1416 h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
1417 h
1418 };
1419
1420 let result = parse_ascii_headers(invalid_headers);
1421 assert!(result.is_err());
1422 assert_eq!(
1423 result.err().unwrap().to_string(),
1424 "Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
1425 );
1426 }
1427
1428 #[test]
1429 fn invalid_ascii_header_value() {
1430 let invalid_headers = {
1431 let mut h = HashMap::new();
1432 h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
1434 h
1435 };
1436
1437 let result = parse_ascii_headers(invalid_headers);
1438 assert!(result.is_err());
1439 assert_eq!(
1440 result.err().unwrap().to_string(),
1441 "Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
1442 );
1443 }
1444
1445 #[test]
1446 fn invalid_binary_header_key() {
1447 let invalid_headers = {
1448 let mut h = HashMap::new();
1449 h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
1450 h
1451 };
1452
1453 let result = parse_binary_headers(invalid_headers);
1454 assert!(result.is_err());
1455 assert_eq!(
1456 result.err().unwrap().to_string(),
1457 "Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
1458 );
1459 }
1460
1461 #[test]
1462 fn keep_alive_defaults() {
1463 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1464 .identity("enchicat".to_string())
1465 .client_name("cute-kitty".to_string())
1466 .client_version("0.1.0".to_string())
1467 .build();
1468 assert_eq!(
1469 opts.keep_alive.clone().unwrap().interval,
1470 ClientKeepAliveOptions::default().interval
1471 );
1472 assert_eq!(
1473 opts.keep_alive.clone().unwrap().timeout,
1474 ClientKeepAliveOptions::default().timeout
1475 );
1476
1477 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1479 .identity("enchicat".to_string())
1480 .client_name("cute-kitty".to_string())
1481 .client_version("0.1.0".to_string())
1482 .keep_alive(None)
1483 .build();
1484 dbg!(&opts.keep_alive);
1485 assert!(opts.keep_alive.is_none());
1486 }
1487
1488 mod list_workflows_tests {
1489 use super::*;
1490 use futures_util::{FutureExt, StreamExt};
1491 use std::sync::atomic::{AtomicUsize, Ordering};
1492 use temporalio_common::protos::temporal::api::common::v1::WorkflowExecution as ProtoWorkflowExecution;
1493 use tonic::{Request, Response};
1494
1495 #[derive(Clone)]
1496 struct MockListWorkflowsClient {
1497 call_count: Arc<AtomicUsize>,
1498 page_size: usize,
1500 total_workflows: usize,
1502 }
1503
1504 impl NamespacedClient for MockListWorkflowsClient {
1505 fn namespace(&self) -> String {
1506 "test-namespace".to_string()
1507 }
1508 fn identity(&self) -> String {
1509 "test-identity".to_string()
1510 }
1511 }
1512
1513 impl WorkflowService for MockListWorkflowsClient {
1514 fn list_workflow_executions(
1515 &mut self,
1516 request: Request<ListWorkflowExecutionsRequest>,
1517 ) -> futures_util::future::BoxFuture<
1518 '_,
1519 Result<Response<ListWorkflowExecutionsResponse>, tonic::Status>,
1520 > {
1521 self.call_count.fetch_add(1, Ordering::SeqCst);
1522 let req = request.into_inner();
1523
1524 let offset: usize = if req.next_page_token.is_empty() {
1526 0
1527 } else {
1528 String::from_utf8(req.next_page_token)
1529 .unwrap()
1530 .parse()
1531 .unwrap()
1532 };
1533
1534 let remaining = self.total_workflows.saturating_sub(offset);
1535 let count = remaining.min(self.page_size);
1536 let new_offset = offset + count;
1537
1538 let executions: Vec<_> = (offset..offset + count)
1539 .map(|i| workflow::WorkflowExecutionInfo {
1540 execution: Some(ProtoWorkflowExecution {
1541 workflow_id: format!("wf-{i}"),
1542 run_id: format!("run-{i}"),
1543 }),
1544 r#type: Some(WorkflowType {
1545 name: "TestWorkflow".to_string(),
1546 }),
1547 task_queue: "test-queue".to_string(),
1548 ..Default::default()
1549 })
1550 .collect();
1551
1552 let next_page_token = if new_offset < self.total_workflows {
1553 new_offset.to_string().into_bytes()
1554 } else {
1555 vec![]
1556 };
1557
1558 async move {
1559 Ok(Response::new(ListWorkflowExecutionsResponse {
1560 executions,
1561 next_page_token,
1562 }))
1563 }
1564 .boxed()
1565 }
1566 }
1567
1568 #[tokio::test]
1569 async fn list_workflows_paginates_through_all_results() {
1570 let call_count = Arc::new(AtomicUsize::new(0));
1571 let client = MockListWorkflowsClient {
1572 call_count: call_count.clone(),
1573 page_size: 3,
1574 total_workflows: 10,
1575 };
1576
1577 let stream = client.list_workflows("", WorkflowListOptions::default());
1578 let results: Vec<_> = stream.collect().await;
1579
1580 assert_eq!(results.len(), 10);
1581 for (i, result) in results.iter().enumerate() {
1582 let wf = result.as_ref().unwrap();
1583 assert_eq!(wf.id(), format!("wf-{i}"));
1584 assert_eq!(wf.run_id(), format!("run-{i}"));
1585 }
1586 assert_eq!(call_count.load(Ordering::SeqCst), 4);
1588 }
1589
1590 #[tokio::test]
1591 async fn list_workflows_respects_limit() {
1592 let call_count = Arc::new(AtomicUsize::new(0));
1593 let client = MockListWorkflowsClient {
1594 call_count: call_count.clone(),
1595 page_size: 3,
1596 total_workflows: 10,
1597 };
1598
1599 let opts = WorkflowListOptions::builder().limit(5).build();
1600 let stream = client.list_workflows("", opts);
1601 let results: Vec<_> = stream.collect().await;
1602
1603 assert_eq!(results.len(), 5);
1604 for (i, result) in results.iter().enumerate() {
1605 let wf = result.as_ref().unwrap();
1606 assert_eq!(wf.id(), format!("wf-{i}"));
1607 }
1608 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1610 }
1611
1612 #[tokio::test]
1613 async fn list_workflows_limit_less_than_page_size() {
1614 let call_count = Arc::new(AtomicUsize::new(0));
1615 let client = MockListWorkflowsClient {
1616 call_count: call_count.clone(),
1617 page_size: 10,
1618 total_workflows: 100,
1619 };
1620
1621 let opts = WorkflowListOptions::builder().limit(3).build();
1622 let stream = client.list_workflows("", opts);
1623 let results: Vec<_> = stream.collect().await;
1624
1625 assert_eq!(results.len(), 3);
1626 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1628 }
1629
1630 #[tokio::test]
1631 async fn list_workflows_empty_results() {
1632 let call_count = Arc::new(AtomicUsize::new(0));
1633 let client = MockListWorkflowsClient {
1634 call_count: call_count.clone(),
1635 page_size: 10,
1636 total_workflows: 0,
1637 };
1638
1639 let stream = client.list_workflows("", WorkflowListOptions::default());
1640 let results: Vec<_> = stream.collect().await;
1641
1642 assert_eq!(results.len(), 0);
1643 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1644 }
1645 }
1646}