1use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use anyhow::Context;
8use anyhow::Result;
9use anyhow::bail;
10use serde::Deserialize;
11use serde::Serialize;
12use tracing::warn;
13
14use crate::DockerBackend;
15use crate::LocalBackend;
16use crate::SYSTEM;
17use crate::TaskExecutionBackend;
18use crate::convert_unit_string;
19
20pub const MAX_RETRIES: u64 = 100;
22
23pub const DEFAULT_TASK_SHELL: &str = "bash";
25
26pub const DEFAULT_MAX_CONCURRENT_DOWNLOADS: u64 = 10;
28
29#[derive(Debug, Default, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case", deny_unknown_fields)]
32pub struct Config {
33 #[serde(default)]
35 pub http: HttpConfig,
36 #[serde(default)]
38 pub workflow: WorkflowConfig,
39 #[serde(default)]
41 pub task: TaskConfig,
42 #[serde(default)]
44 pub backend: BackendConfig,
45 #[serde(default)]
47 pub storage: StorageConfig,
48}
49
50impl Config {
51 pub fn validate(&self) -> Result<()> {
53 self.http.validate()?;
54 self.workflow.validate()?;
55 self.task.validate()?;
56 self.backend.validate()?;
57 self.storage.validate()?;
58 Ok(())
59 }
60
61 pub async fn create_backend(&self) -> Result<Arc<dyn TaskExecutionBackend>> {
63 match &self.backend {
64 BackendConfig::Local(config) => {
65 warn!(
66 "the engine is configured to use the local backend: tasks will not be run \
67 inside of a container"
68 );
69 Ok(Arc::new(LocalBackend::new(&self.task, config)?))
70 }
71 BackendConfig::Docker(config) => {
72 Ok(Arc::new(DockerBackend::new(&self.task, config).await?))
73 }
74 }
75 }
76}
77
78#[derive(Debug, Default, Clone, Serialize, Deserialize)]
80#[serde(rename_all = "snake_case", deny_unknown_fields)]
81pub struct HttpConfig {
82 #[serde(default)]
86 pub cache: Option<PathBuf>,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
91 pub max_concurrent_downloads: Option<u64>,
92}
93
94impl HttpConfig {
95 pub fn validate(&self) -> Result<()> {
97 if let Some(limit) = self.max_concurrent_downloads {
98 if limit == 0 {
99 bail!("configuration value `http.max_concurrent_downloads` cannot be zero");
100 }
101 }
102 Ok(())
103 }
104}
105
106#[derive(Debug, Default, Clone, Serialize, Deserialize)]
108#[serde(rename_all = "snake_case", deny_unknown_fields)]
109pub struct StorageConfig {
110 #[serde(default)]
112 pub azure: AzureStorageConfig,
113 #[serde(default)]
115 pub s3: S3StorageConfig,
116 #[serde(default)]
118 pub google: GoogleStorageConfig,
119}
120
121impl StorageConfig {
122 pub fn validate(&self) -> Result<()> {
124 self.azure.validate()?;
125 self.s3.validate()?;
126 self.google.validate()?;
127 Ok(())
128 }
129}
130
131#[derive(Debug, Default, Clone, Serialize, Deserialize)]
133#[serde(rename_all = "snake_case", deny_unknown_fields)]
134pub struct AzureStorageConfig {
135 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
144 pub auth: HashMap<String, HashMap<String, String>>,
145}
146
147impl AzureStorageConfig {
148 pub fn validate(&self) -> Result<()> {
150 Ok(())
151 }
152}
153
154#[derive(Debug, Default, Clone, Serialize, Deserialize)]
156#[serde(rename_all = "snake_case", deny_unknown_fields)]
157pub struct S3StorageConfig {
158 #[serde(default, skip_serializing_if = "Option::is_none")]
163 pub region: Option<String>,
164
165 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
171 pub auth: HashMap<String, String>,
172}
173
174impl S3StorageConfig {
175 pub fn validate(&self) -> Result<()> {
177 Ok(())
178 }
179}
180
181#[derive(Debug, Default, Clone, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case", deny_unknown_fields)]
184pub struct GoogleStorageConfig {
185 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
191 pub auth: HashMap<String, String>,
192}
193
194impl GoogleStorageConfig {
195 pub fn validate(&self) -> Result<()> {
197 Ok(())
198 }
199}
200
201#[derive(Debug, Default, Clone, Serialize, Deserialize)]
203#[serde(rename_all = "snake_case", deny_unknown_fields)]
204pub struct WorkflowConfig {
205 #[serde(default)]
207 pub scatter: ScatterConfig,
208}
209
210impl WorkflowConfig {
211 pub fn validate(&self) -> Result<()> {
213 self.scatter.validate()?;
214 Ok(())
215 }
216}
217
218#[derive(Debug, Default, Clone, Serialize, Deserialize)]
220#[serde(rename_all = "snake_case", deny_unknown_fields)]
221pub struct ScatterConfig {
222 #[serde(default, skip_serializing_if = "Option::is_none")]
275 pub concurrency: Option<u64>,
276}
277
278impl ScatterConfig {
279 pub fn validate(&self) -> Result<()> {
281 if let Some(concurrency) = self.concurrency {
282 if concurrency == 0 {
283 bail!("configuration value `workflow.scatter.concurrency` cannot be zero");
284 }
285 }
286
287 Ok(())
288 }
289}
290
291#[derive(Debug, Default, Clone, Serialize, Deserialize)]
293#[serde(rename_all = "snake_case", deny_unknown_fields)]
294pub struct TaskConfig {
295 #[serde(default, skip_serializing_if = "Option::is_none")]
301 pub retries: Option<u64>,
302 #[serde(default, skip_serializing_if = "Option::is_none")]
307 pub container: Option<String>,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
316 pub shell: Option<String>,
317}
318
319impl TaskConfig {
320 pub fn validate(&self) -> Result<()> {
322 if self.retries.unwrap_or(0) > MAX_RETRIES {
323 bail!("configuration value `task.retries` cannot exceed {MAX_RETRIES}");
324 }
325
326 Ok(())
327 }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
332#[serde(rename_all = "snake_case", tag = "type")]
333pub enum BackendConfig {
334 Local(LocalBackendConfig),
336 Docker(DockerBackendConfig),
338}
339
340impl Default for BackendConfig {
341 fn default() -> Self {
342 Self::Docker(Default::default())
343 }
344}
345
346impl BackendConfig {
347 pub fn validate(&self) -> Result<()> {
349 match self {
350 Self::Local(config) => config.validate(),
351 Self::Docker(config) => config.validate(),
352 }
353 }
354}
355
356#[derive(Debug, Default, Clone, Serialize, Deserialize)]
363#[serde(rename_all = "snake_case", deny_unknown_fields)]
364pub struct LocalBackendConfig {
365 #[serde(default, skip_serializing_if = "Option::is_none")]
371 pub cpu: Option<u64>,
372
373 #[serde(default, skip_serializing_if = "Option::is_none")]
380 pub memory: Option<String>,
381}
382
383impl LocalBackendConfig {
384 pub fn validate(&self) -> Result<()> {
386 if let Some(cpu) = self.cpu {
387 if cpu == 0 {
388 bail!("local backend configuration value `cpu` cannot be zero");
389 }
390
391 let total = SYSTEM.cpus().len() as u64;
392 if cpu > total {
393 bail!(
394 "local backend configuration value `cpu` cannot exceed the virtual CPUs \
395 available to the host ({total})"
396 );
397 }
398 }
399
400 if let Some(memory) = &self.memory {
401 let memory = convert_unit_string(memory).with_context(|| {
402 format!("local backend configuration value `memory` has invalid value `{memory}`")
403 })?;
404
405 if memory == 0 {
406 bail!("local backend configuration value `memory` cannot be zero");
407 }
408
409 let total = SYSTEM.total_memory();
410 if memory > total {
411 bail!(
412 "local backend configuration value `memory` cannot exceed the total memory of \
413 the host ({total} bytes)"
414 );
415 }
416 }
417
418 Ok(())
419 }
420}
421
422const fn cleanup_default() -> bool {
424 true
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
429#[serde(rename_all = "snake_case", deny_unknown_fields)]
430pub struct DockerBackendConfig {
431 #[serde(default = "cleanup_default")]
435 pub cleanup: bool,
436}
437
438impl DockerBackendConfig {
439 pub fn validate(&self) -> Result<()> {
441 Ok(())
442 }
443}
444
445impl Default for DockerBackendConfig {
446 fn default() -> Self {
447 Self { cleanup: true }
448 }
449}
450
451#[cfg(test)]
452mod test {
453 use pretty_assertions::assert_eq;
454
455 use super::*;
456
457 #[test]
458 fn test_config_validate() {
459 let mut config = Config::default();
461 config.task.retries = Some(1000000);
462 assert_eq!(
463 config.validate().unwrap_err().to_string(),
464 "configuration value `task.retries` cannot exceed 100"
465 );
466
467 let mut config = Config::default();
469 config.workflow.scatter.concurrency = Some(0);
470 assert_eq!(
471 config.validate().unwrap_err().to_string(),
472 "configuration value `workflow.scatter.concurrency` cannot be zero"
473 );
474
475 let config = Config {
477 backend: BackendConfig::Local(LocalBackendConfig {
478 cpu: Some(0),
479 ..Default::default()
480 }),
481 ..Default::default()
482 };
483 assert_eq!(
484 config.validate().unwrap_err().to_string(),
485 "local backend configuration value `cpu` cannot be zero"
486 );
487 let config = Config {
488 backend: BackendConfig::Local(LocalBackendConfig {
489 cpu: Some(10000000),
490 ..Default::default()
491 }),
492 ..Default::default()
493 };
494 assert!(config.validate().unwrap_err().to_string().starts_with(
495 "local backend configuration value `cpu` cannot exceed the virtual CPUs available to \
496 the host"
497 ));
498
499 let config = Config {
501 backend: BackendConfig::Local(LocalBackendConfig {
502 memory: Some("0 GiB".to_string()),
503 ..Default::default()
504 }),
505 ..Default::default()
506 };
507 assert_eq!(
508 config.validate().unwrap_err().to_string(),
509 "local backend configuration value `memory` cannot be zero"
510 );
511 let config = Config {
512 backend: BackendConfig::Local(LocalBackendConfig {
513 memory: Some("100 meows".to_string()),
514 ..Default::default()
515 }),
516 ..Default::default()
517 };
518 assert_eq!(
519 config.validate().unwrap_err().to_string(),
520 "local backend configuration value `memory` has invalid value `100 meows`"
521 );
522
523 let config = Config {
524 backend: BackendConfig::Local(LocalBackendConfig {
525 memory: Some("1000 TiB".to_string()),
526 ..Default::default()
527 }),
528 ..Default::default()
529 };
530 assert!(config.validate().unwrap_err().to_string().starts_with(
531 "local backend configuration value `memory` cannot exceed the total memory of the host"
532 ));
533
534 let mut config = Config::default();
535 config.http.max_concurrent_downloads = Some(0);
536 assert_eq!(
537 config.validate().unwrap_err().to_string(),
538 "configuration value `http.max_concurrent_downloads` cannot be zero"
539 );
540
541 let mut config = Config::default();
542 config.http.max_concurrent_downloads = Some(5);
543 assert!(
544 config.validate().is_ok(),
545 "should pass for valid configuration"
546 );
547
548 let mut config = Config::default();
549 config.http.max_concurrent_downloads = None;
550 assert!(config.validate().is_ok(), "should pass for default (None)");
551 }
552}