1#![warn(missing_docs)] #[macro_use]
8extern crate tracing;
9
10mod async_activity_handle;
11pub mod callback_based;
12pub mod errors;
13pub mod grpc;
14mod metrics;
15mod options_structs;
16#[doc(hidden)]
18pub mod proxy;
19mod replaceable;
20pub mod request_extensions;
21mod retry;
22pub mod schedules;
24pub mod worker;
25mod workflow_handle;
26
27pub use crate::{
28 proxy::HttpConnectProxyOptions,
29 retry::{CallType, RETRYABLE_ERROR_CODES},
30};
31pub use async_activity_handle::{
32 ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle,
33};
34
35pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
36pub use options_structs::*;
37pub use replaceable::SharedReplaceableClient;
38pub use retry::RetryOptions;
39pub use tonic;
40pub use workflow_handle::{
41 UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle,
42 WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
43 WorkflowHistory, WorkflowUpdateHandle,
44};
45
46use crate::{
47 grpc::{
48 AttachMetricLabels, CloudService, HealthService, OperatorService, TestService,
49 WorkflowService,
50 },
51 metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
52 request_extensions::RequestExt,
53 worker::ClientWorkerSet,
54};
55use errors::*;
56use futures_util::{stream, stream::Stream};
57use http::Uri;
58use parking_lot::RwLock;
59use std::{
60 collections::{HashMap, VecDeque},
61 fmt::Debug,
62 pin::Pin,
63 str::FromStr,
64 sync::{Arc, OnceLock},
65 task::{Context, Poll},
66 time::{Duration, SystemTime},
67};
68use temporalio_common::{
69 WorkflowDefinition,
70 data_converters::{DataConverter, SerializationContextData},
71 protos::{
72 coresdk::IntoPayloadsExt,
73 grpc::health::v1::health_client::HealthClient,
74 proto_ts_to_system_time,
75 temporal::api::{
76 cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
77 common::v1::{Memo, Payload, SearchAttributes, WorkflowType},
78 enums::v1::{TaskQueueKind, WorkflowExecutionStatus},
79 errordetails::v1::WorkflowExecutionAlreadyStartedFailure,
80 operatorservice::v1::operator_service_client::OperatorServiceClient,
81 taskqueue::v1::TaskQueue,
82 testservice::v1::test_service_client::TestServiceClient,
83 workflow::v1 as workflow,
84 workflowservice::v1::{
85 count_workflow_executions_response, workflow_service_client::WorkflowServiceClient,
86 *,
87 },
88 },
89 utilities::decode_status_detail,
90 },
91};
92use tonic::{
93 Code, IntoRequest,
94 body::Body,
95 client::GrpcService,
96 codegen::InterceptedService,
97 metadata::{
98 AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
99 MetadataValue,
100 },
101 service::Interceptor,
102 transport::{Certificate, Channel, Endpoint, Identity},
103};
104use tower::ServiceBuilder;
105use uuid::Uuid;
106
107static CLIENT_NAME_HEADER_KEY: &str = "client-name";
108static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
109static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
110
111#[doc(hidden)]
112pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
114#[doc(hidden)]
115pub static ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT: &str = "short-circuit";
117
118const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
120const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
121const VERSION: &str = env!("CARGO_PKG_VERSION");
122
123#[derive(Clone)]
128pub struct Connection {
129 inner: Arc<ConnectionInner>,
130}
131
132#[derive(Clone)]
133struct ConnectionInner {
134 service: TemporalServiceClient,
135 retry_options: RetryOptions,
136 identity: String,
137 headers: Arc<RwLock<ClientHeaders>>,
138 client_name: String,
139 client_version: String,
140 capabilities: Option<get_system_info_response::Capabilities>,
142 workers: Arc<ClientWorkerSet>,
143}
144
145impl Connection {
146 pub async fn connect(options: ConnectionOptions) -> Result<Self, ClientConnectError> {
148 let service = if let Some(service_override) = options.service_override {
149 GrpcMetricSvc {
150 inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
151 metrics: options.metrics_meter.clone().map(MetricsContext::new),
152 disable_errcode_label: options.disable_error_code_metric_tags,
153 }
154 } else {
155 let channel = Channel::from_shared(options.target.to_string())?;
156 let channel = add_tls_to_channel(options.tls_options.as_ref(), channel).await?;
157 let channel = if let Some(keep_alive) = options.keep_alive.as_ref() {
158 channel
159 .keep_alive_while_idle(true)
160 .http2_keep_alive_interval(keep_alive.interval)
161 .keep_alive_timeout(keep_alive.timeout)
162 } else {
163 channel
164 };
165 let channel = if let Some(origin) = options.override_origin.clone() {
166 channel.origin(origin)
167 } else {
168 channel
169 };
170 let channel = if let Some(proxy) = options.http_connect_proxy.as_ref() {
172 proxy.connect_endpoint(&channel).await?
173 } else {
174 channel.connect().await?
175 };
176 ServiceBuilder::new()
177 .layer_fn(move |channel| GrpcMetricSvc {
178 inner: ChannelOrGrpcOverride::Channel(channel),
179 metrics: options.metrics_meter.clone().map(MetricsContext::new),
180 disable_errcode_label: options.disable_error_code_metric_tags,
181 })
182 .service(channel)
183 };
184
185 let headers = Arc::new(RwLock::new(ClientHeaders {
186 user_headers: parse_ascii_headers(options.headers.clone().unwrap_or_default())?,
187 user_binary_headers: parse_binary_headers(
188 options.binary_headers.clone().unwrap_or_default(),
189 )?,
190 api_key: options.api_key.clone(),
191 }));
192 let interceptor = ServiceCallInterceptor {
193 client_name: options.client_name.clone(),
194 client_version: options.client_version.clone(),
195 headers: headers.clone(),
196 };
197 let svc = InterceptedService::new(service, interceptor);
198 let mut svc_client = TemporalServiceClient::new(svc);
199
200 let capabilities = if !options.skip_get_system_info {
201 match svc_client
202 .get_system_info(GetSystemInfoRequest::default().into_request())
203 .await
204 {
205 Ok(sysinfo) => sysinfo.into_inner().capabilities,
206 Err(status) => match status.code() {
207 Code::Unimplemented => None,
208 _ => return Err(ClientConnectError::SystemInfoCallError(status)),
209 },
210 }
211 } else {
212 None
213 };
214 Ok(Self {
215 inner: Arc::new(ConnectionInner {
216 service: svc_client,
217 retry_options: options.retry_options,
218 identity: options.identity,
219 headers,
220 client_name: options.client_name,
221 client_version: options.client_version,
222 capabilities,
223 workers: Arc::new(ClientWorkerSet::new()),
224 }),
225 })
226 }
227
228 pub fn set_api_key(&self, api_key: Option<String>) {
230 self.inner.headers.write().api_key = api_key;
231 }
232
233 pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
242 self.inner.headers.write().user_headers = parse_ascii_headers(headers)?;
243 Ok(())
244 }
245
246 pub fn set_binary_headers(
255 &self,
256 binary_headers: HashMap<String, Vec<u8>>,
257 ) -> Result<(), InvalidHeaderError> {
258 self.inner.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
259 Ok(())
260 }
261
262 pub fn client_name(&self) -> &str {
264 &self.inner.client_name
265 }
266
267 pub fn client_version(&self) -> &str {
269 &self.inner.client_version
270 }
271
272 pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
275 self.inner.capabilities.as_ref()
276 }
277
278 pub fn retry_options_mut(&mut self) -> &mut RetryOptions {
283 &mut Arc::make_mut(&mut self.inner).retry_options
284 }
285
286 pub fn identity(&self) -> &str {
288 &self.inner.identity
289 }
290
291 pub fn identity_mut(&mut self) -> &mut String {
296 &mut Arc::make_mut(&mut self.inner).identity
297 }
298
299 pub fn workers(&self) -> Arc<ClientWorkerSet> {
301 self.inner.workers.clone()
302 }
303
304 pub fn worker_grouping_key(&self) -> Uuid {
306 self.inner.workers.worker_grouping_key()
307 }
308
309 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
311 self.inner.service.workflow_service()
312 }
313
314 pub fn operator_service(&self) -> Box<dyn OperatorService> {
316 self.inner.service.operator_service()
317 }
318
319 pub fn cloud_service(&self) -> Box<dyn CloudService> {
321 self.inner.service.cloud_service()
322 }
323
324 pub fn test_service(&self) -> Box<dyn TestService> {
326 self.inner.service.test_service()
327 }
328
329 pub fn health_service(&self) -> Box<dyn HealthService> {
331 self.inner.service.health_service()
332 }
333}
334
335#[derive(Debug)]
336struct ClientHeaders {
337 user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
338 user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
339 api_key: Option<String>,
340}
341
342impl ClientHeaders {
343 fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
344 for (key, val) in self.user_headers.iter() {
345 if !metadata.contains_key(key) {
347 metadata.insert(key, val.clone());
348 }
349 }
350 for (key, val) in self.user_binary_headers.iter() {
351 if !metadata.contains_key(key) {
353 metadata.insert_bin(key, val.clone());
354 }
355 }
356 if let Some(api_key) = &self.api_key {
357 if !metadata.contains_key("authorization")
359 && let Ok(val) = format!("Bearer {api_key}").parse()
360 {
361 metadata.insert("authorization", val);
362 }
363 }
364 }
365}
366
367async fn add_tls_to_channel(
370 tls_options: Option<&TlsOptions>,
371 mut channel: Endpoint,
372) -> Result<Endpoint, ClientConnectError> {
373 if let Some(tls_cfg) = tls_options {
374 let mut tls = tonic::transport::ClientTlsConfig::new();
375
376 if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
377 let server_root_ca_cert = Certificate::from_pem(root_cert);
378 tls = tls.ca_certificate(server_root_ca_cert);
379 } else {
380 tls = tls.with_native_roots();
381 }
382
383 if let Some(domain) = &tls_cfg.domain {
384 tls = tls.domain_name(domain);
385
386 let uri: Uri = format!("https://{domain}").parse()?;
391 channel = channel.origin(uri);
392 }
393
394 if let Some(client_opts) = &tls_cfg.client_tls_options {
395 let client_identity =
396 Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
397 tls = tls.identity(client_identity);
398 }
399
400 return channel.tls_config(tls).map_err(Into::into);
401 }
402 Ok(channel)
403}
404
405fn parse_ascii_headers(
406 headers: HashMap<String, String>,
407) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
408 let mut parsed_headers = HashMap::with_capacity(headers.len());
409 for (k, v) in headers.into_iter() {
410 let key = match AsciiMetadataKey::from_str(&k) {
411 Ok(key) => key,
412 Err(err) => {
413 return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
414 key: k,
415 source: err,
416 });
417 }
418 };
419 let value = match MetadataValue::from_str(&v) {
420 Ok(value) => value,
421 Err(err) => {
422 return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
423 key: k,
424 value: v,
425 source: err,
426 });
427 }
428 };
429 parsed_headers.insert(key, value);
430 }
431
432 Ok(parsed_headers)
433}
434
435fn parse_binary_headers(
436 headers: HashMap<String, Vec<u8>>,
437) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
438 let mut parsed_headers = HashMap::with_capacity(headers.len());
439 for (k, v) in headers.into_iter() {
440 let key = match BinaryMetadataKey::from_str(&k) {
441 Ok(key) => key,
442 Err(err) => {
443 return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
444 key: k,
445 source: err,
446 });
447 }
448 };
449 let value = BinaryMetadataValue::from_bytes(&v);
450 parsed_headers.insert(key, value);
451 }
452
453 Ok(parsed_headers)
454}
455
456#[derive(Clone)]
458pub struct ServiceCallInterceptor {
459 client_name: String,
460 client_version: String,
461 headers: Arc<RwLock<ClientHeaders>>,
463}
464
465impl Interceptor for ServiceCallInterceptor {
466 fn call(
469 &mut self,
470 mut request: tonic::Request<()>,
471 ) -> Result<tonic::Request<()>, tonic::Status> {
472 let metadata = request.metadata_mut();
473 if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
474 metadata.insert(
475 CLIENT_NAME_HEADER_KEY,
476 self.client_name
477 .parse()
478 .unwrap_or_else(|_| MetadataValue::from_static("")),
479 );
480 }
481 if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
482 metadata.insert(
483 CLIENT_VERSION_HEADER_KEY,
484 self.client_version
485 .parse()
486 .unwrap_or_else(|_| MetadataValue::from_static("")),
487 );
488 }
489 self.headers.read().apply_to_metadata(metadata);
490 request.set_default_timeout(OTHER_CALL_TIMEOUT);
491
492 Ok(request)
493 }
494}
495
496#[derive(Clone)]
498pub struct TemporalServiceClient {
499 workflow_svc_client: Box<dyn WorkflowService>,
500 operator_svc_client: Box<dyn OperatorService>,
501 cloud_svc_client: Box<dyn CloudService>,
502 test_svc_client: Box<dyn TestService>,
503 health_svc_client: Box<dyn HealthService>,
504}
505
506fn get_decode_max_size() -> usize {
509 static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
510 *_DECODE_MAX_SIZE.get_or_init(|| {
511 std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
512 .ok()
513 .and_then(|s| s.parse().ok())
514 .unwrap_or(128 * 1024 * 1024)
515 })
516}
517
518impl TemporalServiceClient {
519 fn new<T>(svc: T) -> Self
520 where
521 T: GrpcService<Body> + Send + Sync + Clone + 'static,
522 T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
523 T::Error: Into<tonic::codegen::StdError>,
524 <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
525 <T as GrpcService<Body>>::Future: Send,
526 {
527 let workflow_svc_client = Box::new(
528 WorkflowServiceClient::new(svc.clone())
529 .max_decoding_message_size(get_decode_max_size()),
530 );
531 let operator_svc_client = Box::new(
532 OperatorServiceClient::new(svc.clone())
533 .max_decoding_message_size(get_decode_max_size()),
534 );
535 let cloud_svc_client = Box::new(
536 CloudServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
537 );
538 let test_svc_client = Box::new(
539 TestServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
540 );
541 let health_svc_client = Box::new(
542 HealthClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
543 );
544
545 Self {
546 workflow_svc_client,
547 operator_svc_client,
548 cloud_svc_client,
549 test_svc_client,
550 health_svc_client,
551 }
552 }
553
554 pub fn from_services(
557 workflow: Box<dyn WorkflowService>,
558 operator: Box<dyn OperatorService>,
559 cloud: Box<dyn CloudService>,
560 test: Box<dyn TestService>,
561 health: Box<dyn HealthService>,
562 ) -> Self {
563 Self {
564 workflow_svc_client: workflow,
565 operator_svc_client: operator,
566 cloud_svc_client: cloud,
567 test_svc_client: test,
568 health_svc_client: health,
569 }
570 }
571
572 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
574 self.workflow_svc_client.clone()
575 }
576 pub fn operator_service(&self) -> Box<dyn OperatorService> {
578 self.operator_svc_client.clone()
579 }
580 pub fn cloud_service(&self) -> Box<dyn CloudService> {
582 self.cloud_svc_client.clone()
583 }
584 pub fn test_service(&self) -> Box<dyn TestService> {
586 self.test_svc_client.clone()
587 }
588 pub fn health_service(&self) -> Box<dyn HealthService> {
590 self.health_svc_client.clone()
591 }
592}
593
594#[derive(Clone)]
597pub struct Client {
598 connection: Connection,
599 options: Arc<ClientOptions>,
600}
601
602impl Client {
603 pub fn new(connection: Connection, options: ClientOptions) -> Result<Self, ClientNewError> {
608 Ok(Client {
609 connection,
610 options: Arc::new(options),
611 })
612 }
613
614 pub fn options(&self) -> &ClientOptions {
616 &self.options
617 }
618
619 pub fn options_mut(&mut self) -> &mut ClientOptions {
624 Arc::make_mut(&mut self.options)
625 }
626
627 pub fn connection(&self) -> &Connection {
629 &self.connection
630 }
631
632 pub fn connection_mut(&mut self) -> &mut Connection {
634 &mut self.connection
635 }
636}
637
638impl Client {
642 pub async fn start_workflow<W>(
647 &self,
648 workflow: W,
649 input: W::Input,
650 options: WorkflowStartOptions,
651 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
652 where
653 W: WorkflowDefinition,
654 W::Input: Send,
655 {
656 WorkflowClientTrait::start_workflow(self, workflow, input, options).await
657 }
658
659 pub fn get_workflow_handle<W: WorkflowDefinition>(
663 &self,
664 workflow_id: impl Into<String>,
665 ) -> WorkflowHandle<Self, W> {
666 WorkflowClientTrait::get_workflow_handle(self, workflow_id)
667 }
668
669 pub fn list_workflows(
674 &self,
675 query: impl Into<String>,
676 opts: WorkflowListOptions,
677 ) -> ListWorkflowsStream {
678 WorkflowClientTrait::list_workflows(self, query, opts)
679 }
680
681 pub async fn count_workflows(
683 &self,
684 query: impl Into<String>,
685 opts: WorkflowCountOptions,
686 ) -> Result<WorkflowExecutionCount, ClientError> {
687 WorkflowClientTrait::count_workflows(self, query, opts).await
688 }
689
690 pub fn get_async_activity_handle(
694 &self,
695 identifier: ActivityIdentifier,
696 ) -> AsyncActivityHandle<Self> {
697 WorkflowClientTrait::get_async_activity_handle(self, identifier)
698 }
699}
700
701impl NamespacedClient for Client {
702 fn namespace(&self) -> String {
703 self.options.namespace.clone()
704 }
705
706 fn identity(&self) -> String {
707 self.connection.identity().to_owned()
708 }
709
710 fn data_converter(&self) -> &DataConverter {
711 &self.options.data_converter
712 }
713}
714
715#[derive(Clone)]
717pub enum Namespace {
718 Name(String),
720 Id(String),
722}
723
724impl Namespace {
725 pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
727 let (namespace, id) = match self {
728 Namespace::Name(n) => (n, "".to_owned()),
729 Namespace::Id(n) => ("".to_owned(), n),
730 };
731 DescribeNamespaceRequest { namespace, id }
732 }
733}
734
735pub(crate) trait WorkflowClientTrait: NamespacedClient {
738 fn start_workflow<W>(
740 &self,
741 workflow: W,
742 input: W::Input,
743 options: WorkflowStartOptions,
744 ) -> impl Future<Output = Result<WorkflowHandle<Self, W>, WorkflowStartError>>
745 where
746 Self: Sized,
747 W: WorkflowDefinition,
748 W::Input: Send;
749
750 fn get_workflow_handle<W: WorkflowDefinition>(
757 &self,
758 workflow_id: impl Into<String>,
759 ) -> WorkflowHandle<Self, W>
760 where
761 Self: Sized;
762
763 fn list_workflows(
767 &self,
768 query: impl Into<String>,
769 opts: WorkflowListOptions,
770 ) -> ListWorkflowsStream;
771
772 fn count_workflows(
774 &self,
775 query: impl Into<String>,
776 opts: WorkflowCountOptions,
777 ) -> impl Future<Output = Result<WorkflowExecutionCount, ClientError>>;
778
779 fn get_async_activity_handle(
783 &self,
784 identifier: ActivityIdentifier,
785 ) -> AsyncActivityHandle<Self>
786 where
787 Self: Sized;
788}
789
790pub trait NamespacedClient {
792 fn namespace(&self) -> String;
794 fn identity(&self) -> String;
796 fn data_converter(&self) -> &DataConverter {
799 static DEFAULT: OnceLock<DataConverter> = OnceLock::new();
800 DEFAULT.get_or_init(DataConverter::default)
801 }
802}
803
804#[derive(Debug, Clone)]
807pub struct WorkflowExecution {
808 raw: workflow::WorkflowExecutionInfo,
809}
810
811impl WorkflowExecution {
812 pub fn new(raw: workflow::WorkflowExecutionInfo) -> Self {
814 Self { raw }
815 }
816
817 pub fn id(&self) -> &str {
819 self.raw
820 .execution
821 .as_ref()
822 .map(|e| e.workflow_id.as_str())
823 .unwrap_or("")
824 }
825
826 pub fn run_id(&self) -> &str {
828 self.raw
829 .execution
830 .as_ref()
831 .map(|e| e.run_id.as_str())
832 .unwrap_or("")
833 }
834
835 pub fn workflow_type(&self) -> &str {
837 self.raw
838 .r#type
839 .as_ref()
840 .map(|t| t.name.as_str())
841 .unwrap_or("")
842 }
843
844 pub fn status(&self) -> WorkflowExecutionStatus {
846 self.raw.status()
847 }
848
849 pub fn start_time(&self) -> Option<SystemTime> {
851 self.raw
852 .start_time
853 .as_ref()
854 .and_then(proto_ts_to_system_time)
855 }
856
857 pub fn execution_time(&self) -> Option<SystemTime> {
859 self.raw
860 .execution_time
861 .as_ref()
862 .and_then(proto_ts_to_system_time)
863 }
864
865 pub fn close_time(&self) -> Option<SystemTime> {
867 self.raw
868 .close_time
869 .as_ref()
870 .and_then(proto_ts_to_system_time)
871 }
872
873 pub fn task_queue(&self) -> &str {
875 &self.raw.task_queue
876 }
877
878 pub fn history_length(&self) -> i64 {
880 self.raw.history_length
881 }
882
883 pub fn memo(&self) -> Option<&Memo> {
885 self.raw.memo.as_ref()
886 }
887
888 pub fn parent_id(&self) -> Option<&str> {
890 self.raw
891 .parent_execution
892 .as_ref()
893 .map(|e| e.workflow_id.as_str())
894 }
895
896 pub fn parent_run_id(&self) -> Option<&str> {
898 self.raw
899 .parent_execution
900 .as_ref()
901 .map(|e| e.run_id.as_str())
902 }
903
904 pub fn search_attributes(&self) -> Option<&SearchAttributes> {
906 self.raw.search_attributes.as_ref()
907 }
908
909 pub fn raw(&self) -> &workflow::WorkflowExecutionInfo {
911 &self.raw
912 }
913
914 pub fn into_raw(self) -> workflow::WorkflowExecutionInfo {
916 self.raw
917 }
918}
919
920impl From<workflow::WorkflowExecutionInfo> for WorkflowExecution {
921 fn from(raw: workflow::WorkflowExecutionInfo) -> Self {
922 Self::new(raw)
923 }
924}
925
926pub struct ListWorkflowsStream {
929 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
930}
931
932impl ListWorkflowsStream {
933 fn new(
934 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
935 ) -> Self {
936 Self { inner }
937 }
938}
939
940impl Stream for ListWorkflowsStream {
941 type Item = Result<WorkflowExecution, ClientError>;
942
943 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
944 self.inner.as_mut().poll_next(cx)
945 }
946}
947
948#[derive(Debug, Clone)]
953pub struct WorkflowExecutionCount {
954 count: usize,
955 groups: Vec<WorkflowCountAggregationGroup>,
956}
957
958impl WorkflowExecutionCount {
959 pub(crate) fn from_response(resp: CountWorkflowExecutionsResponse) -> Self {
960 Self {
961 count: resp.count as usize,
962 groups: resp
963 .groups
964 .into_iter()
965 .map(WorkflowCountAggregationGroup::from_proto)
966 .collect(),
967 }
968 }
969
970 pub fn count(&self) -> usize {
973 self.count
974 }
975
976 pub fn groups(&self) -> &[WorkflowCountAggregationGroup] {
978 &self.groups
979 }
980}
981
982#[derive(Debug, Clone)]
984pub struct WorkflowCountAggregationGroup {
985 group_values: Vec<Payload>,
986 count: usize,
987}
988
989impl WorkflowCountAggregationGroup {
990 fn from_proto(proto: count_workflow_executions_response::AggregationGroup) -> Self {
991 Self {
992 group_values: proto.group_values,
993 count: proto.count as usize,
994 }
995 }
996
997 pub fn group_values(&self) -> &[Payload] {
999 &self.group_values
1000 }
1001
1002 pub fn count(&self) -> usize {
1004 self.count
1005 }
1006}
1007
1008impl<T> WorkflowClientTrait for T
1009where
1010 T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static,
1011{
1012 async fn start_workflow<W>(
1013 &self,
1014 workflow: W,
1015 input: W::Input,
1016 options: WorkflowStartOptions,
1017 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
1018 where
1019 W: WorkflowDefinition,
1020 W::Input: Send,
1021 {
1022 let payloads = self
1023 .data_converter()
1024 .to_payloads(&SerializationContextData::Workflow, &input)
1025 .await?;
1026 let namespace = self.namespace();
1027 let workflow_id = options.workflow_id.clone();
1028 let task_queue_name = options.task_queue.clone();
1029
1030 let run_id = if let Some(start_signal) = options.start_signal {
1031 let res = WorkflowService::signal_with_start_workflow_execution(
1033 &mut self.clone(),
1034 SignalWithStartWorkflowExecutionRequest {
1035 namespace: namespace.clone(),
1036 workflow_id: workflow_id.clone(),
1037 workflow_type: Some(WorkflowType {
1038 name: workflow.name().to_string(),
1039 }),
1040 task_queue: Some(TaskQueue {
1041 name: task_queue_name,
1042 kind: TaskQueueKind::Normal as i32,
1043 normal_name: "".to_string(),
1044 }),
1045 input: payloads.into_payloads(),
1046 signal_name: start_signal.signal_name,
1047 signal_input: start_signal.input,
1048 identity: self.identity(),
1049 request_id: Uuid::new_v4().to_string(),
1050 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1051 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1052 workflow_execution_timeout: options
1053 .execution_timeout
1054 .and_then(|d| d.try_into().ok()),
1055 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1056 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1057 search_attributes: options.search_attributes.map(|d| d.into()),
1058 cron_schedule: options.cron_schedule.unwrap_or_default(),
1059 header: options.header.or(start_signal.header),
1060 ..Default::default()
1061 }
1062 .into_request(),
1063 )
1064 .await?
1065 .into_inner();
1066 res.run_id
1067 } else {
1068 let res = self
1070 .clone()
1071 .start_workflow_execution(
1072 StartWorkflowExecutionRequest {
1073 namespace: namespace.clone(),
1074 input: payloads.into_payloads(),
1075 workflow_id: workflow_id.clone(),
1076 workflow_type: Some(WorkflowType {
1077 name: workflow.name().to_string(),
1078 }),
1079 task_queue: Some(TaskQueue {
1080 name: task_queue_name,
1081 kind: TaskQueueKind::Unspecified as i32,
1082 normal_name: "".to_string(),
1083 }),
1084 request_id: Uuid::new_v4().to_string(),
1085 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1086 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1087 workflow_execution_timeout: options
1088 .execution_timeout
1089 .and_then(|d| d.try_into().ok()),
1090 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1091 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1092 search_attributes: options.search_attributes.map(|d| d.into()),
1093 cron_schedule: options.cron_schedule.unwrap_or_default(),
1094 request_eager_execution: options.enable_eager_workflow_start,
1095 retry_policy: options.retry_policy,
1096 links: options.links,
1097 completion_callbacks: options.completion_callbacks,
1098 priority: Some(options.priority.into()),
1099 header: options.header,
1100 ..Default::default()
1101 }
1102 .into_request(),
1103 )
1104 .await
1105 .map_err(|status| {
1106 if status.code() == Code::AlreadyExists {
1107 let run_id =
1108 decode_status_detail::<WorkflowExecutionAlreadyStartedFailure>(
1109 status.details(),
1110 )
1111 .map(|f| f.run_id);
1112 WorkflowStartError::AlreadyStarted {
1113 run_id,
1114 source: status,
1115 }
1116 } else {
1117 WorkflowStartError::Rpc(status)
1118 }
1119 })?
1120 .into_inner();
1121 res.run_id
1122 };
1123
1124 Ok(WorkflowHandle::new(
1125 self.clone(),
1126 WorkflowExecutionInfo {
1127 namespace,
1128 workflow_id,
1129 run_id: Some(run_id.clone()),
1130 first_execution_run_id: Some(run_id),
1131 },
1132 ))
1133 }
1134
1135 fn get_workflow_handle<W: WorkflowDefinition>(
1136 &self,
1137 workflow_id: impl Into<String>,
1138 ) -> WorkflowHandle<Self, W>
1139 where
1140 Self: Sized,
1141 {
1142 WorkflowHandle::new(
1143 self.clone(),
1144 WorkflowExecutionInfo {
1145 namespace: self.namespace(),
1146 workflow_id: workflow_id.into(),
1147 run_id: None,
1148 first_execution_run_id: None,
1149 },
1150 )
1151 }
1152
1153 fn list_workflows(
1154 &self,
1155 query: impl Into<String>,
1156 opts: WorkflowListOptions,
1157 ) -> ListWorkflowsStream {
1158 let client = self.clone();
1159 let namespace = self.namespace();
1160 let query = query.into();
1161 let limit = opts.limit;
1162
1163 let initial_state = (Vec::new(), VecDeque::new(), 0, false);
1165
1166 let stream = stream::unfold(
1167 initial_state,
1168 move |(next_page_token, mut buffer, mut yielded, exhausted)| {
1169 let mut client = client.clone();
1170 let namespace = namespace.clone();
1171 let query = query.clone();
1172
1173 async move {
1174 if let Some(l) = limit
1175 && yielded >= l
1176 {
1177 return None;
1178 }
1179
1180 if let Some(exec) = buffer.pop_front() {
1181 yielded += 1;
1182 return Some((Ok(exec), (next_page_token, buffer, yielded, exhausted)));
1183 }
1184
1185 if exhausted {
1186 return None;
1187 }
1188
1189 let response = WorkflowService::list_workflow_executions(
1190 &mut client,
1191 ListWorkflowExecutionsRequest {
1192 namespace,
1193 page_size: 0, next_page_token: next_page_token.clone(),
1195 query,
1196 }
1197 .into_request(),
1198 )
1199 .await;
1200
1201 match response {
1202 Ok(resp) => {
1203 let resp = resp.into_inner();
1204 let new_exhausted = resp.next_page_token.is_empty();
1205 let new_token = resp.next_page_token;
1206
1207 buffer = resp
1208 .executions
1209 .into_iter()
1210 .map(WorkflowExecution::from)
1211 .collect();
1212
1213 if let Some(exec) = buffer.pop_front() {
1214 yielded += 1;
1215 Some((Ok(exec), (new_token, buffer, yielded, new_exhausted)))
1216 } else {
1217 None
1218 }
1219 }
1220 Err(e) => Some((Err(e.into()), (next_page_token, buffer, yielded, true))),
1221 }
1222 }
1223 },
1224 );
1225
1226 ListWorkflowsStream::new(Box::pin(stream))
1227 }
1228
1229 async fn count_workflows(
1230 &self,
1231 query: impl Into<String>,
1232 _opts: WorkflowCountOptions,
1233 ) -> Result<WorkflowExecutionCount, ClientError> {
1234 let resp = WorkflowService::count_workflow_executions(
1235 &mut self.clone(),
1236 CountWorkflowExecutionsRequest {
1237 namespace: self.namespace(),
1238 query: query.into(),
1239 }
1240 .into_request(),
1241 )
1242 .await?
1243 .into_inner();
1244
1245 Ok(WorkflowExecutionCount::from_response(resp))
1246 }
1247
1248 fn get_async_activity_handle(&self, identifier: ActivityIdentifier) -> AsyncActivityHandle<Self>
1249 where
1250 Self: Sized,
1251 {
1252 AsyncActivityHandle::new(self.clone(), identifier)
1253 }
1254}
1255
1256macro_rules! dbg_panic {
1257 ($($arg:tt)*) => {
1258 use tracing::error;
1259 error!($($arg)*);
1260 debug_assert!(false, $($arg)*);
1261 };
1262}
1263pub(crate) use dbg_panic;
1264
1265#[cfg(test)]
1266mod tests {
1267 use super::*;
1268 use tonic::metadata::Ascii;
1269 use url::Url;
1270
1271 #[test]
1272 fn applies_headers() {
1273 let headers = Arc::new(RwLock::new(ClientHeaders {
1275 user_headers: HashMap::new(),
1276 user_binary_headers: HashMap::new(),
1277 api_key: Some("my-api-key".to_owned()),
1278 }));
1279 headers.clone().write().user_headers.insert(
1280 "my-meta-key".parse().unwrap(),
1281 "my-meta-val".parse().unwrap(),
1282 );
1283 headers.clone().write().user_binary_headers.insert(
1284 "my-bin-meta-key-bin".parse().unwrap(),
1285 vec![1, 2, 3].try_into().unwrap(),
1286 );
1287 let mut interceptor = ServiceCallInterceptor {
1288 client_name: "cute-kitty".to_string(),
1289 client_version: "0.1.0".to_string(),
1290 headers: headers.clone(),
1291 };
1292
1293 let req = interceptor.call(tonic::Request::new(())).unwrap();
1295 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1296 assert_eq!(
1297 req.metadata().get("authorization").unwrap(),
1298 "Bearer my-api-key"
1299 );
1300 assert_eq!(
1301 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1302 vec![1, 2, 3].as_slice()
1303 );
1304
1305 let mut req = tonic::Request::new(());
1307 req.metadata_mut()
1308 .insert("my-meta-key", "my-meta-val2".parse().unwrap());
1309 req.metadata_mut()
1310 .insert("authorization", "my-api-key2".parse().unwrap());
1311 req.metadata_mut()
1312 .insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
1313 let req = interceptor.call(req).unwrap();
1314 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1315 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1316 assert_eq!(
1317 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1318 vec![4, 5, 6].as_slice()
1319 );
1320
1321 headers.clone().write().user_headers.insert(
1323 "authorization".parse().unwrap(),
1324 "my-api-key3".parse().unwrap(),
1325 );
1326 let req = interceptor.call(tonic::Request::new(())).unwrap();
1327 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1328 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1329
1330 headers.clone().write().user_headers.clear();
1332 headers.clone().write().user_binary_headers.clear();
1333 headers.clone().write().api_key.take();
1334 let req = interceptor.call(tonic::Request::new(())).unwrap();
1335 assert!(!req.metadata().contains_key("my-meta-key"));
1336 assert!(!req.metadata().contains_key("authorization"));
1337 assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));
1338
1339 let mut req = tonic::Request::new(());
1341 req.metadata_mut()
1342 .insert("grpc-timeout", "1S".parse().unwrap());
1343 let req = interceptor.call(req).unwrap();
1344 assert_eq!(
1345 req.metadata().get("grpc-timeout").unwrap(),
1346 "1S".parse::<MetadataValue<Ascii>>().unwrap()
1347 );
1348 }
1349
1350 #[test]
1351 fn invalid_ascii_header_key() {
1352 let invalid_headers = {
1353 let mut h = HashMap::new();
1354 h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
1355 h
1356 };
1357
1358 let result = parse_ascii_headers(invalid_headers);
1359 assert!(result.is_err());
1360 assert_eq!(
1361 result.err().unwrap().to_string(),
1362 "Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
1363 );
1364 }
1365
1366 #[test]
1367 fn invalid_ascii_header_value() {
1368 let invalid_headers = {
1369 let mut h = HashMap::new();
1370 h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
1372 h
1373 };
1374
1375 let result = parse_ascii_headers(invalid_headers);
1376 assert!(result.is_err());
1377 assert_eq!(
1378 result.err().unwrap().to_string(),
1379 "Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
1380 );
1381 }
1382
1383 #[test]
1384 fn invalid_binary_header_key() {
1385 let invalid_headers = {
1386 let mut h = HashMap::new();
1387 h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
1388 h
1389 };
1390
1391 let result = parse_binary_headers(invalid_headers);
1392 assert!(result.is_err());
1393 assert_eq!(
1394 result.err().unwrap().to_string(),
1395 "Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
1396 );
1397 }
1398
1399 #[test]
1400 fn keep_alive_defaults() {
1401 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1402 .identity("enchicat".to_string())
1403 .client_name("cute-kitty".to_string())
1404 .client_version("0.1.0".to_string())
1405 .build();
1406 assert_eq!(
1407 opts.keep_alive.clone().unwrap().interval,
1408 ClientKeepAliveOptions::default().interval
1409 );
1410 assert_eq!(
1411 opts.keep_alive.clone().unwrap().timeout,
1412 ClientKeepAliveOptions::default().timeout
1413 );
1414
1415 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1417 .identity("enchicat".to_string())
1418 .client_name("cute-kitty".to_string())
1419 .client_version("0.1.0".to_string())
1420 .keep_alive(None)
1421 .build();
1422 dbg!(&opts.keep_alive);
1423 assert!(opts.keep_alive.is_none());
1424 }
1425
1426 mod list_workflows_tests {
1427 use super::*;
1428 use futures_util::{FutureExt, StreamExt};
1429 use std::sync::atomic::{AtomicUsize, Ordering};
1430 use temporalio_common::protos::temporal::api::common::v1::WorkflowExecution as ProtoWorkflowExecution;
1431 use tonic::{Request, Response};
1432
1433 #[derive(Clone)]
1434 struct MockListWorkflowsClient {
1435 call_count: Arc<AtomicUsize>,
1436 page_size: usize,
1438 total_workflows: usize,
1440 }
1441
1442 impl NamespacedClient for MockListWorkflowsClient {
1443 fn namespace(&self) -> String {
1444 "test-namespace".to_string()
1445 }
1446 fn identity(&self) -> String {
1447 "test-identity".to_string()
1448 }
1449 }
1450
1451 impl WorkflowService for MockListWorkflowsClient {
1452 fn list_workflow_executions(
1453 &mut self,
1454 request: Request<ListWorkflowExecutionsRequest>,
1455 ) -> futures_util::future::BoxFuture<
1456 '_,
1457 Result<Response<ListWorkflowExecutionsResponse>, tonic::Status>,
1458 > {
1459 self.call_count.fetch_add(1, Ordering::SeqCst);
1460 let req = request.into_inner();
1461
1462 let offset: usize = if req.next_page_token.is_empty() {
1464 0
1465 } else {
1466 String::from_utf8(req.next_page_token)
1467 .unwrap()
1468 .parse()
1469 .unwrap()
1470 };
1471
1472 let remaining = self.total_workflows.saturating_sub(offset);
1473 let count = remaining.min(self.page_size);
1474 let new_offset = offset + count;
1475
1476 let executions: Vec<_> = (offset..offset + count)
1477 .map(|i| workflow::WorkflowExecutionInfo {
1478 execution: Some(ProtoWorkflowExecution {
1479 workflow_id: format!("wf-{i}"),
1480 run_id: format!("run-{i}"),
1481 }),
1482 r#type: Some(WorkflowType {
1483 name: "TestWorkflow".to_string(),
1484 }),
1485 task_queue: "test-queue".to_string(),
1486 ..Default::default()
1487 })
1488 .collect();
1489
1490 let next_page_token = if new_offset < self.total_workflows {
1491 new_offset.to_string().into_bytes()
1492 } else {
1493 vec![]
1494 };
1495
1496 async move {
1497 Ok(Response::new(ListWorkflowExecutionsResponse {
1498 executions,
1499 next_page_token,
1500 }))
1501 }
1502 .boxed()
1503 }
1504 }
1505
1506 #[tokio::test]
1507 async fn list_workflows_paginates_through_all_results() {
1508 let call_count = Arc::new(AtomicUsize::new(0));
1509 let client = MockListWorkflowsClient {
1510 call_count: call_count.clone(),
1511 page_size: 3,
1512 total_workflows: 10,
1513 };
1514
1515 let stream = client.list_workflows("", WorkflowListOptions::default());
1516 let results: Vec<_> = stream.collect().await;
1517
1518 assert_eq!(results.len(), 10);
1519 for (i, result) in results.iter().enumerate() {
1520 let wf = result.as_ref().unwrap();
1521 assert_eq!(wf.id(), format!("wf-{i}"));
1522 assert_eq!(wf.run_id(), format!("run-{i}"));
1523 }
1524 assert_eq!(call_count.load(Ordering::SeqCst), 4);
1526 }
1527
1528 #[tokio::test]
1529 async fn list_workflows_respects_limit() {
1530 let call_count = Arc::new(AtomicUsize::new(0));
1531 let client = MockListWorkflowsClient {
1532 call_count: call_count.clone(),
1533 page_size: 3,
1534 total_workflows: 10,
1535 };
1536
1537 let opts = WorkflowListOptions::builder().limit(5).build();
1538 let stream = client.list_workflows("", opts);
1539 let results: Vec<_> = stream.collect().await;
1540
1541 assert_eq!(results.len(), 5);
1542 for (i, result) in results.iter().enumerate() {
1543 let wf = result.as_ref().unwrap();
1544 assert_eq!(wf.id(), format!("wf-{i}"));
1545 }
1546 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1548 }
1549
1550 #[tokio::test]
1551 async fn list_workflows_limit_less_than_page_size() {
1552 let call_count = Arc::new(AtomicUsize::new(0));
1553 let client = MockListWorkflowsClient {
1554 call_count: call_count.clone(),
1555 page_size: 10,
1556 total_workflows: 100,
1557 };
1558
1559 let opts = WorkflowListOptions::builder().limit(3).build();
1560 let stream = client.list_workflows("", opts);
1561 let results: Vec<_> = stream.collect().await;
1562
1563 assert_eq!(results.len(), 3);
1564 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1566 }
1567
1568 #[tokio::test]
1569 async fn list_workflows_empty_results() {
1570 let call_count = Arc::new(AtomicUsize::new(0));
1571 let client = MockListWorkflowsClient {
1572 call_count: call_count.clone(),
1573 page_size: 10,
1574 total_workflows: 0,
1575 };
1576
1577 let stream = client.list_workflows("", WorkflowListOptions::default());
1578 let results: Vec<_> = stream.collect().await;
1579
1580 assert_eq!(results.len(), 0);
1581 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1582 }
1583 }
1584}