1#![warn(missing_docs)] #[macro_use]
8extern crate tracing;
9
10mod async_activity_handle;
11pub mod callback_based;
12pub mod errors;
13pub mod grpc;
14mod metrics;
15mod options_structs;
16#[doc(hidden)]
18pub mod proxy;
19mod replaceable;
20pub mod request_extensions;
21mod retry;
22pub mod worker;
23mod workflow_handle;
24
25pub use crate::{
26 proxy::HttpConnectProxyOptions,
27 retry::{CallType, RETRYABLE_ERROR_CODES},
28};
29pub use async_activity_handle::{
30 ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle,
31};
32
33pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
34pub use options_structs::*;
35pub use replaceable::SharedReplaceableClient;
36pub use retry::RetryOptions;
37pub use tonic;
38pub use workflow_handle::{
39 UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle,
40 WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
41 WorkflowHistory, WorkflowUpdateHandle,
42};
43
44use crate::{
45 grpc::{
46 AttachMetricLabels, CloudService, HealthService, OperatorService, TestService,
47 WorkflowService,
48 },
49 metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
50 request_extensions::RequestExt,
51 worker::ClientWorkerSet,
52};
53use errors::*;
54use futures_util::{stream, stream::Stream};
55use http::Uri;
56use parking_lot::RwLock;
57use std::{
58 collections::{HashMap, VecDeque},
59 fmt::Debug,
60 pin::Pin,
61 str::FromStr,
62 sync::{Arc, OnceLock},
63 task::{Context, Poll},
64 time::{Duration, SystemTime},
65};
66use temporalio_common::{
67 WorkflowDefinition,
68 data_converters::{DataConverter, SerializationContextData},
69 protos::{
70 coresdk::IntoPayloadsExt,
71 grpc::health::v1::health_client::HealthClient,
72 proto_ts_to_system_time,
73 temporal::api::{
74 cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
75 common::v1::{Memo, Payload, SearchAttributes, WorkflowType},
76 enums::v1::{TaskQueueKind, WorkflowExecutionStatus},
77 errordetails::v1::WorkflowExecutionAlreadyStartedFailure,
78 operatorservice::v1::operator_service_client::OperatorServiceClient,
79 taskqueue::v1::TaskQueue,
80 testservice::v1::test_service_client::TestServiceClient,
81 workflow::v1 as workflow,
82 workflowservice::v1::{
83 count_workflow_executions_response, workflow_service_client::WorkflowServiceClient,
84 *,
85 },
86 },
87 utilities::decode_status_detail,
88 },
89};
90use tonic::{
91 Code, IntoRequest,
92 body::Body,
93 client::GrpcService,
94 codegen::InterceptedService,
95 metadata::{
96 AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
97 MetadataValue,
98 },
99 service::Interceptor,
100 transport::{Certificate, Channel, Endpoint, Identity},
101};
102use tower::ServiceBuilder;
103use uuid::Uuid;
104
105static CLIENT_NAME_HEADER_KEY: &str = "client-name";
106static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
107static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
108
109#[doc(hidden)]
110pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
112#[doc(hidden)]
113pub static ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT: &str = "short-circuit";
115
116const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
118const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
119const VERSION: &str = env!("CARGO_PKG_VERSION");
120
121#[derive(Clone)]
126pub struct Connection {
127 inner: Arc<ConnectionInner>,
128}
129
130#[derive(Clone)]
131struct ConnectionInner {
132 service: TemporalServiceClient,
133 retry_options: RetryOptions,
134 identity: String,
135 headers: Arc<RwLock<ClientHeaders>>,
136 client_name: String,
137 client_version: String,
138 capabilities: Option<get_system_info_response::Capabilities>,
140 workers: Arc<ClientWorkerSet>,
141}
142
143impl Connection {
144 pub async fn connect(options: ConnectionOptions) -> Result<Self, ClientConnectError> {
146 let service = if let Some(service_override) = options.service_override {
147 GrpcMetricSvc {
148 inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
149 metrics: options.metrics_meter.clone().map(MetricsContext::new),
150 disable_errcode_label: options.disable_error_code_metric_tags,
151 }
152 } else {
153 let channel = Channel::from_shared(options.target.to_string())?;
154 let channel = add_tls_to_channel(options.tls_options.as_ref(), channel).await?;
155 let channel = if let Some(keep_alive) = options.keep_alive.as_ref() {
156 channel
157 .keep_alive_while_idle(true)
158 .http2_keep_alive_interval(keep_alive.interval)
159 .keep_alive_timeout(keep_alive.timeout)
160 } else {
161 channel
162 };
163 let channel = if let Some(origin) = options.override_origin.clone() {
164 channel.origin(origin)
165 } else {
166 channel
167 };
168 let channel = if let Some(proxy) = options.http_connect_proxy.as_ref() {
170 proxy.connect_endpoint(&channel).await?
171 } else {
172 channel.connect().await?
173 };
174 ServiceBuilder::new()
175 .layer_fn(move |channel| GrpcMetricSvc {
176 inner: ChannelOrGrpcOverride::Channel(channel),
177 metrics: options.metrics_meter.clone().map(MetricsContext::new),
178 disable_errcode_label: options.disable_error_code_metric_tags,
179 })
180 .service(channel)
181 };
182
183 let headers = Arc::new(RwLock::new(ClientHeaders {
184 user_headers: parse_ascii_headers(options.headers.clone().unwrap_or_default())?,
185 user_binary_headers: parse_binary_headers(
186 options.binary_headers.clone().unwrap_or_default(),
187 )?,
188 api_key: options.api_key.clone(),
189 }));
190 let interceptor = ServiceCallInterceptor {
191 client_name: options.client_name.clone(),
192 client_version: options.client_version.clone(),
193 headers: headers.clone(),
194 };
195 let svc = InterceptedService::new(service, interceptor);
196 let mut svc_client = TemporalServiceClient::new(svc);
197
198 let capabilities = if !options.skip_get_system_info {
199 match svc_client
200 .get_system_info(GetSystemInfoRequest::default().into_request())
201 .await
202 {
203 Ok(sysinfo) => sysinfo.into_inner().capabilities,
204 Err(status) => match status.code() {
205 Code::Unimplemented => None,
206 _ => return Err(ClientConnectError::SystemInfoCallError(status)),
207 },
208 }
209 } else {
210 None
211 };
212 Ok(Self {
213 inner: Arc::new(ConnectionInner {
214 service: svc_client,
215 retry_options: options.retry_options,
216 identity: options.identity,
217 headers,
218 client_name: options.client_name,
219 client_version: options.client_version,
220 capabilities,
221 workers: Arc::new(ClientWorkerSet::new()),
222 }),
223 })
224 }
225
226 pub fn set_api_key(&self, api_key: Option<String>) {
228 self.inner.headers.write().api_key = api_key;
229 }
230
231 pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
240 self.inner.headers.write().user_headers = parse_ascii_headers(headers)?;
241 Ok(())
242 }
243
244 pub fn set_binary_headers(
253 &self,
254 binary_headers: HashMap<String, Vec<u8>>,
255 ) -> Result<(), InvalidHeaderError> {
256 self.inner.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
257 Ok(())
258 }
259
260 pub fn client_name(&self) -> &str {
262 &self.inner.client_name
263 }
264
265 pub fn client_version(&self) -> &str {
267 &self.inner.client_version
268 }
269
270 pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
273 self.inner.capabilities.as_ref()
274 }
275
276 pub fn retry_options_mut(&mut self) -> &mut RetryOptions {
281 &mut Arc::make_mut(&mut self.inner).retry_options
282 }
283
284 pub fn identity(&self) -> &str {
286 &self.inner.identity
287 }
288
289 pub fn identity_mut(&mut self) -> &mut String {
294 &mut Arc::make_mut(&mut self.inner).identity
295 }
296
297 pub fn workers(&self) -> Arc<ClientWorkerSet> {
299 self.inner.workers.clone()
300 }
301
302 pub fn worker_grouping_key(&self) -> Uuid {
304 self.inner.workers.worker_grouping_key()
305 }
306
307 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
309 self.inner.service.workflow_service()
310 }
311
312 pub fn operator_service(&self) -> Box<dyn OperatorService> {
314 self.inner.service.operator_service()
315 }
316
317 pub fn cloud_service(&self) -> Box<dyn CloudService> {
319 self.inner.service.cloud_service()
320 }
321
322 pub fn test_service(&self) -> Box<dyn TestService> {
324 self.inner.service.test_service()
325 }
326
327 pub fn health_service(&self) -> Box<dyn HealthService> {
329 self.inner.service.health_service()
330 }
331}
332
333#[derive(Debug)]
334struct ClientHeaders {
335 user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
336 user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
337 api_key: Option<String>,
338}
339
340impl ClientHeaders {
341 fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
342 for (key, val) in self.user_headers.iter() {
343 if !metadata.contains_key(key) {
345 metadata.insert(key, val.clone());
346 }
347 }
348 for (key, val) in self.user_binary_headers.iter() {
349 if !metadata.contains_key(key) {
351 metadata.insert_bin(key, val.clone());
352 }
353 }
354 if let Some(api_key) = &self.api_key {
355 if !metadata.contains_key("authorization")
357 && let Ok(val) = format!("Bearer {api_key}").parse()
358 {
359 metadata.insert("authorization", val);
360 }
361 }
362 }
363}
364
365async fn add_tls_to_channel(
368 tls_options: Option<&TlsOptions>,
369 mut channel: Endpoint,
370) -> Result<Endpoint, ClientConnectError> {
371 if let Some(tls_cfg) = tls_options {
372 let mut tls = tonic::transport::ClientTlsConfig::new();
373
374 if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
375 let server_root_ca_cert = Certificate::from_pem(root_cert);
376 tls = tls.ca_certificate(server_root_ca_cert);
377 } else {
378 tls = tls.with_native_roots();
379 }
380
381 if let Some(domain) = &tls_cfg.domain {
382 tls = tls.domain_name(domain);
383
384 let uri: Uri = format!("https://{domain}").parse()?;
389 channel = channel.origin(uri);
390 }
391
392 if let Some(client_opts) = &tls_cfg.client_tls_options {
393 let client_identity =
394 Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
395 tls = tls.identity(client_identity);
396 }
397
398 return channel.tls_config(tls).map_err(Into::into);
399 }
400 Ok(channel)
401}
402
403fn parse_ascii_headers(
404 headers: HashMap<String, String>,
405) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
406 let mut parsed_headers = HashMap::with_capacity(headers.len());
407 for (k, v) in headers.into_iter() {
408 let key = match AsciiMetadataKey::from_str(&k) {
409 Ok(key) => key,
410 Err(err) => {
411 return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
412 key: k,
413 source: err,
414 });
415 }
416 };
417 let value = match MetadataValue::from_str(&v) {
418 Ok(value) => value,
419 Err(err) => {
420 return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
421 key: k,
422 value: v,
423 source: err,
424 });
425 }
426 };
427 parsed_headers.insert(key, value);
428 }
429
430 Ok(parsed_headers)
431}
432
433fn parse_binary_headers(
434 headers: HashMap<String, Vec<u8>>,
435) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
436 let mut parsed_headers = HashMap::with_capacity(headers.len());
437 for (k, v) in headers.into_iter() {
438 let key = match BinaryMetadataKey::from_str(&k) {
439 Ok(key) => key,
440 Err(err) => {
441 return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
442 key: k,
443 source: err,
444 });
445 }
446 };
447 let value = BinaryMetadataValue::from_bytes(&v);
448 parsed_headers.insert(key, value);
449 }
450
451 Ok(parsed_headers)
452}
453
454#[derive(Clone)]
456pub struct ServiceCallInterceptor {
457 client_name: String,
458 client_version: String,
459 headers: Arc<RwLock<ClientHeaders>>,
461}
462
463impl Interceptor for ServiceCallInterceptor {
464 fn call(
467 &mut self,
468 mut request: tonic::Request<()>,
469 ) -> Result<tonic::Request<()>, tonic::Status> {
470 let metadata = request.metadata_mut();
471 if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
472 metadata.insert(
473 CLIENT_NAME_HEADER_KEY,
474 self.client_name
475 .parse()
476 .unwrap_or_else(|_| MetadataValue::from_static("")),
477 );
478 }
479 if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
480 metadata.insert(
481 CLIENT_VERSION_HEADER_KEY,
482 self.client_version
483 .parse()
484 .unwrap_or_else(|_| MetadataValue::from_static("")),
485 );
486 }
487 self.headers.read().apply_to_metadata(metadata);
488 request.set_default_timeout(OTHER_CALL_TIMEOUT);
489
490 Ok(request)
491 }
492}
493
494#[derive(Clone)]
496pub struct TemporalServiceClient {
497 workflow_svc_client: Box<dyn WorkflowService>,
498 operator_svc_client: Box<dyn OperatorService>,
499 cloud_svc_client: Box<dyn CloudService>,
500 test_svc_client: Box<dyn TestService>,
501 health_svc_client: Box<dyn HealthService>,
502}
503
504fn get_decode_max_size() -> usize {
507 static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
508 *_DECODE_MAX_SIZE.get_or_init(|| {
509 std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
510 .ok()
511 .and_then(|s| s.parse().ok())
512 .unwrap_or(128 * 1024 * 1024)
513 })
514}
515
516impl TemporalServiceClient {
517 fn new<T>(svc: T) -> Self
518 where
519 T: GrpcService<Body> + Send + Sync + Clone + 'static,
520 T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
521 T::Error: Into<tonic::codegen::StdError>,
522 <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
523 <T as GrpcService<Body>>::Future: Send,
524 {
525 let workflow_svc_client = Box::new(
526 WorkflowServiceClient::new(svc.clone())
527 .max_decoding_message_size(get_decode_max_size()),
528 );
529 let operator_svc_client = Box::new(
530 OperatorServiceClient::new(svc.clone())
531 .max_decoding_message_size(get_decode_max_size()),
532 );
533 let cloud_svc_client = Box::new(
534 CloudServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
535 );
536 let test_svc_client = Box::new(
537 TestServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
538 );
539 let health_svc_client = Box::new(
540 HealthClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
541 );
542
543 Self {
544 workflow_svc_client,
545 operator_svc_client,
546 cloud_svc_client,
547 test_svc_client,
548 health_svc_client,
549 }
550 }
551
552 pub fn from_services(
555 workflow: Box<dyn WorkflowService>,
556 operator: Box<dyn OperatorService>,
557 cloud: Box<dyn CloudService>,
558 test: Box<dyn TestService>,
559 health: Box<dyn HealthService>,
560 ) -> Self {
561 Self {
562 workflow_svc_client: workflow,
563 operator_svc_client: operator,
564 cloud_svc_client: cloud,
565 test_svc_client: test,
566 health_svc_client: health,
567 }
568 }
569
570 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
572 self.workflow_svc_client.clone()
573 }
574 pub fn operator_service(&self) -> Box<dyn OperatorService> {
576 self.operator_svc_client.clone()
577 }
578 pub fn cloud_service(&self) -> Box<dyn CloudService> {
580 self.cloud_svc_client.clone()
581 }
582 pub fn test_service(&self) -> Box<dyn TestService> {
584 self.test_svc_client.clone()
585 }
586 pub fn health_service(&self) -> Box<dyn HealthService> {
588 self.health_svc_client.clone()
589 }
590}
591
592#[derive(Clone)]
595pub struct Client {
596 connection: Connection,
597 options: Arc<ClientOptions>,
598}
599
600impl Client {
601 pub fn new(connection: Connection, options: ClientOptions) -> Result<Self, ClientNewError> {
606 Ok(Client {
607 connection,
608 options: Arc::new(options),
609 })
610 }
611
612 pub fn options(&self) -> &ClientOptions {
614 &self.options
615 }
616
617 pub fn options_mut(&mut self) -> &mut ClientOptions {
622 Arc::make_mut(&mut self.options)
623 }
624
625 pub fn connection(&self) -> &Connection {
627 &self.connection
628 }
629
630 pub fn connection_mut(&mut self) -> &mut Connection {
632 &mut self.connection
633 }
634}
635
636impl Client {
640 pub async fn start_workflow<W>(
645 &self,
646 workflow: W,
647 input: W::Input,
648 options: WorkflowStartOptions,
649 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
650 where
651 W: WorkflowDefinition,
652 W::Input: Send,
653 {
654 WorkflowClientTrait::start_workflow(self, workflow, input, options).await
655 }
656
657 pub fn get_workflow_handle<W: WorkflowDefinition>(
661 &self,
662 workflow_id: impl Into<String>,
663 ) -> WorkflowHandle<Self, W> {
664 WorkflowClientTrait::get_workflow_handle(self, workflow_id)
665 }
666
667 pub fn list_workflows(
672 &self,
673 query: impl Into<String>,
674 opts: WorkflowListOptions,
675 ) -> ListWorkflowsStream {
676 WorkflowClientTrait::list_workflows(self, query, opts)
677 }
678
679 pub async fn count_workflows(
681 &self,
682 query: impl Into<String>,
683 opts: WorkflowCountOptions,
684 ) -> Result<WorkflowExecutionCount, ClientError> {
685 WorkflowClientTrait::count_workflows(self, query, opts).await
686 }
687
688 pub fn get_async_activity_handle(
692 &self,
693 identifier: ActivityIdentifier,
694 ) -> AsyncActivityHandle<Self> {
695 WorkflowClientTrait::get_async_activity_handle(self, identifier)
696 }
697}
698
699impl NamespacedClient for Client {
700 fn namespace(&self) -> String {
701 self.options.namespace.clone()
702 }
703
704 fn identity(&self) -> String {
705 self.connection.identity().to_owned()
706 }
707
708 fn data_converter(&self) -> &DataConverter {
709 &self.options.data_converter
710 }
711}
712
713#[derive(Clone)]
715pub enum Namespace {
716 Name(String),
718 Id(String),
720}
721
722impl Namespace {
723 pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
725 let (namespace, id) = match self {
726 Namespace::Name(n) => (n, "".to_owned()),
727 Namespace::Id(n) => ("".to_owned(), n),
728 };
729 DescribeNamespaceRequest { namespace, id }
730 }
731}
732
733pub(crate) trait WorkflowClientTrait: NamespacedClient {
736 fn start_workflow<W>(
738 &self,
739 workflow: W,
740 input: W::Input,
741 options: WorkflowStartOptions,
742 ) -> impl Future<Output = Result<WorkflowHandle<Self, W>, WorkflowStartError>>
743 where
744 Self: Sized,
745 W: WorkflowDefinition,
746 W::Input: Send;
747
748 fn get_workflow_handle<W: WorkflowDefinition>(
755 &self,
756 workflow_id: impl Into<String>,
757 ) -> WorkflowHandle<Self, W>
758 where
759 Self: Sized;
760
761 fn list_workflows(
765 &self,
766 query: impl Into<String>,
767 opts: WorkflowListOptions,
768 ) -> ListWorkflowsStream;
769
770 fn count_workflows(
772 &self,
773 query: impl Into<String>,
774 opts: WorkflowCountOptions,
775 ) -> impl Future<Output = Result<WorkflowExecutionCount, ClientError>>;
776
777 fn get_async_activity_handle(
781 &self,
782 identifier: ActivityIdentifier,
783 ) -> AsyncActivityHandle<Self>
784 where
785 Self: Sized;
786}
787
788pub trait NamespacedClient {
790 fn namespace(&self) -> String;
792 fn identity(&self) -> String;
794 fn data_converter(&self) -> &DataConverter {
797 static DEFAULT: OnceLock<DataConverter> = OnceLock::new();
798 DEFAULT.get_or_init(DataConverter::default)
799 }
800}
801
802#[derive(Debug, Clone)]
805pub struct WorkflowExecution {
806 raw: workflow::WorkflowExecutionInfo,
807}
808
809impl WorkflowExecution {
810 pub fn new(raw: workflow::WorkflowExecutionInfo) -> Self {
812 Self { raw }
813 }
814
815 pub fn id(&self) -> &str {
817 self.raw
818 .execution
819 .as_ref()
820 .map(|e| e.workflow_id.as_str())
821 .unwrap_or("")
822 }
823
824 pub fn run_id(&self) -> &str {
826 self.raw
827 .execution
828 .as_ref()
829 .map(|e| e.run_id.as_str())
830 .unwrap_or("")
831 }
832
833 pub fn workflow_type(&self) -> &str {
835 self.raw
836 .r#type
837 .as_ref()
838 .map(|t| t.name.as_str())
839 .unwrap_or("")
840 }
841
842 pub fn status(&self) -> WorkflowExecutionStatus {
844 self.raw.status()
845 }
846
847 pub fn start_time(&self) -> Option<SystemTime> {
849 self.raw
850 .start_time
851 .as_ref()
852 .and_then(proto_ts_to_system_time)
853 }
854
855 pub fn execution_time(&self) -> Option<SystemTime> {
857 self.raw
858 .execution_time
859 .as_ref()
860 .and_then(proto_ts_to_system_time)
861 }
862
863 pub fn close_time(&self) -> Option<SystemTime> {
865 self.raw
866 .close_time
867 .as_ref()
868 .and_then(proto_ts_to_system_time)
869 }
870
871 pub fn task_queue(&self) -> &str {
873 &self.raw.task_queue
874 }
875
876 pub fn history_length(&self) -> i64 {
878 self.raw.history_length
879 }
880
881 pub fn memo(&self) -> Option<&Memo> {
883 self.raw.memo.as_ref()
884 }
885
886 pub fn parent_id(&self) -> Option<&str> {
888 self.raw
889 .parent_execution
890 .as_ref()
891 .map(|e| e.workflow_id.as_str())
892 }
893
894 pub fn parent_run_id(&self) -> Option<&str> {
896 self.raw
897 .parent_execution
898 .as_ref()
899 .map(|e| e.run_id.as_str())
900 }
901
902 pub fn search_attributes(&self) -> Option<&SearchAttributes> {
904 self.raw.search_attributes.as_ref()
905 }
906
907 pub fn raw(&self) -> &workflow::WorkflowExecutionInfo {
909 &self.raw
910 }
911
912 pub fn into_raw(self) -> workflow::WorkflowExecutionInfo {
914 self.raw
915 }
916}
917
918impl From<workflow::WorkflowExecutionInfo> for WorkflowExecution {
919 fn from(raw: workflow::WorkflowExecutionInfo) -> Self {
920 Self::new(raw)
921 }
922}
923
924pub struct ListWorkflowsStream {
927 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
928}
929
930impl ListWorkflowsStream {
931 fn new(
932 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
933 ) -> Self {
934 Self { inner }
935 }
936}
937
938impl Stream for ListWorkflowsStream {
939 type Item = Result<WorkflowExecution, ClientError>;
940
941 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
942 self.inner.as_mut().poll_next(cx)
943 }
944}
945
946#[derive(Debug, Clone)]
951pub struct WorkflowExecutionCount {
952 count: usize,
953 groups: Vec<WorkflowCountAggregationGroup>,
954}
955
956impl WorkflowExecutionCount {
957 pub(crate) fn from_response(resp: CountWorkflowExecutionsResponse) -> Self {
958 Self {
959 count: resp.count as usize,
960 groups: resp
961 .groups
962 .into_iter()
963 .map(WorkflowCountAggregationGroup::from_proto)
964 .collect(),
965 }
966 }
967
968 pub fn count(&self) -> usize {
971 self.count
972 }
973
974 pub fn groups(&self) -> &[WorkflowCountAggregationGroup] {
976 &self.groups
977 }
978}
979
980#[derive(Debug, Clone)]
982pub struct WorkflowCountAggregationGroup {
983 group_values: Vec<Payload>,
984 count: usize,
985}
986
987impl WorkflowCountAggregationGroup {
988 fn from_proto(proto: count_workflow_executions_response::AggregationGroup) -> Self {
989 Self {
990 group_values: proto.group_values,
991 count: proto.count as usize,
992 }
993 }
994
995 pub fn group_values(&self) -> &[Payload] {
997 &self.group_values
998 }
999
1000 pub fn count(&self) -> usize {
1002 self.count
1003 }
1004}
1005
1006impl<T> WorkflowClientTrait for T
1007where
1008 T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static,
1009{
1010 async fn start_workflow<W>(
1011 &self,
1012 workflow: W,
1013 input: W::Input,
1014 options: WorkflowStartOptions,
1015 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
1016 where
1017 W: WorkflowDefinition,
1018 W::Input: Send,
1019 {
1020 let payloads = self
1021 .data_converter()
1022 .to_payloads(&SerializationContextData::Workflow, &input)
1023 .await?;
1024 let namespace = self.namespace();
1025 let workflow_id = options.workflow_id.clone();
1026 let task_queue_name = options.task_queue.clone();
1027
1028 let run_id = if let Some(start_signal) = options.start_signal {
1029 let res = WorkflowService::signal_with_start_workflow_execution(
1031 &mut self.clone(),
1032 SignalWithStartWorkflowExecutionRequest {
1033 namespace: namespace.clone(),
1034 workflow_id: workflow_id.clone(),
1035 workflow_type: Some(WorkflowType {
1036 name: workflow.name().to_string(),
1037 }),
1038 task_queue: Some(TaskQueue {
1039 name: task_queue_name,
1040 kind: TaskQueueKind::Normal as i32,
1041 normal_name: "".to_string(),
1042 }),
1043 input: payloads.into_payloads(),
1044 signal_name: start_signal.signal_name,
1045 signal_input: start_signal.input,
1046 identity: self.identity(),
1047 request_id: Uuid::new_v4().to_string(),
1048 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1049 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1050 workflow_execution_timeout: options
1051 .execution_timeout
1052 .and_then(|d| d.try_into().ok()),
1053 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1054 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1055 search_attributes: options.search_attributes.map(|d| d.into()),
1056 cron_schedule: options.cron_schedule.unwrap_or_default(),
1057 header: options.header.or(start_signal.header),
1058 ..Default::default()
1059 }
1060 .into_request(),
1061 )
1062 .await?
1063 .into_inner();
1064 res.run_id
1065 } else {
1066 let res = self
1068 .clone()
1069 .start_workflow_execution(
1070 StartWorkflowExecutionRequest {
1071 namespace: namespace.clone(),
1072 input: payloads.into_payloads(),
1073 workflow_id: workflow_id.clone(),
1074 workflow_type: Some(WorkflowType {
1075 name: workflow.name().to_string(),
1076 }),
1077 task_queue: Some(TaskQueue {
1078 name: task_queue_name,
1079 kind: TaskQueueKind::Unspecified as i32,
1080 normal_name: "".to_string(),
1081 }),
1082 request_id: Uuid::new_v4().to_string(),
1083 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1084 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1085 workflow_execution_timeout: options
1086 .execution_timeout
1087 .and_then(|d| d.try_into().ok()),
1088 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1089 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1090 search_attributes: options.search_attributes.map(|d| d.into()),
1091 cron_schedule: options.cron_schedule.unwrap_or_default(),
1092 request_eager_execution: options.enable_eager_workflow_start,
1093 retry_policy: options.retry_policy,
1094 links: options.links,
1095 completion_callbacks: options.completion_callbacks,
1096 priority: Some(options.priority.into()),
1097 header: options.header,
1098 ..Default::default()
1099 }
1100 .into_request(),
1101 )
1102 .await
1103 .map_err(|status| {
1104 if status.code() == Code::AlreadyExists {
1105 let run_id =
1106 decode_status_detail::<WorkflowExecutionAlreadyStartedFailure>(
1107 status.details(),
1108 )
1109 .map(|f| f.run_id);
1110 WorkflowStartError::AlreadyStarted {
1111 run_id,
1112 source: status,
1113 }
1114 } else {
1115 WorkflowStartError::Rpc(status)
1116 }
1117 })?
1118 .into_inner();
1119 res.run_id
1120 };
1121
1122 Ok(WorkflowHandle::new(
1123 self.clone(),
1124 WorkflowExecutionInfo {
1125 namespace,
1126 workflow_id,
1127 run_id: Some(run_id.clone()),
1128 first_execution_run_id: Some(run_id),
1129 },
1130 ))
1131 }
1132
1133 fn get_workflow_handle<W: WorkflowDefinition>(
1134 &self,
1135 workflow_id: impl Into<String>,
1136 ) -> WorkflowHandle<Self, W>
1137 where
1138 Self: Sized,
1139 {
1140 WorkflowHandle::new(
1141 self.clone(),
1142 WorkflowExecutionInfo {
1143 namespace: self.namespace(),
1144 workflow_id: workflow_id.into(),
1145 run_id: None,
1146 first_execution_run_id: None,
1147 },
1148 )
1149 }
1150
1151 fn list_workflows(
1152 &self,
1153 query: impl Into<String>,
1154 opts: WorkflowListOptions,
1155 ) -> ListWorkflowsStream {
1156 let client = self.clone();
1157 let namespace = self.namespace();
1158 let query = query.into();
1159 let limit = opts.limit;
1160
1161 let initial_state = (Vec::new(), VecDeque::new(), 0, false);
1163
1164 let stream = stream::unfold(
1165 initial_state,
1166 move |(next_page_token, mut buffer, mut yielded, exhausted)| {
1167 let mut client = client.clone();
1168 let namespace = namespace.clone();
1169 let query = query.clone();
1170
1171 async move {
1172 if let Some(l) = limit
1173 && yielded >= l
1174 {
1175 return None;
1176 }
1177
1178 if let Some(exec) = buffer.pop_front() {
1179 yielded += 1;
1180 return Some((Ok(exec), (next_page_token, buffer, yielded, exhausted)));
1181 }
1182
1183 if exhausted {
1184 return None;
1185 }
1186
1187 let response = WorkflowService::list_workflow_executions(
1188 &mut client,
1189 ListWorkflowExecutionsRequest {
1190 namespace,
1191 page_size: 0, next_page_token: next_page_token.clone(),
1193 query,
1194 }
1195 .into_request(),
1196 )
1197 .await;
1198
1199 match response {
1200 Ok(resp) => {
1201 let resp = resp.into_inner();
1202 let new_exhausted = resp.next_page_token.is_empty();
1203 let new_token = resp.next_page_token;
1204
1205 buffer = resp
1206 .executions
1207 .into_iter()
1208 .map(WorkflowExecution::from)
1209 .collect();
1210
1211 if let Some(exec) = buffer.pop_front() {
1212 yielded += 1;
1213 Some((Ok(exec), (new_token, buffer, yielded, new_exhausted)))
1214 } else {
1215 None
1216 }
1217 }
1218 Err(e) => Some((Err(e.into()), (next_page_token, buffer, yielded, true))),
1219 }
1220 }
1221 },
1222 );
1223
1224 ListWorkflowsStream::new(Box::pin(stream))
1225 }
1226
1227 async fn count_workflows(
1228 &self,
1229 query: impl Into<String>,
1230 _opts: WorkflowCountOptions,
1231 ) -> Result<WorkflowExecutionCount, ClientError> {
1232 let resp = WorkflowService::count_workflow_executions(
1233 &mut self.clone(),
1234 CountWorkflowExecutionsRequest {
1235 namespace: self.namespace(),
1236 query: query.into(),
1237 }
1238 .into_request(),
1239 )
1240 .await?
1241 .into_inner();
1242
1243 Ok(WorkflowExecutionCount::from_response(resp))
1244 }
1245
1246 fn get_async_activity_handle(&self, identifier: ActivityIdentifier) -> AsyncActivityHandle<Self>
1247 where
1248 Self: Sized,
1249 {
1250 AsyncActivityHandle::new(self.clone(), identifier)
1251 }
1252}
1253
1254macro_rules! dbg_panic {
1255 ($($arg:tt)*) => {
1256 use tracing::error;
1257 error!($($arg)*);
1258 debug_assert!(false, $($arg)*);
1259 };
1260}
1261pub(crate) use dbg_panic;
1262
1263#[cfg(test)]
1264mod tests {
1265 use super::*;
1266 use tonic::metadata::Ascii;
1267 use url::Url;
1268
1269 #[test]
1270 fn applies_headers() {
1271 let headers = Arc::new(RwLock::new(ClientHeaders {
1273 user_headers: HashMap::new(),
1274 user_binary_headers: HashMap::new(),
1275 api_key: Some("my-api-key".to_owned()),
1276 }));
1277 headers.clone().write().user_headers.insert(
1278 "my-meta-key".parse().unwrap(),
1279 "my-meta-val".parse().unwrap(),
1280 );
1281 headers.clone().write().user_binary_headers.insert(
1282 "my-bin-meta-key-bin".parse().unwrap(),
1283 vec![1, 2, 3].try_into().unwrap(),
1284 );
1285 let mut interceptor = ServiceCallInterceptor {
1286 client_name: "cute-kitty".to_string(),
1287 client_version: "0.1.0".to_string(),
1288 headers: headers.clone(),
1289 };
1290
1291 let req = interceptor.call(tonic::Request::new(())).unwrap();
1293 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1294 assert_eq!(
1295 req.metadata().get("authorization").unwrap(),
1296 "Bearer my-api-key"
1297 );
1298 assert_eq!(
1299 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1300 vec![1, 2, 3].as_slice()
1301 );
1302
1303 let mut req = tonic::Request::new(());
1305 req.metadata_mut()
1306 .insert("my-meta-key", "my-meta-val2".parse().unwrap());
1307 req.metadata_mut()
1308 .insert("authorization", "my-api-key2".parse().unwrap());
1309 req.metadata_mut()
1310 .insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
1311 let req = interceptor.call(req).unwrap();
1312 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1313 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1314 assert_eq!(
1315 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1316 vec![4, 5, 6].as_slice()
1317 );
1318
1319 headers.clone().write().user_headers.insert(
1321 "authorization".parse().unwrap(),
1322 "my-api-key3".parse().unwrap(),
1323 );
1324 let req = interceptor.call(tonic::Request::new(())).unwrap();
1325 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1326 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1327
1328 headers.clone().write().user_headers.clear();
1330 headers.clone().write().user_binary_headers.clear();
1331 headers.clone().write().api_key.take();
1332 let req = interceptor.call(tonic::Request::new(())).unwrap();
1333 assert!(!req.metadata().contains_key("my-meta-key"));
1334 assert!(!req.metadata().contains_key("authorization"));
1335 assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));
1336
1337 let mut req = tonic::Request::new(());
1339 req.metadata_mut()
1340 .insert("grpc-timeout", "1S".parse().unwrap());
1341 let req = interceptor.call(req).unwrap();
1342 assert_eq!(
1343 req.metadata().get("grpc-timeout").unwrap(),
1344 "1S".parse::<MetadataValue<Ascii>>().unwrap()
1345 );
1346 }
1347
1348 #[test]
1349 fn invalid_ascii_header_key() {
1350 let invalid_headers = {
1351 let mut h = HashMap::new();
1352 h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
1353 h
1354 };
1355
1356 let result = parse_ascii_headers(invalid_headers);
1357 assert!(result.is_err());
1358 assert_eq!(
1359 result.err().unwrap().to_string(),
1360 "Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
1361 );
1362 }
1363
1364 #[test]
1365 fn invalid_ascii_header_value() {
1366 let invalid_headers = {
1367 let mut h = HashMap::new();
1368 h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
1370 h
1371 };
1372
1373 let result = parse_ascii_headers(invalid_headers);
1374 assert!(result.is_err());
1375 assert_eq!(
1376 result.err().unwrap().to_string(),
1377 "Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
1378 );
1379 }
1380
1381 #[test]
1382 fn invalid_binary_header_key() {
1383 let invalid_headers = {
1384 let mut h = HashMap::new();
1385 h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
1386 h
1387 };
1388
1389 let result = parse_binary_headers(invalid_headers);
1390 assert!(result.is_err());
1391 assert_eq!(
1392 result.err().unwrap().to_string(),
1393 "Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
1394 );
1395 }
1396
1397 #[test]
1398 fn keep_alive_defaults() {
1399 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1400 .identity("enchicat".to_string())
1401 .client_name("cute-kitty".to_string())
1402 .client_version("0.1.0".to_string())
1403 .build();
1404 assert_eq!(
1405 opts.keep_alive.clone().unwrap().interval,
1406 ClientKeepAliveOptions::default().interval
1407 );
1408 assert_eq!(
1409 opts.keep_alive.clone().unwrap().timeout,
1410 ClientKeepAliveOptions::default().timeout
1411 );
1412
1413 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1415 .identity("enchicat".to_string())
1416 .client_name("cute-kitty".to_string())
1417 .client_version("0.1.0".to_string())
1418 .keep_alive(None)
1419 .build();
1420 dbg!(&opts.keep_alive);
1421 assert!(opts.keep_alive.is_none());
1422 }
1423
1424 mod list_workflows_tests {
1425 use super::*;
1426 use futures_util::{FutureExt, StreamExt};
1427 use std::sync::atomic::{AtomicUsize, Ordering};
1428 use temporalio_common::protos::temporal::api::common::v1::WorkflowExecution as ProtoWorkflowExecution;
1429 use tonic::{Request, Response};
1430
1431 #[derive(Clone)]
1432 struct MockListWorkflowsClient {
1433 call_count: Arc<AtomicUsize>,
1434 page_size: usize,
1436 total_workflows: usize,
1438 }
1439
1440 impl NamespacedClient for MockListWorkflowsClient {
1441 fn namespace(&self) -> String {
1442 "test-namespace".to_string()
1443 }
1444 fn identity(&self) -> String {
1445 "test-identity".to_string()
1446 }
1447 }
1448
1449 impl WorkflowService for MockListWorkflowsClient {
1450 fn list_workflow_executions(
1451 &mut self,
1452 request: Request<ListWorkflowExecutionsRequest>,
1453 ) -> futures_util::future::BoxFuture<
1454 '_,
1455 Result<Response<ListWorkflowExecutionsResponse>, tonic::Status>,
1456 > {
1457 self.call_count.fetch_add(1, Ordering::SeqCst);
1458 let req = request.into_inner();
1459
1460 let offset: usize = if req.next_page_token.is_empty() {
1462 0
1463 } else {
1464 String::from_utf8(req.next_page_token)
1465 .unwrap()
1466 .parse()
1467 .unwrap()
1468 };
1469
1470 let remaining = self.total_workflows.saturating_sub(offset);
1471 let count = remaining.min(self.page_size);
1472 let new_offset = offset + count;
1473
1474 let executions: Vec<_> = (offset..offset + count)
1475 .map(|i| workflow::WorkflowExecutionInfo {
1476 execution: Some(ProtoWorkflowExecution {
1477 workflow_id: format!("wf-{i}"),
1478 run_id: format!("run-{i}"),
1479 }),
1480 r#type: Some(WorkflowType {
1481 name: "TestWorkflow".to_string(),
1482 }),
1483 task_queue: "test-queue".to_string(),
1484 ..Default::default()
1485 })
1486 .collect();
1487
1488 let next_page_token = if new_offset < self.total_workflows {
1489 new_offset.to_string().into_bytes()
1490 } else {
1491 vec![]
1492 };
1493
1494 async move {
1495 Ok(Response::new(ListWorkflowExecutionsResponse {
1496 executions,
1497 next_page_token,
1498 }))
1499 }
1500 .boxed()
1501 }
1502 }
1503
1504 #[tokio::test]
1505 async fn list_workflows_paginates_through_all_results() {
1506 let call_count = Arc::new(AtomicUsize::new(0));
1507 let client = MockListWorkflowsClient {
1508 call_count: call_count.clone(),
1509 page_size: 3,
1510 total_workflows: 10,
1511 };
1512
1513 let stream = client.list_workflows("", WorkflowListOptions::default());
1514 let results: Vec<_> = stream.collect().await;
1515
1516 assert_eq!(results.len(), 10);
1517 for (i, result) in results.iter().enumerate() {
1518 let wf = result.as_ref().unwrap();
1519 assert_eq!(wf.id(), format!("wf-{i}"));
1520 assert_eq!(wf.run_id(), format!("run-{i}"));
1521 }
1522 assert_eq!(call_count.load(Ordering::SeqCst), 4);
1524 }
1525
1526 #[tokio::test]
1527 async fn list_workflows_respects_limit() {
1528 let call_count = Arc::new(AtomicUsize::new(0));
1529 let client = MockListWorkflowsClient {
1530 call_count: call_count.clone(),
1531 page_size: 3,
1532 total_workflows: 10,
1533 };
1534
1535 let opts = WorkflowListOptions::builder().limit(5).build();
1536 let stream = client.list_workflows("", opts);
1537 let results: Vec<_> = stream.collect().await;
1538
1539 assert_eq!(results.len(), 5);
1540 for (i, result) in results.iter().enumerate() {
1541 let wf = result.as_ref().unwrap();
1542 assert_eq!(wf.id(), format!("wf-{i}"));
1543 }
1544 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1546 }
1547
1548 #[tokio::test]
1549 async fn list_workflows_limit_less_than_page_size() {
1550 let call_count = Arc::new(AtomicUsize::new(0));
1551 let client = MockListWorkflowsClient {
1552 call_count: call_count.clone(),
1553 page_size: 10,
1554 total_workflows: 100,
1555 };
1556
1557 let opts = WorkflowListOptions::builder().limit(3).build();
1558 let stream = client.list_workflows("", opts);
1559 let results: Vec<_> = stream.collect().await;
1560
1561 assert_eq!(results.len(), 3);
1562 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1564 }
1565
1566 #[tokio::test]
1567 async fn list_workflows_empty_results() {
1568 let call_count = Arc::new(AtomicUsize::new(0));
1569 let client = MockListWorkflowsClient {
1570 call_count: call_count.clone(),
1571 page_size: 10,
1572 total_workflows: 0,
1573 };
1574
1575 let stream = client.list_workflows("", WorkflowListOptions::default());
1576 let results: Vec<_> = stream.collect().await;
1577
1578 assert_eq!(results.len(), 0);
1579 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1580 }
1581 }
1582}