Skip to main content

temporalio_client/
envconfig.rs

1//! Conversion from [`temporalio_common::envconfig::ClientConfigProfile`] to [`ConnectionOptions`] and [`ClientOptions`].
2//!
3//! This module bridges the environment/file-based configuration in `temporalio-common` with
4//! the client connection types.
5
6use std::{collections::HashMap, fs};
7use url::Url;
8
9pub use temporalio_common::envconfig::{
10    ClientConfigProfile, ConfigError, DataSource, LoadClientConfigProfileOptions,
11};
12use temporalio_common::envconfig::{ClientConfigTLS, load_client_config_profile};
13
14use crate::{ClientOptions, ClientTlsOptions, ConnectionOptions, TlsOptions};
15
16const DEFAULT_ADDRESS: &str = "http://localhost:7233";
17const DEFAULT_NAMESPACE: &str = "default";
18
19impl ClientOptions {
20    /// Load client and connection options from environment variables and/or a TOML config file.
21    pub fn load_from_config(
22        options: LoadClientConfigProfileOptions,
23    ) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
24        load_from_config_with_env(options, None)
25    }
26}
27
28// Separate function allows injecting env vars for testing.
29fn load_from_config_with_env(
30    options: LoadClientConfigProfileOptions,
31    env_vars: Option<&HashMap<String, String>>,
32) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
33    let profile = load_client_config_profile(options, env_vars)?;
34    let namespace = profile
35        .namespace
36        .clone()
37        .unwrap_or_else(|| DEFAULT_NAMESPACE.to_owned());
38    let conn_opts = ConnectionOptions::try_from(profile)?;
39    let client_opts = ClientOptions::new(namespace).build();
40    Ok((conn_opts, client_opts))
41}
42
43/// Parse an address string into a [`Url`], prepending a scheme if none is present.
44///
45/// Other SDKs pass addresses as bare `host:port` strings. Our [`ConnectionOptions`] requires a
46/// [`Url`], so we attempt a direct parse first and fall back to prepending a scheme.
47/// When the user omits a scheme, we use `https://` if TLS will be enabled, otherwise `http://`.
48fn parse_address(address: &str, use_tls: bool) -> Result<Url, ConfigError> {
49    // Try parsing as-is. `Url::parse("localhost:7233")` "succeeds" by treating `localhost` as
50    // the scheme, so reject parses that have no host — those need a scheme prefix.
51    if let Ok(url) = Url::parse(address)
52        && url.host().is_some()
53    {
54        return Ok(url);
55    }
56    let scheme = if use_tls { "https" } else { "http" };
57    Url::parse(&format!("{scheme}://{address}"))
58        .map_err(|e| ConfigError::InvalidConfig(format!("Invalid address: {e}")))
59}
60
61/// Build [`TlsOptions`] from a [`ClientConfigTLS`] config, resolving any file-based data sources.
62fn build_tls_options(tls: ClientConfigTLS) -> Result<TlsOptions, ConfigError> {
63    let client_tls_options = match (tls.client_cert, tls.client_key) {
64        (Some(cert), Some(key)) => {
65            let cert_bytes =
66                resolve_datasource(cert).map_err(|e| ConfigError::LoadError(e.into()))?;
67            let key_bytes =
68                resolve_datasource(key).map_err(|e| ConfigError::LoadError(e.into()))?;
69            Some(ClientTlsOptions {
70                client_cert: cert_bytes,
71                client_private_key: key_bytes,
72            })
73        }
74        (Some(_), None) | (None, Some(_)) => {
75            return Err(ConfigError::InvalidConfig(
76                "Both client certificate and client key must be provided together".to_string(),
77            ));
78        }
79        (None, None) => None,
80    };
81
82    let server_root_ca_cert = tls
83        .server_ca_cert
84        .map(resolve_datasource)
85        .transpose()
86        .map_err(|e| ConfigError::LoadError(e.into()))?;
87
88    Ok(TlsOptions {
89        server_root_ca_cert,
90        domain: tls.server_name,
91        client_tls_options,
92        server_cert_verifier: None,
93    })
94}
95
96/// Determine whether TLS should be enabled based on the profile's TLS config and API key.
97///
98/// TLS is enabled when:
99/// - There is a TLS section that is not explicitly disabled, OR
100/// - An API key is set and TLS is not explicitly disabled
101fn should_enable_tls(tls: &Option<ClientConfigTLS>, has_api_key: bool) -> bool {
102    match tls {
103        Some(t) => t.disabled != Some(true),
104        None => has_api_key,
105    }
106}
107
108impl TryFrom<ClientConfigProfile> for ConnectionOptions {
109    type Error = ConfigError;
110
111    fn try_from(profile: ClientConfigProfile) -> Result<Self, Self::Error> {
112        let ClientConfigProfile {
113            address,
114            namespace: _,
115            api_key,
116            tls,
117            codec: _,
118            grpc_meta,
119        } = profile;
120
121        let has_api_key = api_key.is_some();
122        let use_tls = should_enable_tls(&tls, has_api_key);
123        let target = parse_address(address.as_deref().unwrap_or(DEFAULT_ADDRESS), use_tls)?;
124
125        let tls_options = if use_tls {
126            match tls {
127                Some(tls_cfg) => Some(build_tls_options(tls_cfg)?),
128                None => Some(TlsOptions::default()),
129            }
130        } else {
131            None
132        };
133
134        let headers = (!grpc_meta.is_empty()).then_some(grpc_meta);
135
136        Ok(ConnectionOptions::new(target)
137            .maybe_api_key(api_key)
138            .maybe_tls_options(tls_options)
139            .maybe_headers(headers)
140            .build())
141    }
142}
143
144/// Resolve a data source to its raw bytes.
145fn resolve_datasource(data_source: DataSource) -> Result<Vec<u8>, std::io::Error> {
146    match data_source {
147        DataSource::Path(path) => fs::read(path),
148        DataSource::Data(data) => Ok(data),
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use rstest::{fixture, rstest};
156    use std::path::PathBuf;
157    use tempfile::TempDir;
158    use temporalio_common::envconfig::{ClientConfigTLS, DataSource};
159
160    /// Write a TOML config file into a temp directory and return (dir, path).
161    /// The `TempDir` handle keeps the directory alive; it is cleaned up on drop.
162    #[fixture]
163    fn config_dir() -> TempDir {
164        TempDir::new().unwrap()
165    }
166
167    /// Write `content` to `temporal.toml` inside `dir`, returning the file path.
168    fn write_config(dir: &TempDir, content: &str) -> PathBuf {
169        let path = dir.path().join("temporal.toml");
170        std::fs::write(&path, content).unwrap();
171        path
172    }
173
174    #[rstest]
175    #[case::default(None, false, "http://localhost:7233/")]
176    #[case::with_scheme(Some("https://my-server:7233"), false, "https://my-server:7233/")]
177    #[case::without_scheme(Some("localhost:7233"), false, "http://localhost:7233/")]
178    #[case::without_scheme_tls(Some("localhost:7233"), true, "https://localhost:7233/")]
179    #[case::explicit_http_with_tls(Some("http://my-server:7233"), true, "http://my-server:7233/")]
180    fn address_parsing(
181        #[case] address: Option<&str>,
182        #[case] enable_tls: bool,
183        #[case] expected: &str,
184    ) {
185        let tls = enable_tls.then(ClientConfigTLS::default);
186        let profile = ClientConfigProfile {
187            address: address.map(str::to_string),
188            tls,
189            ..Default::default()
190        };
191        let conn: ConnectionOptions = profile.try_into().unwrap();
192        assert_eq!(conn.target.as_str(), expected);
193    }
194
195    #[test]
196    fn invalid_address_errors() {
197        let profile = ClientConfigProfile {
198            address: Some("://bad".to_string()),
199            ..Default::default()
200        };
201        assert!(ConnectionOptions::try_from(profile).is_err());
202    }
203
204    #[test]
205    fn empty_profile_defaults() {
206        let env = HashMap::new();
207        let opts = LoadClientConfigProfileOptions {
208            disable_file: true,
209            ..Default::default()
210        };
211        let (conn, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
212
213        assert_eq!(conn.target.as_str(), "http://localhost:7233/");
214        assert_eq!(client.namespace, "default");
215        assert!(conn.tls_options.is_none());
216        assert!(conn.headers.is_none());
217        assert!(conn.api_key.is_none());
218    }
219
220    #[test]
221    fn namespace_override() {
222        let mut env = HashMap::new();
223        env.insert("TEMPORAL_NAMESPACE".to_string(), "my-namespace".to_string());
224        let opts = LoadClientConfigProfileOptions {
225            disable_file: true,
226            ..Default::default()
227        };
228        let (_, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
229        assert_eq!(client.namespace, "my-namespace");
230    }
231
232    #[test]
233    fn grpc_metadata_passthrough() {
234        let mut meta = HashMap::new();
235        meta.insert("x-custom".to_string(), "value".to_string());
236        meta.insert("another".to_string(), "header".to_string());
237        let profile = ClientConfigProfile {
238            grpc_meta: meta.clone(),
239            ..Default::default()
240        };
241        let conn: ConnectionOptions = profile.try_into().unwrap();
242        assert_eq!(conn.headers.unwrap(), meta);
243    }
244
245    #[test]
246    fn api_key_populates_field() {
247        let profile = ClientConfigProfile {
248            api_key: Some("my-key".to_string()),
249            ..Default::default()
250        };
251        let conn: ConnectionOptions = profile.try_into().unwrap();
252        assert_eq!(conn.api_key.as_deref(), Some("my-key"));
253    }
254
255    #[rstest]
256    #[case::no_tls_no_key(None, None, false)]
257    #[case::no_tls_with_key(None, Some("key"), true)]
258    #[case::tls_disabled_false(Some(Some(false)), None, true)]
259    #[case::tls_disabled_true(Some(Some(true)), None, false)]
260    #[case::tls_disabled_none(Some(None), None, true)]
261    #[case::key_with_tls_disabled(Some(Some(true)), Some("key"), false)]
262    #[case::key_with_tls_enabled(Some(Some(false)), Some("key"), true)]
263    fn tls_enablement(
264        #[case] tls_disabled: Option<Option<bool>>,
265        #[case] api_key: Option<&str>,
266        #[case] expect_tls: bool,
267    ) {
268        let profile = ClientConfigProfile {
269            api_key: api_key.map(str::to_string),
270            tls: tls_disabled.map(|disabled| ClientConfigTLS {
271                disabled,
272                ..Default::default()
273            }),
274            ..Default::default()
275        };
276        let conn: ConnectionOptions = profile.try_into().unwrap();
277        assert_eq!(conn.tls_options.is_some(), expect_tls);
278    }
279
280    #[test]
281    fn data_source_certs() {
282        let profile = ClientConfigProfile {
283            tls: Some(ClientConfigTLS {
284                client_cert: Some(DataSource::Data(b"cert-data".to_vec())),
285                client_key: Some(DataSource::Data(b"key-data".to_vec())),
286                ..Default::default()
287            }),
288            ..Default::default()
289        };
290        let conn: ConnectionOptions = profile.try_into().unwrap();
291        let tls = conn.tls_options.unwrap();
292        let mtls = tls.client_tls_options.unwrap();
293        assert_eq!(mtls.client_cert, b"cert-data");
294        assert_eq!(mtls.client_private_key, b"key-data");
295    }
296
297    #[rstest]
298    fn path_source_certs(config_dir: TempDir) {
299        let cert_path = config_dir.path().join("cert.pem");
300        let key_path = config_dir.path().join("key.pem");
301        std::fs::write(&cert_path, b"file-cert").unwrap();
302        std::fs::write(&key_path, b"file-key").unwrap();
303
304        let profile = ClientConfigProfile {
305            tls: Some(ClientConfigTLS {
306                client_cert: Some(DataSource::Path(cert_path.to_str().unwrap().to_string())),
307                client_key: Some(DataSource::Path(key_path.to_str().unwrap().to_string())),
308                ..Default::default()
309            }),
310            ..Default::default()
311        };
312        let conn: ConnectionOptions = profile.try_into().unwrap();
313        let tls = conn.tls_options.unwrap();
314        let mtls = tls.client_tls_options.unwrap();
315        assert_eq!(mtls.client_cert, b"file-cert");
316        assert_eq!(mtls.client_private_key, b"file-key");
317    }
318
319    #[test]
320    fn server_ca_cert() {
321        let profile = ClientConfigProfile {
322            tls: Some(ClientConfigTLS {
323                server_ca_cert: Some(DataSource::Data(b"ca-data".to_vec())),
324                ..Default::default()
325            }),
326            ..Default::default()
327        };
328        let conn: ConnectionOptions = profile.try_into().unwrap();
329        let tls = conn.tls_options.unwrap();
330        assert_eq!(tls.server_root_ca_cert.unwrap(), b"ca-data");
331    }
332
333    #[test]
334    fn server_name_sni() {
335        let profile = ClientConfigProfile {
336            tls: Some(ClientConfigTLS {
337                server_name: Some("my.server.com".to_string()),
338                ..Default::default()
339            }),
340            ..Default::default()
341        };
342        let conn: ConnectionOptions = profile.try_into().unwrap();
343        let tls = conn.tls_options.unwrap();
344        assert_eq!(tls.domain.as_deref(), Some("my.server.com"));
345    }
346
347    #[rstest]
348    #[case::cert_without_key(Some(DataSource::Data(b"cert".to_vec())), None)]
349    #[case::key_without_cert(None, Some(DataSource::Data(b"key".to_vec())))]
350    fn partial_tls_errors(
351        #[case] client_cert: Option<DataSource>,
352        #[case] client_key: Option<DataSource>,
353    ) {
354        let profile = ClientConfigProfile {
355            tls: Some(ClientConfigTLS {
356                client_cert,
357                client_key,
358                ..Default::default()
359            }),
360            ..Default::default()
361        };
362        assert!(ConnectionOptions::try_from(profile).is_err());
363    }
364
365    #[rstest]
366    fn load_from_config_from_toml(config_dir: TempDir) {
367        let config_path = write_config(
368            &config_dir,
369            r#"
370[profile.default]
371address = "toml-server:7233"
372namespace = "toml-ns"
373api_key = "toml-key"
374
375[profile.default.grpc_meta]
376x-custom = "value"
377
378[profile.custom]
379address = "custom-server:9090"
380namespace = "custom-ns"
381"#,
382        );
383
384        // Default profile
385        let opts = LoadClientConfigProfileOptions {
386            config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
387            disable_env: true,
388            ..Default::default()
389        };
390        let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
391        assert_eq!(conn.target.as_str(), "https://toml-server:7233/");
392        assert_eq!(client.namespace, "toml-ns");
393        assert_eq!(conn.api_key.as_deref(), Some("toml-key"));
394        assert!(conn.tls_options.is_some());
395        assert_eq!(
396            conn.headers.as_ref().unwrap().get("x-custom").unwrap(),
397            "value"
398        );
399
400        // Custom profile
401        let opts = LoadClientConfigProfileOptions {
402            config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
403            config_file_profile: Some("custom".to_string()),
404            disable_env: true,
405            ..Default::default()
406        };
407        let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
408        assert_eq!(conn.target.as_str(), "http://custom-server:9090/");
409        assert_eq!(client.namespace, "custom-ns");
410    }
411}