1use std::{collections::HashMap, fs};
7use url::Url;
8
9pub use temporalio_common::envconfig::{
10 ClientConfigProfile, ConfigError, DataSource, LoadClientConfigProfileOptions,
11};
12use temporalio_common::envconfig::{ClientConfigTLS, load_client_config_profile};
13
14use crate::{ClientOptions, ClientTlsOptions, ConnectionOptions, TlsOptions};
15
16const DEFAULT_ADDRESS: &str = "http://localhost:7233";
17const DEFAULT_NAMESPACE: &str = "default";
18
19impl ClientOptions {
20 pub fn load_from_config(
22 options: LoadClientConfigProfileOptions,
23 ) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
24 load_from_config_with_env(options, None)
25 }
26}
27
28fn load_from_config_with_env(
30 options: LoadClientConfigProfileOptions,
31 env_vars: Option<&HashMap<String, String>>,
32) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
33 let profile = load_client_config_profile(options, env_vars)?;
34 let namespace = profile
35 .namespace
36 .clone()
37 .unwrap_or_else(|| DEFAULT_NAMESPACE.to_owned());
38 let conn_opts = ConnectionOptions::try_from(profile)?;
39 let client_opts = ClientOptions::new(namespace).build();
40 Ok((conn_opts, client_opts))
41}
42
43fn parse_address(address: &str, use_tls: bool) -> Result<Url, ConfigError> {
49 if let Ok(url) = Url::parse(address)
52 && url.host().is_some()
53 {
54 return Ok(url);
55 }
56 let scheme = if use_tls { "https" } else { "http" };
57 Url::parse(&format!("{scheme}://{address}"))
58 .map_err(|e| ConfigError::InvalidConfig(format!("Invalid address: {e}")))
59}
60
61fn build_tls_options(tls: ClientConfigTLS) -> Result<TlsOptions, ConfigError> {
63 let client_tls_options = match (tls.client_cert, tls.client_key) {
64 (Some(cert), Some(key)) => {
65 let cert_bytes =
66 resolve_datasource(cert).map_err(|e| ConfigError::LoadError(e.into()))?;
67 let key_bytes =
68 resolve_datasource(key).map_err(|e| ConfigError::LoadError(e.into()))?;
69 Some(ClientTlsOptions {
70 client_cert: cert_bytes,
71 client_private_key: key_bytes,
72 })
73 }
74 (Some(_), None) | (None, Some(_)) => {
75 return Err(ConfigError::InvalidConfig(
76 "Both client certificate and client key must be provided together".to_string(),
77 ));
78 }
79 (None, None) => None,
80 };
81
82 let server_root_ca_cert = tls
83 .server_ca_cert
84 .map(resolve_datasource)
85 .transpose()
86 .map_err(|e| ConfigError::LoadError(e.into()))?;
87
88 Ok(TlsOptions {
89 server_root_ca_cert,
90 domain: tls.server_name,
91 client_tls_options,
92 })
93}
94
95fn should_enable_tls(tls: &Option<ClientConfigTLS>, has_api_key: bool) -> bool {
101 match tls {
102 Some(t) => t.disabled != Some(true),
103 None => has_api_key,
104 }
105}
106
107impl TryFrom<ClientConfigProfile> for ConnectionOptions {
108 type Error = ConfigError;
109
110 fn try_from(profile: ClientConfigProfile) -> Result<Self, Self::Error> {
111 let ClientConfigProfile {
112 address,
113 namespace: _,
114 api_key,
115 tls,
116 codec: _,
117 grpc_meta,
118 } = profile;
119
120 let has_api_key = api_key.is_some();
121 let use_tls = should_enable_tls(&tls, has_api_key);
122 let target = parse_address(address.as_deref().unwrap_or(DEFAULT_ADDRESS), use_tls)?;
123
124 let tls_options = if use_tls {
125 match tls {
126 Some(tls_cfg) => Some(build_tls_options(tls_cfg)?),
127 None => Some(TlsOptions::default()),
128 }
129 } else {
130 None
131 };
132
133 let headers = (!grpc_meta.is_empty()).then_some(grpc_meta);
134
135 Ok(ConnectionOptions::new(target)
136 .maybe_api_key(api_key)
137 .maybe_tls_options(tls_options)
138 .maybe_headers(headers)
139 .build())
140 }
141}
142
143fn resolve_datasource(data_source: DataSource) -> Result<Vec<u8>, std::io::Error> {
145 match data_source {
146 DataSource::Path(path) => fs::read(path),
147 DataSource::Data(data) => Ok(data),
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use rstest::{fixture, rstest};
155 use std::path::PathBuf;
156 use tempfile::TempDir;
157 use temporalio_common::envconfig::{ClientConfigTLS, DataSource};
158
159 #[fixture]
162 fn config_dir() -> TempDir {
163 TempDir::new().unwrap()
164 }
165
166 fn write_config(dir: &TempDir, content: &str) -> PathBuf {
168 let path = dir.path().join("temporal.toml");
169 std::fs::write(&path, content).unwrap();
170 path
171 }
172
173 #[rstest]
174 #[case::default(None, false, "http://localhost:7233/")]
175 #[case::with_scheme(Some("https://my-server:7233"), false, "https://my-server:7233/")]
176 #[case::without_scheme(Some("localhost:7233"), false, "http://localhost:7233/")]
177 #[case::without_scheme_tls(Some("localhost:7233"), true, "https://localhost:7233/")]
178 #[case::explicit_http_with_tls(Some("http://my-server:7233"), true, "http://my-server:7233/")]
179 fn address_parsing(
180 #[case] address: Option<&str>,
181 #[case] enable_tls: bool,
182 #[case] expected: &str,
183 ) {
184 let tls = enable_tls.then(ClientConfigTLS::default);
185 let profile = ClientConfigProfile {
186 address: address.map(str::to_string),
187 tls,
188 ..Default::default()
189 };
190 let conn: ConnectionOptions = profile.try_into().unwrap();
191 assert_eq!(conn.target.as_str(), expected);
192 }
193
194 #[test]
195 fn invalid_address_errors() {
196 let profile = ClientConfigProfile {
197 address: Some("://bad".to_string()),
198 ..Default::default()
199 };
200 assert!(ConnectionOptions::try_from(profile).is_err());
201 }
202
203 #[test]
204 fn empty_profile_defaults() {
205 let env = HashMap::new();
206 let opts = LoadClientConfigProfileOptions {
207 disable_file: true,
208 ..Default::default()
209 };
210 let (conn, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
211
212 assert_eq!(conn.target.as_str(), "http://localhost:7233/");
213 assert_eq!(client.namespace, "default");
214 assert!(conn.tls_options.is_none());
215 assert!(conn.headers.is_none());
216 assert!(conn.api_key.is_none());
217 }
218
219 #[test]
220 fn namespace_override() {
221 let mut env = HashMap::new();
222 env.insert("TEMPORAL_NAMESPACE".to_string(), "my-namespace".to_string());
223 let opts = LoadClientConfigProfileOptions {
224 disable_file: true,
225 ..Default::default()
226 };
227 let (_, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
228 assert_eq!(client.namespace, "my-namespace");
229 }
230
231 #[test]
232 fn grpc_metadata_passthrough() {
233 let mut meta = HashMap::new();
234 meta.insert("x-custom".to_string(), "value".to_string());
235 meta.insert("another".to_string(), "header".to_string());
236 let profile = ClientConfigProfile {
237 grpc_meta: meta.clone(),
238 ..Default::default()
239 };
240 let conn: ConnectionOptions = profile.try_into().unwrap();
241 assert_eq!(conn.headers.unwrap(), meta);
242 }
243
244 #[test]
245 fn api_key_populates_field() {
246 let profile = ClientConfigProfile {
247 api_key: Some("my-key".to_string()),
248 ..Default::default()
249 };
250 let conn: ConnectionOptions = profile.try_into().unwrap();
251 assert_eq!(conn.api_key.as_deref(), Some("my-key"));
252 }
253
254 #[rstest]
255 #[case::no_tls_no_key(None, None, false)]
256 #[case::no_tls_with_key(None, Some("key"), true)]
257 #[case::tls_disabled_false(Some(Some(false)), None, true)]
258 #[case::tls_disabled_true(Some(Some(true)), None, false)]
259 #[case::tls_disabled_none(Some(None), None, true)]
260 #[case::key_with_tls_disabled(Some(Some(true)), Some("key"), false)]
261 #[case::key_with_tls_enabled(Some(Some(false)), Some("key"), true)]
262 fn tls_enablement(
263 #[case] tls_disabled: Option<Option<bool>>,
264 #[case] api_key: Option<&str>,
265 #[case] expect_tls: bool,
266 ) {
267 let profile = ClientConfigProfile {
268 api_key: api_key.map(str::to_string),
269 tls: tls_disabled.map(|disabled| ClientConfigTLS {
270 disabled,
271 ..Default::default()
272 }),
273 ..Default::default()
274 };
275 let conn: ConnectionOptions = profile.try_into().unwrap();
276 assert_eq!(conn.tls_options.is_some(), expect_tls);
277 }
278
279 #[test]
280 fn data_source_certs() {
281 let profile = ClientConfigProfile {
282 tls: Some(ClientConfigTLS {
283 client_cert: Some(DataSource::Data(b"cert-data".to_vec())),
284 client_key: Some(DataSource::Data(b"key-data".to_vec())),
285 ..Default::default()
286 }),
287 ..Default::default()
288 };
289 let conn: ConnectionOptions = profile.try_into().unwrap();
290 let tls = conn.tls_options.unwrap();
291 let mtls = tls.client_tls_options.unwrap();
292 assert_eq!(mtls.client_cert, b"cert-data");
293 assert_eq!(mtls.client_private_key, b"key-data");
294 }
295
296 #[rstest]
297 fn path_source_certs(config_dir: TempDir) {
298 let cert_path = config_dir.path().join("cert.pem");
299 let key_path = config_dir.path().join("key.pem");
300 std::fs::write(&cert_path, b"file-cert").unwrap();
301 std::fs::write(&key_path, b"file-key").unwrap();
302
303 let profile = ClientConfigProfile {
304 tls: Some(ClientConfigTLS {
305 client_cert: Some(DataSource::Path(cert_path.to_str().unwrap().to_string())),
306 client_key: Some(DataSource::Path(key_path.to_str().unwrap().to_string())),
307 ..Default::default()
308 }),
309 ..Default::default()
310 };
311 let conn: ConnectionOptions = profile.try_into().unwrap();
312 let tls = conn.tls_options.unwrap();
313 let mtls = tls.client_tls_options.unwrap();
314 assert_eq!(mtls.client_cert, b"file-cert");
315 assert_eq!(mtls.client_private_key, b"file-key");
316 }
317
318 #[test]
319 fn server_ca_cert() {
320 let profile = ClientConfigProfile {
321 tls: Some(ClientConfigTLS {
322 server_ca_cert: Some(DataSource::Data(b"ca-data".to_vec())),
323 ..Default::default()
324 }),
325 ..Default::default()
326 };
327 let conn: ConnectionOptions = profile.try_into().unwrap();
328 let tls = conn.tls_options.unwrap();
329 assert_eq!(tls.server_root_ca_cert.unwrap(), b"ca-data");
330 }
331
332 #[test]
333 fn server_name_sni() {
334 let profile = ClientConfigProfile {
335 tls: Some(ClientConfigTLS {
336 server_name: Some("my.server.com".to_string()),
337 ..Default::default()
338 }),
339 ..Default::default()
340 };
341 let conn: ConnectionOptions = profile.try_into().unwrap();
342 let tls = conn.tls_options.unwrap();
343 assert_eq!(tls.domain.as_deref(), Some("my.server.com"));
344 }
345
346 #[rstest]
347 #[case::cert_without_key(Some(DataSource::Data(b"cert".to_vec())), None)]
348 #[case::key_without_cert(None, Some(DataSource::Data(b"key".to_vec())))]
349 fn partial_tls_errors(
350 #[case] client_cert: Option<DataSource>,
351 #[case] client_key: Option<DataSource>,
352 ) {
353 let profile = ClientConfigProfile {
354 tls: Some(ClientConfigTLS {
355 client_cert,
356 client_key,
357 ..Default::default()
358 }),
359 ..Default::default()
360 };
361 assert!(ConnectionOptions::try_from(profile).is_err());
362 }
363
364 #[rstest]
365 fn load_from_config_from_toml(config_dir: TempDir) {
366 let config_path = write_config(
367 &config_dir,
368 r#"
369[profile.default]
370address = "toml-server:7233"
371namespace = "toml-ns"
372api_key = "toml-key"
373
374[profile.default.grpc_meta]
375x-custom = "value"
376
377[profile.custom]
378address = "custom-server:9090"
379namespace = "custom-ns"
380"#,
381 );
382
383 let opts = LoadClientConfigProfileOptions {
385 config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
386 disable_env: true,
387 ..Default::default()
388 };
389 let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
390 assert_eq!(conn.target.as_str(), "https://toml-server:7233/");
391 assert_eq!(client.namespace, "toml-ns");
392 assert_eq!(conn.api_key.as_deref(), Some("toml-key"));
393 assert!(conn.tls_options.is_some());
394 assert_eq!(
395 conn.headers.as_ref().unwrap().get("x-custom").unwrap(),
396 "value"
397 );
398
399 let opts = LoadClientConfigProfileOptions {
401 config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
402 config_file_profile: Some("custom".to_string()),
403 disable_env: true,
404 ..Default::default()
405 };
406 let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
407 assert_eq!(conn.target.as_str(), "http://custom-server:9090/");
408 assert_eq!(client.namespace, "custom-ns");
409 }
410}