Skip to main content

temporalio_client/
envconfig.rs

1//! Conversion from [`temporalio_common::envconfig::ClientConfigProfile`] to [`ConnectionOptions`] and [`ClientOptions`].
2//!
3//! This module bridges the environment/file-based configuration in `temporalio-common` with
4//! the client connection types.
5
6use std::{collections::HashMap, fs};
7use url::Url;
8
9pub use temporalio_common::envconfig::{
10    ClientConfigProfile, ConfigError, DataSource, LoadClientConfigProfileOptions,
11};
12use temporalio_common::envconfig::{ClientConfigTLS, load_client_config_profile};
13
14use crate::{ClientOptions, ClientTlsOptions, ConnectionOptions, TlsOptions};
15
16const DEFAULT_ADDRESS: &str = "http://localhost:7233";
17const DEFAULT_NAMESPACE: &str = "default";
18
19impl ClientOptions {
20    /// Load client and connection options from environment variables and/or a TOML config file.
21    pub fn load_from_config(
22        options: LoadClientConfigProfileOptions,
23    ) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
24        load_from_config_with_env(options, None)
25    }
26}
27
28// Separate function allows injecting env vars for testing.
29fn load_from_config_with_env(
30    options: LoadClientConfigProfileOptions,
31    env_vars: Option<&HashMap<String, String>>,
32) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
33    let profile = load_client_config_profile(options, env_vars)?;
34    let namespace = profile
35        .namespace
36        .clone()
37        .unwrap_or_else(|| DEFAULT_NAMESPACE.to_owned());
38    let conn_opts = ConnectionOptions::try_from(profile)?;
39    let client_opts = ClientOptions::new(namespace).build();
40    Ok((conn_opts, client_opts))
41}
42
43/// Parse an address string into a [`Url`], prepending a scheme if none is present.
44///
45/// Other SDKs pass addresses as bare `host:port` strings. Our [`ConnectionOptions`] requires a
46/// [`Url`], so we attempt a direct parse first and fall back to prepending a scheme.
47/// When the user omits a scheme, we use `https://` if TLS will be enabled, otherwise `http://`.
48fn parse_address(address: &str, use_tls: bool) -> Result<Url, ConfigError> {
49    // Try parsing as-is. `Url::parse("localhost:7233")` "succeeds" by treating `localhost` as
50    // the scheme, so reject parses that have no host — those need a scheme prefix.
51    if let Ok(url) = Url::parse(address)
52        && url.host().is_some()
53    {
54        return Ok(url);
55    }
56    let scheme = if use_tls { "https" } else { "http" };
57    Url::parse(&format!("{scheme}://{address}"))
58        .map_err(|e| ConfigError::InvalidConfig(format!("Invalid address: {e}")))
59}
60
61/// Build [`TlsOptions`] from a [`ClientConfigTLS`] config, resolving any file-based data sources.
62fn build_tls_options(tls: ClientConfigTLS) -> Result<TlsOptions, ConfigError> {
63    let client_tls_options = match (tls.client_cert, tls.client_key) {
64        (Some(cert), Some(key)) => {
65            let cert_bytes =
66                resolve_datasource(cert).map_err(|e| ConfigError::LoadError(e.into()))?;
67            let key_bytes =
68                resolve_datasource(key).map_err(|e| ConfigError::LoadError(e.into()))?;
69            Some(ClientTlsOptions {
70                client_cert: cert_bytes,
71                client_private_key: key_bytes,
72            })
73        }
74        (Some(_), None) | (None, Some(_)) => {
75            return Err(ConfigError::InvalidConfig(
76                "Both client certificate and client key must be provided together".to_string(),
77            ));
78        }
79        (None, None) => None,
80    };
81
82    let server_root_ca_cert = tls
83        .server_ca_cert
84        .map(resolve_datasource)
85        .transpose()
86        .map_err(|e| ConfigError::LoadError(e.into()))?;
87
88    Ok(TlsOptions {
89        server_root_ca_cert,
90        domain: tls.server_name,
91        client_tls_options,
92    })
93}
94
95/// Determine whether TLS should be enabled based on the profile's TLS config and API key.
96///
97/// TLS is enabled when:
98/// - There is a TLS section that is not explicitly disabled, OR
99/// - An API key is set and TLS is not explicitly disabled
100fn should_enable_tls(tls: &Option<ClientConfigTLS>, has_api_key: bool) -> bool {
101    match tls {
102        Some(t) => t.disabled != Some(true),
103        None => has_api_key,
104    }
105}
106
107impl TryFrom<ClientConfigProfile> for ConnectionOptions {
108    type Error = ConfigError;
109
110    fn try_from(profile: ClientConfigProfile) -> Result<Self, Self::Error> {
111        let ClientConfigProfile {
112            address,
113            namespace: _,
114            api_key,
115            tls,
116            codec: _,
117            grpc_meta,
118        } = profile;
119
120        let has_api_key = api_key.is_some();
121        let use_tls = should_enable_tls(&tls, has_api_key);
122        let target = parse_address(address.as_deref().unwrap_or(DEFAULT_ADDRESS), use_tls)?;
123
124        let tls_options = if use_tls {
125            match tls {
126                Some(tls_cfg) => Some(build_tls_options(tls_cfg)?),
127                None => Some(TlsOptions::default()),
128            }
129        } else {
130            None
131        };
132
133        let headers = (!grpc_meta.is_empty()).then_some(grpc_meta);
134
135        Ok(ConnectionOptions::new(target)
136            .maybe_api_key(api_key)
137            .maybe_tls_options(tls_options)
138            .maybe_headers(headers)
139            .build())
140    }
141}
142
143/// Resolve a data source to its raw bytes.
144fn resolve_datasource(data_source: DataSource) -> Result<Vec<u8>, std::io::Error> {
145    match data_source {
146        DataSource::Path(path) => fs::read(path),
147        DataSource::Data(data) => Ok(data),
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use rstest::{fixture, rstest};
155    use std::path::PathBuf;
156    use tempfile::TempDir;
157    use temporalio_common::envconfig::{ClientConfigTLS, DataSource};
158
159    /// Write a TOML config file into a temp directory and return (dir, path).
160    /// The `TempDir` handle keeps the directory alive; it is cleaned up on drop.
161    #[fixture]
162    fn config_dir() -> TempDir {
163        TempDir::new().unwrap()
164    }
165
166    /// Write `content` to `temporal.toml` inside `dir`, returning the file path.
167    fn write_config(dir: &TempDir, content: &str) -> PathBuf {
168        let path = dir.path().join("temporal.toml");
169        std::fs::write(&path, content).unwrap();
170        path
171    }
172
173    #[rstest]
174    #[case::default(None, false, "http://localhost:7233/")]
175    #[case::with_scheme(Some("https://my-server:7233"), false, "https://my-server:7233/")]
176    #[case::without_scheme(Some("localhost:7233"), false, "http://localhost:7233/")]
177    #[case::without_scheme_tls(Some("localhost:7233"), true, "https://localhost:7233/")]
178    #[case::explicit_http_with_tls(Some("http://my-server:7233"), true, "http://my-server:7233/")]
179    fn address_parsing(
180        #[case] address: Option<&str>,
181        #[case] enable_tls: bool,
182        #[case] expected: &str,
183    ) {
184        let tls = enable_tls.then(ClientConfigTLS::default);
185        let profile = ClientConfigProfile {
186            address: address.map(str::to_string),
187            tls,
188            ..Default::default()
189        };
190        let conn: ConnectionOptions = profile.try_into().unwrap();
191        assert_eq!(conn.target.as_str(), expected);
192    }
193
194    #[test]
195    fn invalid_address_errors() {
196        let profile = ClientConfigProfile {
197            address: Some("://bad".to_string()),
198            ..Default::default()
199        };
200        assert!(ConnectionOptions::try_from(profile).is_err());
201    }
202
203    #[test]
204    fn empty_profile_defaults() {
205        let env = HashMap::new();
206        let opts = LoadClientConfigProfileOptions {
207            disable_file: true,
208            ..Default::default()
209        };
210        let (conn, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
211
212        assert_eq!(conn.target.as_str(), "http://localhost:7233/");
213        assert_eq!(client.namespace, "default");
214        assert!(conn.tls_options.is_none());
215        assert!(conn.headers.is_none());
216        assert!(conn.api_key.is_none());
217    }
218
219    #[test]
220    fn namespace_override() {
221        let mut env = HashMap::new();
222        env.insert("TEMPORAL_NAMESPACE".to_string(), "my-namespace".to_string());
223        let opts = LoadClientConfigProfileOptions {
224            disable_file: true,
225            ..Default::default()
226        };
227        let (_, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
228        assert_eq!(client.namespace, "my-namespace");
229    }
230
231    #[test]
232    fn grpc_metadata_passthrough() {
233        let mut meta = HashMap::new();
234        meta.insert("x-custom".to_string(), "value".to_string());
235        meta.insert("another".to_string(), "header".to_string());
236        let profile = ClientConfigProfile {
237            grpc_meta: meta.clone(),
238            ..Default::default()
239        };
240        let conn: ConnectionOptions = profile.try_into().unwrap();
241        assert_eq!(conn.headers.unwrap(), meta);
242    }
243
244    #[test]
245    fn api_key_populates_field() {
246        let profile = ClientConfigProfile {
247            api_key: Some("my-key".to_string()),
248            ..Default::default()
249        };
250        let conn: ConnectionOptions = profile.try_into().unwrap();
251        assert_eq!(conn.api_key.as_deref(), Some("my-key"));
252    }
253
254    #[rstest]
255    #[case::no_tls_no_key(None, None, false)]
256    #[case::no_tls_with_key(None, Some("key"), true)]
257    #[case::tls_disabled_false(Some(Some(false)), None, true)]
258    #[case::tls_disabled_true(Some(Some(true)), None, false)]
259    #[case::tls_disabled_none(Some(None), None, true)]
260    #[case::key_with_tls_disabled(Some(Some(true)), Some("key"), false)]
261    #[case::key_with_tls_enabled(Some(Some(false)), Some("key"), true)]
262    fn tls_enablement(
263        #[case] tls_disabled: Option<Option<bool>>,
264        #[case] api_key: Option<&str>,
265        #[case] expect_tls: bool,
266    ) {
267        let profile = ClientConfigProfile {
268            api_key: api_key.map(str::to_string),
269            tls: tls_disabled.map(|disabled| ClientConfigTLS {
270                disabled,
271                ..Default::default()
272            }),
273            ..Default::default()
274        };
275        let conn: ConnectionOptions = profile.try_into().unwrap();
276        assert_eq!(conn.tls_options.is_some(), expect_tls);
277    }
278
279    #[test]
280    fn data_source_certs() {
281        let profile = ClientConfigProfile {
282            tls: Some(ClientConfigTLS {
283                client_cert: Some(DataSource::Data(b"cert-data".to_vec())),
284                client_key: Some(DataSource::Data(b"key-data".to_vec())),
285                ..Default::default()
286            }),
287            ..Default::default()
288        };
289        let conn: ConnectionOptions = profile.try_into().unwrap();
290        let tls = conn.tls_options.unwrap();
291        let mtls = tls.client_tls_options.unwrap();
292        assert_eq!(mtls.client_cert, b"cert-data");
293        assert_eq!(mtls.client_private_key, b"key-data");
294    }
295
296    #[rstest]
297    fn path_source_certs(config_dir: TempDir) {
298        let cert_path = config_dir.path().join("cert.pem");
299        let key_path = config_dir.path().join("key.pem");
300        std::fs::write(&cert_path, b"file-cert").unwrap();
301        std::fs::write(&key_path, b"file-key").unwrap();
302
303        let profile = ClientConfigProfile {
304            tls: Some(ClientConfigTLS {
305                client_cert: Some(DataSource::Path(cert_path.to_str().unwrap().to_string())),
306                client_key: Some(DataSource::Path(key_path.to_str().unwrap().to_string())),
307                ..Default::default()
308            }),
309            ..Default::default()
310        };
311        let conn: ConnectionOptions = profile.try_into().unwrap();
312        let tls = conn.tls_options.unwrap();
313        let mtls = tls.client_tls_options.unwrap();
314        assert_eq!(mtls.client_cert, b"file-cert");
315        assert_eq!(mtls.client_private_key, b"file-key");
316    }
317
318    #[test]
319    fn server_ca_cert() {
320        let profile = ClientConfigProfile {
321            tls: Some(ClientConfigTLS {
322                server_ca_cert: Some(DataSource::Data(b"ca-data".to_vec())),
323                ..Default::default()
324            }),
325            ..Default::default()
326        };
327        let conn: ConnectionOptions = profile.try_into().unwrap();
328        let tls = conn.tls_options.unwrap();
329        assert_eq!(tls.server_root_ca_cert.unwrap(), b"ca-data");
330    }
331
332    #[test]
333    fn server_name_sni() {
334        let profile = ClientConfigProfile {
335            tls: Some(ClientConfigTLS {
336                server_name: Some("my.server.com".to_string()),
337                ..Default::default()
338            }),
339            ..Default::default()
340        };
341        let conn: ConnectionOptions = profile.try_into().unwrap();
342        let tls = conn.tls_options.unwrap();
343        assert_eq!(tls.domain.as_deref(), Some("my.server.com"));
344    }
345
346    #[rstest]
347    #[case::cert_without_key(Some(DataSource::Data(b"cert".to_vec())), None)]
348    #[case::key_without_cert(None, Some(DataSource::Data(b"key".to_vec())))]
349    fn partial_tls_errors(
350        #[case] client_cert: Option<DataSource>,
351        #[case] client_key: Option<DataSource>,
352    ) {
353        let profile = ClientConfigProfile {
354            tls: Some(ClientConfigTLS {
355                client_cert,
356                client_key,
357                ..Default::default()
358            }),
359            ..Default::default()
360        };
361        assert!(ConnectionOptions::try_from(profile).is_err());
362    }
363
364    #[rstest]
365    fn load_from_config_from_toml(config_dir: TempDir) {
366        let config_path = write_config(
367            &config_dir,
368            r#"
369[profile.default]
370address = "toml-server:7233"
371namespace = "toml-ns"
372api_key = "toml-key"
373
374[profile.default.grpc_meta]
375x-custom = "value"
376
377[profile.custom]
378address = "custom-server:9090"
379namespace = "custom-ns"
380"#,
381        );
382
383        // Default profile
384        let opts = LoadClientConfigProfileOptions {
385            config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
386            disable_env: true,
387            ..Default::default()
388        };
389        let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
390        assert_eq!(conn.target.as_str(), "https://toml-server:7233/");
391        assert_eq!(client.namespace, "toml-ns");
392        assert_eq!(conn.api_key.as_deref(), Some("toml-key"));
393        assert!(conn.tls_options.is_some());
394        assert_eq!(
395            conn.headers.as_ref().unwrap().get("x-custom").unwrap(),
396            "value"
397        );
398
399        // Custom profile
400        let opts = LoadClientConfigProfileOptions {
401            config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
402            config_file_profile: Some("custom".to_string()),
403            disable_env: true,
404            ..Default::default()
405        };
406        let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
407        assert_eq!(conn.target.as_str(), "http://custom-server:9090/");
408        assert_eq!(client.namespace, "custom-ns");
409    }
410}