1use std::{collections::HashMap, fs};
7use url::Url;
8
9pub use temporalio_common::envconfig::{
10 ClientConfigProfile, ConfigError, DataSource, LoadClientConfigProfileOptions,
11};
12use temporalio_common::envconfig::{ClientConfigTLS, load_client_config_profile};
13
14use crate::{ClientOptions, ClientTlsOptions, ConnectionOptions, TlsOptions};
15
16const DEFAULT_ADDRESS: &str = "http://localhost:7233";
17const DEFAULT_NAMESPACE: &str = "default";
18
19impl ClientOptions {
20 pub fn load_from_config(
22 options: LoadClientConfigProfileOptions,
23 ) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
24 load_from_config_with_env(options, None)
25 }
26}
27
28fn load_from_config_with_env(
30 options: LoadClientConfigProfileOptions,
31 env_vars: Option<&HashMap<String, String>>,
32) -> Result<(ConnectionOptions, ClientOptions), ConfigError> {
33 let profile = load_client_config_profile(options, env_vars)?;
34 let namespace = profile
35 .namespace
36 .clone()
37 .unwrap_or_else(|| DEFAULT_NAMESPACE.to_owned());
38 let conn_opts = ConnectionOptions::try_from(profile)?;
39 let client_opts = ClientOptions::new(namespace).build();
40 Ok((conn_opts, client_opts))
41}
42
43fn parse_address(address: &str, use_tls: bool) -> Result<Url, ConfigError> {
49 if let Ok(url) = Url::parse(address)
52 && url.host().is_some()
53 {
54 return Ok(url);
55 }
56 let scheme = if use_tls { "https" } else { "http" };
57 Url::parse(&format!("{scheme}://{address}"))
58 .map_err(|e| ConfigError::InvalidConfig(format!("Invalid address: {e}")))
59}
60
61fn build_tls_options(tls: ClientConfigTLS) -> Result<TlsOptions, ConfigError> {
63 let client_tls_options = match (tls.client_cert, tls.client_key) {
64 (Some(cert), Some(key)) => {
65 let cert_bytes =
66 resolve_datasource(cert).map_err(|e| ConfigError::LoadError(e.into()))?;
67 let key_bytes =
68 resolve_datasource(key).map_err(|e| ConfigError::LoadError(e.into()))?;
69 Some(ClientTlsOptions {
70 client_cert: cert_bytes,
71 client_private_key: key_bytes,
72 })
73 }
74 (Some(_), None) | (None, Some(_)) => {
75 return Err(ConfigError::InvalidConfig(
76 "Both client certificate and client key must be provided together".to_string(),
77 ));
78 }
79 (None, None) => None,
80 };
81
82 let server_root_ca_cert = tls
83 .server_ca_cert
84 .map(resolve_datasource)
85 .transpose()
86 .map_err(|e| ConfigError::LoadError(e.into()))?;
87
88 Ok(TlsOptions {
89 server_root_ca_cert,
90 domain: tls.server_name,
91 client_tls_options,
92 server_cert_verifier: None,
93 })
94}
95
96fn should_enable_tls(tls: &Option<ClientConfigTLS>, has_api_key: bool) -> bool {
102 match tls {
103 Some(t) => t.disabled != Some(true),
104 None => has_api_key,
105 }
106}
107
108impl TryFrom<ClientConfigProfile> for ConnectionOptions {
109 type Error = ConfigError;
110
111 fn try_from(profile: ClientConfigProfile) -> Result<Self, Self::Error> {
112 let ClientConfigProfile {
113 address,
114 namespace: _,
115 api_key,
116 tls,
117 codec: _,
118 grpc_meta,
119 } = profile;
120
121 let has_api_key = api_key.is_some();
122 let use_tls = should_enable_tls(&tls, has_api_key);
123 let target = parse_address(address.as_deref().unwrap_or(DEFAULT_ADDRESS), use_tls)?;
124
125 let tls_options = if use_tls {
126 match tls {
127 Some(tls_cfg) => Some(build_tls_options(tls_cfg)?),
128 None => Some(TlsOptions::default()),
129 }
130 } else {
131 None
132 };
133
134 let headers = (!grpc_meta.is_empty()).then_some(grpc_meta);
135
136 Ok(ConnectionOptions::new(target)
137 .maybe_api_key(api_key)
138 .maybe_tls_options(tls_options)
139 .maybe_headers(headers)
140 .build())
141 }
142}
143
144fn resolve_datasource(data_source: DataSource) -> Result<Vec<u8>, std::io::Error> {
146 match data_source {
147 DataSource::Path(path) => fs::read(path),
148 DataSource::Data(data) => Ok(data),
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use rstest::{fixture, rstest};
156 use std::path::PathBuf;
157 use tempfile::TempDir;
158 use temporalio_common::envconfig::{ClientConfigTLS, DataSource};
159
160 #[fixture]
163 fn config_dir() -> TempDir {
164 TempDir::new().unwrap()
165 }
166
167 fn write_config(dir: &TempDir, content: &str) -> PathBuf {
169 let path = dir.path().join("temporal.toml");
170 std::fs::write(&path, content).unwrap();
171 path
172 }
173
174 #[rstest]
175 #[case::default(None, false, "http://localhost:7233/")]
176 #[case::with_scheme(Some("https://my-server:7233"), false, "https://my-server:7233/")]
177 #[case::without_scheme(Some("localhost:7233"), false, "http://localhost:7233/")]
178 #[case::without_scheme_tls(Some("localhost:7233"), true, "https://localhost:7233/")]
179 #[case::explicit_http_with_tls(Some("http://my-server:7233"), true, "http://my-server:7233/")]
180 fn address_parsing(
181 #[case] address: Option<&str>,
182 #[case] enable_tls: bool,
183 #[case] expected: &str,
184 ) {
185 let tls = enable_tls.then(ClientConfigTLS::default);
186 let profile = ClientConfigProfile {
187 address: address.map(str::to_string),
188 tls,
189 ..Default::default()
190 };
191 let conn: ConnectionOptions = profile.try_into().unwrap();
192 assert_eq!(conn.target.as_str(), expected);
193 }
194
195 #[test]
196 fn invalid_address_errors() {
197 let profile = ClientConfigProfile {
198 address: Some("://bad".to_string()),
199 ..Default::default()
200 };
201 assert!(ConnectionOptions::try_from(profile).is_err());
202 }
203
204 #[test]
205 fn empty_profile_defaults() {
206 let env = HashMap::new();
207 let opts = LoadClientConfigProfileOptions {
208 disable_file: true,
209 ..Default::default()
210 };
211 let (conn, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
212
213 assert_eq!(conn.target.as_str(), "http://localhost:7233/");
214 assert_eq!(client.namespace, "default");
215 assert!(conn.tls_options.is_none());
216 assert!(conn.headers.is_none());
217 assert!(conn.api_key.is_none());
218 }
219
220 #[test]
221 fn namespace_override() {
222 let mut env = HashMap::new();
223 env.insert("TEMPORAL_NAMESPACE".to_string(), "my-namespace".to_string());
224 let opts = LoadClientConfigProfileOptions {
225 disable_file: true,
226 ..Default::default()
227 };
228 let (_, client) = load_from_config_with_env(opts, Some(&env)).unwrap();
229 assert_eq!(client.namespace, "my-namespace");
230 }
231
232 #[test]
233 fn grpc_metadata_passthrough() {
234 let mut meta = HashMap::new();
235 meta.insert("x-custom".to_string(), "value".to_string());
236 meta.insert("another".to_string(), "header".to_string());
237 let profile = ClientConfigProfile {
238 grpc_meta: meta.clone(),
239 ..Default::default()
240 };
241 let conn: ConnectionOptions = profile.try_into().unwrap();
242 assert_eq!(conn.headers.unwrap(), meta);
243 }
244
245 #[test]
246 fn api_key_populates_field() {
247 let profile = ClientConfigProfile {
248 api_key: Some("my-key".to_string()),
249 ..Default::default()
250 };
251 let conn: ConnectionOptions = profile.try_into().unwrap();
252 assert_eq!(conn.api_key.as_deref(), Some("my-key"));
253 }
254
255 #[rstest]
256 #[case::no_tls_no_key(None, None, false)]
257 #[case::no_tls_with_key(None, Some("key"), true)]
258 #[case::tls_disabled_false(Some(Some(false)), None, true)]
259 #[case::tls_disabled_true(Some(Some(true)), None, false)]
260 #[case::tls_disabled_none(Some(None), None, true)]
261 #[case::key_with_tls_disabled(Some(Some(true)), Some("key"), false)]
262 #[case::key_with_tls_enabled(Some(Some(false)), Some("key"), true)]
263 fn tls_enablement(
264 #[case] tls_disabled: Option<Option<bool>>,
265 #[case] api_key: Option<&str>,
266 #[case] expect_tls: bool,
267 ) {
268 let profile = ClientConfigProfile {
269 api_key: api_key.map(str::to_string),
270 tls: tls_disabled.map(|disabled| ClientConfigTLS {
271 disabled,
272 ..Default::default()
273 }),
274 ..Default::default()
275 };
276 let conn: ConnectionOptions = profile.try_into().unwrap();
277 assert_eq!(conn.tls_options.is_some(), expect_tls);
278 }
279
280 #[test]
281 fn data_source_certs() {
282 let profile = ClientConfigProfile {
283 tls: Some(ClientConfigTLS {
284 client_cert: Some(DataSource::Data(b"cert-data".to_vec())),
285 client_key: Some(DataSource::Data(b"key-data".to_vec())),
286 ..Default::default()
287 }),
288 ..Default::default()
289 };
290 let conn: ConnectionOptions = profile.try_into().unwrap();
291 let tls = conn.tls_options.unwrap();
292 let mtls = tls.client_tls_options.unwrap();
293 assert_eq!(mtls.client_cert, b"cert-data");
294 assert_eq!(mtls.client_private_key, b"key-data");
295 }
296
297 #[rstest]
298 fn path_source_certs(config_dir: TempDir) {
299 let cert_path = config_dir.path().join("cert.pem");
300 let key_path = config_dir.path().join("key.pem");
301 std::fs::write(&cert_path, b"file-cert").unwrap();
302 std::fs::write(&key_path, b"file-key").unwrap();
303
304 let profile = ClientConfigProfile {
305 tls: Some(ClientConfigTLS {
306 client_cert: Some(DataSource::Path(cert_path.to_str().unwrap().to_string())),
307 client_key: Some(DataSource::Path(key_path.to_str().unwrap().to_string())),
308 ..Default::default()
309 }),
310 ..Default::default()
311 };
312 let conn: ConnectionOptions = profile.try_into().unwrap();
313 let tls = conn.tls_options.unwrap();
314 let mtls = tls.client_tls_options.unwrap();
315 assert_eq!(mtls.client_cert, b"file-cert");
316 assert_eq!(mtls.client_private_key, b"file-key");
317 }
318
319 #[test]
320 fn server_ca_cert() {
321 let profile = ClientConfigProfile {
322 tls: Some(ClientConfigTLS {
323 server_ca_cert: Some(DataSource::Data(b"ca-data".to_vec())),
324 ..Default::default()
325 }),
326 ..Default::default()
327 };
328 let conn: ConnectionOptions = profile.try_into().unwrap();
329 let tls = conn.tls_options.unwrap();
330 assert_eq!(tls.server_root_ca_cert.unwrap(), b"ca-data");
331 }
332
333 #[test]
334 fn server_name_sni() {
335 let profile = ClientConfigProfile {
336 tls: Some(ClientConfigTLS {
337 server_name: Some("my.server.com".to_string()),
338 ..Default::default()
339 }),
340 ..Default::default()
341 };
342 let conn: ConnectionOptions = profile.try_into().unwrap();
343 let tls = conn.tls_options.unwrap();
344 assert_eq!(tls.domain.as_deref(), Some("my.server.com"));
345 }
346
347 #[rstest]
348 #[case::cert_without_key(Some(DataSource::Data(b"cert".to_vec())), None)]
349 #[case::key_without_cert(None, Some(DataSource::Data(b"key".to_vec())))]
350 fn partial_tls_errors(
351 #[case] client_cert: Option<DataSource>,
352 #[case] client_key: Option<DataSource>,
353 ) {
354 let profile = ClientConfigProfile {
355 tls: Some(ClientConfigTLS {
356 client_cert,
357 client_key,
358 ..Default::default()
359 }),
360 ..Default::default()
361 };
362 assert!(ConnectionOptions::try_from(profile).is_err());
363 }
364
365 #[rstest]
366 fn load_from_config_from_toml(config_dir: TempDir) {
367 let config_path = write_config(
368 &config_dir,
369 r#"
370[profile.default]
371address = "toml-server:7233"
372namespace = "toml-ns"
373api_key = "toml-key"
374
375[profile.default.grpc_meta]
376x-custom = "value"
377
378[profile.custom]
379address = "custom-server:9090"
380namespace = "custom-ns"
381"#,
382 );
383
384 let opts = LoadClientConfigProfileOptions {
386 config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
387 disable_env: true,
388 ..Default::default()
389 };
390 let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
391 assert_eq!(conn.target.as_str(), "https://toml-server:7233/");
392 assert_eq!(client.namespace, "toml-ns");
393 assert_eq!(conn.api_key.as_deref(), Some("toml-key"));
394 assert!(conn.tls_options.is_some());
395 assert_eq!(
396 conn.headers.as_ref().unwrap().get("x-custom").unwrap(),
397 "value"
398 );
399
400 let opts = LoadClientConfigProfileOptions {
402 config_source: Some(DataSource::Path(config_path.to_str().unwrap().to_string())),
403 config_file_profile: Some("custom".to_string()),
404 disable_env: true,
405 ..Default::default()
406 };
407 let (conn, client) = ClientOptions::load_from_config(opts).unwrap();
408 assert_eq!(conn.target.as_str(), "http://custom-server:9090/");
409 assert_eq!(client.namespace, "custom-ns");
410 }
411}