Skip to main content

temporalio_client/
lib.rs

1#![warn(missing_docs)] // error if there are missing docs
2
3//! This crate contains client implementations that can be used to contact the Temporal service.
4//!
5//! It implements auto-retry behavior and metrics collection.
6
7#[macro_use]
8extern crate tracing;
9
10mod async_activity_handle;
11pub mod callback_based;
12mod dns;
13/// Configuration loading from environment variables and TOML files.
14#[cfg(feature = "envconfig")]
15pub mod envconfig;
16pub mod errors;
17pub mod grpc;
18mod metrics;
19mod options_structs;
20/// Visible only for tests
21#[doc(hidden)]
22pub mod proxy;
23mod replaceable;
24pub mod request_extensions;
25mod retry;
26/// Schedule operations: create, describe, update, pause, trigger, backfill, list, and delete.
27pub mod schedules;
28pub mod worker;
29mod workflow_handle;
30
31pub use crate::{
32    proxy::HttpConnectProxyOptions,
33    retry::{CallType, RETRYABLE_ERROR_CODES},
34};
35pub use async_activity_handle::{
36    ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle,
37};
38
39pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
40pub use options_structs::*;
41pub use replaceable::SharedReplaceableClient;
42pub use retry::RetryOptions;
43pub use tonic;
44pub use workflow_handle::{
45    UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle,
46    WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
47    WorkflowHistory, WorkflowUpdateHandle,
48};
49
50use crate::{
51    grpc::{
52        AttachMetricLabels, CloudService, HealthService, OperatorService, TestService,
53        WorkflowService,
54    },
55    metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
56    request_extensions::RequestExt,
57    worker::ClientWorkerSet,
58};
59use errors::*;
60use futures_util::{stream, stream::Stream};
61use http::Uri;
62use parking_lot::RwLock;
63use std::{
64    collections::{HashMap, VecDeque},
65    fmt::Debug,
66    pin::Pin,
67    str::FromStr,
68    sync::{Arc, OnceLock},
69    task::{Context, Poll},
70    time::{Duration, SystemTime},
71};
72use temporalio_common::{
73    HasWorkflowDefinition,
74    data_converters::{
75        DataConverter, GenericPayloadConverter, PayloadConverter, SerializationContext,
76        SerializationContextData,
77    },
78    protos::{
79        coresdk::IntoPayloadsExt,
80        grpc::health::v1::health_client::HealthClient,
81        proto_ts_to_system_time,
82        temporal::api::{
83            cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
84            common::v1::{Memo, Payload, SearchAttributes, WorkflowType},
85            enums::v1::{TaskQueueKind, WorkflowExecutionStatus},
86            errordetails::v1::WorkflowExecutionAlreadyStartedFailure,
87            operatorservice::v1::operator_service_client::OperatorServiceClient,
88            sdk::v1::UserMetadata,
89            taskqueue::v1::TaskQueue,
90            testservice::v1::test_service_client::TestServiceClient,
91            workflow::v1 as workflow,
92            workflowservice::v1::{
93                count_workflow_executions_response, workflow_service_client::WorkflowServiceClient,
94                *,
95            },
96        },
97        utilities::decode_status_detail,
98    },
99};
100use tonic::{
101    Code, IntoRequest,
102    body::Body,
103    client::GrpcService,
104    codegen::InterceptedService,
105    metadata::{
106        AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
107        MetadataValue,
108    },
109    service::Interceptor,
110    transport::{Certificate, Channel, Endpoint, Identity},
111};
112use tower::ServiceBuilder;
113use uuid::Uuid;
114
115static CLIENT_NAME_HEADER_KEY: &str = "client-name";
116static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
117static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
118
119#[doc(hidden)]
120/// Key used to communicate when a GRPC message is too large
121pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
122#[doc(hidden)]
123/// Key used to indicate a error was returned by the retryer because of the short-circuit predicate
124pub static ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT: &str = "short-circuit";
125
126/// The server times out polls after 60 seconds. Set our timeout to be slightly beyond that.
127const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
128const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
129const VERSION: &str = env!("CARGO_PKG_VERSION");
130
131/// A connection to the Temporal service.
132///
133/// Cloning a connection is cheap (single Arc increment). The underlying connection is shared
134/// between clones.
135#[derive(Clone)]
136pub struct Connection {
137    inner: Arc<ConnectionInner>,
138}
139
140#[derive(Clone)]
141struct ConnectionInner {
142    service: TemporalServiceClient,
143    retry_options: RetryOptions,
144    identity: String,
145    headers: Arc<RwLock<ClientHeaders>>,
146    client_name: String,
147    client_version: String,
148    /// Capabilities as read from the `get_system_info` RPC call made on client connection
149    capabilities: Option<get_system_info_response::Capabilities>,
150    workers: Arc<ClientWorkerSet>,
151    _dns_task: Option<Arc<dns::DnsReresolutionHandle>>,
152}
153
154impl Connection {
155    /// Connect to a Temporal service.
156    pub async fn connect(options: ConnectionOptions) -> Result<Self, ClientConnectError> {
157        let dns_lb_opts = dns::validate_and_get_dns_lb(&options)?.cloned();
158        let (service, dns_task) = if let Some(service_override) = options.service_override {
159            (
160                GrpcMetricSvc {
161                    inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
162                    metrics: options.metrics_meter.clone().map(MetricsContext::new),
163                    disable_errcode_label: options.disable_error_code_metric_tags,
164                },
165                None,
166            )
167        } else if let Some(dns_opts) = &dns_lb_opts {
168            let (channel, sender) = dns::create_balanced_channel(&options).await?;
169            let handle = dns::spawn_dns_reresolution(
170                sender,
171                options.target.clone(),
172                options.tls_options.clone(),
173                options.keep_alive.clone(),
174                options.override_origin.clone(),
175                dns_opts.resolution_interval,
176            );
177            (
178                ServiceBuilder::new()
179                    .layer_fn(move |channel| GrpcMetricSvc {
180                        inner: ChannelOrGrpcOverride::Channel(channel),
181                        metrics: options.metrics_meter.clone().map(MetricsContext::new),
182                        disable_errcode_label: options.disable_error_code_metric_tags,
183                    })
184                    .service(channel),
185                Some(handle),
186            )
187        } else {
188            let channel = Channel::from_shared(options.target.to_string())?;
189            let channel = add_tls_to_channel(options.tls_options.as_ref(), channel).await?;
190            let channel = if let Some(keep_alive) = options.keep_alive.as_ref() {
191                channel
192                    .keep_alive_while_idle(true)
193                    .http2_keep_alive_interval(keep_alive.interval)
194                    .keep_alive_timeout(keep_alive.timeout)
195            } else {
196                channel
197            };
198            let channel = if let Some(origin) = options.override_origin.clone() {
199                channel.origin(origin)
200            } else {
201                channel
202            };
203            // If there is a proxy, we have to connect that way
204            let channel = if let Some(proxy) = options.http_connect_proxy.as_ref() {
205                proxy.connect_endpoint(&channel).await?
206            } else {
207                channel.connect().await?
208            };
209            (
210                ServiceBuilder::new()
211                    .layer_fn(move |channel| GrpcMetricSvc {
212                        inner: ChannelOrGrpcOverride::Channel(channel),
213                        metrics: options.metrics_meter.clone().map(MetricsContext::new),
214                        disable_errcode_label: options.disable_error_code_metric_tags,
215                    })
216                    .service(channel),
217                None,
218            )
219        };
220
221        let headers = Arc::new(RwLock::new(ClientHeaders {
222            user_headers: parse_ascii_headers(options.headers.clone().unwrap_or_default())?,
223            user_binary_headers: parse_binary_headers(
224                options.binary_headers.clone().unwrap_or_default(),
225            )?,
226            api_key: options.api_key.clone(),
227        }));
228        let interceptor = ServiceCallInterceptor {
229            client_name: options.client_name.clone(),
230            client_version: options.client_version.clone(),
231            headers: headers.clone(),
232        };
233        let svc = InterceptedService::new(service, interceptor);
234        let mut svc_client = TemporalServiceClient::new(svc);
235
236        let capabilities = if !options.skip_get_system_info {
237            match svc_client
238                .get_system_info(GetSystemInfoRequest::default().into_request())
239                .await
240            {
241                Ok(sysinfo) => sysinfo.into_inner().capabilities,
242                Err(status) => match status.code() {
243                    Code::Unimplemented => None,
244                    _ => return Err(ClientConnectError::SystemInfoCallError(status)),
245                },
246            }
247        } else {
248            None
249        };
250        Ok(Self {
251            inner: Arc::new(ConnectionInner {
252                service: svc_client,
253                retry_options: options.retry_options,
254                identity: options.identity,
255                headers,
256                client_name: options.client_name,
257                client_version: options.client_version,
258                capabilities,
259                workers: Arc::new(ClientWorkerSet::new()),
260                _dns_task: dns_task,
261            }),
262        })
263    }
264
265    /// Set API key, overwriting any previous one.
266    pub fn set_api_key(&self, api_key: Option<String>) {
267        self.inner.headers.write().api_key = api_key;
268    }
269
270    /// Set HTTP request headers overwriting previous headers.
271    ///
272    /// This will not affect headers set via [ConnectionOptions::binary_headers].
273    ///
274    /// # Errors
275    ///
276    /// Will return an error if any of the provided keys or values are not valid gRPC metadata.
277    /// If an error is returned, the previous headers will remain unchanged.
278    pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
279        self.inner.headers.write().user_headers = parse_ascii_headers(headers)?;
280        Ok(())
281    }
282
283    /// Set binary HTTP request headers overwriting previous headers.
284    ///
285    /// This will not affect headers set via [ConnectionOptions::headers].
286    ///
287    /// # Errors
288    ///
289    /// Will return an error if any of the provided keys are not valid gRPC binary metadata keys.
290    /// If an error is returned, the previous headers will remain unchanged.
291    pub fn set_binary_headers(
292        &self,
293        binary_headers: HashMap<String, Vec<u8>>,
294    ) -> Result<(), InvalidHeaderError> {
295        self.inner.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
296        Ok(())
297    }
298
299    /// Returns the value used for the `client-name` header by this connection.
300    pub fn client_name(&self) -> &str {
301        &self.inner.client_name
302    }
303
304    /// Returns the value used for the `client-version` header by this connection.
305    pub fn client_version(&self) -> &str {
306        &self.inner.client_version
307    }
308
309    /// Returns the server capabilities we (may have) learned about when establishing an initial
310    /// connection
311    pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
312        self.inner.capabilities.as_ref()
313    }
314
315    /// Get a mutable reference to the retry options.
316    ///
317    /// Note: If this connection has been cloned, this will copy-on-write to avoid
318    /// affecting other clones.
319    pub fn retry_options_mut(&mut self) -> &mut RetryOptions {
320        &mut Arc::make_mut(&mut self.inner).retry_options
321    }
322
323    /// Get a reference to the connection identity.
324    pub fn identity(&self) -> &str {
325        &self.inner.identity
326    }
327
328    /// Get a mutable reference to the connection identity.
329    ///
330    /// Note: If this connection has been cloned, this will copy-on-write to avoid
331    /// affecting other clones.
332    pub fn identity_mut(&mut self) -> &mut String {
333        &mut Arc::make_mut(&mut self.inner).identity
334    }
335
336    /// Returns a reference to a registry with workers using this client instance.
337    pub fn workers(&self) -> Arc<ClientWorkerSet> {
338        self.inner.workers.clone()
339    }
340
341    /// Returns the client-wide key.
342    pub fn worker_grouping_key(&self) -> Uuid {
343        self.inner.workers.worker_grouping_key()
344    }
345
346    /// Get the underlying workflow service client for making raw gRPC calls.
347    pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
348        self.inner.service.workflow_service()
349    }
350
351    /// Get the underlying operator service client for making raw gRPC calls.
352    pub fn operator_service(&self) -> Box<dyn OperatorService> {
353        self.inner.service.operator_service()
354    }
355
356    /// Get the underlying cloud service client for making raw gRPC calls.
357    pub fn cloud_service(&self) -> Box<dyn CloudService> {
358        self.inner.service.cloud_service()
359    }
360
361    /// Get the underlying test service client for making raw gRPC calls.
362    pub fn test_service(&self) -> Box<dyn TestService> {
363        self.inner.service.test_service()
364    }
365
366    /// Get the underlying health service client for making raw gRPC calls.
367    pub fn health_service(&self) -> Box<dyn HealthService> {
368        self.inner.service.health_service()
369    }
370}
371
372#[derive(Debug)]
373struct ClientHeaders {
374    user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
375    user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
376    api_key: Option<String>,
377}
378
379impl ClientHeaders {
380    fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
381        for (key, val) in self.user_headers.iter() {
382            // Only if not already present
383            if !metadata.contains_key(key) {
384                metadata.insert(key, val.clone());
385            }
386        }
387        for (key, val) in self.user_binary_headers.iter() {
388            // Only if not already present
389            if !metadata.contains_key(key) {
390                metadata.insert_bin(key, val.clone());
391            }
392        }
393        if let Some(api_key) = &self.api_key {
394            // Only if not already present
395            if !metadata.contains_key("authorization")
396                && let Ok(val) = format!("Bearer {api_key}").parse()
397            {
398                metadata.insert("authorization", val);
399            }
400        }
401    }
402}
403
404/// If TLS is configured, set the appropriate options on the provided channel and return it.
405/// Passes it through if TLS options not set.
406async fn add_tls_to_channel(
407    tls_options: Option<&TlsOptions>,
408    mut channel: Endpoint,
409) -> Result<Endpoint, ClientConnectError> {
410    if let Some(tls_cfg) = tls_options {
411        let mut tls = tonic::transport::ClientTlsConfig::new();
412
413        if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
414            let server_root_ca_cert = Certificate::from_pem(root_cert);
415            tls = tls.ca_certificate(server_root_ca_cert);
416        } else {
417            tls = tls.with_native_roots();
418        }
419
420        if let Some(domain) = &tls_cfg.domain {
421            tls = tls.domain_name(domain);
422
423            // This song and dance ultimately is just to make sure the `:authority` header ends
424            // up correct on requests while we use TLS. Setting the header directly in our
425            // interceptor doesn't work since seemingly it is overridden at some point by
426            // something lower level.
427            let uri: Uri = format!("https://{domain}").parse()?;
428            channel = channel.origin(uri);
429        }
430
431        if let Some(client_opts) = &tls_cfg.client_tls_options {
432            let client_identity =
433                Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
434            tls = tls.identity(client_identity);
435        }
436
437        return channel.tls_config(tls).map_err(Into::into);
438    }
439    Ok(channel)
440}
441
442fn parse_ascii_headers(
443    headers: HashMap<String, String>,
444) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
445    let mut parsed_headers = HashMap::with_capacity(headers.len());
446    for (k, v) in headers.into_iter() {
447        let key = match AsciiMetadataKey::from_str(&k) {
448            Ok(key) => key,
449            Err(err) => {
450                return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
451                    key: k,
452                    source: err,
453                });
454            }
455        };
456        let value = match MetadataValue::from_str(&v) {
457            Ok(value) => value,
458            Err(err) => {
459                return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
460                    key: k,
461                    value: v,
462                    source: err,
463                });
464            }
465        };
466        parsed_headers.insert(key, value);
467    }
468
469    Ok(parsed_headers)
470}
471
472fn parse_binary_headers(
473    headers: HashMap<String, Vec<u8>>,
474) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
475    let mut parsed_headers = HashMap::with_capacity(headers.len());
476    for (k, v) in headers.into_iter() {
477        let key = match BinaryMetadataKey::from_str(&k) {
478            Ok(key) => key,
479            Err(err) => {
480                return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
481                    key: k,
482                    source: err,
483                });
484            }
485        };
486        let value = BinaryMetadataValue::from_bytes(&v);
487        parsed_headers.insert(key, value);
488    }
489
490    Ok(parsed_headers)
491}
492
493/// Interceptor which attaches common metadata (like "client-name") to every outgoing call
494#[derive(Clone)]
495pub struct ServiceCallInterceptor {
496    client_name: String,
497    client_version: String,
498    /// Only accessed as a reader
499    headers: Arc<RwLock<ClientHeaders>>,
500}
501
502impl Interceptor for ServiceCallInterceptor {
503    /// This function will get called on each outbound request. Returning a `Status` here will
504    /// cancel the request and have that status returned to the caller.
505    fn call(
506        &mut self,
507        mut request: tonic::Request<()>,
508    ) -> Result<tonic::Request<()>, tonic::Status> {
509        let metadata = request.metadata_mut();
510        if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
511            metadata.insert(
512                CLIENT_NAME_HEADER_KEY,
513                self.client_name
514                    .parse()
515                    .unwrap_or_else(|_| MetadataValue::from_static("")),
516            );
517        }
518        if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
519            metadata.insert(
520                CLIENT_VERSION_HEADER_KEY,
521                self.client_version
522                    .parse()
523                    .unwrap_or_else(|_| MetadataValue::from_static("")),
524            );
525        }
526        self.headers.read().apply_to_metadata(metadata);
527        request.set_default_timeout(OTHER_CALL_TIMEOUT);
528
529        Ok(request)
530    }
531}
532
533/// Aggregates various services exposed by the Temporal server
534#[derive(Clone)]
535pub struct TemporalServiceClient {
536    workflow_svc_client: Box<dyn WorkflowService>,
537    operator_svc_client: Box<dyn OperatorService>,
538    cloud_svc_client: Box<dyn CloudService>,
539    test_svc_client: Box<dyn TestService>,
540    health_svc_client: Box<dyn HealthService>,
541}
542
543/// We up the limit on incoming messages from server from the 4Mb default to 128Mb. If for
544/// whatever reason this needs to be changed by the user, we support overriding it via env var.
545fn get_decode_max_size() -> usize {
546    static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
547    *_DECODE_MAX_SIZE.get_or_init(|| {
548        std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
549            .ok()
550            .and_then(|s| s.parse().ok())
551            .unwrap_or(128 * 1024 * 1024)
552    })
553}
554
555impl TemporalServiceClient {
556    fn new<T>(svc: T) -> Self
557    where
558        T: GrpcService<Body> + Send + Sync + Clone + 'static,
559        T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
560        T::Error: Into<tonic::codegen::StdError>,
561        <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
562        <T as GrpcService<Body>>::Future: Send,
563    {
564        let workflow_svc_client = Box::new(
565            WorkflowServiceClient::new(svc.clone())
566                .max_decoding_message_size(get_decode_max_size()),
567        );
568        let operator_svc_client = Box::new(
569            OperatorServiceClient::new(svc.clone())
570                .max_decoding_message_size(get_decode_max_size()),
571        );
572        let cloud_svc_client = Box::new(
573            CloudServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
574        );
575        let test_svc_client = Box::new(
576            TestServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
577        );
578        let health_svc_client = Box::new(
579            HealthClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
580        );
581
582        Self {
583            workflow_svc_client,
584            operator_svc_client,
585            cloud_svc_client,
586            test_svc_client,
587            health_svc_client,
588        }
589    }
590
591    /// Create a service client from implementations of the individual underlying services. Useful
592    /// for mocking out service implementations.
593    pub fn from_services(
594        workflow: Box<dyn WorkflowService>,
595        operator: Box<dyn OperatorService>,
596        cloud: Box<dyn CloudService>,
597        test: Box<dyn TestService>,
598        health: Box<dyn HealthService>,
599    ) -> Self {
600        Self {
601            workflow_svc_client: workflow,
602            operator_svc_client: operator,
603            cloud_svc_client: cloud,
604            test_svc_client: test,
605            health_svc_client: health,
606        }
607    }
608
609    /// Get the underlying workflow service client
610    pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
611        self.workflow_svc_client.clone()
612    }
613    /// Get the underlying operator service client
614    pub fn operator_service(&self) -> Box<dyn OperatorService> {
615        self.operator_svc_client.clone()
616    }
617    /// Get the underlying cloud service client
618    pub fn cloud_service(&self) -> Box<dyn CloudService> {
619        self.cloud_svc_client.clone()
620    }
621    /// Get the underlying test service client
622    pub fn test_service(&self) -> Box<dyn TestService> {
623        self.test_svc_client.clone()
624    }
625    /// Get the underlying health service client
626    pub fn health_service(&self) -> Box<dyn HealthService> {
627        self.health_svc_client.clone()
628    }
629}
630
631/// Contains an instance of a namespace-bound client for interacting with the Temporal server.
632/// Cheap to clone.
633#[derive(Clone)]
634pub struct Client {
635    connection: Connection,
636    options: Arc<ClientOptions>,
637}
638
639impl Client {
640    /// Create a new client from a connection and options.
641    ///
642    /// Currently infallible, but returns a `Result` for future extensibility
643    /// (e.g., interceptor or plugin validation).
644    pub fn new(connection: Connection, options: ClientOptions) -> Result<Self, ClientNewError> {
645        Ok(Client {
646            connection,
647            options: Arc::new(options),
648        })
649    }
650
651    /// Return the options this client was initialized with
652    pub fn options(&self) -> &ClientOptions {
653        &self.options
654    }
655
656    /// Return this client's options mutably.
657    ///
658    /// Note: If this client has been cloned, this will copy-on-write to avoid affecting other
659    /// clones.
660    pub fn options_mut(&mut self) -> &mut ClientOptions {
661        Arc::make_mut(&mut self.options)
662    }
663
664    /// Returns a reference to the underlying connection
665    pub fn connection(&self) -> &Connection {
666        &self.connection
667    }
668
669    /// Returns a mutable reference to the underlying connection
670    pub fn connection_mut(&mut self) -> &mut Connection {
671        &mut self.connection
672    }
673}
674
675// High-level workflow operations on Client.
676// These forward to the internal WorkflowClientTrait blanket impl which is
677// available because Client implements WorkflowService + NamespacedClient + Clone.
678impl Client {
679    /// Start a workflow execution.
680    ///
681    /// Returns a [`WorkflowHandle`] that can be used to interact with the workflow
682    /// (e.g., get its result, send signals, query, etc.).
683    pub async fn start_workflow<W>(
684        &self,
685        workflow: W,
686        input: W::Input,
687        options: WorkflowStartOptions,
688    ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
689    where
690        W: HasWorkflowDefinition,
691        W::Input: Send,
692    {
693        WorkflowClientTrait::start_workflow(self, workflow, input, options).await
694    }
695
696    /// Get a handle to an existing workflow.
697    ///
698    /// For untyped access, use `get_workflow_handle::<UntypedWorkflow>(...)`.
699    pub fn get_workflow_handle<W: HasWorkflowDefinition>(
700        &self,
701        workflow_id: impl Into<String>,
702    ) -> WorkflowHandle<Self, W> {
703        WorkflowClientTrait::get_workflow_handle(self, workflow_id)
704    }
705
706    /// List workflows matching a query.
707    ///
708    /// Returns a stream that lazily paginates through results.
709    /// Use `limit` in options to cap the number of results returned.
710    pub fn list_workflows(
711        &self,
712        query: impl Into<String>,
713        opts: WorkflowListOptions,
714    ) -> ListWorkflowsStream {
715        WorkflowClientTrait::list_workflows(self, query, opts)
716    }
717
718    /// Count workflows matching a query.
719    pub async fn count_workflows(
720        &self,
721        query: impl Into<String>,
722        opts: WorkflowCountOptions,
723    ) -> Result<WorkflowExecutionCount, ClientError> {
724        WorkflowClientTrait::count_workflows(self, query, opts).await
725    }
726
727    /// Get a handle to complete an activity asynchronously.
728    ///
729    /// An activity returning `ActivityError::WillCompleteAsync` can be completed with this handle.
730    pub fn get_async_activity_handle(
731        &self,
732        identifier: ActivityIdentifier,
733    ) -> AsyncActivityHandle<Self> {
734        WorkflowClientTrait::get_async_activity_handle(self, identifier)
735    }
736}
737
738impl NamespacedClient for Client {
739    fn namespace(&self) -> String {
740        self.options.namespace.clone()
741    }
742
743    fn identity(&self) -> String {
744        self.connection.identity().to_owned()
745    }
746
747    fn data_converter(&self) -> &DataConverter {
748        &self.options.data_converter
749    }
750}
751
752/// Enum to help reference a namespace by either the namespace name or the namespace id
753#[derive(Clone)]
754pub enum Namespace {
755    /// Namespace name
756    Name(String),
757    /// Namespace id
758    Id(String),
759}
760
761impl Namespace {
762    /// Convert into grpc request
763    pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
764        let (namespace, id) = match self {
765            Namespace::Name(n) => (n, "".to_owned()),
766            Namespace::Id(n) => ("".to_owned(), n),
767        };
768        DescribeNamespaceRequest { namespace, id }
769    }
770}
771
772/// This trait provides higher-level friendlier interaction with the server.
773/// See the [WorkflowService] trait for a lower-level client.
774pub(crate) trait WorkflowClientTrait: NamespacedClient {
775    /// Start a workflow execution.
776    fn start_workflow<W>(
777        &self,
778        workflow: W,
779        input: W::Input,
780        options: WorkflowStartOptions,
781    ) -> impl Future<Output = Result<WorkflowHandle<Self, W>, WorkflowStartError>>
782    where
783        Self: Sized,
784        W: HasWorkflowDefinition,
785        W::Input: Send;
786
787    /// Get a handle to an existing workflow. `run_id` may be left blank to specify the most recent
788    /// execution having the provided `workflow_id`.
789    ///
790    /// For untyped access, use `get_workflow_handle::<UntypedWorkflow>(...)`.
791    ///
792    /// See also [WorkflowHandle::new], for specifying namespace or first_execution_run_id.
793    fn get_workflow_handle<W: HasWorkflowDefinition>(
794        &self,
795        workflow_id: impl Into<String>,
796    ) -> WorkflowHandle<Self, W>
797    where
798        Self: Sized;
799
800    /// List workflows matching a query.
801    /// Returns a stream that lazily paginates through results.
802    /// Use `limit` in options to cap the number of results returned.
803    fn list_workflows(
804        &self,
805        query: impl Into<String>,
806        opts: WorkflowListOptions,
807    ) -> ListWorkflowsStream;
808
809    /// Count workflows matching a query.
810    fn count_workflows(
811        &self,
812        query: impl Into<String>,
813        opts: WorkflowCountOptions,
814    ) -> impl Future<Output = Result<WorkflowExecutionCount, ClientError>>;
815
816    /// Get a handle to complete an activity asynchronously.
817    ///
818    /// An activity returning `ActivityError::WillCompleteAsync` can be completed with this handle.
819    fn get_async_activity_handle(
820        &self,
821        identifier: ActivityIdentifier,
822    ) -> AsyncActivityHandle<Self>
823    where
824        Self: Sized;
825}
826
827/// A client that is bound to a namespace
828pub trait NamespacedClient {
829    /// Returns the namespace this client is bound to
830    fn namespace(&self) -> String;
831    /// Returns the client identity
832    fn identity(&self) -> String;
833    /// Returns the data converter for serializing/deserializing payloads.
834    /// Default implementation returns a static default converter.
835    fn data_converter(&self) -> &DataConverter {
836        static DEFAULT: OnceLock<DataConverter> = OnceLock::new();
837        DEFAULT.get_or_init(DataConverter::default)
838    }
839}
840
841/// A workflow execution returned from list operations.
842/// This represents information about a workflow present in visibility.
843#[derive(Debug, Clone)]
844pub struct WorkflowExecution {
845    raw: workflow::WorkflowExecutionInfo,
846}
847
848impl WorkflowExecution {
849    /// Create a new WorkflowExecution from the raw proto.
850    pub fn new(raw: workflow::WorkflowExecutionInfo) -> Self {
851        Self { raw }
852    }
853
854    /// The workflow ID.
855    pub fn id(&self) -> &str {
856        self.raw
857            .execution
858            .as_ref()
859            .map(|e| e.workflow_id.as_str())
860            .unwrap_or("")
861    }
862
863    /// The run ID.
864    pub fn run_id(&self) -> &str {
865        self.raw
866            .execution
867            .as_ref()
868            .map(|e| e.run_id.as_str())
869            .unwrap_or("")
870    }
871
872    /// The workflow type name.
873    pub fn workflow_type(&self) -> &str {
874        self.raw
875            .r#type
876            .as_ref()
877            .map(|t| t.name.as_str())
878            .unwrap_or("")
879    }
880
881    /// The current status of the workflow execution.
882    pub fn status(&self) -> WorkflowExecutionStatus {
883        self.raw.status()
884    }
885
886    /// When the workflow was created.
887    pub fn start_time(&self) -> Option<SystemTime> {
888        self.raw
889            .start_time
890            .as_ref()
891            .and_then(proto_ts_to_system_time)
892    }
893
894    /// When the workflow run started or should start.
895    pub fn execution_time(&self) -> Option<SystemTime> {
896        self.raw
897            .execution_time
898            .as_ref()
899            .and_then(proto_ts_to_system_time)
900    }
901
902    /// When the workflow was closed, if closed.
903    pub fn close_time(&self) -> Option<SystemTime> {
904        self.raw
905            .close_time
906            .as_ref()
907            .and_then(proto_ts_to_system_time)
908    }
909
910    /// The task queue the workflow runs on.
911    pub fn task_queue(&self) -> &str {
912        &self.raw.task_queue
913    }
914
915    /// Number of events in history.
916    pub fn history_length(&self) -> i64 {
917        self.raw.history_length
918    }
919
920    /// Workflow memo.
921    pub fn memo(&self) -> Option<&Memo> {
922        self.raw.memo.as_ref()
923    }
924
925    /// Parent workflow ID, if this is a child workflow.
926    pub fn parent_id(&self) -> Option<&str> {
927        self.raw
928            .parent_execution
929            .as_ref()
930            .map(|e| e.workflow_id.as_str())
931    }
932
933    /// Parent run ID, if this is a child workflow.
934    pub fn parent_run_id(&self) -> Option<&str> {
935        self.raw
936            .parent_execution
937            .as_ref()
938            .map(|e| e.run_id.as_str())
939    }
940
941    /// Search attributes on the workflow.
942    pub fn search_attributes(&self) -> Option<&SearchAttributes> {
943        self.raw.search_attributes.as_ref()
944    }
945
946    /// Access the raw proto for additional fields not exposed via accessors.
947    pub fn raw(&self) -> &workflow::WorkflowExecutionInfo {
948        &self.raw
949    }
950
951    /// Consume the wrapper and return the raw proto.
952    pub fn into_raw(self) -> workflow::WorkflowExecutionInfo {
953        self.raw
954    }
955}
956
957impl From<workflow::WorkflowExecutionInfo> for WorkflowExecution {
958    fn from(raw: workflow::WorkflowExecutionInfo) -> Self {
959        Self::new(raw)
960    }
961}
962
963/// A stream of workflow executions from a list query.
964/// Internally paginates through results from the server.
965pub struct ListWorkflowsStream {
966    inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
967}
968
969impl ListWorkflowsStream {
970    fn new(
971        inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
972    ) -> Self {
973        Self { inner }
974    }
975}
976
977impl Stream for ListWorkflowsStream {
978    type Item = Result<WorkflowExecution, ClientError>;
979
980    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
981        self.inner.as_mut().poll_next(cx)
982    }
983}
984
985/// Result of a workflow count operation.
986///
987/// If the query includes a group-by clause, `groups` will contain the aggregated
988/// counts and `count` will be the sum of all group counts.
989#[derive(Debug, Clone)]
990pub struct WorkflowExecutionCount {
991    count: usize,
992    groups: Vec<WorkflowCountAggregationGroup>,
993}
994
995impl WorkflowExecutionCount {
996    pub(crate) fn from_response(resp: CountWorkflowExecutionsResponse) -> Self {
997        Self {
998            count: resp.count as usize,
999            groups: resp
1000                .groups
1001                .into_iter()
1002                .map(WorkflowCountAggregationGroup::from_proto)
1003                .collect(),
1004        }
1005    }
1006
1007    /// The approximate number of workflows matching the query.
1008    /// If grouping was applied, this is the sum of all group counts.
1009    pub fn count(&self) -> usize {
1010        self.count
1011    }
1012
1013    /// The groups if the query had a group-by clause, or empty if not.
1014    pub fn groups(&self) -> &[WorkflowCountAggregationGroup] {
1015        &self.groups
1016    }
1017}
1018
1019/// Aggregation group from a workflow count query with a group-by clause.
1020#[derive(Debug, Clone)]
1021pub struct WorkflowCountAggregationGroup {
1022    group_values: Vec<Payload>,
1023    count: usize,
1024}
1025
1026impl WorkflowCountAggregationGroup {
1027    fn from_proto(proto: count_workflow_executions_response::AggregationGroup) -> Self {
1028        Self {
1029            group_values: proto.group_values,
1030            count: proto.count as usize,
1031        }
1032    }
1033
1034    /// The search attribute values for this group.
1035    pub fn group_values(&self) -> &[Payload] {
1036        &self.group_values
1037    }
1038
1039    /// The approximate number of workflows matching for this group.
1040    pub fn count(&self) -> usize {
1041        self.count
1042    }
1043}
1044
1045impl<T> WorkflowClientTrait for T
1046where
1047    T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static,
1048{
1049    async fn start_workflow<W>(
1050        &self,
1051        workflow: W,
1052        input: W::Input,
1053        options: WorkflowStartOptions,
1054    ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
1055    where
1056        W: HasWorkflowDefinition,
1057        W::Input: Send,
1058    {
1059        let payloads = self
1060            .data_converter()
1061            .to_payloads(&SerializationContextData::Workflow, &input)
1062            .await?;
1063        let namespace = self.namespace();
1064        let workflow_id = options.workflow_id.clone();
1065        let task_queue_name = options.task_queue.clone();
1066
1067        let user_metadata = if options.static_summary.is_some() || options.static_details.is_some()
1068        {
1069            let payload_converter = PayloadConverter::default();
1070            let context = SerializationContext {
1071                data: &SerializationContextData::Workflow,
1072                converter: &payload_converter,
1073            };
1074            Some(UserMetadata {
1075                summary: options.static_summary.map(|s| {
1076                    payload_converter
1077                        .to_payload(&context, &s)
1078                        .expect("String-to-JSON payload serialization is infallible")
1079                }),
1080                details: options.static_details.map(|s| {
1081                    payload_converter
1082                        .to_payload(&context, &s)
1083                        .expect("String-to-JSON payload serialization is infallible")
1084                }),
1085            })
1086        } else {
1087            None
1088        };
1089
1090        let run_id = if let Some(start_signal) = options.start_signal {
1091            // Use signal-with-start when a start_signal is provided
1092            let res = WorkflowService::signal_with_start_workflow_execution(
1093                &mut self.clone(),
1094                SignalWithStartWorkflowExecutionRequest {
1095                    namespace: namespace.clone(),
1096                    workflow_id: workflow_id.clone(),
1097                    workflow_type: Some(WorkflowType {
1098                        name: workflow.name().to_string(),
1099                    }),
1100                    task_queue: Some(TaskQueue {
1101                        name: task_queue_name,
1102                        kind: TaskQueueKind::Normal as i32,
1103                        normal_name: "".to_string(),
1104                    }),
1105                    input: payloads.into_payloads(),
1106                    signal_name: start_signal.signal_name,
1107                    signal_input: start_signal.input,
1108                    identity: self.identity(),
1109                    request_id: Uuid::new_v4().to_string(),
1110                    workflow_id_reuse_policy: options.id_reuse_policy as i32,
1111                    workflow_id_conflict_policy: options.id_conflict_policy as i32,
1112                    workflow_execution_timeout: options
1113                        .execution_timeout
1114                        .and_then(|d| d.try_into().ok()),
1115                    workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1116                    workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1117                    search_attributes: options.search_attributes.map(|d| d.into()),
1118                    cron_schedule: options.cron_schedule.unwrap_or_default(),
1119                    header: options.header.or(start_signal.header),
1120                    user_metadata,
1121                    ..Default::default()
1122                }
1123                .into_request(),
1124            )
1125            .await?
1126            .into_inner();
1127            res.run_id
1128        } else {
1129            // Normal start workflow
1130            let res = self
1131                .clone()
1132                .start_workflow_execution(
1133                    StartWorkflowExecutionRequest {
1134                        namespace: namespace.clone(),
1135                        input: payloads.into_payloads(),
1136                        workflow_id: workflow_id.clone(),
1137                        workflow_type: Some(WorkflowType {
1138                            name: workflow.name().to_string(),
1139                        }),
1140                        task_queue: Some(TaskQueue {
1141                            name: task_queue_name,
1142                            kind: TaskQueueKind::Unspecified as i32,
1143                            normal_name: "".to_string(),
1144                        }),
1145                        request_id: Uuid::new_v4().to_string(),
1146                        workflow_id_reuse_policy: options.id_reuse_policy as i32,
1147                        workflow_id_conflict_policy: options.id_conflict_policy as i32,
1148                        workflow_execution_timeout: options
1149                            .execution_timeout
1150                            .and_then(|d| d.try_into().ok()),
1151                        workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1152                        workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1153                        search_attributes: options.search_attributes.map(|d| d.into()),
1154                        cron_schedule: options.cron_schedule.unwrap_or_default(),
1155                        request_eager_execution: options.enable_eager_workflow_start,
1156                        retry_policy: options.retry_policy,
1157                        links: options.links,
1158                        completion_callbacks: options.completion_callbacks,
1159                        priority: Some(options.priority.into()),
1160                        header: options.header,
1161                        user_metadata,
1162                        ..Default::default()
1163                    }
1164                    .into_request(),
1165                )
1166                .await
1167                .map_err(|status| {
1168                    if status.code() == Code::AlreadyExists {
1169                        let run_id =
1170                            decode_status_detail::<WorkflowExecutionAlreadyStartedFailure>(
1171                                status.details(),
1172                            )
1173                            .map(|f| f.run_id);
1174                        WorkflowStartError::AlreadyStarted {
1175                            run_id,
1176                            source: status,
1177                        }
1178                    } else {
1179                        WorkflowStartError::Rpc(status)
1180                    }
1181                })?
1182                .into_inner();
1183            res.run_id
1184        };
1185
1186        Ok(WorkflowHandle::new(
1187            self.clone(),
1188            WorkflowExecutionInfo {
1189                namespace,
1190                workflow_id,
1191                run_id: Some(run_id.clone()),
1192                first_execution_run_id: Some(run_id),
1193            },
1194        ))
1195    }
1196
1197    fn get_workflow_handle<W: HasWorkflowDefinition>(
1198        &self,
1199        workflow_id: impl Into<String>,
1200    ) -> WorkflowHandle<Self, W>
1201    where
1202        Self: Sized,
1203    {
1204        WorkflowHandle::new(
1205            self.clone(),
1206            WorkflowExecutionInfo {
1207                namespace: self.namespace(),
1208                workflow_id: workflow_id.into(),
1209                run_id: None,
1210                first_execution_run_id: None,
1211            },
1212        )
1213    }
1214
1215    fn list_workflows(
1216        &self,
1217        query: impl Into<String>,
1218        opts: WorkflowListOptions,
1219    ) -> ListWorkflowsStream {
1220        let client = self.clone();
1221        let namespace = self.namespace();
1222        let query = query.into();
1223        let limit = opts.limit;
1224
1225        // State: (next_page_token, buffer, yielded_count, exhausted)
1226        let initial_state = (Vec::new(), VecDeque::new(), 0, false);
1227
1228        let stream = stream::unfold(
1229            initial_state,
1230            move |(next_page_token, mut buffer, mut yielded, exhausted)| {
1231                let mut client = client.clone();
1232                let namespace = namespace.clone();
1233                let query = query.clone();
1234
1235                async move {
1236                    if let Some(l) = limit
1237                        && yielded >= l
1238                    {
1239                        return None;
1240                    }
1241
1242                    if let Some(exec) = buffer.pop_front() {
1243                        yielded += 1;
1244                        return Some((Ok(exec), (next_page_token, buffer, yielded, exhausted)));
1245                    }
1246
1247                    if exhausted {
1248                        return None;
1249                    }
1250
1251                    let response = WorkflowService::list_workflow_executions(
1252                        &mut client,
1253                        ListWorkflowExecutionsRequest {
1254                            namespace,
1255                            page_size: 0, // Use server default
1256                            next_page_token: next_page_token.clone(),
1257                            query,
1258                        }
1259                        .into_request(),
1260                    )
1261                    .await;
1262
1263                    match response {
1264                        Ok(resp) => {
1265                            let resp = resp.into_inner();
1266                            let new_exhausted = resp.next_page_token.is_empty();
1267                            let new_token = resp.next_page_token;
1268
1269                            buffer = resp
1270                                .executions
1271                                .into_iter()
1272                                .map(WorkflowExecution::from)
1273                                .collect();
1274
1275                            if let Some(exec) = buffer.pop_front() {
1276                                yielded += 1;
1277                                Some((Ok(exec), (new_token, buffer, yielded, new_exhausted)))
1278                            } else {
1279                                None
1280                            }
1281                        }
1282                        Err(e) => Some((Err(e.into()), (next_page_token, buffer, yielded, true))),
1283                    }
1284                }
1285            },
1286        );
1287
1288        ListWorkflowsStream::new(Box::pin(stream))
1289    }
1290
1291    async fn count_workflows(
1292        &self,
1293        query: impl Into<String>,
1294        _opts: WorkflowCountOptions,
1295    ) -> Result<WorkflowExecutionCount, ClientError> {
1296        let resp = WorkflowService::count_workflow_executions(
1297            &mut self.clone(),
1298            CountWorkflowExecutionsRequest {
1299                namespace: self.namespace(),
1300                query: query.into(),
1301            }
1302            .into_request(),
1303        )
1304        .await?
1305        .into_inner();
1306
1307        Ok(WorkflowExecutionCount::from_response(resp))
1308    }
1309
1310    fn get_async_activity_handle(&self, identifier: ActivityIdentifier) -> AsyncActivityHandle<Self>
1311    where
1312        Self: Sized,
1313    {
1314        AsyncActivityHandle::new(self.clone(), identifier)
1315    }
1316}
1317
1318macro_rules! dbg_panic {
1319  ($($arg:tt)*) => {
1320      use tracing::error;
1321      error!($($arg)*);
1322      debug_assert!(false, $($arg)*);
1323  };
1324}
1325pub(crate) use dbg_panic;
1326
1327#[cfg(test)]
1328mod tests {
1329    use super::*;
1330    use tonic::metadata::Ascii;
1331    use url::Url;
1332
1333    #[test]
1334    fn applies_headers() {
1335        // Initial header set
1336        let headers = Arc::new(RwLock::new(ClientHeaders {
1337            user_headers: HashMap::new(),
1338            user_binary_headers: HashMap::new(),
1339            api_key: Some("my-api-key".to_owned()),
1340        }));
1341        headers.clone().write().user_headers.insert(
1342            "my-meta-key".parse().unwrap(),
1343            "my-meta-val".parse().unwrap(),
1344        );
1345        headers.clone().write().user_binary_headers.insert(
1346            "my-bin-meta-key-bin".parse().unwrap(),
1347            vec![1, 2, 3].try_into().unwrap(),
1348        );
1349        let mut interceptor = ServiceCallInterceptor {
1350            client_name: "cute-kitty".to_string(),
1351            client_version: "0.1.0".to_string(),
1352            headers: headers.clone(),
1353        };
1354
1355        // Confirm on metadata
1356        let req = interceptor.call(tonic::Request::new(())).unwrap();
1357        assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1358        assert_eq!(
1359            req.metadata().get("authorization").unwrap(),
1360            "Bearer my-api-key"
1361        );
1362        assert_eq!(
1363            req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1364            vec![1, 2, 3].as_slice()
1365        );
1366
1367        // Overwrite at request time
1368        let mut req = tonic::Request::new(());
1369        req.metadata_mut()
1370            .insert("my-meta-key", "my-meta-val2".parse().unwrap());
1371        req.metadata_mut()
1372            .insert("authorization", "my-api-key2".parse().unwrap());
1373        req.metadata_mut()
1374            .insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
1375        let req = interceptor.call(req).unwrap();
1376        assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1377        assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1378        assert_eq!(
1379            req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1380            vec![4, 5, 6].as_slice()
1381        );
1382
1383        // Overwrite auth on header
1384        headers.clone().write().user_headers.insert(
1385            "authorization".parse().unwrap(),
1386            "my-api-key3".parse().unwrap(),
1387        );
1388        let req = interceptor.call(tonic::Request::new(())).unwrap();
1389        assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1390        assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1391
1392        // Remove headers and auth and confirm gone
1393        headers.clone().write().user_headers.clear();
1394        headers.clone().write().user_binary_headers.clear();
1395        headers.clone().write().api_key.take();
1396        let req = interceptor.call(tonic::Request::new(())).unwrap();
1397        assert!(!req.metadata().contains_key("my-meta-key"));
1398        assert!(!req.metadata().contains_key("authorization"));
1399        assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));
1400
1401        // Timeout header not overriden
1402        let mut req = tonic::Request::new(());
1403        req.metadata_mut()
1404            .insert("grpc-timeout", "1S".parse().unwrap());
1405        let req = interceptor.call(req).unwrap();
1406        assert_eq!(
1407            req.metadata().get("grpc-timeout").unwrap(),
1408            "1S".parse::<MetadataValue<Ascii>>().unwrap()
1409        );
1410    }
1411
1412    #[test]
1413    fn invalid_ascii_header_key() {
1414        let invalid_headers = {
1415            let mut h = HashMap::new();
1416            h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
1417            h
1418        };
1419
1420        let result = parse_ascii_headers(invalid_headers);
1421        assert!(result.is_err());
1422        assert_eq!(
1423            result.err().unwrap().to_string(),
1424            "Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
1425        );
1426    }
1427
1428    #[test]
1429    fn invalid_ascii_header_value() {
1430        let invalid_headers = {
1431            let mut h = HashMap::new();
1432            // Nul bytes are valid UTF-8, but not valid ascii gRPC headers:
1433            h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
1434            h
1435        };
1436
1437        let result = parse_ascii_headers(invalid_headers);
1438        assert!(result.is_err());
1439        assert_eq!(
1440            result.err().unwrap().to_string(),
1441            "Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
1442        );
1443    }
1444
1445    #[test]
1446    fn invalid_binary_header_key() {
1447        let invalid_headers = {
1448            let mut h = HashMap::new();
1449            h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
1450            h
1451        };
1452
1453        let result = parse_binary_headers(invalid_headers);
1454        assert!(result.is_err());
1455        assert_eq!(
1456            result.err().unwrap().to_string(),
1457            "Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
1458        );
1459    }
1460
1461    #[test]
1462    fn keep_alive_defaults() {
1463        let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1464            .identity("enchicat".to_string())
1465            .client_name("cute-kitty".to_string())
1466            .client_version("0.1.0".to_string())
1467            .build();
1468        assert_eq!(
1469            opts.keep_alive.clone().unwrap().interval,
1470            ClientKeepAliveOptions::default().interval
1471        );
1472        assert_eq!(
1473            opts.keep_alive.clone().unwrap().timeout,
1474            ClientKeepAliveOptions::default().timeout
1475        );
1476
1477        // Can be explicitly set to None
1478        let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1479            .identity("enchicat".to_string())
1480            .client_name("cute-kitty".to_string())
1481            .client_version("0.1.0".to_string())
1482            .keep_alive(None)
1483            .build();
1484        dbg!(&opts.keep_alive);
1485        assert!(opts.keep_alive.is_none());
1486    }
1487
1488    mod list_workflows_tests {
1489        use super::*;
1490        use futures_util::{FutureExt, StreamExt};
1491        use std::sync::atomic::{AtomicUsize, Ordering};
1492        use temporalio_common::protos::temporal::api::common::v1::WorkflowExecution as ProtoWorkflowExecution;
1493        use tonic::{Request, Response};
1494
1495        #[derive(Clone)]
1496        struct MockListWorkflowsClient {
1497            call_count: Arc<AtomicUsize>,
1498            // Returns this many workflows per page
1499            page_size: usize,
1500            // Total workflows available
1501            total_workflows: usize,
1502        }
1503
1504        impl NamespacedClient for MockListWorkflowsClient {
1505            fn namespace(&self) -> String {
1506                "test-namespace".to_string()
1507            }
1508            fn identity(&self) -> String {
1509                "test-identity".to_string()
1510            }
1511        }
1512
1513        impl WorkflowService for MockListWorkflowsClient {
1514            fn list_workflow_executions(
1515                &mut self,
1516                request: Request<ListWorkflowExecutionsRequest>,
1517            ) -> futures_util::future::BoxFuture<
1518                '_,
1519                Result<Response<ListWorkflowExecutionsResponse>, tonic::Status>,
1520            > {
1521                self.call_count.fetch_add(1, Ordering::SeqCst);
1522                let req = request.into_inner();
1523
1524                // Determine offset from page token
1525                let offset: usize = if req.next_page_token.is_empty() {
1526                    0
1527                } else {
1528                    String::from_utf8(req.next_page_token)
1529                        .unwrap()
1530                        .parse()
1531                        .unwrap()
1532                };
1533
1534                let remaining = self.total_workflows.saturating_sub(offset);
1535                let count = remaining.min(self.page_size);
1536                let new_offset = offset + count;
1537
1538                let executions: Vec<_> = (offset..offset + count)
1539                    .map(|i| workflow::WorkflowExecutionInfo {
1540                        execution: Some(ProtoWorkflowExecution {
1541                            workflow_id: format!("wf-{i}"),
1542                            run_id: format!("run-{i}"),
1543                        }),
1544                        r#type: Some(WorkflowType {
1545                            name: "TestWorkflow".to_string(),
1546                        }),
1547                        task_queue: "test-queue".to_string(),
1548                        ..Default::default()
1549                    })
1550                    .collect();
1551
1552                let next_page_token = if new_offset < self.total_workflows {
1553                    new_offset.to_string().into_bytes()
1554                } else {
1555                    vec![]
1556                };
1557
1558                async move {
1559                    Ok(Response::new(ListWorkflowExecutionsResponse {
1560                        executions,
1561                        next_page_token,
1562                    }))
1563                }
1564                .boxed()
1565            }
1566        }
1567
1568        #[tokio::test]
1569        async fn list_workflows_paginates_through_all_results() {
1570            let call_count = Arc::new(AtomicUsize::new(0));
1571            let client = MockListWorkflowsClient {
1572                call_count: call_count.clone(),
1573                page_size: 3,
1574                total_workflows: 10,
1575            };
1576
1577            let stream = client.list_workflows("", WorkflowListOptions::default());
1578            let results: Vec<_> = stream.collect().await;
1579
1580            assert_eq!(results.len(), 10);
1581            for (i, result) in results.iter().enumerate() {
1582                let wf = result.as_ref().unwrap();
1583                assert_eq!(wf.id(), format!("wf-{i}"));
1584                assert_eq!(wf.run_id(), format!("run-{i}"));
1585            }
1586            // Should have made 4 calls: pages of 3, 3, 3, 1
1587            assert_eq!(call_count.load(Ordering::SeqCst), 4);
1588        }
1589
1590        #[tokio::test]
1591        async fn list_workflows_respects_limit() {
1592            let call_count = Arc::new(AtomicUsize::new(0));
1593            let client = MockListWorkflowsClient {
1594                call_count: call_count.clone(),
1595                page_size: 3,
1596                total_workflows: 10,
1597            };
1598
1599            let opts = WorkflowListOptions::builder().limit(5).build();
1600            let stream = client.list_workflows("", opts);
1601            let results: Vec<_> = stream.collect().await;
1602
1603            assert_eq!(results.len(), 5);
1604            for (i, result) in results.iter().enumerate() {
1605                let wf = result.as_ref().unwrap();
1606                assert_eq!(wf.id(), format!("wf-{i}"));
1607            }
1608            // Should have made 2 calls: 1 page of 3, then 2 more from next page
1609            assert_eq!(call_count.load(Ordering::SeqCst), 2);
1610        }
1611
1612        #[tokio::test]
1613        async fn list_workflows_limit_less_than_page_size() {
1614            let call_count = Arc::new(AtomicUsize::new(0));
1615            let client = MockListWorkflowsClient {
1616                call_count: call_count.clone(),
1617                page_size: 10,
1618                total_workflows: 100,
1619            };
1620
1621            let opts = WorkflowListOptions::builder().limit(3).build();
1622            let stream = client.list_workflows("", opts);
1623            let results: Vec<_> = stream.collect().await;
1624
1625            assert_eq!(results.len(), 3);
1626            // Only 1 call needed since limit < page_size
1627            assert_eq!(call_count.load(Ordering::SeqCst), 1);
1628        }
1629
1630        #[tokio::test]
1631        async fn list_workflows_empty_results() {
1632            let call_count = Arc::new(AtomicUsize::new(0));
1633            let client = MockListWorkflowsClient {
1634                call_count: call_count.clone(),
1635                page_size: 10,
1636                total_workflows: 0,
1637            };
1638
1639            let stream = client.list_workflows("", WorkflowListOptions::default());
1640            let results: Vec<_> = stream.collect().await;
1641
1642            assert_eq!(results.len(), 0);
1643            assert_eq!(call_count.load(Ordering::SeqCst), 1);
1644        }
1645    }
1646}