Skip to main content

temporalio_client/
lib.rs

1#![warn(missing_docs)] // error if there are missing docs
2
3//! This crate contains client implementations that can be used to contact the Temporal service.
4//!
5//! It implements auto-retry behavior and metrics collection.
6
7#[macro_use]
8extern crate tracing;
9
10mod async_activity_handle;
11pub mod callback_based;
12mod dns;
13/// Configuration loading from environment variables and TOML files.
14#[cfg(feature = "envconfig")]
15pub mod envconfig;
16pub mod errors;
17pub mod grpc;
18mod metrics;
19mod options_structs;
20/// Visible only for tests
21#[doc(hidden)]
22pub mod proxy;
23mod replaceable;
24pub mod request_extensions;
25mod retry;
26/// Schedule operations: create, describe, update, pause, trigger, backfill, list, and delete.
27pub mod schedules;
28pub mod worker;
29mod workflow_handle;
30
31pub use crate::{
32    proxy::HttpConnectProxyOptions,
33    retry::{CallType, RETRYABLE_ERROR_CODES},
34};
35pub use async_activity_handle::{
36    ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle,
37};
38
39pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
40pub use options_structs::*;
41pub use replaceable::SharedReplaceableClient;
42pub use retry::RetryOptions;
43/// Potentially dangerous TLS related functionality.
44pub mod danger {
45    /// Re-export the `ServerCertVerifier` trait so that users can implement custom TLS
46    /// server certificate verification without depending on `tokio-rustls` directly,
47    /// while explicitly acknowledging the danger in the import path.
48    pub use tokio_rustls::rustls::client::danger::ServerCertVerifier;
49}
50pub use tonic;
51pub use workflow_handle::{
52    UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle,
53    WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
54    WorkflowHistory, WorkflowUpdateHandle,
55};
56
57use crate::{
58    grpc::{
59        AttachMetricLabels, CloudService, HealthService, OperatorService, TestService,
60        WorkflowService,
61    },
62    metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
63    request_extensions::RequestExt,
64    worker::ClientWorkerSet,
65};
66use errors::*;
67use futures_util::{stream, stream::Stream};
68use http::Uri;
69use parking_lot::RwLock;
70use std::{
71    collections::{HashMap, VecDeque},
72    fmt::Debug,
73    pin::Pin,
74    str::FromStr,
75    sync::{Arc, OnceLock},
76    task::{Context, Poll},
77    time::{Duration, SystemTime},
78};
79use temporalio_common::{
80    HasWorkflowDefinition,
81    data_converters::{
82        DataConverter, GenericPayloadConverter, PayloadConverter, SerializationContext,
83        SerializationContextData,
84    },
85    protos::{
86        coresdk::IntoPayloadsExt,
87        grpc::health::v1::health_client::HealthClient,
88        proto_ts_to_system_time,
89        temporal::api::{
90            cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
91            common::v1::{Memo, Payload, SearchAttributes, WorkflowType},
92            enums::v1::{TaskQueueKind, WorkflowExecutionStatus},
93            errordetails::v1::WorkflowExecutionAlreadyStartedFailure,
94            operatorservice::v1::operator_service_client::OperatorServiceClient,
95            sdk::v1::UserMetadata,
96            taskqueue::v1::TaskQueue,
97            testservice::v1::test_service_client::TestServiceClient,
98            workflow::v1 as workflow,
99            workflowservice::v1::{
100                count_workflow_executions_response, workflow_service_client::WorkflowServiceClient,
101                *,
102            },
103        },
104        utilities::decode_status_detail,
105    },
106};
107use tonic::{
108    Code, IntoRequest,
109    body::Body,
110    client::GrpcService,
111    codec::CompressionEncoding,
112    codegen::InterceptedService,
113    metadata::{
114        AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
115        MetadataValue,
116    },
117    service::Interceptor,
118    transport::{Certificate, Endpoint, Identity},
119};
120use tower::ServiceBuilder;
121use uuid::Uuid;
122
123static CLIENT_NAME_HEADER_KEY: &str = "client-name";
124static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
125static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
126
127#[doc(hidden)]
128/// Key used to communicate when a GRPC message is too large
129pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
130#[doc(hidden)]
131/// Key used to indicate a error was returned by the retryer because of the short-circuit predicate
132pub static ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT: &str = "short-circuit";
133
134/// The server times out polls after 60 seconds. Set our timeout to be slightly beyond that.
135const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
136const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
137const VERSION: &str = env!("CARGO_PKG_VERSION");
138
139/// A connection to the Temporal service.
140///
141/// Cloning a connection is cheap (single Arc increment). The underlying connection is shared
142/// between clones.
143#[derive(Clone)]
144pub struct Connection {
145    inner: Arc<ConnectionInner>,
146}
147
148#[derive(Clone)]
149struct ConnectionInner {
150    service: TemporalServiceClient,
151    retry_options: RetryOptions,
152    identity: String,
153    headers: Arc<RwLock<ClientHeaders>>,
154    client_name: String,
155    client_version: String,
156    /// Capabilities as read from the `get_system_info` RPC call made on client connection
157    capabilities: Option<get_system_info_response::Capabilities>,
158    workers: Arc<ClientWorkerSet>,
159    _dns_task: Option<Arc<dns::DnsReresolutionHandle>>,
160}
161
162impl Connection {
163    /// Connect to a Temporal service.
164    pub async fn connect(options: ConnectionOptions) -> Result<Self, ClientConnectError> {
165        let dns_lb_opts = dns::validate_and_get_dns_lb(&options)?.cloned();
166        // The callback-based override transport cannot decode compressed request bodies, so
167        // compression is forced off whenever a service override is in use.
168        let compression = if options.service_override.is_some() {
169            GrpcCompression::None
170        } else {
171            options.grpc_compression
172        };
173        let (service, dns_task) = if let Some(service_override) = options.service_override {
174            (
175                GrpcMetricSvc {
176                    inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
177                    metrics: options.metrics_meter.clone().map(MetricsContext::new),
178                    disable_errcode_label: options.disable_error_code_metric_tags,
179                },
180                None,
181            )
182        } else if let Some(dns_opts) = &dns_lb_opts {
183            let (channel, sender) = dns::create_balanced_channel(&options).await?;
184            let handle = dns::spawn_dns_reresolution(
185                sender,
186                options.target.clone(),
187                options.tls_options.clone(),
188                options.keep_alive.clone(),
189                options.override_origin.clone(),
190                dns_opts.resolution_interval,
191            );
192            (
193                ServiceBuilder::new()
194                    .layer_fn(move |channel| GrpcMetricSvc {
195                        inner: ChannelOrGrpcOverride::Channel(channel),
196                        metrics: options.metrics_meter.clone().map(MetricsContext::new),
197                        disable_errcode_label: options.disable_error_code_metric_tags,
198                    })
199                    .service(channel),
200                Some(handle),
201            )
202        } else {
203            let channel = Endpoint::from_shared(options.target.to_string())?;
204            let channel = add_tls_to_channel(options.tls_options.as_ref(), channel).await?;
205            let channel = if let Some(keep_alive) = options.keep_alive.as_ref() {
206                channel
207                    .keep_alive_while_idle(true)
208                    .http2_keep_alive_interval(keep_alive.interval)
209                    .keep_alive_timeout(keep_alive.timeout)
210            } else {
211                channel
212            };
213            let channel = if let Some(origin) = options.override_origin.clone() {
214                channel.origin(origin)
215            } else {
216                channel
217            };
218            // If there is a proxy, we have to connect that way
219            let channel = if let Some(proxy) = options.http_connect_proxy.as_ref() {
220                proxy.connect_endpoint(&channel).await?
221            } else {
222                channel.connect().await?
223            };
224            (
225                ServiceBuilder::new()
226                    .layer_fn(move |channel| GrpcMetricSvc {
227                        inner: ChannelOrGrpcOverride::Channel(channel),
228                        metrics: options.metrics_meter.clone().map(MetricsContext::new),
229                        disable_errcode_label: options.disable_error_code_metric_tags,
230                    })
231                    .service(channel),
232                None,
233            )
234        };
235
236        let headers = Arc::new(RwLock::new(ClientHeaders {
237            user_headers: parse_ascii_headers(options.headers.clone().unwrap_or_default())?,
238            user_binary_headers: parse_binary_headers(
239                options.binary_headers.clone().unwrap_or_default(),
240            )?,
241            api_key: options.api_key.clone(),
242        }));
243        let interceptor = ServiceCallInterceptor {
244            client_name: options.client_name.clone(),
245            client_version: options.client_version.clone(),
246            headers: headers.clone(),
247        };
248        let svc = InterceptedService::new(service, interceptor);
249        let mut svc_client = TemporalServiceClient::new(svc, compression);
250
251        let capabilities = if !options.skip_get_system_info {
252            match svc_client
253                .get_system_info(GetSystemInfoRequest::default().into_request())
254                .await
255            {
256                Ok(sysinfo) => sysinfo.into_inner().capabilities,
257                Err(status) => match status.code() {
258                    Code::Unimplemented => None,
259                    _ => return Err(ClientConnectError::SystemInfoCallError(status)),
260                },
261            }
262        } else {
263            None
264        };
265        Ok(Self {
266            inner: Arc::new(ConnectionInner {
267                service: svc_client,
268                retry_options: options.retry_options,
269                identity: options.identity,
270                headers,
271                client_name: options.client_name,
272                client_version: options.client_version,
273                capabilities,
274                workers: Arc::new(ClientWorkerSet::new()),
275                _dns_task: dns_task,
276            }),
277        })
278    }
279
280    /// Set API key, overwriting any previous one.
281    pub fn set_api_key(&self, api_key: Option<String>) {
282        self.inner.headers.write().api_key = api_key;
283    }
284
285    /// Set HTTP request headers overwriting previous headers.
286    ///
287    /// This will not affect headers set via [ConnectionOptions::binary_headers].
288    ///
289    /// # Errors
290    ///
291    /// Will return an error if any of the provided keys or values are not valid gRPC metadata.
292    /// If an error is returned, the previous headers will remain unchanged.
293    pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
294        self.inner.headers.write().user_headers = parse_ascii_headers(headers)?;
295        Ok(())
296    }
297
298    /// Set binary HTTP request headers overwriting previous headers.
299    ///
300    /// This will not affect headers set via [ConnectionOptions::headers].
301    ///
302    /// # Errors
303    ///
304    /// Will return an error if any of the provided keys are not valid gRPC binary metadata keys.
305    /// If an error is returned, the previous headers will remain unchanged.
306    pub fn set_binary_headers(
307        &self,
308        binary_headers: HashMap<String, Vec<u8>>,
309    ) -> Result<(), InvalidHeaderError> {
310        self.inner.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
311        Ok(())
312    }
313
314    /// Returns the value used for the `client-name` header by this connection.
315    pub fn client_name(&self) -> &str {
316        &self.inner.client_name
317    }
318
319    /// Returns the value used for the `client-version` header by this connection.
320    pub fn client_version(&self) -> &str {
321        &self.inner.client_version
322    }
323
324    /// Returns the server capabilities we (may have) learned about when establishing an initial
325    /// connection
326    pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
327        self.inner.capabilities.as_ref()
328    }
329
330    /// Get a mutable reference to the retry options.
331    ///
332    /// Note: If this connection has been cloned, this will copy-on-write to avoid
333    /// affecting other clones.
334    pub fn retry_options_mut(&mut self) -> &mut RetryOptions {
335        &mut Arc::make_mut(&mut self.inner).retry_options
336    }
337
338    /// Get a reference to the connection identity.
339    pub fn identity(&self) -> &str {
340        &self.inner.identity
341    }
342
343    /// Get a mutable reference to the connection identity.
344    ///
345    /// Note: If this connection has been cloned, this will copy-on-write to avoid
346    /// affecting other clones.
347    pub fn identity_mut(&mut self) -> &mut String {
348        &mut Arc::make_mut(&mut self.inner).identity
349    }
350
351    /// Returns a reference to a registry with workers using this client instance.
352    pub fn workers(&self) -> Arc<ClientWorkerSet> {
353        self.inner.workers.clone()
354    }
355
356    /// Returns the client-wide key.
357    pub fn worker_grouping_key(&self) -> Uuid {
358        self.inner.workers.worker_grouping_key()
359    }
360
361    /// Get the underlying workflow service client for making raw gRPC calls.
362    pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
363        self.inner.service.workflow_service()
364    }
365
366    /// Get the underlying operator service client for making raw gRPC calls.
367    pub fn operator_service(&self) -> Box<dyn OperatorService> {
368        self.inner.service.operator_service()
369    }
370
371    /// Get the underlying cloud service client for making raw gRPC calls.
372    pub fn cloud_service(&self) -> Box<dyn CloudService> {
373        self.inner.service.cloud_service()
374    }
375
376    /// Get the underlying test service client for making raw gRPC calls.
377    pub fn test_service(&self) -> Box<dyn TestService> {
378        self.inner.service.test_service()
379    }
380
381    /// Get the underlying health service client for making raw gRPC calls.
382    pub fn health_service(&self) -> Box<dyn HealthService> {
383        self.inner.service.health_service()
384    }
385}
386
387#[derive(Debug)]
388struct ClientHeaders {
389    user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
390    user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
391    api_key: Option<String>,
392}
393
394impl ClientHeaders {
395    fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
396        for (key, val) in self.user_headers.iter() {
397            // Only if not already present
398            if !metadata.contains_key(key) {
399                metadata.insert(key, val.clone());
400            }
401        }
402        for (key, val) in self.user_binary_headers.iter() {
403            // Only if not already present
404            if !metadata.contains_key(key) {
405                metadata.insert_bin(key, val.clone());
406            }
407        }
408        if let Some(api_key) = &self.api_key {
409            // Only if not already present
410            if !metadata.contains_key("authorization")
411                && let Ok(val) = format!("Bearer {api_key}").parse()
412            {
413                metadata.insert("authorization", val);
414            }
415        }
416    }
417}
418
419/// If TLS is configured, set the appropriate options on the provided channel and return it.
420/// Passes it through if TLS options not set.
421async fn add_tls_to_channel(
422    tls_options: Option<&TlsOptions>,
423    mut channel: Endpoint,
424) -> Result<Endpoint, ClientConnectError> {
425    if let Some(tls_cfg) = tls_options {
426        if tls_cfg.server_cert_verifier.is_some() && tls_cfg.server_root_ca_cert.is_some() {
427            return Err(ClientConnectError::InvalidConfig(
428                "Cannot set both `server_root_ca_cert` and `server_cert_verifier`".to_owned(),
429            ));
430        }
431
432        let mut tls = tonic::transport::ClientTlsConfig::new();
433
434        if tls_cfg.server_cert_verifier.is_none() {
435            if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
436                let server_root_ca_cert = Certificate::from_pem(root_cert);
437                tls = tls.ca_certificate(server_root_ca_cert);
438            } else {
439                tls = tls.with_native_roots();
440            }
441        }
442
443        if let Some(domain) = &tls_cfg.domain {
444            tls = tls.domain_name(domain);
445
446            // This song and dance ultimately is just to make sure the `:authority` header ends
447            // up correct on requests while we use TLS. Setting the header directly in our
448            // interceptor doesn't work since seemingly it is overridden at some point by
449            // something lower level.
450            let uri: Uri = format!("https://{domain}").parse()?;
451            channel = channel.origin(uri);
452        }
453
454        if let Some(client_opts) = &tls_cfg.client_tls_options {
455            let client_identity =
456                Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
457            tls = tls.identity(client_identity);
458        }
459
460        return if let Some(verifier) = &tls_cfg.server_cert_verifier {
461            channel
462                .tls_config_with_verifier(tls, verifier.clone())
463                .map_err(Into::into)
464        } else {
465            channel.tls_config(tls).map_err(Into::into)
466        };
467    }
468    Ok(channel)
469}
470
471fn parse_ascii_headers(
472    headers: HashMap<String, String>,
473) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
474    let mut parsed_headers = HashMap::with_capacity(headers.len());
475    for (k, v) in headers.into_iter() {
476        let key = match AsciiMetadataKey::from_str(&k) {
477            Ok(key) => key,
478            Err(err) => {
479                return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
480                    key: k,
481                    source: err,
482                });
483            }
484        };
485        let value = match MetadataValue::from_str(&v) {
486            Ok(value) => value,
487            Err(err) => {
488                return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
489                    key: k,
490                    value: v,
491                    source: err,
492                });
493            }
494        };
495        parsed_headers.insert(key, value);
496    }
497
498    Ok(parsed_headers)
499}
500
501fn parse_binary_headers(
502    headers: HashMap<String, Vec<u8>>,
503) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
504    let mut parsed_headers = HashMap::with_capacity(headers.len());
505    for (k, v) in headers.into_iter() {
506        let key = match BinaryMetadataKey::from_str(&k) {
507            Ok(key) => key,
508            Err(err) => {
509                return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
510                    key: k,
511                    source: err,
512                });
513            }
514        };
515        let value = BinaryMetadataValue::from_bytes(&v);
516        parsed_headers.insert(key, value);
517    }
518
519    Ok(parsed_headers)
520}
521
522/// Interceptor which attaches common metadata (like "client-name") to every outgoing call
523#[derive(Clone)]
524pub struct ServiceCallInterceptor {
525    client_name: String,
526    client_version: String,
527    /// Only accessed as a reader
528    headers: Arc<RwLock<ClientHeaders>>,
529}
530
531impl Interceptor for ServiceCallInterceptor {
532    /// This function will get called on each outbound request. Returning a `Status` here will
533    /// cancel the request and have that status returned to the caller.
534    fn call(
535        &mut self,
536        mut request: tonic::Request<()>,
537    ) -> Result<tonic::Request<()>, tonic::Status> {
538        let metadata = request.metadata_mut();
539        if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
540            metadata.insert(
541                CLIENT_NAME_HEADER_KEY,
542                self.client_name
543                    .parse()
544                    .unwrap_or_else(|_| MetadataValue::from_static("")),
545            );
546        }
547        if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
548            metadata.insert(
549                CLIENT_VERSION_HEADER_KEY,
550                self.client_version
551                    .parse()
552                    .unwrap_or_else(|_| MetadataValue::from_static("")),
553            );
554        }
555        self.headers.read().apply_to_metadata(metadata);
556        request.set_default_timeout(OTHER_CALL_TIMEOUT);
557
558        Ok(request)
559    }
560}
561
562/// Aggregates various services exposed by the Temporal server
563#[derive(Clone)]
564pub struct TemporalServiceClient {
565    workflow_svc_client: Box<dyn WorkflowService>,
566    operator_svc_client: Box<dyn OperatorService>,
567    cloud_svc_client: Box<dyn CloudService>,
568    test_svc_client: Box<dyn TestService>,
569    health_svc_client: Box<dyn HealthService>,
570}
571
572/// We up the limit on incoming messages from server from the 4Mb default to 128Mb. If for
573/// whatever reason this needs to be changed by the user, we support overriding it via env var.
574fn get_decode_max_size() -> usize {
575    static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
576    *_DECODE_MAX_SIZE.get_or_init(|| {
577        std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
578            .ok()
579            .and_then(|s| s.parse().ok())
580            .unwrap_or(128 * 1024 * 1024)
581    })
582}
583
584impl TemporalServiceClient {
585    fn new<T>(svc: T, compression: GrpcCompression) -> Self
586    where
587        T: GrpcService<Body> + Send + Sync + Clone + 'static,
588        T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
589        T::Error: Into<tonic::codegen::StdError>,
590        <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
591        <T as GrpcService<Body>>::Future: Send,
592    {
593        // The generated service clients don't share a trait exposing the compression setters, so
594        // a macro applies the same configuration to each concrete client type.
595        macro_rules! configure {
596            ($client:expr) => {{
597                let client = $client.max_decoding_message_size(get_decode_max_size());
598                match compression {
599                    GrpcCompression::Gzip => client
600                        .send_compressed(CompressionEncoding::Gzip)
601                        .accept_compressed(CompressionEncoding::Gzip),
602                    GrpcCompression::None => client,
603                }
604            }};
605        }
606
607        let workflow_svc_client = Box::new(configure!(WorkflowServiceClient::new(svc.clone())));
608        let operator_svc_client = Box::new(configure!(OperatorServiceClient::new(svc.clone())));
609        let cloud_svc_client = Box::new(configure!(CloudServiceClient::new(svc.clone())));
610        let test_svc_client = Box::new(configure!(TestServiceClient::new(svc.clone())));
611        let health_svc_client = Box::new(configure!(HealthClient::new(svc.clone())));
612
613        Self {
614            workflow_svc_client,
615            operator_svc_client,
616            cloud_svc_client,
617            test_svc_client,
618            health_svc_client,
619        }
620    }
621
622    /// Create a service client from implementations of the individual underlying services. Useful
623    /// for mocking out service implementations.
624    pub fn from_services(
625        workflow: Box<dyn WorkflowService>,
626        operator: Box<dyn OperatorService>,
627        cloud: Box<dyn CloudService>,
628        test: Box<dyn TestService>,
629        health: Box<dyn HealthService>,
630    ) -> Self {
631        Self {
632            workflow_svc_client: workflow,
633            operator_svc_client: operator,
634            cloud_svc_client: cloud,
635            test_svc_client: test,
636            health_svc_client: health,
637        }
638    }
639
640    /// Get the underlying workflow service client
641    pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
642        self.workflow_svc_client.clone()
643    }
644    /// Get the underlying operator service client
645    pub fn operator_service(&self) -> Box<dyn OperatorService> {
646        self.operator_svc_client.clone()
647    }
648    /// Get the underlying cloud service client
649    pub fn cloud_service(&self) -> Box<dyn CloudService> {
650        self.cloud_svc_client.clone()
651    }
652    /// Get the underlying test service client
653    pub fn test_service(&self) -> Box<dyn TestService> {
654        self.test_svc_client.clone()
655    }
656    /// Get the underlying health service client
657    pub fn health_service(&self) -> Box<dyn HealthService> {
658        self.health_svc_client.clone()
659    }
660}
661
662/// Contains an instance of a namespace-bound client for interacting with the Temporal server.
663/// Cheap to clone.
664#[derive(Clone)]
665pub struct Client {
666    connection: Connection,
667    options: Arc<ClientOptions>,
668}
669
670impl Client {
671    /// Create a new client from a connection and options.
672    ///
673    /// Currently infallible, but returns a `Result` for future extensibility
674    /// (e.g., interceptor or plugin validation).
675    pub fn new(connection: Connection, options: ClientOptions) -> Result<Self, ClientNewError> {
676        Ok(Client {
677            connection,
678            options: Arc::new(options),
679        })
680    }
681
682    /// Return the options this client was initialized with
683    pub fn options(&self) -> &ClientOptions {
684        &self.options
685    }
686
687    /// Return this client's options mutably.
688    ///
689    /// Note: If this client has been cloned, this will copy-on-write to avoid affecting other
690    /// clones.
691    pub fn options_mut(&mut self) -> &mut ClientOptions {
692        Arc::make_mut(&mut self.options)
693    }
694
695    /// Returns a reference to the underlying connection
696    pub fn connection(&self) -> &Connection {
697        &self.connection
698    }
699
700    /// Returns a mutable reference to the underlying connection
701    pub fn connection_mut(&mut self) -> &mut Connection {
702        &mut self.connection
703    }
704}
705
706// High-level workflow operations on Client.
707// These forward to the internal WorkflowClientTrait blanket impl which is
708// available because Client implements WorkflowService + NamespacedClient + Clone.
709impl Client {
710    /// Start a workflow execution.
711    ///
712    /// Returns a [`WorkflowHandle`] that can be used to interact with the workflow
713    /// (e.g., get its result, send signals, query, etc.).
714    pub async fn start_workflow<W>(
715        &self,
716        workflow: W,
717        input: W::Input,
718        options: WorkflowStartOptions,
719    ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
720    where
721        W: HasWorkflowDefinition,
722        W::Input: Send,
723    {
724        WorkflowClientTrait::start_workflow(self, workflow, input, options).await
725    }
726
727    /// Get a handle to an existing workflow.
728    ///
729    /// For untyped access, use `get_workflow_handle::<UntypedWorkflow>(...)`.
730    pub fn get_workflow_handle<W: HasWorkflowDefinition>(
731        &self,
732        workflow_id: impl Into<String>,
733    ) -> WorkflowHandle<Self, W> {
734        WorkflowClientTrait::get_workflow_handle(self, workflow_id)
735    }
736
737    /// List workflows matching a query.
738    ///
739    /// Returns a stream that lazily paginates through results.
740    /// Use `limit` in options to cap the number of results returned.
741    pub fn list_workflows(
742        &self,
743        query: impl Into<String>,
744        opts: WorkflowListOptions,
745    ) -> ListWorkflowsStream {
746        WorkflowClientTrait::list_workflows(self, query, opts)
747    }
748
749    /// Count workflows matching a query.
750    pub async fn count_workflows(
751        &self,
752        query: impl Into<String>,
753        opts: WorkflowCountOptions,
754    ) -> Result<WorkflowExecutionCount, ClientError> {
755        WorkflowClientTrait::count_workflows(self, query, opts).await
756    }
757
758    /// Get a handle to complete an activity asynchronously.
759    ///
760    /// An activity returning `ActivityError::WillCompleteAsync` can be completed with this handle.
761    pub fn get_async_activity_handle(
762        &self,
763        identifier: ActivityIdentifier,
764    ) -> AsyncActivityHandle<Self> {
765        WorkflowClientTrait::get_async_activity_handle(self, identifier)
766    }
767}
768
769impl NamespacedClient for Client {
770    fn namespace(&self) -> String {
771        self.options.namespace.clone()
772    }
773
774    fn identity(&self) -> String {
775        self.connection.identity().to_owned()
776    }
777
778    fn data_converter(&self) -> &DataConverter {
779        &self.options.data_converter
780    }
781}
782
783/// Enum to help reference a namespace by either the namespace name or the namespace id
784#[derive(Clone)]
785pub enum Namespace {
786    /// Namespace name
787    Name(String),
788    /// Namespace id
789    Id(String),
790}
791
792impl Namespace {
793    /// Convert into grpc request
794    pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
795        let (namespace, id) = match self {
796            Namespace::Name(n) => (n, "".to_owned()),
797            Namespace::Id(n) => ("".to_owned(), n),
798        };
799        DescribeNamespaceRequest {
800            namespace,
801            id,
802            weak_consistency: false,
803        }
804    }
805}
806
807/// This trait provides higher-level friendlier interaction with the server.
808/// See the [WorkflowService] trait for a lower-level client.
809pub(crate) trait WorkflowClientTrait: NamespacedClient {
810    /// Start a workflow execution.
811    fn start_workflow<W>(
812        &self,
813        workflow: W,
814        input: W::Input,
815        options: WorkflowStartOptions,
816    ) -> impl Future<Output = Result<WorkflowHandle<Self, W>, WorkflowStartError>>
817    where
818        Self: Sized,
819        W: HasWorkflowDefinition,
820        W::Input: Send;
821
822    /// Get a handle to an existing workflow. `run_id` may be left blank to specify the most recent
823    /// execution having the provided `workflow_id`.
824    ///
825    /// For untyped access, use `get_workflow_handle::<UntypedWorkflow>(...)`.
826    ///
827    /// See also [WorkflowHandle::new], for specifying namespace or first_execution_run_id.
828    fn get_workflow_handle<W: HasWorkflowDefinition>(
829        &self,
830        workflow_id: impl Into<String>,
831    ) -> WorkflowHandle<Self, W>
832    where
833        Self: Sized;
834
835    /// List workflows matching a query.
836    /// Returns a stream that lazily paginates through results.
837    /// Use `limit` in options to cap the number of results returned.
838    fn list_workflows(
839        &self,
840        query: impl Into<String>,
841        opts: WorkflowListOptions,
842    ) -> ListWorkflowsStream;
843
844    /// Count workflows matching a query.
845    fn count_workflows(
846        &self,
847        query: impl Into<String>,
848        opts: WorkflowCountOptions,
849    ) -> impl Future<Output = Result<WorkflowExecutionCount, ClientError>>;
850
851    /// Get a handle to complete an activity asynchronously.
852    ///
853    /// An activity returning `ActivityError::WillCompleteAsync` can be completed with this handle.
854    fn get_async_activity_handle(
855        &self,
856        identifier: ActivityIdentifier,
857    ) -> AsyncActivityHandle<Self>
858    where
859        Self: Sized;
860}
861
862/// A client that is bound to a namespace
863pub trait NamespacedClient {
864    /// Returns the namespace this client is bound to
865    fn namespace(&self) -> String;
866    /// Returns the client identity
867    fn identity(&self) -> String;
868    /// Returns the data converter for serializing/deserializing payloads.
869    /// Default implementation returns a static default converter.
870    fn data_converter(&self) -> &DataConverter {
871        static DEFAULT: OnceLock<DataConverter> = OnceLock::new();
872        DEFAULT.get_or_init(DataConverter::default)
873    }
874}
875
876/// A workflow execution returned from list operations.
877/// This represents information about a workflow present in visibility.
878#[derive(Debug, Clone)]
879pub struct WorkflowExecution {
880    raw: workflow::WorkflowExecutionInfo,
881}
882
883impl WorkflowExecution {
884    /// Create a new WorkflowExecution from the raw proto.
885    pub fn new(raw: workflow::WorkflowExecutionInfo) -> Self {
886        Self { raw }
887    }
888
889    /// The workflow ID.
890    pub fn id(&self) -> &str {
891        self.raw
892            .execution
893            .as_ref()
894            .map(|e| e.workflow_id.as_str())
895            .unwrap_or("")
896    }
897
898    /// The run ID.
899    pub fn run_id(&self) -> &str {
900        self.raw
901            .execution
902            .as_ref()
903            .map(|e| e.run_id.as_str())
904            .unwrap_or("")
905    }
906
907    /// The workflow type name.
908    pub fn workflow_type(&self) -> &str {
909        self.raw
910            .r#type
911            .as_ref()
912            .map(|t| t.name.as_str())
913            .unwrap_or("")
914    }
915
916    /// The current status of the workflow execution.
917    pub fn status(&self) -> WorkflowExecutionStatus {
918        self.raw.status()
919    }
920
921    /// When the workflow was created.
922    pub fn start_time(&self) -> Option<SystemTime> {
923        self.raw
924            .start_time
925            .as_ref()
926            .and_then(proto_ts_to_system_time)
927    }
928
929    /// When the workflow run started or should start.
930    pub fn execution_time(&self) -> Option<SystemTime> {
931        self.raw
932            .execution_time
933            .as_ref()
934            .and_then(proto_ts_to_system_time)
935    }
936
937    /// When the workflow was closed, if closed.
938    pub fn close_time(&self) -> Option<SystemTime> {
939        self.raw
940            .close_time
941            .as_ref()
942            .and_then(proto_ts_to_system_time)
943    }
944
945    /// The task queue the workflow runs on.
946    pub fn task_queue(&self) -> &str {
947        &self.raw.task_queue
948    }
949
950    /// Number of events in history.
951    pub fn history_length(&self) -> i64 {
952        self.raw.history_length
953    }
954
955    /// Workflow memo.
956    pub fn memo(&self) -> Option<&Memo> {
957        self.raw.memo.as_ref()
958    }
959
960    /// Parent workflow ID, if this is a child workflow.
961    pub fn parent_id(&self) -> Option<&str> {
962        self.raw
963            .parent_execution
964            .as_ref()
965            .map(|e| e.workflow_id.as_str())
966    }
967
968    /// Parent run ID, if this is a child workflow.
969    pub fn parent_run_id(&self) -> Option<&str> {
970        self.raw
971            .parent_execution
972            .as_ref()
973            .map(|e| e.run_id.as_str())
974    }
975
976    /// Search attributes on the workflow.
977    pub fn search_attributes(&self) -> Option<&SearchAttributes> {
978        self.raw.search_attributes.as_ref()
979    }
980
981    /// Access the raw proto for additional fields not exposed via accessors.
982    pub fn raw(&self) -> &workflow::WorkflowExecutionInfo {
983        &self.raw
984    }
985
986    /// Consume the wrapper and return the raw proto.
987    pub fn into_raw(self) -> workflow::WorkflowExecutionInfo {
988        self.raw
989    }
990}
991
992impl From<workflow::WorkflowExecutionInfo> for WorkflowExecution {
993    fn from(raw: workflow::WorkflowExecutionInfo) -> Self {
994        Self::new(raw)
995    }
996}
997
998/// A stream of workflow executions from a list query.
999/// Internally paginates through results from the server.
1000pub struct ListWorkflowsStream {
1001    inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
1002}
1003
1004impl ListWorkflowsStream {
1005    fn new(
1006        inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
1007    ) -> Self {
1008        Self { inner }
1009    }
1010}
1011
1012impl Stream for ListWorkflowsStream {
1013    type Item = Result<WorkflowExecution, ClientError>;
1014
1015    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1016        self.inner.as_mut().poll_next(cx)
1017    }
1018}
1019
1020/// Result of a workflow count operation.
1021///
1022/// If the query includes a group-by clause, `groups` will contain the aggregated
1023/// counts and `count` will be the sum of all group counts.
1024#[derive(Debug, Clone)]
1025pub struct WorkflowExecutionCount {
1026    count: usize,
1027    groups: Vec<WorkflowCountAggregationGroup>,
1028}
1029
1030impl WorkflowExecutionCount {
1031    pub(crate) fn from_response(resp: CountWorkflowExecutionsResponse) -> Self {
1032        Self {
1033            count: resp.count as usize,
1034            groups: resp
1035                .groups
1036                .into_iter()
1037                .map(WorkflowCountAggregationGroup::from_proto)
1038                .collect(),
1039        }
1040    }
1041
1042    /// The approximate number of workflows matching the query.
1043    /// If grouping was applied, this is the sum of all group counts.
1044    pub fn count(&self) -> usize {
1045        self.count
1046    }
1047
1048    /// The groups if the query had a group-by clause, or empty if not.
1049    pub fn groups(&self) -> &[WorkflowCountAggregationGroup] {
1050        &self.groups
1051    }
1052}
1053
1054/// Aggregation group from a workflow count query with a group-by clause.
1055#[derive(Debug, Clone)]
1056pub struct WorkflowCountAggregationGroup {
1057    group_values: Vec<Payload>,
1058    count: usize,
1059}
1060
1061impl WorkflowCountAggregationGroup {
1062    fn from_proto(proto: count_workflow_executions_response::AggregationGroup) -> Self {
1063        Self {
1064            group_values: proto.group_values,
1065            count: proto.count as usize,
1066        }
1067    }
1068
1069    /// The search attribute values for this group.
1070    pub fn group_values(&self) -> &[Payload] {
1071        &self.group_values
1072    }
1073
1074    /// The approximate number of workflows matching for this group.
1075    pub fn count(&self) -> usize {
1076        self.count
1077    }
1078}
1079
1080impl<T> WorkflowClientTrait for T
1081where
1082    T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static,
1083{
1084    async fn start_workflow<W>(
1085        &self,
1086        workflow: W,
1087        input: W::Input,
1088        options: WorkflowStartOptions,
1089    ) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
1090    where
1091        W: HasWorkflowDefinition,
1092        W::Input: Send,
1093    {
1094        let payloads = self
1095            .data_converter()
1096            .to_payloads(&SerializationContextData::Workflow, &input)
1097            .await?;
1098        let namespace = self.namespace();
1099        let workflow_id = options.workflow_id.clone();
1100        let task_queue_name = options.task_queue.clone();
1101
1102        let user_metadata = if options.static_summary.is_some() || options.static_details.is_some()
1103        {
1104            let payload_converter = PayloadConverter::default();
1105            let context = SerializationContext {
1106                data: &SerializationContextData::Workflow,
1107                converter: &payload_converter,
1108            };
1109            Some(UserMetadata {
1110                summary: options.static_summary.map(|s| {
1111                    payload_converter
1112                        .to_payload(&context, &s)
1113                        .expect("String-to-JSON payload serialization is infallible")
1114                }),
1115                details: options.static_details.map(|s| {
1116                    payload_converter
1117                        .to_payload(&context, &s)
1118                        .expect("String-to-JSON payload serialization is infallible")
1119                }),
1120            })
1121        } else {
1122            None
1123        };
1124
1125        let run_id = if let Some(start_signal) = options.start_signal {
1126            // Use signal-with-start when a start_signal is provided
1127            let res = WorkflowService::signal_with_start_workflow_execution(
1128                &mut self.clone(),
1129                SignalWithStartWorkflowExecutionRequest {
1130                    namespace: namespace.clone(),
1131                    workflow_id: workflow_id.clone(),
1132                    workflow_type: Some(WorkflowType {
1133                        name: workflow.name().to_string(),
1134                    }),
1135                    task_queue: Some(TaskQueue {
1136                        name: task_queue_name,
1137                        kind: TaskQueueKind::Normal as i32,
1138                        normal_name: "".to_string(),
1139                    }),
1140                    input: payloads.into_payloads(),
1141                    signal_name: start_signal.signal_name,
1142                    signal_input: start_signal.input,
1143                    identity: self.identity(),
1144                    request_id: Uuid::new_v4().to_string(),
1145                    workflow_id_reuse_policy: options.id_reuse_policy as i32,
1146                    workflow_id_conflict_policy: options.id_conflict_policy as i32,
1147                    workflow_execution_timeout: options
1148                        .execution_timeout
1149                        .and_then(|d| d.try_into().ok()),
1150                    workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1151                    workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1152                    search_attributes: options.search_attributes.map(|d| d.into()),
1153                    cron_schedule: options.cron_schedule.unwrap_or_default(),
1154                    header: options.header.or(start_signal.header),
1155                    user_metadata,
1156                    ..Default::default()
1157                }
1158                .into_request(),
1159            )
1160            .await?
1161            .into_inner();
1162            res.run_id
1163        } else {
1164            // Normal start workflow
1165            let res = self
1166                .clone()
1167                .start_workflow_execution(
1168                    StartWorkflowExecutionRequest {
1169                        namespace: namespace.clone(),
1170                        input: payloads.into_payloads(),
1171                        workflow_id: workflow_id.clone(),
1172                        workflow_type: Some(WorkflowType {
1173                            name: workflow.name().to_string(),
1174                        }),
1175                        task_queue: Some(TaskQueue {
1176                            name: task_queue_name,
1177                            kind: TaskQueueKind::Unspecified as i32,
1178                            normal_name: "".to_string(),
1179                        }),
1180                        request_id: Uuid::new_v4().to_string(),
1181                        workflow_id_reuse_policy: options.id_reuse_policy as i32,
1182                        workflow_id_conflict_policy: options.id_conflict_policy as i32,
1183                        workflow_execution_timeout: options
1184                            .execution_timeout
1185                            .and_then(|d| d.try_into().ok()),
1186                        workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
1187                        workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
1188                        search_attributes: options.search_attributes.map(|d| d.into()),
1189                        cron_schedule: options.cron_schedule.unwrap_or_default(),
1190                        request_eager_execution: options.enable_eager_workflow_start,
1191                        retry_policy: options.retry_policy,
1192                        links: options.links,
1193                        completion_callbacks: options.completion_callbacks,
1194                        priority: Some(options.priority.into()),
1195                        header: options.header,
1196                        user_metadata,
1197                        ..Default::default()
1198                    }
1199                    .into_request(),
1200                )
1201                .await
1202                .map_err(|status| {
1203                    if status.code() == Code::AlreadyExists {
1204                        let run_id =
1205                            decode_status_detail::<WorkflowExecutionAlreadyStartedFailure>(
1206                                status.details(),
1207                            )
1208                            .map(|f| f.run_id);
1209                        WorkflowStartError::AlreadyStarted {
1210                            run_id,
1211                            source: status,
1212                        }
1213                    } else {
1214                        WorkflowStartError::Rpc(status)
1215                    }
1216                })?
1217                .into_inner();
1218            res.run_id
1219        };
1220
1221        Ok(WorkflowHandle::new(
1222            self.clone(),
1223            WorkflowExecutionInfo {
1224                namespace,
1225                workflow_id,
1226                run_id: Some(run_id.clone()),
1227                first_execution_run_id: Some(run_id),
1228            },
1229        ))
1230    }
1231
1232    fn get_workflow_handle<W: HasWorkflowDefinition>(
1233        &self,
1234        workflow_id: impl Into<String>,
1235    ) -> WorkflowHandle<Self, W>
1236    where
1237        Self: Sized,
1238    {
1239        WorkflowHandle::new(
1240            self.clone(),
1241            WorkflowExecutionInfo {
1242                namespace: self.namespace(),
1243                workflow_id: workflow_id.into(),
1244                run_id: None,
1245                first_execution_run_id: None,
1246            },
1247        )
1248    }
1249
1250    fn list_workflows(
1251        &self,
1252        query: impl Into<String>,
1253        opts: WorkflowListOptions,
1254    ) -> ListWorkflowsStream {
1255        let client = self.clone();
1256        let namespace = self.namespace();
1257        let query = query.into();
1258        let limit = opts.limit;
1259
1260        // State: (next_page_token, buffer, yielded_count, exhausted)
1261        let initial_state = (Vec::new(), VecDeque::new(), 0, false);
1262
1263        let stream = stream::unfold(
1264            initial_state,
1265            move |(next_page_token, mut buffer, mut yielded, exhausted)| {
1266                let mut client = client.clone();
1267                let namespace = namespace.clone();
1268                let query = query.clone();
1269
1270                async move {
1271                    if let Some(l) = limit
1272                        && yielded >= l
1273                    {
1274                        return None;
1275                    }
1276
1277                    if let Some(exec) = buffer.pop_front() {
1278                        yielded += 1;
1279                        return Some((Ok(exec), (next_page_token, buffer, yielded, exhausted)));
1280                    }
1281
1282                    if exhausted {
1283                        return None;
1284                    }
1285
1286                    let response = WorkflowService::list_workflow_executions(
1287                        &mut client,
1288                        ListWorkflowExecutionsRequest {
1289                            namespace,
1290                            page_size: 0, // Use server default
1291                            next_page_token: next_page_token.clone(),
1292                            query,
1293                        }
1294                        .into_request(),
1295                    )
1296                    .await;
1297
1298                    match response {
1299                        Ok(resp) => {
1300                            let resp = resp.into_inner();
1301                            let new_exhausted = resp.next_page_token.is_empty();
1302                            let new_token = resp.next_page_token;
1303
1304                            buffer = resp
1305                                .executions
1306                                .into_iter()
1307                                .map(WorkflowExecution::from)
1308                                .collect();
1309
1310                            if let Some(exec) = buffer.pop_front() {
1311                                yielded += 1;
1312                                Some((Ok(exec), (new_token, buffer, yielded, new_exhausted)))
1313                            } else {
1314                                None
1315                            }
1316                        }
1317                        Err(e) => Some((Err(e.into()), (next_page_token, buffer, yielded, true))),
1318                    }
1319                }
1320            },
1321        );
1322
1323        ListWorkflowsStream::new(Box::pin(stream))
1324    }
1325
1326    async fn count_workflows(
1327        &self,
1328        query: impl Into<String>,
1329        _opts: WorkflowCountOptions,
1330    ) -> Result<WorkflowExecutionCount, ClientError> {
1331        let resp = WorkflowService::count_workflow_executions(
1332            &mut self.clone(),
1333            CountWorkflowExecutionsRequest {
1334                namespace: self.namespace(),
1335                query: query.into(),
1336            }
1337            .into_request(),
1338        )
1339        .await?
1340        .into_inner();
1341
1342        Ok(WorkflowExecutionCount::from_response(resp))
1343    }
1344
1345    fn get_async_activity_handle(&self, identifier: ActivityIdentifier) -> AsyncActivityHandle<Self>
1346    where
1347        Self: Sized,
1348    {
1349        AsyncActivityHandle::new(self.clone(), identifier)
1350    }
1351}
1352
1353macro_rules! dbg_panic {
1354  ($($arg:tt)*) => {
1355      use tracing::error;
1356      error!($($arg)*);
1357      debug_assert!(false, $($arg)*);
1358  };
1359}
1360pub(crate) use dbg_panic;
1361
1362#[cfg(test)]
1363mod tests {
1364    use super::*;
1365    use tonic::metadata::Ascii;
1366    use url::Url;
1367
1368    #[test]
1369    fn applies_headers() {
1370        // Initial header set
1371        let headers = Arc::new(RwLock::new(ClientHeaders {
1372            user_headers: HashMap::new(),
1373            user_binary_headers: HashMap::new(),
1374            api_key: Some("my-api-key".to_owned()),
1375        }));
1376        headers.clone().write().user_headers.insert(
1377            "my-meta-key".parse().unwrap(),
1378            "my-meta-val".parse().unwrap(),
1379        );
1380        headers.clone().write().user_binary_headers.insert(
1381            "my-bin-meta-key-bin".parse().unwrap(),
1382            vec![1, 2, 3].try_into().unwrap(),
1383        );
1384        let mut interceptor = ServiceCallInterceptor {
1385            client_name: "cute-kitty".to_string(),
1386            client_version: "0.1.0".to_string(),
1387            headers: headers.clone(),
1388        };
1389
1390        // Confirm on metadata
1391        let req = interceptor.call(tonic::Request::new(())).unwrap();
1392        assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1393        assert_eq!(
1394            req.metadata().get("authorization").unwrap(),
1395            "Bearer my-api-key"
1396        );
1397        assert_eq!(
1398            req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1399            vec![1, 2, 3].as_slice()
1400        );
1401
1402        // Overwrite at request time
1403        let mut req = tonic::Request::new(());
1404        req.metadata_mut()
1405            .insert("my-meta-key", "my-meta-val2".parse().unwrap());
1406        req.metadata_mut()
1407            .insert("authorization", "my-api-key2".parse().unwrap());
1408        req.metadata_mut()
1409            .insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
1410        let req = interceptor.call(req).unwrap();
1411        assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1412        assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1413        assert_eq!(
1414            req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
1415            vec![4, 5, 6].as_slice()
1416        );
1417
1418        // Overwrite auth on header
1419        headers.clone().write().user_headers.insert(
1420            "authorization".parse().unwrap(),
1421            "my-api-key3".parse().unwrap(),
1422        );
1423        let req = interceptor.call(tonic::Request::new(())).unwrap();
1424        assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1425        assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1426
1427        // Remove headers and auth and confirm gone
1428        headers.clone().write().user_headers.clear();
1429        headers.clone().write().user_binary_headers.clear();
1430        headers.clone().write().api_key.take();
1431        let req = interceptor.call(tonic::Request::new(())).unwrap();
1432        assert!(!req.metadata().contains_key("my-meta-key"));
1433        assert!(!req.metadata().contains_key("authorization"));
1434        assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));
1435
1436        // Timeout header not overriden
1437        let mut req = tonic::Request::new(());
1438        req.metadata_mut()
1439            .insert("grpc-timeout", "1S".parse().unwrap());
1440        let req = interceptor.call(req).unwrap();
1441        assert_eq!(
1442            req.metadata().get("grpc-timeout").unwrap(),
1443            "1S".parse::<MetadataValue<Ascii>>().unwrap()
1444        );
1445    }
1446
1447    #[test]
1448    fn invalid_ascii_header_key() {
1449        let invalid_headers = {
1450            let mut h = HashMap::new();
1451            h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
1452            h
1453        };
1454
1455        let result = parse_ascii_headers(invalid_headers);
1456        assert!(result.is_err());
1457        assert_eq!(
1458            result.err().unwrap().to_string(),
1459            "Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
1460        );
1461    }
1462
1463    #[test]
1464    fn invalid_ascii_header_value() {
1465        let invalid_headers = {
1466            let mut h = HashMap::new();
1467            // Nul bytes are valid UTF-8, but not valid ascii gRPC headers:
1468            h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
1469            h
1470        };
1471
1472        let result = parse_ascii_headers(invalid_headers);
1473        assert!(result.is_err());
1474        assert_eq!(
1475            result.err().unwrap().to_string(),
1476            "Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
1477        );
1478    }
1479
1480    #[test]
1481    fn invalid_binary_header_key() {
1482        let invalid_headers = {
1483            let mut h = HashMap::new();
1484            h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
1485            h
1486        };
1487
1488        let result = parse_binary_headers(invalid_headers);
1489        assert!(result.is_err());
1490        assert_eq!(
1491            result.err().unwrap().to_string(),
1492            "Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
1493        );
1494    }
1495
1496    #[test]
1497    fn keep_alive_defaults() {
1498        let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1499            .identity("enchicat".to_string())
1500            .client_name("cute-kitty".to_string())
1501            .client_version("0.1.0".to_string())
1502            .build();
1503        assert_eq!(
1504            opts.keep_alive.clone().unwrap().interval,
1505            ClientKeepAliveOptions::default().interval
1506        );
1507        assert_eq!(
1508            opts.keep_alive.clone().unwrap().timeout,
1509            ClientKeepAliveOptions::default().timeout
1510        );
1511
1512        // Can be explicitly set to None
1513        let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
1514            .identity("enchicat".to_string())
1515            .client_name("cute-kitty".to_string())
1516            .client_version("0.1.0".to_string())
1517            .keep_alive(None)
1518            .build();
1519        dbg!(&opts.keep_alive);
1520        assert!(opts.keep_alive.is_none());
1521    }
1522
1523    mod tls_custom_verifier_tests {
1524        use super::*;
1525        use tokio_rustls::rustls::{
1526            DigitallySignedStruct, Error as RustlsError, SignatureScheme,
1527            client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
1528            pki_types::{CertificateDer, ServerName, UnixTime},
1529        };
1530
1531        /// A minimal mock verifier for testing. In production, users would
1532        /// implement real certificate pinning or custom validation here.
1533        #[derive(Debug)]
1534        struct MockVerifier;
1535
1536        impl ServerCertVerifier for MockVerifier {
1537            fn verify_server_cert(
1538                &self,
1539                _end_entity: &CertificateDer<'_>,
1540                _intermediates: &[CertificateDer<'_>],
1541                _server_name: &ServerName<'_>,
1542                _ocsp_response: &[u8],
1543                _now: UnixTime,
1544            ) -> Result<ServerCertVerified, RustlsError> {
1545                Ok(ServerCertVerified::assertion())
1546            }
1547
1548            fn verify_tls12_signature(
1549                &self,
1550                _message: &[u8],
1551                _cert: &CertificateDer<'_>,
1552                _dss: &DigitallySignedStruct,
1553            ) -> Result<HandshakeSignatureValid, RustlsError> {
1554                Ok(HandshakeSignatureValid::assertion())
1555            }
1556
1557            fn verify_tls13_signature(
1558                &self,
1559                _message: &[u8],
1560                _cert: &CertificateDer<'_>,
1561                _dss: &DigitallySignedStruct,
1562            ) -> Result<HandshakeSignatureValid, RustlsError> {
1563                Ok(HandshakeSignatureValid::assertion())
1564            }
1565
1566            fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
1567                vec![
1568                    SignatureScheme::ECDSA_NISTP256_SHA256,
1569                    SignatureScheme::RSA_PSS_SHA256,
1570                ]
1571            }
1572        }
1573
1574        #[tokio::test]
1575        async fn add_tls_to_channel_with_custom_verifier() {
1576            let tls_opts = TlsOptions {
1577                server_cert_verifier: Some(Arc::new(MockVerifier)),
1578                domain: Some("test.temporal.io".to_string()),
1579                ..Default::default()
1580            };
1581            let endpoint = tonic::transport::Channel::from_static("https://test.temporal.io:7233");
1582            let result = add_tls_to_channel(Some(&tls_opts), endpoint).await;
1583            assert!(
1584                result.is_ok(),
1585                "add_tls_to_channel should succeed with a custom verifier: {:?}",
1586                result.err()
1587            );
1588        }
1589
1590        #[tokio::test]
1591        async fn add_tls_to_channel_with_verifier_and_ca_cert_fails() {
1592            // When both server_cert_verifier and server_root_ca_cert are set,
1593            // add_tls_to_channel should fail with InvalidConfig.
1594            let tls_opts = TlsOptions {
1595                server_root_ca_cert: Some(b"some-ca-cert-bytes".to_vec()),
1596                server_cert_verifier: Some(Arc::new(MockVerifier)),
1597                domain: Some("test.temporal.io".to_string()),
1598                ..Default::default()
1599            };
1600            let endpoint = tonic::transport::Channel::from_static("https://test.temporal.io:7233");
1601            let result = add_tls_to_channel(Some(&tls_opts), endpoint).await;
1602            assert!(
1603                matches!(result, Err(ClientConnectError::InvalidConfig(_))),
1604                "add_tls_to_channel should fail with InvalidConfig when both CA cert and verifier are set: {:?}",
1605                result
1606            );
1607        }
1608
1609        #[tokio::test]
1610        async fn add_tls_to_channel_without_verifier_still_works() {
1611            // Regression test: the original PEM path must still work.
1612            let tls_opts = TlsOptions {
1613                domain: Some("test.temporal.io".to_string()),
1614                ..Default::default()
1615            };
1616            let endpoint = tonic::transport::Channel::from_static("https://test.temporal.io:7233");
1617            let result = add_tls_to_channel(Some(&tls_opts), endpoint).await;
1618            assert!(
1619                result.is_ok(),
1620                "add_tls_to_channel should succeed without a verifier (native roots): {:?}",
1621                result.err()
1622            );
1623        }
1624    }
1625
1626    mod list_workflows_tests {
1627        use super::*;
1628        use futures_util::{FutureExt, StreamExt};
1629        use std::sync::atomic::{AtomicUsize, Ordering};
1630        use temporalio_common::protos::temporal::api::common::v1::WorkflowExecution as ProtoWorkflowExecution;
1631        use tonic::{Request, Response};
1632
1633        #[derive(Clone)]
1634        struct MockListWorkflowsClient {
1635            call_count: Arc<AtomicUsize>,
1636            // Returns this many workflows per page
1637            page_size: usize,
1638            // Total workflows available
1639            total_workflows: usize,
1640        }
1641
1642        impl NamespacedClient for MockListWorkflowsClient {
1643            fn namespace(&self) -> String {
1644                "test-namespace".to_string()
1645            }
1646            fn identity(&self) -> String {
1647                "test-identity".to_string()
1648            }
1649        }
1650
1651        impl WorkflowService for MockListWorkflowsClient {
1652            fn list_workflow_executions(
1653                &mut self,
1654                request: Request<ListWorkflowExecutionsRequest>,
1655            ) -> futures_util::future::BoxFuture<
1656                '_,
1657                Result<Response<ListWorkflowExecutionsResponse>, tonic::Status>,
1658            > {
1659                self.call_count.fetch_add(1, Ordering::SeqCst);
1660                let req = request.into_inner();
1661
1662                // Determine offset from page token
1663                let offset: usize = if req.next_page_token.is_empty() {
1664                    0
1665                } else {
1666                    String::from_utf8(req.next_page_token)
1667                        .unwrap()
1668                        .parse()
1669                        .unwrap()
1670                };
1671
1672                let remaining = self.total_workflows.saturating_sub(offset);
1673                let count = remaining.min(self.page_size);
1674                let new_offset = offset + count;
1675
1676                let executions: Vec<_> = (offset..offset + count)
1677                    .map(|i| workflow::WorkflowExecutionInfo {
1678                        execution: Some(ProtoWorkflowExecution {
1679                            workflow_id: format!("wf-{i}"),
1680                            run_id: format!("run-{i}"),
1681                        }),
1682                        r#type: Some(WorkflowType {
1683                            name: "TestWorkflow".to_string(),
1684                        }),
1685                        task_queue: "test-queue".to_string(),
1686                        ..Default::default()
1687                    })
1688                    .collect();
1689
1690                let next_page_token = if new_offset < self.total_workflows {
1691                    new_offset.to_string().into_bytes()
1692                } else {
1693                    vec![]
1694                };
1695
1696                async move {
1697                    Ok(Response::new(ListWorkflowExecutionsResponse {
1698                        executions,
1699                        next_page_token,
1700                    }))
1701                }
1702                .boxed()
1703            }
1704        }
1705
1706        #[tokio::test]
1707        async fn list_workflows_paginates_through_all_results() {
1708            let call_count = Arc::new(AtomicUsize::new(0));
1709            let client = MockListWorkflowsClient {
1710                call_count: call_count.clone(),
1711                page_size: 3,
1712                total_workflows: 10,
1713            };
1714
1715            let stream = client.list_workflows("", WorkflowListOptions::default());
1716            let results: Vec<_> = stream.collect().await;
1717
1718            assert_eq!(results.len(), 10);
1719            for (i, result) in results.iter().enumerate() {
1720                let wf = result.as_ref().unwrap();
1721                assert_eq!(wf.id(), format!("wf-{i}"));
1722                assert_eq!(wf.run_id(), format!("run-{i}"));
1723            }
1724            // Should have made 4 calls: pages of 3, 3, 3, 1
1725            assert_eq!(call_count.load(Ordering::SeqCst), 4);
1726        }
1727
1728        #[tokio::test]
1729        async fn list_workflows_respects_limit() {
1730            let call_count = Arc::new(AtomicUsize::new(0));
1731            let client = MockListWorkflowsClient {
1732                call_count: call_count.clone(),
1733                page_size: 3,
1734                total_workflows: 10,
1735            };
1736
1737            let opts = WorkflowListOptions::builder().limit(5).build();
1738            let stream = client.list_workflows("", opts);
1739            let results: Vec<_> = stream.collect().await;
1740
1741            assert_eq!(results.len(), 5);
1742            for (i, result) in results.iter().enumerate() {
1743                let wf = result.as_ref().unwrap();
1744                assert_eq!(wf.id(), format!("wf-{i}"));
1745            }
1746            // Should have made 2 calls: 1 page of 3, then 2 more from next page
1747            assert_eq!(call_count.load(Ordering::SeqCst), 2);
1748        }
1749
1750        #[tokio::test]
1751        async fn list_workflows_limit_less_than_page_size() {
1752            let call_count = Arc::new(AtomicUsize::new(0));
1753            let client = MockListWorkflowsClient {
1754                call_count: call_count.clone(),
1755                page_size: 10,
1756                total_workflows: 100,
1757            };
1758
1759            let opts = WorkflowListOptions::builder().limit(3).build();
1760            let stream = client.list_workflows("", opts);
1761            let results: Vec<_> = stream.collect().await;
1762
1763            assert_eq!(results.len(), 3);
1764            // Only 1 call needed since limit < page_size
1765            assert_eq!(call_count.load(Ordering::SeqCst), 1);
1766        }
1767
1768        #[tokio::test]
1769        async fn list_workflows_empty_results() {
1770            let call_count = Arc::new(AtomicUsize::new(0));
1771            let client = MockListWorkflowsClient {
1772                call_count: call_count.clone(),
1773                page_size: 10,
1774                total_workflows: 0,
1775            };
1776
1777            let stream = client.list_workflows("", WorkflowListOptions::default());
1778            let results: Vec<_> = stream.collect().await;
1779
1780            assert_eq!(results.len(), 0);
1781            assert_eq!(call_count.load(Ordering::SeqCst), 1);
1782        }
1783    }
1784}