1#![warn(missing_docs)] #[macro_use]
8extern crate tracing;
9
10mod async_activity_handle;
11pub mod callback_based;
12mod dns;
13#[cfg(feature = "envconfig")]
15pub mod envconfig;
16pub mod errors;
17pub mod grpc;
18mod metrics;
19mod options_structs;
20#[doc(hidden)]
22pub mod proxy;
23mod replaceable;
24pub mod request_extensions;
25mod retry;
26pub mod schedules;
28pub mod worker;
29mod workflow_handle;
30
31pub use crate::{
32 proxy::HttpConnectProxyOptions,
33 retry::{CallType, RETRYABLE_ERROR_CODES},
34};
35pub use async_activity_handle::{
36 ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle,
37};
38
39pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
40pub use options_structs::*;
41pub use replaceable::SharedReplaceableClient;
42pub use retry::RetryOptions;
43pub mod danger {
45 pub use tokio_rustls::rustls::client::danger::ServerCertVerifier;
49}
50pub use tonic;
51pub use workflow_handle::{
52 UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle,
53 WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
54 WorkflowHistory, WorkflowUpdateHandle,
55};
56
57use crate::{
58 grpc::{
59 AttachMetricLabels, CloudService, HealthService, OperatorService, TestService,
60 WorkflowService,
61 },
62 metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
63 request_extensions::RequestExt,
64 worker::ClientWorkerSet,
65};
66use errors::*;
67use futures_util::{stream, stream::Stream};
68use http::Uri;
69use parking_lot::RwLock;
70use std::{
71 collections::{HashMap, VecDeque},
72 fmt::Debug,
73 pin::Pin,
74 str::FromStr,
75 sync::{Arc, OnceLock},
76 task::{Context, Poll},
77 time::{Duration, SystemTime},
78};
79use temporalio_common::{
80 HasWorkflowDefinition,
81 data_converters::{
82 DataConverter, GenericPayloadConverter, PayloadConverter, SerializationContext,
83 SerializationContextData,
84 },
85 protos::{
86 coresdk::IntoPayloadsExt,
87 grpc::health::v1::health_client::HealthClient,
88 proto_ts_to_system_time,
89 temporal::api::{
90 cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
91 common::v1::{Memo, Payload, SearchAttributes, WorkflowType},
92 enums::v1::{TaskQueueKind, WorkflowExecutionStatus},
93 errordetails::v1::WorkflowExecutionAlreadyStartedFailure,
94 operatorservice::v1::operator_service_client::OperatorServiceClient,
95 sdk::v1::UserMetadata,
96 taskqueue::v1::TaskQueue,
97 testservice::v1::test_service_client::TestServiceClient,
98 workflow::v1 as workflow,
99 workflowservice::v1::{
100 count_workflow_executions_response, workflow_service_client::WorkflowServiceClient,
101 *,
102 },
103 },
104 utilities::decode_status_detail,
105 },
106};
107use tonic::{
108 Code, IntoRequest,
109 body::Body,
110 client::GrpcService,
111 codec::CompressionEncoding,
112 codegen::InterceptedService,
113 metadata::{
114 AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
115 MetadataValue,
116 },
117 service::Interceptor,
118 transport::{Certificate, Endpoint, Identity},
119};
120use tower::ServiceBuilder;
121use uuid::Uuid;
122
123static CLIENT_NAME_HEADER_KEY: &str = "client-name";
124static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
125static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
126
127#[doc(hidden)]
128pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
130#[doc(hidden)]
131pub static ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT: &str = "short-circuit";
133
134const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
136const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
137const VERSION: &str = env!("CARGO_PKG_VERSION");
138
139#[derive(Clone)]
144pub struct Connection {
145 inner: Arc<ConnectionInner>,
146}
147
148#[derive(Clone)]
149struct ConnectionInner {
150 service: TemporalServiceClient,
151 retry_options: RetryOptions,
152 identity: String,
153 headers: Arc<RwLock<ClientHeaders>>,
154 client_name: String,
155 client_version: String,
156 capabilities: Option<get_system_info_response::Capabilities>,
158 workers: Arc<ClientWorkerSet>,
159 _dns_task: Option<Arc<dns::DnsReresolutionHandle>>,
160}
161
162impl Connection {
163 pub async fn connect(options: ConnectionOptions) -> Result<Self, ClientConnectError> {
165 let dns_lb_opts = dns::validate_and_get_dns_lb(&options)?.cloned();
166 let compression = if options.service_override.is_some() {
169 GrpcCompression::None
170 } else {
171 options.grpc_compression
172 };
173 let (service, dns_task) = if let Some(service_override) = options.service_override {
174 (
175 GrpcMetricSvc {
176 inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
177 metrics: options.metrics_meter.clone().map(MetricsContext::new),
178 disable_errcode_label: options.disable_error_code_metric_tags,
179 },
180 None,
181 )
182 } else if let Some(dns_opts) = &dns_lb_opts {
183 let (channel, sender) = dns::create_balanced_channel(&options).await?;
184 let handle = dns::spawn_dns_reresolution(
185 sender,
186 options.target.clone(),
187 options.tls_options.clone(),
188 options.keep_alive.clone(),
189 options.override_origin.clone(),
190 dns_opts.resolution_interval,
191 );
192 (
193 ServiceBuilder::new()
194 .layer_fn(move |channel| GrpcMetricSvc {
195 inner: ChannelOrGrpcOverride::Channel(channel),
196 metrics: options.metrics_meter.clone().map(MetricsContext::new),
197 disable_errcode_label: options.disable_error_code_metric_tags,
198 })
199 .service(channel),
200 Some(handle),
201 )
202 } else {
203 let channel = Endpoint::from_shared(options.target.to_string())?;
204 let channel = add_tls_to_channel(options.tls_options.as_ref(), channel).await?;
205 let channel = if let Some(keep_alive) = options.keep_alive.as_ref() {
206 channel
207 .keep_alive_while_idle(true)
208 .http2_keep_alive_interval(keep_alive.interval)
209 .keep_alive_timeout(keep_alive.timeout)
210 } else {
211 channel
212 };
213 let channel = if let Some(origin) = options.override_origin.clone() {
214 channel.origin(origin)
215 } else {
216 channel
217 };
218 let channel = if let Some(proxy) = options.http_connect_proxy.as_ref() {
220 proxy.connect_endpoint(&channel).await?
221 } else {
222 channel.connect().await?
223 };
224 (
225 ServiceBuilder::new()
226 .layer_fn(move |channel| GrpcMetricSvc {
227 inner: ChannelOrGrpcOverride::Channel(channel),
228 metrics: options.metrics_meter.clone().map(MetricsContext::new),
229 disable_errcode_label: options.disable_error_code_metric_tags,
230 })
231 .service(channel),
232 None,
233 )
234 };
235
236 let headers = Arc::new(RwLock::new(ClientHeaders {
237 user_headers: parse_ascii_headers(options.headers.clone().unwrap_or_default())?,
238 user_binary_headers: parse_binary_headers(
239 options.binary_headers.clone().unwrap_or_default(),
240 )?,
241 api_key: options.api_key.clone(),
242 }));
243 let interceptor = ServiceCallInterceptor {
244 client_name: options.client_name.clone(),
245 client_version: options.client_version.clone(),
246 headers: headers.clone(),
247 };
248 let svc = InterceptedService::new(service, interceptor);
249 let mut svc_client = TemporalServiceClient::new(svc, compression);
250
251 let capabilities = if !options.skip_get_system_info {
252 match svc_client
253 .get_system_info(GetSystemInfoRequest::default().into_request())
254 .await
255 {
256 Ok(sysinfo) => sysinfo.into_inner().capabilities,
257 Err(status) => match status.code() {
258 Code::Unimplemented => None,
259 _ => return Err(ClientConnectError::SystemInfoCallError(status)),
260 },
261 }
262 } else {
263 None
264 };
265 Ok(Self {
266 inner: Arc::new(ConnectionInner {
267 service: svc_client,
268 retry_options: options.retry_options,
269 identity: options.identity,
270 headers,
271 client_name: options.client_name,
272 client_version: options.client_version,
273 capabilities,
274 workers: Arc::new(ClientWorkerSet::new()),
275 _dns_task: dns_task,
276 }),
277 })
278 }
279
280 pub fn set_api_key(&self, api_key: Option<String>) {
282 self.inner.headers.write().api_key = api_key;
283 }
284
285 pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
294 self.inner.headers.write().user_headers = parse_ascii_headers(headers)?;
295 Ok(())
296 }
297
298 pub fn set_binary_headers(
307 &self,
308 binary_headers: HashMap<String, Vec<u8>>,
309 ) -> Result<(), InvalidHeaderError> {
310 self.inner.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
311 Ok(())
312 }
313
314 pub fn client_name(&self) -> &str {
316 &self.inner.client_name
317 }
318
319 pub fn client_version(&self) -> &str {
321 &self.inner.client_version
322 }
323
324 pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
327 self.inner.capabilities.as_ref()
328 }
329
330 pub fn retry_options_mut(&mut self) -> &mut RetryOptions {
335 &mut Arc::make_mut(&mut self.inner).retry_options
336 }
337
338 pub fn identity(&self) -> &str {
340 &self.inner.identity
341 }
342
343 pub fn identity_mut(&mut self) -> &mut String {
348 &mut Arc::make_mut(&mut self.inner).identity
349 }
350
351 pub fn workers(&self) -> Arc<ClientWorkerSet> {
353 self.inner.workers.clone()
354 }
355
356 pub fn worker_grouping_key(&self) -> Uuid {
358 self.inner.workers.worker_grouping_key()
359 }
360
361 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
363 self.inner.service.workflow_service()
364 }
365
366 pub fn operator_service(&self) -> Box<dyn OperatorService> {
368 self.inner.service.operator_service()
369 }
370
371 pub fn cloud_service(&self) -> Box<dyn CloudService> {
373 self.inner.service.cloud_service()
374 }
375
376 pub fn test_service(&self) -> Box<dyn TestService> {
378 self.inner.service.test_service()
379 }
380
381 pub fn health_service(&self) -> Box<dyn HealthService> {
383 self.inner.service.health_service()
384 }
385}
386
387#[derive(Debug)]
388struct ClientHeaders {
389 user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
390 user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
391 api_key: Option<String>,
392}
393
394impl ClientHeaders {
395 fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
396 for (key, val) in self.user_headers.iter() {
397 if !metadata.contains_key(key) {
399 metadata.insert(key, val.clone());
400 }
401 }
402 for (key, val) in self.user_binary_headers.iter() {
403 if !metadata.contains_key(key) {
405 metadata.insert_bin(key, val.clone());
406 }
407 }
408 if let Some(api_key) = &self.api_key {
409 if !metadata.contains_key("authorization")
411 && let Ok(val) = format!("Bearer {api_key}").parse()
412 {
413 metadata.insert("authorization", val);
414 }
415 }
416 }
417}
418
419async fn add_tls_to_channel(
422 tls_options: Option<&TlsOptions>,
423 mut channel: Endpoint,
424) -> Result<Endpoint, ClientConnectError> {
425 if let Some(tls_cfg) = tls_options {
426 if tls_cfg.server_cert_verifier.is_some() && tls_cfg.server_root_ca_cert.is_some() {
427 return Err(ClientConnectError::InvalidConfig(
428 "Cannot set both `server_root_ca_cert` and `server_cert_verifier`".to_owned(),
429 ));
430 }
431
432 let mut tls = tonic::transport::ClientTlsConfig::new();
433
434 if tls_cfg.server_cert_verifier.is_none() {
435 if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
436 let server_root_ca_cert = Certificate::from_pem(root_cert);
437 tls = tls.ca_certificate(server_root_ca_cert);
438 } else {
439 tls = tls.with_native_roots();
440 }
441 }
442
443 if let Some(domain) = &tls_cfg.domain {
444 tls = tls.domain_name(domain);
445
446 let uri: Uri = format!("https://{domain}").parse()?;
451 channel = channel.origin(uri);
452 }
453
454 if let Some(client_opts) = &tls_cfg.client_tls_options {
455 let client_identity =
456 Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
457 tls = tls.identity(client_identity);
458 }
459
460 return if let Some(verifier) = &tls_cfg.server_cert_verifier {
461 channel
462 .tls_config_with_verifier(tls, verifier.clone())
463 .map_err(Into::into)
464 } else {
465 channel.tls_config(tls).map_err(Into::into)
466 };
467 }
468 Ok(channel)
469}
470
471fn parse_ascii_headers(
472 headers: HashMap<String, String>,
473) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
474 let mut parsed_headers = HashMap::with_capacity(headers.len());
475 for (k, v) in headers.into_iter() {
476 let key = match AsciiMetadataKey::from_str(&k) {
477 Ok(key) => key,
478 Err(err) => {
479 return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
480 key: k,
481 source: err,
482 });
483 }
484 };
485 let value = match MetadataValue::from_str(&v) {
486 Ok(value) => value,
487 Err(err) => {
488 return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
489 key: k,
490 value: v,
491 source: err,
492 });
493 }
494 };
495 parsed_headers.insert(key, value);
496 }
497
498 Ok(parsed_headers)
499}
500
501fn parse_binary_headers(
502 headers: HashMap<String, Vec<u8>>,
503) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
504 let mut parsed_headers = HashMap::with_capacity(headers.len());
505 for (k, v) in headers.into_iter() {
506 let key = match BinaryMetadataKey::from_str(&k) {
507 Ok(key) => key,
508 Err(err) => {
509 return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
510 key: k,
511 source: err,
512 });
513 }
514 };
515 let value = BinaryMetadataValue::from_bytes(&v);
516 parsed_headers.insert(key, value);
517 }
518
519 Ok(parsed_headers)
520}
521
522#[derive(Clone)]
524pub struct ServiceCallInterceptor {
525 client_name: String,
526 client_version: String,
527 headers: Arc<RwLock<ClientHeaders>>,
529}
530
531impl Interceptor for ServiceCallInterceptor {
532 fn call(
535 &mut self,
536 mut request: tonic::Request<()>,
537 ) -> Result<tonic::Request<()>, tonic::Status> {
538 let metadata = request.metadata_mut();
539 if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
540 metadata.insert(
541 CLIENT_NAME_HEADER_KEY,
542 self.client_name
543 .parse()
544 .unwrap_or_else(|_| MetadataValue::from_static("")),
545 );
546 }
547 if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
548 metadata.insert(
549 CLIENT_VERSION_HEADER_KEY,
550 self.client_version
551 .parse()
552 .unwrap_or_else(|_| MetadataValue::from_static("")),
553 );
554 }
555 self.headers.read().apply_to_metadata(metadata);
556 request.set_default_timeout(OTHER_CALL_TIMEOUT);
557
558 Ok(request)
559 }
560}
561
562#[derive(Clone)]
564pub struct TemporalServiceClient {
565 workflow_svc_client: Box<dyn WorkflowService>,
566 operator_svc_client: Box<dyn OperatorService>,
567 cloud_svc_client: Box<dyn CloudService>,
568 test_svc_client: Box<dyn TestService>,
569 health_svc_client: Box<dyn HealthService>,
570}
571
572fn get_decode_max_size() -> usize {
575 static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
576 *_DECODE_MAX_SIZE.get_or_init(|| {
577 std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
578 .ok()
579 .and_then(|s| s.parse().ok())
580 .unwrap_or(128 * 1024 * 1024)
581 })
582}
583
584impl TemporalServiceClient {
585 fn new<T>(svc: T, compression: GrpcCompression) -> Self
586 where
587 T: GrpcService<Body> + Send + Sync + Clone + 'static,
588 T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
589 T::Error: Into<tonic::codegen::StdError>,
590 <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
591 <T as GrpcService<Body>>::Future: Send,
592 {
593 macro_rules! configure {
596 ($client:expr) => {{
597 let client = $client.max_decoding_message_size(get_decode_max_size());
598 match compression {
599 GrpcCompression::Gzip => client
600 .send_compressed(CompressionEncoding::Gzip)
601 .accept_compressed(CompressionEncoding::Gzip),
602 GrpcCompression::None => client,
603 }
604 }};
605 }
606
607 let workflow_svc_client = Box::new(configure!(WorkflowServiceClient::new(svc.clone())));
608 let operator_svc_client = Box::new(configure!(OperatorServiceClient::new(svc.clone())));
609 let cloud_svc_client = Box::new(configure!(CloudServiceClient::new(svc.clone())));
610 let test_svc_client = Box::new(configure!(TestServiceClient::new(svc.clone())));
611 let health_svc_client = Box::new(configure!(HealthClient::new(svc.clone())));
612
613 Self {
614 workflow_svc_client,
615 operator_svc_client,
616 cloud_svc_client,
617 test_svc_client,
618 health_svc_client,
619 }
620 }
621
622 pub fn from_services(
625 workflow: Box<dyn WorkflowService>,
626 operator: Box<dyn OperatorService>,
627 cloud: Box<dyn CloudService>,
628 test: Box<dyn TestService>,
629 health: Box<dyn HealthService>,
630 ) -> Self {
631 Self {
632 workflow_svc_client: workflow,
633 operator_svc_client: operator,
634 cloud_svc_client: cloud,
635 test_svc_client: test,
636 health_svc_client: health,
637 }
638 }
639
640 pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
642 self.workflow_svc_client.clone()
643 }
644 pub fn operator_service(&self) -> Box<dyn OperatorService> {
646 self.operator_svc_client.clone()
647 }
648 pub fn cloud_service(&self) -> Box<dyn CloudService> {
650 self.cloud_svc_client.clone()
651 }
652 pub fn test_service(&self) -> Box<dyn TestService> {
654 self.test_svc_client.clone()
655 }
656 pub fn health_service(&self) -> Box<dyn HealthService> {
658 self.health_svc_client.clone()
659 }
660}
661
662#[derive(Clone)]
665pub struct Client {
666 connection: Connection,
667 options: Arc<ClientOptions>,
668}
669
670impl Client {
671 pub fn new(connection: Connection, options: ClientOptions) -> Result<Self, ClientNewError> {
676 Ok(Client {
677 connection,
678 options: Arc::new(options),
679 })
680 }
681
682 pub fn options(&self) -> &ClientOptions {
684 &self.options
685 }
686
687 pub fn options_mut(&mut self) -> &mut ClientOptions {
692 Arc::make_mut(&mut self.options)
693 }
694
695 pub fn connection(&self) -> &Connection {
697 &self.connection
698 }
699
700 pub fn connection_mut(&mut self) -> &mut Connection {
702 &mut self.connection
703 }
704}
705
706impl Client {
710 pub async fn start_workflow<W>(
715 &self,
716 workflow: W,
717 input: W::Input,
718 options: WorkflowStartOptions,
719 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
720 where
721 W: HasWorkflowDefinition,
722 W::Input: Send,
723 {
724 WorkflowClientTrait::start_workflow(self, workflow, input, options).await
725 }
726
727 pub fn get_workflow_handle<W: HasWorkflowDefinition>(
731 &self,
732 workflow_id: impl Into<String>,
733 ) -> WorkflowHandle<Self, W> {
734 WorkflowClientTrait::get_workflow_handle(self, workflow_id)
735 }
736
737 pub fn list_workflows(
742 &self,
743 query: impl Into<String>,
744 opts: WorkflowListOptions,
745 ) -> ListWorkflowsStream {
746 WorkflowClientTrait::list_workflows(self, query, opts)
747 }
748
749 pub async fn count_workflows(
751 &self,
752 query: impl Into<String>,
753 opts: WorkflowCountOptions,
754 ) -> Result<WorkflowExecutionCount, ClientError> {
755 WorkflowClientTrait::count_workflows(self, query, opts).await
756 }
757
758 pub fn get_async_activity_handle(
762 &self,
763 identifier: ActivityIdentifier,
764 ) -> AsyncActivityHandle<Self> {
765 WorkflowClientTrait::get_async_activity_handle(self, identifier)
766 }
767}
768
769impl NamespacedClient for Client {
770 fn namespace(&self) -> String {
771 self.options.namespace.clone()
772 }
773
774 fn identity(&self) -> String {
775 self.connection.identity().to_owned()
776 }
777
778 fn data_converter(&self) -> &DataConverter {
779 &self.options.data_converter
780 }
781}
782
783#[derive(Clone)]
785pub enum Namespace {
786 Name(String),
788 Id(String),
790}
791
792impl Namespace {
793 pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
795 let (namespace, id) = match self {
796 Namespace::Name(n) => (n, "".to_owned()),
797 Namespace::Id(n) => ("".to_owned(), n),
798 };
799 DescribeNamespaceRequest {
800 namespace,
801 id,
802 weak_consistency: false,
803 }
804 }
805}
806
807pub(crate) trait WorkflowClientTrait: NamespacedClient {
810 fn start_workflow<W>(
812 &self,
813 workflow: W,
814 input: W::Input,
815 options: WorkflowStartOptions,
816 ) -> impl Future<Output = Result<WorkflowHandle<Self, W>, WorkflowStartError>>
817 where
818 Self: Sized,
819 W: HasWorkflowDefinition,
820 W::Input: Send;
821
822 fn get_workflow_handle<W: HasWorkflowDefinition>(
829 &self,
830 workflow_id: impl Into<String>,
831 ) -> WorkflowHandle<Self, W>
832 where
833 Self: Sized;
834
835 fn list_workflows(
839 &self,
840 query: impl Into<String>,
841 opts: WorkflowListOptions,
842 ) -> ListWorkflowsStream;
843
844 fn count_workflows(
846 &self,
847 query: impl Into<String>,
848 opts: WorkflowCountOptions,
849 ) -> impl Future<Output = Result<WorkflowExecutionCount, ClientError>>;
850
851 fn get_async_activity_handle(
855 &self,
856 identifier: ActivityIdentifier,
857 ) -> AsyncActivityHandle<Self>
858 where
859 Self: Sized;
860}
861
862pub trait NamespacedClient {
864 fn namespace(&self) -> String;
866 fn identity(&self) -> String;
868 fn data_converter(&self) -> &DataConverter {
871 static DEFAULT: OnceLock<DataConverter> = OnceLock::new();
872 DEFAULT.get_or_init(DataConverter::default)
873 }
874}
875
876#[derive(Debug, Clone)]
879pub struct WorkflowExecution {
880 raw: workflow::WorkflowExecutionInfo,
881}
882
883impl WorkflowExecution {
884 pub fn new(raw: workflow::WorkflowExecutionInfo) -> Self {
886 Self { raw }
887 }
888
889 pub fn id(&self) -> &str {
891 self.raw
892 .execution
893 .as_ref()
894 .map(|e| e.workflow_id.as_str())
895 .unwrap_or("")
896 }
897
898 pub fn run_id(&self) -> &str {
900 self.raw
901 .execution
902 .as_ref()
903 .map(|e| e.run_id.as_str())
904 .unwrap_or("")
905 }
906
907 pub fn workflow_type(&self) -> &str {
909 self.raw
910 .r#type
911 .as_ref()
912 .map(|t| t.name.as_str())
913 .unwrap_or("")
914 }
915
916 pub fn status(&self) -> WorkflowExecutionStatus {
918 self.raw.status()
919 }
920
921 pub fn start_time(&self) -> Option<SystemTime> {
923 self.raw
924 .start_time
925 .as_ref()
926 .and_then(proto_ts_to_system_time)
927 }
928
929 pub fn execution_time(&self) -> Option<SystemTime> {
931 self.raw
932 .execution_time
933 .as_ref()
934 .and_then(proto_ts_to_system_time)
935 }
936
937 pub fn close_time(&self) -> Option<SystemTime> {
939 self.raw
940 .close_time
941 .as_ref()
942 .and_then(proto_ts_to_system_time)
943 }
944
945 pub fn task_queue(&self) -> &str {
947 &self.raw.task_queue
948 }
949
950 pub fn history_length(&self) -> i64 {
952 self.raw.history_length
953 }
954
955 pub fn memo(&self) -> Option<&Memo> {
957 self.raw.memo.as_ref()
958 }
959
960 pub fn parent_id(&self) -> Option<&str> {
962 self.raw
963 .parent_execution
964 .as_ref()
965 .map(|e| e.workflow_id.as_str())
966 }
967
968 pub fn parent_run_id(&self) -> Option<&str> {
970 self.raw
971 .parent_execution
972 .as_ref()
973 .map(|e| e.run_id.as_str())
974 }
975
976 pub fn search_attributes(&self) -> Option<&SearchAttributes> {
978 self.raw.search_attributes.as_ref()
979 }
980
981 pub fn raw(&self) -> &workflow::WorkflowExecutionInfo {
983 &self.raw
984 }
985
986 pub fn into_raw(self) -> workflow::WorkflowExecutionInfo {
988 self.raw
989 }
990}
991
992impl From<workflow::WorkflowExecutionInfo> for WorkflowExecution {
993 fn from(raw: workflow::WorkflowExecutionInfo) -> Self {
994 Self::new(raw)
995 }
996}
997
998pub struct ListWorkflowsStream {
1001 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
1002}
1003
1004impl ListWorkflowsStream {
1005 fn new(
1006 inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
1007 ) -> Self {
1008 Self { inner }
1009 }
1010}
1011
1012impl Stream for ListWorkflowsStream {
1013 type Item = Result<WorkflowExecution, ClientError>;
1014
1015 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1016 self.inner.as_mut().poll_next(cx)
1017 }
1018}
1019
1020#[derive(Debug, Clone)]
1025pub struct WorkflowExecutionCount {
1026 count: usize,
1027 groups: Vec<WorkflowCountAggregationGroup>,
1028}
1029
1030impl WorkflowExecutionCount {
1031 pub(crate) fn from_response(resp: CountWorkflowExecutionsResponse) -> Self {
1032 Self {
1033 count: resp.count as usize,
1034 groups: resp
1035 .groups
1036 .into_iter()
1037 .map(WorkflowCountAggregationGroup::from_proto)
1038 .collect(),
1039 }
1040 }
1041
1042 pub fn count(&self) -> usize {
1045 self.count
1046 }
1047
1048 pub fn groups(&self) -> &[WorkflowCountAggregationGroup] {
1050 &self.groups
1051 }
1052}
1053
1054#[derive(Debug, Clone)]
1056pub struct WorkflowCountAggregationGroup {
1057 group_values: Vec<Payload>,
1058 count: usize,
1059}
1060
1061impl WorkflowCountAggregationGroup {
1062 fn from_proto(proto: count_workflow_executions_response::AggregationGroup) -> Self {
1063 Self {
1064 group_values: proto.group_values,
1065 count: proto.count as usize,
1066 }
1067 }
1068
1069 pub fn group_values(&self) -> &[Payload] {
1071 &self.group_values
1072 }
1073
1074 pub fn count(&self) -> usize {
1076 self.count
1077 }
1078}
1079
1080impl<T> WorkflowClientTrait for T
1081where
1082 T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static,
1083{
1084 async fn start_workflow<W>(
1085 &self,
1086 workflow: W,
1087 input: W::Input,
1088 options: WorkflowStartOptions,
1089 ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
1090 where
1091 W: HasWorkflowDefinition,
1092 W::Input: Send,
1093 {
1094 let payloads = self
1095 .data_converter()
1096 .to_payloads(&SerializationContextData::Workflow, &input)
1097 .await?;
1098 let namespace = self.namespace();
1099 let workflow_id = options.workflow_id.clone();
1100 let task_queue_name = options.task_queue.clone();
1101
1102 let user_metadata = if options.static_summary.is_some() || options.static_details.is_some()
1103 {
1104 let payload_converter = PayloadConverter::default();
1105 let context = SerializationContext {
1106 data: &SerializationContextData::Workflow,
1107 converter: &payload_converter,
1108 };
1109 Some(UserMetadata {
1110 summary: options.static_summary.map(|s| {
1111 payload_converter
1112 .to_payload(&context, &s)
1113 .expect("String-to-JSON payload serialization is infallible")
1114 }),
1115 details: options.static_details.map(|s| {
1116 payload_converter
1117 .to_payload(&context, &s)
1118 .expect("String-to-JSON payload serialization is infallible")
1119 }),
1120 })
1121 } else {
1122 None
1123 };
1124
1125 let run_id = if let Some(start_signal) = options.start_signal {
1126 let res = WorkflowService::signal_with_start_workflow_execution(
1128 &mut self.clone(),
1129 SignalWithStartWorkflowExecutionRequest {
1130 namespace: namespace.clone(),
1131 workflow_id: workflow_id.clone(),
1132 workflow_type: Some(WorkflowType {
1133 name: workflow.name().to_string(),
1134 }),
1135 task_queue: Some(TaskQueue {
1136 name: task_queue_name,
1137 kind: TaskQueueKind::Normal as i32,
1138 normal_name: "".to_string(),
1139 }),
1140 input: payloads.into_payloads(),
1141 signal_name: start_signal.signal_name,
1142 signal_input: start_signal.input,
1143 identity: self.identity(),
1144 request_id: Uuid::new_v4().to_string(),
1145 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1146 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1147 workflow_execution_timeout: options
1148 .execution_timeout
1149 .and_then(|d| d.try_into().ok()),
1150 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1151 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1152 search_attributes: options.search_attributes.map(|d| d.into()),
1153 cron_schedule: options.cron_schedule.unwrap_or_default(),
1154 header: options.header.or(start_signal.header),
1155 user_metadata,
1156 ..Default::default()
1157 }
1158 .into_request(),
1159 )
1160 .await?
1161 .into_inner();
1162 res.run_id
1163 } else {
1164 let res = self
1166 .clone()
1167 .start_workflow_execution(
1168 StartWorkflowExecutionRequest {
1169 namespace: namespace.clone(),
1170 input: payloads.into_payloads(),
1171 workflow_id: workflow_id.clone(),
1172 workflow_type: Some(WorkflowType {
1173 name: workflow.name().to_string(),
1174 }),
1175 task_queue: Some(TaskQueue {
1176 name: task_queue_name,
1177 kind: TaskQueueKind::Unspecified as i32,
1178 normal_name: "".to_string(),
1179 }),
1180 request_id: Uuid::new_v4().to_string(),
1181 workflow_id_reuse_policy: options.id_reuse_policy as i32,
1182 workflow_id_conflict_policy: options.id_conflict_policy as i32,
1183 workflow_execution_timeout: options
1184 .execution_timeout
1185 .and_then(|d| d.try_into().ok()),
1186 workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1187 workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1188 search_attributes: options.search_attributes.map(|d| d.into()),
1189 cron_schedule: options.cron_schedule.unwrap_or_default(),
1190 request_eager_execution: options.enable_eager_workflow_start,
1191 retry_policy: options.retry_policy,
1192 links: options.links,
1193 completion_callbacks: options.completion_callbacks,
1194 priority: Some(options.priority.into()),
1195 header: options.header,
1196 user_metadata,
1197 ..Default::default()
1198 }
1199 .into_request(),
1200 )
1201 .await
1202 .map_err(|status| {
1203 if status.code() == Code::AlreadyExists {
1204 let run_id =
1205 decode_status_detail::<WorkflowExecutionAlreadyStartedFailure>(
1206 status.details(),
1207 )
1208 .map(|f| f.run_id);
1209 WorkflowStartError::AlreadyStarted {
1210 run_id,
1211 source: status,
1212 }
1213 } else {
1214 WorkflowStartError::Rpc(status)
1215 }
1216 })?
1217 .into_inner();
1218 res.run_id
1219 };
1220
1221 Ok(WorkflowHandle::new(
1222 self.clone(),
1223 WorkflowExecutionInfo {
1224 namespace,
1225 workflow_id,
1226 run_id: Some(run_id.clone()),
1227 first_execution_run_id: Some(run_id),
1228 },
1229 ))
1230 }
1231
1232 fn get_workflow_handle<W: HasWorkflowDefinition>(
1233 &self,
1234 workflow_id: impl Into<String>,
1235 ) -> WorkflowHandle<Self, W>
1236 where
1237 Self: Sized,
1238 {
1239 WorkflowHandle::new(
1240 self.clone(),
1241 WorkflowExecutionInfo {
1242 namespace: self.namespace(),
1243 workflow_id: workflow_id.into(),
1244 run_id: None,
1245 first_execution_run_id: None,
1246 },
1247 )
1248 }
1249
1250 fn list_workflows(
1251 &self,
1252 query: impl Into<String>,
1253 opts: WorkflowListOptions,
1254 ) -> ListWorkflowsStream {
1255 let client = self.clone();
1256 let namespace = self.namespace();
1257 let query = query.into();
1258 let limit = opts.limit;
1259
1260 let initial_state = (Vec::new(), VecDeque::new(), 0, false);
1262
1263 let stream = stream::unfold(
1264 initial_state,
1265 move |(next_page_token, mut buffer, mut yielded, exhausted)| {
1266 let mut client = client.clone();
1267 let namespace = namespace.clone();
1268 let query = query.clone();
1269
1270 async move {
1271 if let Some(l) = limit
1272 && yielded >= l
1273 {
1274 return None;
1275 }
1276
1277 if let Some(exec) = buffer.pop_front() {
1278 yielded += 1;
1279 return Some((Ok(exec), (next_page_token, buffer, yielded, exhausted)));
1280 }
1281
1282 if exhausted {
1283 return None;
1284 }
1285
1286 let response = WorkflowService::list_workflow_executions(
1287 &mut client,
1288 ListWorkflowExecutionsRequest {
1289 namespace,
1290 page_size: 0, next_page_token: next_page_token.clone(),
1292 query,
1293 }
1294 .into_request(),
1295 )
1296 .await;
1297
1298 match response {
1299 Ok(resp) => {
1300 let resp = resp.into_inner();
1301 let new_exhausted = resp.next_page_token.is_empty();
1302 let new_token = resp.next_page_token;
1303
1304 buffer = resp
1305 .executions
1306 .into_iter()
1307 .map(WorkflowExecution::from)
1308 .collect();
1309
1310 if let Some(exec) = buffer.pop_front() {
1311 yielded += 1;
1312 Some((Ok(exec), (new_token, buffer, yielded, new_exhausted)))
1313 } else {
1314 None
1315 }
1316 }
1317 Err(e) => Some((Err(e.into()), (next_page_token, buffer, yielded, true))),
1318 }
1319 }
1320 },
1321 );
1322
1323 ListWorkflowsStream::new(Box::pin(stream))
1324 }
1325
1326 async fn count_workflows(
1327 &self,
1328 query: impl Into<String>,
1329 _opts: WorkflowCountOptions,
1330 ) -> Result<WorkflowExecutionCount, ClientError> {
1331 let resp = WorkflowService::count_workflow_executions(
1332 &mut self.clone(),
1333 CountWorkflowExecutionsRequest {
1334 namespace: self.namespace(),
1335 query: query.into(),
1336 }
1337 .into_request(),
1338 )
1339 .await?
1340 .into_inner();
1341
1342 Ok(WorkflowExecutionCount::from_response(resp))
1343 }
1344
1345 fn get_async_activity_handle(&self, identifier: ActivityIdentifier) -> AsyncActivityHandle<Self>
1346 where
1347 Self: Sized,
1348 {
1349 AsyncActivityHandle::new(self.clone(), identifier)
1350 }
1351}
1352
1353macro_rules! dbg_panic {
1354 ($($arg:tt)*) => {
1355 use tracing::error;
1356 error!($($arg)*);
1357 debug_assert!(false, $($arg)*);
1358 };
1359}
1360pub(crate) use dbg_panic;
1361
1362#[cfg(test)]
1363mod tests {
1364 use super::*;
1365 use tonic::metadata::Ascii;
1366 use url::Url;
1367
1368 #[test]
1369 fn applies_headers() {
1370 let headers = Arc::new(RwLock::new(ClientHeaders {
1372 user_headers: HashMap::new(),
1373 user_binary_headers: HashMap::new(),
1374 api_key: Some("my-api-key".to_owned()),
1375 }));
1376 headers.clone().write().user_headers.insert(
1377 "my-meta-key".parse().unwrap(),
1378 "my-meta-val".parse().unwrap(),
1379 );
1380 headers.clone().write().user_binary_headers.insert(
1381 "my-bin-meta-key-bin".parse().unwrap(),
1382 vec![1, 2, 3].try_into().unwrap(),
1383 );
1384 let mut interceptor = ServiceCallInterceptor {
1385 client_name: "cute-kitty".to_string(),
1386 client_version: "0.1.0".to_string(),
1387 headers: headers.clone(),
1388 };
1389
1390 let req = interceptor.call(tonic::Request::new(())).unwrap();
1392 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1393 assert_eq!(
1394 req.metadata().get("authorization").unwrap(),
1395 "Bearer my-api-key"
1396 );
1397 assert_eq!(
1398 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1399 vec![1, 2, 3].as_slice()
1400 );
1401
1402 let mut req = tonic::Request::new(());
1404 req.metadata_mut()
1405 .insert("my-meta-key", "my-meta-val2".parse().unwrap());
1406 req.metadata_mut()
1407 .insert("authorization", "my-api-key2".parse().unwrap());
1408 req.metadata_mut()
1409 .insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
1410 let req = interceptor.call(req).unwrap();
1411 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1412 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1413 assert_eq!(
1414 req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1415 vec![4, 5, 6].as_slice()
1416 );
1417
1418 headers.clone().write().user_headers.insert(
1420 "authorization".parse().unwrap(),
1421 "my-api-key3".parse().unwrap(),
1422 );
1423 let req = interceptor.call(tonic::Request::new(())).unwrap();
1424 assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1425 assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1426
1427 headers.clone().write().user_headers.clear();
1429 headers.clone().write().user_binary_headers.clear();
1430 headers.clone().write().api_key.take();
1431 let req = interceptor.call(tonic::Request::new(())).unwrap();
1432 assert!(!req.metadata().contains_key("my-meta-key"));
1433 assert!(!req.metadata().contains_key("authorization"));
1434 assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));
1435
1436 let mut req = tonic::Request::new(());
1438 req.metadata_mut()
1439 .insert("grpc-timeout", "1S".parse().unwrap());
1440 let req = interceptor.call(req).unwrap();
1441 assert_eq!(
1442 req.metadata().get("grpc-timeout").unwrap(),
1443 "1S".parse::<MetadataValue<Ascii>>().unwrap()
1444 );
1445 }
1446
1447 #[test]
1448 fn invalid_ascii_header_key() {
1449 let invalid_headers = {
1450 let mut h = HashMap::new();
1451 h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
1452 h
1453 };
1454
1455 let result = parse_ascii_headers(invalid_headers);
1456 assert!(result.is_err());
1457 assert_eq!(
1458 result.err().unwrap().to_string(),
1459 "Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
1460 );
1461 }
1462
1463 #[test]
1464 fn invalid_ascii_header_value() {
1465 let invalid_headers = {
1466 let mut h = HashMap::new();
1467 h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
1469 h
1470 };
1471
1472 let result = parse_ascii_headers(invalid_headers);
1473 assert!(result.is_err());
1474 assert_eq!(
1475 result.err().unwrap().to_string(),
1476 "Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
1477 );
1478 }
1479
1480 #[test]
1481 fn invalid_binary_header_key() {
1482 let invalid_headers = {
1483 let mut h = HashMap::new();
1484 h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
1485 h
1486 };
1487
1488 let result = parse_binary_headers(invalid_headers);
1489 assert!(result.is_err());
1490 assert_eq!(
1491 result.err().unwrap().to_string(),
1492 "Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
1493 );
1494 }
1495
1496 #[test]
1497 fn keep_alive_defaults() {
1498 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1499 .identity("enchicat".to_string())
1500 .client_name("cute-kitty".to_string())
1501 .client_version("0.1.0".to_string())
1502 .build();
1503 assert_eq!(
1504 opts.keep_alive.clone().unwrap().interval,
1505 ClientKeepAliveOptions::default().interval
1506 );
1507 assert_eq!(
1508 opts.keep_alive.clone().unwrap().timeout,
1509 ClientKeepAliveOptions::default().timeout
1510 );
1511
1512 let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1514 .identity("enchicat".to_string())
1515 .client_name("cute-kitty".to_string())
1516 .client_version("0.1.0".to_string())
1517 .keep_alive(None)
1518 .build();
1519 dbg!(&opts.keep_alive);
1520 assert!(opts.keep_alive.is_none());
1521 }
1522
1523 mod tls_custom_verifier_tests {
1524 use super::*;
1525 use tokio_rustls::rustls::{
1526 DigitallySignedStruct, Error as RustlsError, SignatureScheme,
1527 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
1528 pki_types::{CertificateDer, ServerName, UnixTime},
1529 };
1530
1531 #[derive(Debug)]
1534 struct MockVerifier;
1535
1536 impl ServerCertVerifier for MockVerifier {
1537 fn verify_server_cert(
1538 &self,
1539 _end_entity: &CertificateDer<'_>,
1540 _intermediates: &[CertificateDer<'_>],
1541 _server_name: &ServerName<'_>,
1542 _ocsp_response: &[u8],
1543 _now: UnixTime,
1544 ) -> Result<ServerCertVerified, RustlsError> {
1545 Ok(ServerCertVerified::assertion())
1546 }
1547
1548 fn verify_tls12_signature(
1549 &self,
1550 _message: &[u8],
1551 _cert: &CertificateDer<'_>,
1552 _dss: &DigitallySignedStruct,
1553 ) -> Result<HandshakeSignatureValid, RustlsError> {
1554 Ok(HandshakeSignatureValid::assertion())
1555 }
1556
1557 fn verify_tls13_signature(
1558 &self,
1559 _message: &[u8],
1560 _cert: &CertificateDer<'_>,
1561 _dss: &DigitallySignedStruct,
1562 ) -> Result<HandshakeSignatureValid, RustlsError> {
1563 Ok(HandshakeSignatureValid::assertion())
1564 }
1565
1566 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
1567 vec![
1568 SignatureScheme::ECDSA_NISTP256_SHA256,
1569 SignatureScheme::RSA_PSS_SHA256,
1570 ]
1571 }
1572 }
1573
1574 #[tokio::test]
1575 async fn add_tls_to_channel_with_custom_verifier() {
1576 let tls_opts = TlsOptions {
1577 server_cert_verifier: Some(Arc::new(MockVerifier)),
1578 domain: Some("test.temporal.io".to_string()),
1579 ..Default::default()
1580 };
1581 let endpoint = tonic::transport::Channel::from_static("https://test.temporal.io:7233");
1582 let result = add_tls_to_channel(Some(&tls_opts), endpoint).await;
1583 assert!(
1584 result.is_ok(),
1585 "add_tls_to_channel should succeed with a custom verifier: {:?}",
1586 result.err()
1587 );
1588 }
1589
1590 #[tokio::test]
1591 async fn add_tls_to_channel_with_verifier_and_ca_cert_fails() {
1592 let tls_opts = TlsOptions {
1595 server_root_ca_cert: Some(b"some-ca-cert-bytes".to_vec()),
1596 server_cert_verifier: Some(Arc::new(MockVerifier)),
1597 domain: Some("test.temporal.io".to_string()),
1598 ..Default::default()
1599 };
1600 let endpoint = tonic::transport::Channel::from_static("https://test.temporal.io:7233");
1601 let result = add_tls_to_channel(Some(&tls_opts), endpoint).await;
1602 assert!(
1603 matches!(result, Err(ClientConnectError::InvalidConfig(_))),
1604 "add_tls_to_channel should fail with InvalidConfig when both CA cert and verifier are set: {:?}",
1605 result
1606 );
1607 }
1608
1609 #[tokio::test]
1610 async fn add_tls_to_channel_without_verifier_still_works() {
1611 let tls_opts = TlsOptions {
1613 domain: Some("test.temporal.io".to_string()),
1614 ..Default::default()
1615 };
1616 let endpoint = tonic::transport::Channel::from_static("https://test.temporal.io:7233");
1617 let result = add_tls_to_channel(Some(&tls_opts), endpoint).await;
1618 assert!(
1619 result.is_ok(),
1620 "add_tls_to_channel should succeed without a verifier (native roots): {:?}",
1621 result.err()
1622 );
1623 }
1624 }
1625
1626 mod list_workflows_tests {
1627 use super::*;
1628 use futures_util::{FutureExt, StreamExt};
1629 use std::sync::atomic::{AtomicUsize, Ordering};
1630 use temporalio_common::protos::temporal::api::common::v1::WorkflowExecution as ProtoWorkflowExecution;
1631 use tonic::{Request, Response};
1632
1633 #[derive(Clone)]
1634 struct MockListWorkflowsClient {
1635 call_count: Arc<AtomicUsize>,
1636 page_size: usize,
1638 total_workflows: usize,
1640 }
1641
1642 impl NamespacedClient for MockListWorkflowsClient {
1643 fn namespace(&self) -> String {
1644 "test-namespace".to_string()
1645 }
1646 fn identity(&self) -> String {
1647 "test-identity".to_string()
1648 }
1649 }
1650
1651 impl WorkflowService for MockListWorkflowsClient {
1652 fn list_workflow_executions(
1653 &mut self,
1654 request: Request<ListWorkflowExecutionsRequest>,
1655 ) -> futures_util::future::BoxFuture<
1656 '_,
1657 Result<Response<ListWorkflowExecutionsResponse>, tonic::Status>,
1658 > {
1659 self.call_count.fetch_add(1, Ordering::SeqCst);
1660 let req = request.into_inner();
1661
1662 let offset: usize = if req.next_page_token.is_empty() {
1664 0
1665 } else {
1666 String::from_utf8(req.next_page_token)
1667 .unwrap()
1668 .parse()
1669 .unwrap()
1670 };
1671
1672 let remaining = self.total_workflows.saturating_sub(offset);
1673 let count = remaining.min(self.page_size);
1674 let new_offset = offset + count;
1675
1676 let executions: Vec<_> = (offset..offset + count)
1677 .map(|i| workflow::WorkflowExecutionInfo {
1678 execution: Some(ProtoWorkflowExecution {
1679 workflow_id: format!("wf-{i}"),
1680 run_id: format!("run-{i}"),
1681 }),
1682 r#type: Some(WorkflowType {
1683 name: "TestWorkflow".to_string(),
1684 }),
1685 task_queue: "test-queue".to_string(),
1686 ..Default::default()
1687 })
1688 .collect();
1689
1690 let next_page_token = if new_offset < self.total_workflows {
1691 new_offset.to_string().into_bytes()
1692 } else {
1693 vec![]
1694 };
1695
1696 async move {
1697 Ok(Response::new(ListWorkflowExecutionsResponse {
1698 executions,
1699 next_page_token,
1700 }))
1701 }
1702 .boxed()
1703 }
1704 }
1705
1706 #[tokio::test]
1707 async fn list_workflows_paginates_through_all_results() {
1708 let call_count = Arc::new(AtomicUsize::new(0));
1709 let client = MockListWorkflowsClient {
1710 call_count: call_count.clone(),
1711 page_size: 3,
1712 total_workflows: 10,
1713 };
1714
1715 let stream = client.list_workflows("", WorkflowListOptions::default());
1716 let results: Vec<_> = stream.collect().await;
1717
1718 assert_eq!(results.len(), 10);
1719 for (i, result) in results.iter().enumerate() {
1720 let wf = result.as_ref().unwrap();
1721 assert_eq!(wf.id(), format!("wf-{i}"));
1722 assert_eq!(wf.run_id(), format!("run-{i}"));
1723 }
1724 assert_eq!(call_count.load(Ordering::SeqCst), 4);
1726 }
1727
1728 #[tokio::test]
1729 async fn list_workflows_respects_limit() {
1730 let call_count = Arc::new(AtomicUsize::new(0));
1731 let client = MockListWorkflowsClient {
1732 call_count: call_count.clone(),
1733 page_size: 3,
1734 total_workflows: 10,
1735 };
1736
1737 let opts = WorkflowListOptions::builder().limit(5).build();
1738 let stream = client.list_workflows("", opts);
1739 let results: Vec<_> = stream.collect().await;
1740
1741 assert_eq!(results.len(), 5);
1742 for (i, result) in results.iter().enumerate() {
1743 let wf = result.as_ref().unwrap();
1744 assert_eq!(wf.id(), format!("wf-{i}"));
1745 }
1746 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1748 }
1749
1750 #[tokio::test]
1751 async fn list_workflows_limit_less_than_page_size() {
1752 let call_count = Arc::new(AtomicUsize::new(0));
1753 let client = MockListWorkflowsClient {
1754 call_count: call_count.clone(),
1755 page_size: 10,
1756 total_workflows: 100,
1757 };
1758
1759 let opts = WorkflowListOptions::builder().limit(3).build();
1760 let stream = client.list_workflows("", opts);
1761 let results: Vec<_> = stream.collect().await;
1762
1763 assert_eq!(results.len(), 3);
1764 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1766 }
1767
1768 #[tokio::test]
1769 async fn list_workflows_empty_results() {
1770 let call_count = Arc::new(AtomicUsize::new(0));
1771 let client = MockListWorkflowsClient {
1772 call_count: call_count.clone(),
1773 page_size: 10,
1774 total_workflows: 0,
1775 };
1776
1777 let stream = client.list_workflows("", WorkflowListOptions::default());
1778 let results: Vec<_> = stream.collect().await;
1779
1780 assert_eq!(results.len(), 0);
1781 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1782 }
1783 }
1784}