wdl_engine/
config.rs

1//! Implementation of engine configuration.
2
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use anyhow::Context;
8use anyhow::Result;
9use anyhow::bail;
10use serde::Deserialize;
11use serde::Serialize;
12use tracing::warn;
13
14use crate::DockerBackend;
15use crate::LocalBackend;
16use crate::SYSTEM;
17use crate::TaskExecutionBackend;
18use crate::convert_unit_string;
19
20/// The inclusive maximum number of task retries the engine supports.
21pub const MAX_RETRIES: u64 = 100;
22
23/// The default task shell.
24pub const DEFAULT_TASK_SHELL: &str = "bash";
25
26/// The default maximum number of concurrent HTTP downloads.
27pub const DEFAULT_MAX_CONCURRENT_DOWNLOADS: u64 = 10;
28
29/// Represents WDL evaluation configuration.
30#[derive(Debug, Default, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case", deny_unknown_fields)]
32pub struct Config {
33    /// HTTP configuration.
34    #[serde(default)]
35    pub http: HttpConfig,
36    /// Workflow evaluation configuration.
37    #[serde(default)]
38    pub workflow: WorkflowConfig,
39    /// Task evaluation configuration.
40    #[serde(default)]
41    pub task: TaskConfig,
42    /// Task execution backend configuration.
43    #[serde(default)]
44    pub backend: BackendConfig,
45    /// Storage configuration.
46    #[serde(default)]
47    pub storage: StorageConfig,
48}
49
50impl Config {
51    /// Validates the evaluation configuration.
52    pub fn validate(&self) -> Result<()> {
53        self.http.validate()?;
54        self.workflow.validate()?;
55        self.task.validate()?;
56        self.backend.validate()?;
57        self.storage.validate()?;
58        Ok(())
59    }
60
61    /// Creates a new task execution backend based on this configuration.
62    pub async fn create_backend(&self) -> Result<Arc<dyn TaskExecutionBackend>> {
63        match &self.backend {
64            BackendConfig::Local(config) => {
65                warn!(
66                    "the engine is configured to use the local backend: tasks will not be run \
67                     inside of a container"
68                );
69                Ok(Arc::new(LocalBackend::new(&self.task, config)?))
70            }
71            BackendConfig::Docker(config) => {
72                Ok(Arc::new(DockerBackend::new(&self.task, config).await?))
73            }
74        }
75    }
76}
77
78/// Represents HTTP configuration.
79#[derive(Debug, Default, Clone, Serialize, Deserialize)]
80#[serde(rename_all = "snake_case", deny_unknown_fields)]
81pub struct HttpConfig {
82    /// The HTTP download cache location.
83    ///
84    /// Defaults to using the system cache directory.
85    #[serde(default)]
86    pub cache: Option<PathBuf>,
87    /// The maximum number of concurrent downloads allowed.
88    ///
89    /// Defaults to 10.
90    #[serde(default, skip_serializing_if = "Option::is_none")]
91    pub max_concurrent_downloads: Option<u64>,
92}
93
94impl HttpConfig {
95    /// Validates the HTTP configuration.
96    pub fn validate(&self) -> Result<()> {
97        if let Some(limit) = self.max_concurrent_downloads {
98            if limit == 0 {
99                bail!("configuration value `http.max_concurrent_downloads` cannot be zero");
100            }
101        }
102        Ok(())
103    }
104}
105
106/// Represents storage configuration.
107#[derive(Debug, Default, Clone, Serialize, Deserialize)]
108#[serde(rename_all = "snake_case", deny_unknown_fields)]
109pub struct StorageConfig {
110    /// Azure Blob Storage configuration.
111    #[serde(default)]
112    pub azure: AzureStorageConfig,
113    /// AWS S3 configuration.
114    #[serde(default)]
115    pub s3: S3StorageConfig,
116    /// Google Cloud Storage configuration.
117    #[serde(default)]
118    pub google: GoogleStorageConfig,
119}
120
121impl StorageConfig {
122    /// Validates the HTTP configuration.
123    pub fn validate(&self) -> Result<()> {
124        self.azure.validate()?;
125        self.s3.validate()?;
126        self.google.validate()?;
127        Ok(())
128    }
129}
130
131/// Represents configuration for Azure Blob Storage.
132#[derive(Debug, Default, Clone, Serialize, Deserialize)]
133#[serde(rename_all = "snake_case", deny_unknown_fields)]
134pub struct AzureStorageConfig {
135    /// The Azure Blob Storage authentication configuration.
136    ///
137    /// The key for the outer map is the storage account name.
138    ///
139    /// The key for the inner map is the container name.
140    ///
141    /// The value for the inner map is the SAS token query string to apply to
142    /// matching Azure Blob Storage URLs.
143    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
144    pub auth: HashMap<String, HashMap<String, String>>,
145}
146
147impl AzureStorageConfig {
148    /// Validates the Azure Blob Storage configuration.
149    pub fn validate(&self) -> Result<()> {
150        Ok(())
151    }
152}
153
154/// Represents configuration for AWS S3 storage.
155#[derive(Debug, Default, Clone, Serialize, Deserialize)]
156#[serde(rename_all = "snake_case", deny_unknown_fields)]
157pub struct S3StorageConfig {
158    /// The default region to use for S3-schemed URLs (e.g.
159    /// `s3://<bucket>/<blob>`).
160    ///
161    /// Defaults to `us-east-1`.
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub region: Option<String>,
164
165    /// The AWS S3 storage authentication configuration.
166    ///
167    /// The key for the map is the bucket name.
168    ///
169    /// The value for the map is the presigned query string.
170    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
171    pub auth: HashMap<String, String>,
172}
173
174impl S3StorageConfig {
175    /// Validates the AWS S3 storage configuration.
176    pub fn validate(&self) -> Result<()> {
177        Ok(())
178    }
179}
180
181/// Represents configuration for Google Cloud Storage.
182#[derive(Debug, Default, Clone, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case", deny_unknown_fields)]
184pub struct GoogleStorageConfig {
185    /// The Google Cloud Storage authentication configuration.
186    ///
187    /// The key for the map is the bucket name.
188    ///
189    /// The value for the map is the presigned query string.
190    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
191    pub auth: HashMap<String, String>,
192}
193
194impl GoogleStorageConfig {
195    /// Validates the Google Cloud Storage configuration.
196    pub fn validate(&self) -> Result<()> {
197        Ok(())
198    }
199}
200
201/// Represents workflow evaluation configuration.
202#[derive(Debug, Default, Clone, Serialize, Deserialize)]
203#[serde(rename_all = "snake_case", deny_unknown_fields)]
204pub struct WorkflowConfig {
205    /// Scatter statement evaluation configuration.
206    #[serde(default)]
207    pub scatter: ScatterConfig,
208}
209
210impl WorkflowConfig {
211    /// Validates the workflow configuration.
212    pub fn validate(&self) -> Result<()> {
213        self.scatter.validate()?;
214        Ok(())
215    }
216}
217
218/// Represents scatter statement evaluation configuration.
219#[derive(Debug, Default, Clone, Serialize, Deserialize)]
220#[serde(rename_all = "snake_case", deny_unknown_fields)]
221pub struct ScatterConfig {
222    /// The number of scatter array elements to process concurrently.
223    ///
224    /// By default, the value is the parallelism supported by the task
225    /// execution backend.
226    ///
227    /// A value of `0` is invalid.
228    ///
229    /// Lower values use less memory for evaluation and higher values may better
230    /// saturate the task execution backend with tasks to execute.
231    ///
232    /// This setting does not change how many tasks an execution backend can run
233    /// concurrently, but may affect how many tasks are sent to the backend to
234    /// run at a time.
235    ///
236    /// For example, if `concurrency` was set to 10 and we evaluate the
237    /// following scatters:
238    ///
239    /// ```wdl
240    /// scatter (i in range(100)) {
241    ///     call my_task
242    /// }
243    ///
244    /// scatter (j in range(100)) {
245    ///     call my_task as my_task2
246    /// }
247    /// ```
248    ///
249    /// Here each scatter is independent and therefore there will be 20 calls
250    /// (10 for each scatter) made concurrently. If the task execution
251    /// backend can only execute 5 tasks concurrently, 5 tasks will execute
252    /// and 15 will be "ready" to execute and waiting for an executing task
253    /// to complete.
254    ///
255    /// If instead we evaluate the following scatters:
256    ///
257    /// ```wdl
258    /// scatter (i in range(100)) {
259    ///     scatter (j in range(100)) {
260    ///         call my_task
261    ///     }
262    /// }
263    /// ```
264    ///
265    /// Then there will be 100 calls (10*10 as 10 are made for each outer
266    /// element) made concurrently. If the task execution backend can only
267    /// execute 5 tasks concurrently, 5 tasks will execute and 95 will be
268    /// "ready" to execute and waiting for an executing task to complete.
269    ///
270    /// <div class="warning">
271    /// Warning: nested scatter statements cause exponential memory usage based
272    /// on this value, as each scatter statement evaluation requires allocating
273    /// new scopes for scatter array elements being processed. </div>
274    #[serde(default, skip_serializing_if = "Option::is_none")]
275    pub concurrency: Option<u64>,
276}
277
278impl ScatterConfig {
279    /// Validates the scatter configuration.
280    pub fn validate(&self) -> Result<()> {
281        if let Some(concurrency) = self.concurrency {
282            if concurrency == 0 {
283                bail!("configuration value `workflow.scatter.concurrency` cannot be zero");
284            }
285        }
286
287        Ok(())
288    }
289}
290
291/// Represents task evaluation configuration.
292#[derive(Debug, Default, Clone, Serialize, Deserialize)]
293#[serde(rename_all = "snake_case", deny_unknown_fields)]
294pub struct TaskConfig {
295    /// The default maximum number of retries to attempt if a task fails.
296    ///
297    /// A task's `max_retries` requirement will override this value.
298    ///
299    /// Defaults to 0 (no retries).
300    #[serde(default, skip_serializing_if = "Option::is_none")]
301    pub retries: Option<u64>,
302    /// The default container to use if a container is not specified in a task's
303    /// requirements.
304    ///
305    /// Defaults to `ubuntu:latest`.
306    #[serde(default, skip_serializing_if = "Option::is_none")]
307    pub container: Option<String>,
308    /// The default shell to use for tasks.
309    ///
310    /// Defaults to `bash`.
311    ///
312    /// <div class="warning">
313    /// Warning: the use of a shell other than `bash` may lead to tasks that may
314    /// not be portable to other execution engines.</div>
315    #[serde(default, skip_serializing_if = "Option::is_none")]
316    pub shell: Option<String>,
317}
318
319impl TaskConfig {
320    /// Validates the task evaluation configuration.
321    pub fn validate(&self) -> Result<()> {
322        if self.retries.unwrap_or(0) > MAX_RETRIES {
323            bail!("configuration value `task.retries` cannot exceed {MAX_RETRIES}");
324        }
325
326        Ok(())
327    }
328}
329
330/// Represents supported task execution backends.
331#[derive(Debug, Clone, Serialize, Deserialize)]
332#[serde(rename_all = "snake_case", tag = "type")]
333pub enum BackendConfig {
334    /// Use the local task execution backend.
335    Local(LocalBackendConfig),
336    /// Use the Docker task execution backend.
337    Docker(DockerBackendConfig),
338}
339
340impl Default for BackendConfig {
341    fn default() -> Self {
342        Self::Docker(Default::default())
343    }
344}
345
346impl BackendConfig {
347    /// Validates the backend configuration.
348    pub fn validate(&self) -> Result<()> {
349        match self {
350            Self::Local(config) => config.validate(),
351            Self::Docker(config) => config.validate(),
352        }
353    }
354}
355
356/// Represents configuration for the local task execution backend.
357///
358/// <div class="warning">
359/// Warning: the local task execution backend spawns processes on the host
360/// directly without the use of a container; only use this backend on trusted
361/// WDL. </div>
362#[derive(Debug, Default, Clone, Serialize, Deserialize)]
363#[serde(rename_all = "snake_case", deny_unknown_fields)]
364pub struct LocalBackendConfig {
365    /// Set the number of CPUs available for task execution.
366    ///
367    /// Defaults to the number of logical CPUs for the host.
368    ///
369    /// The value cannot be zero or exceed the host's number of CPUs.
370    #[serde(default, skip_serializing_if = "Option::is_none")]
371    pub cpu: Option<u64>,
372
373    /// Set the total amount of memory for task execution as a unit string (e.g.
374    /// `2 GiB`).
375    ///
376    /// Defaults to the total amount of memory for the host.
377    ///
378    /// The value cannot be zero or exceed the host's total amount of memory.
379    #[serde(default, skip_serializing_if = "Option::is_none")]
380    pub memory: Option<String>,
381}
382
383impl LocalBackendConfig {
384    /// Validates the local task execution backend configuration.
385    pub fn validate(&self) -> Result<()> {
386        if let Some(cpu) = self.cpu {
387            if cpu == 0 {
388                bail!("local backend configuration value `cpu` cannot be zero");
389            }
390
391            let total = SYSTEM.cpus().len() as u64;
392            if cpu > total {
393                bail!(
394                    "local backend configuration value `cpu` cannot exceed the virtual CPUs \
395                     available to the host ({total})"
396                );
397            }
398        }
399
400        if let Some(memory) = &self.memory {
401            let memory = convert_unit_string(memory).with_context(|| {
402                format!("local backend configuration value `memory` has invalid value `{memory}`")
403            })?;
404
405            if memory == 0 {
406                bail!("local backend configuration value `memory` cannot be zero");
407            }
408
409            let total = SYSTEM.total_memory();
410            if memory > total {
411                bail!(
412                    "local backend configuration value `memory` cannot exceed the total memory of \
413                     the host ({total} bytes)"
414                );
415            }
416        }
417
418        Ok(())
419    }
420}
421
422/// Gets the default value for the docker `cleanup` field.
423const fn cleanup_default() -> bool {
424    true
425}
426
427/// Represents configuration for the Docker backend.
428#[derive(Debug, Clone, Serialize, Deserialize)]
429#[serde(rename_all = "snake_case", deny_unknown_fields)]
430pub struct DockerBackendConfig {
431    /// Whether or not to remove a task's container after the task completes.
432    ///
433    /// Defaults to `true`.
434    #[serde(default = "cleanup_default")]
435    pub cleanup: bool,
436}
437
438impl DockerBackendConfig {
439    /// Validates the Docker backend configuration.
440    pub fn validate(&self) -> Result<()> {
441        Ok(())
442    }
443}
444
445impl Default for DockerBackendConfig {
446    fn default() -> Self {
447        Self { cleanup: true }
448    }
449}
450
451#[cfg(test)]
452mod test {
453    use pretty_assertions::assert_eq;
454
455    use super::*;
456
457    #[test]
458    fn test_config_validate() {
459        // Test invalid task config
460        let mut config = Config::default();
461        config.task.retries = Some(1000000);
462        assert_eq!(
463            config.validate().unwrap_err().to_string(),
464            "configuration value `task.retries` cannot exceed 100"
465        );
466
467        // Test invalid scatter concurrency config
468        let mut config = Config::default();
469        config.workflow.scatter.concurrency = Some(0);
470        assert_eq!(
471            config.validate().unwrap_err().to_string(),
472            "configuration value `workflow.scatter.concurrency` cannot be zero"
473        );
474
475        // Test invalid local backend cpu config
476        let config = Config {
477            backend: BackendConfig::Local(LocalBackendConfig {
478                cpu: Some(0),
479                ..Default::default()
480            }),
481            ..Default::default()
482        };
483        assert_eq!(
484            config.validate().unwrap_err().to_string(),
485            "local backend configuration value `cpu` cannot be zero"
486        );
487        let config = Config {
488            backend: BackendConfig::Local(LocalBackendConfig {
489                cpu: Some(10000000),
490                ..Default::default()
491            }),
492            ..Default::default()
493        };
494        assert!(config.validate().unwrap_err().to_string().starts_with(
495            "local backend configuration value `cpu` cannot exceed the virtual CPUs available to \
496             the host"
497        ));
498
499        // Test invalid local backend memory config
500        let config = Config {
501            backend: BackendConfig::Local(LocalBackendConfig {
502                memory: Some("0 GiB".to_string()),
503                ..Default::default()
504            }),
505            ..Default::default()
506        };
507        assert_eq!(
508            config.validate().unwrap_err().to_string(),
509            "local backend configuration value `memory` cannot be zero"
510        );
511        let config = Config {
512            backend: BackendConfig::Local(LocalBackendConfig {
513                memory: Some("100 meows".to_string()),
514                ..Default::default()
515            }),
516            ..Default::default()
517        };
518        assert_eq!(
519            config.validate().unwrap_err().to_string(),
520            "local backend configuration value `memory` has invalid value `100 meows`"
521        );
522
523        let config = Config {
524            backend: BackendConfig::Local(LocalBackendConfig {
525                memory: Some("1000 TiB".to_string()),
526                ..Default::default()
527            }),
528            ..Default::default()
529        };
530        assert!(config.validate().unwrap_err().to_string().starts_with(
531            "local backend configuration value `memory` cannot exceed the total memory of the host"
532        ));
533
534        let mut config = Config::default();
535        config.http.max_concurrent_downloads = Some(0);
536        assert_eq!(
537            config.validate().unwrap_err().to_string(),
538            "configuration value `http.max_concurrent_downloads` cannot be zero"
539        );
540
541        let mut config = Config::default();
542        config.http.max_concurrent_downloads = Some(5);
543        assert!(
544            config.validate().is_ok(),
545            "should pass for valid configuration"
546        );
547
548        let mut config = Config::default();
549        config.http.max_concurrent_downloads = None;
550        assert!(config.validate().is_ok(), "should pass for default (None)");
551    }
552}