1use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use anyhow::Context;
8use anyhow::Result;
9use anyhow::bail;
10use serde::Deserialize;
11use serde::Serialize;
12use tracing::warn;
13use url::Url;
14
15use crate::DockerBackend;
16use crate::LocalBackend;
17use crate::SYSTEM;
18use crate::TaskExecutionBackend;
19use crate::TesBackend;
20use crate::convert_unit_string;
21use crate::path::is_url;
22
23pub const MAX_RETRIES: u64 = 100;
25
26pub const DEFAULT_TASK_SHELL: &str = "bash";
28
29pub const DEFAULT_MAX_CONCURRENT_DOWNLOADS: u64 = 10;
31
32#[derive(Debug, Default, Clone, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case", deny_unknown_fields)]
35pub struct Config {
36 #[serde(default)]
38 pub http: HttpConfig,
39 #[serde(default)]
41 pub workflow: WorkflowConfig,
42 #[serde(default)]
44 pub task: TaskConfig,
45 #[serde(default)]
47 pub backend: BackendConfig,
48 #[serde(default)]
50 pub storage: StorageConfig,
51}
52
53impl Config {
54 pub fn validate(&self) -> Result<()> {
56 self.http.validate()?;
57 self.workflow.validate()?;
58 self.task.validate()?;
59 self.backend.validate()?;
60 self.storage.validate()?;
61 Ok(())
62 }
63
64 pub async fn create_backend(&self) -> Result<Arc<dyn TaskExecutionBackend>> {
66 match &self.backend {
67 BackendConfig::Local(config) => {
68 warn!(
69 "the engine is configured to use the local backend: tasks will not be run \
70 inside of a container"
71 );
72 Ok(Arc::new(LocalBackend::new(&self.task, config)?))
73 }
74 BackendConfig::Docker(config) => {
75 Ok(Arc::new(DockerBackend::new(&self.task, config).await?))
76 }
77 BackendConfig::Tes(config) => Ok(Arc::new(TesBackend::new(&self.task, config).await?)),
78 }
79 }
80}
81
82#[derive(Debug, Default, Clone, Serialize, Deserialize)]
84#[serde(rename_all = "snake_case", deny_unknown_fields)]
85pub struct HttpConfig {
86 #[serde(default)]
90 pub cache: Option<PathBuf>,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub max_concurrent_downloads: Option<u64>,
96}
97
98impl HttpConfig {
99 pub fn validate(&self) -> Result<()> {
101 if let Some(limit) = self.max_concurrent_downloads {
102 if limit == 0 {
103 bail!("configuration value `http.max_concurrent_downloads` cannot be zero");
104 }
105 }
106 Ok(())
107 }
108}
109
110#[derive(Debug, Default, Clone, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case", deny_unknown_fields)]
113pub struct StorageConfig {
114 #[serde(default)]
116 pub azure: AzureStorageConfig,
117 #[serde(default)]
119 pub s3: S3StorageConfig,
120 #[serde(default)]
122 pub google: GoogleStorageConfig,
123}
124
125impl StorageConfig {
126 pub fn validate(&self) -> Result<()> {
128 self.azure.validate()?;
129 self.s3.validate()?;
130 self.google.validate()?;
131 Ok(())
132 }
133}
134
135#[derive(Debug, Default, Clone, Serialize, Deserialize)]
137#[serde(rename_all = "snake_case", deny_unknown_fields)]
138pub struct AzureStorageConfig {
139 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
148 pub auth: HashMap<String, HashMap<String, String>>,
149}
150
151impl AzureStorageConfig {
152 pub fn validate(&self) -> Result<()> {
154 Ok(())
155 }
156}
157
158#[derive(Debug, Default, Clone, Serialize, Deserialize)]
160#[serde(rename_all = "snake_case", deny_unknown_fields)]
161pub struct S3StorageConfig {
162 #[serde(default, skip_serializing_if = "Option::is_none")]
167 pub region: Option<String>,
168
169 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
175 pub auth: HashMap<String, String>,
176}
177
178impl S3StorageConfig {
179 pub fn validate(&self) -> Result<()> {
181 Ok(())
182 }
183}
184
185#[derive(Debug, Default, Clone, Serialize, Deserialize)]
187#[serde(rename_all = "snake_case", deny_unknown_fields)]
188pub struct GoogleStorageConfig {
189 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
195 pub auth: HashMap<String, String>,
196}
197
198impl GoogleStorageConfig {
199 pub fn validate(&self) -> Result<()> {
201 Ok(())
202 }
203}
204
205#[derive(Debug, Default, Clone, Serialize, Deserialize)]
207#[serde(rename_all = "snake_case", deny_unknown_fields)]
208pub struct WorkflowConfig {
209 #[serde(default)]
211 pub scatter: ScatterConfig,
212}
213
214impl WorkflowConfig {
215 pub fn validate(&self) -> Result<()> {
217 self.scatter.validate()?;
218 Ok(())
219 }
220}
221
222#[derive(Debug, Default, Clone, Serialize, Deserialize)]
224#[serde(rename_all = "snake_case", deny_unknown_fields)]
225pub struct ScatterConfig {
226 #[serde(default, skip_serializing_if = "Option::is_none")]
279 pub concurrency: Option<u64>,
280}
281
282impl ScatterConfig {
283 pub fn validate(&self) -> Result<()> {
285 if let Some(concurrency) = self.concurrency {
286 if concurrency == 0 {
287 bail!("configuration value `workflow.scatter.concurrency` cannot be zero");
288 }
289 }
290
291 Ok(())
292 }
293}
294
295#[derive(Debug, Default, Clone, Serialize, Deserialize)]
297#[serde(rename_all = "snake_case", deny_unknown_fields)]
298pub struct TaskConfig {
299 #[serde(default, skip_serializing_if = "Option::is_none")]
305 pub retries: Option<u64>,
306 #[serde(default, skip_serializing_if = "Option::is_none")]
311 pub container: Option<String>,
312 #[serde(default, skip_serializing_if = "Option::is_none")]
320 pub shell: Option<String>,
321}
322
323impl TaskConfig {
324 pub fn validate(&self) -> Result<()> {
326 if self.retries.unwrap_or(0) > MAX_RETRIES {
327 bail!("configuration value `task.retries` cannot exceed {MAX_RETRIES}");
328 }
329
330 Ok(())
331 }
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
336#[serde(rename_all = "snake_case", tag = "type")]
337pub enum BackendConfig {
338 Local(LocalBackendConfig),
340 Docker(DockerBackendConfig),
342 Tes(Box<TesBackendConfig>),
344}
345
346impl Default for BackendConfig {
347 fn default() -> Self {
348 Self::Docker(Default::default())
349 }
350}
351
352impl BackendConfig {
353 pub fn validate(&self) -> Result<()> {
355 match self {
356 Self::Local(config) => config.validate(),
357 Self::Docker(config) => config.validate(),
358 Self::Tes(config) => config.validate(),
359 }
360 }
361}
362
363#[derive(Debug, Default, Clone, Serialize, Deserialize)]
370#[serde(rename_all = "snake_case", deny_unknown_fields)]
371pub struct LocalBackendConfig {
372 #[serde(default, skip_serializing_if = "Option::is_none")]
378 pub cpu: Option<u64>,
379
380 #[serde(default, skip_serializing_if = "Option::is_none")]
387 pub memory: Option<String>,
388}
389
390impl LocalBackendConfig {
391 pub fn validate(&self) -> Result<()> {
393 if let Some(cpu) = self.cpu {
394 if cpu == 0 {
395 bail!("local backend configuration value `cpu` cannot be zero");
396 }
397
398 let total = SYSTEM.cpus().len() as u64;
399 if cpu > total {
400 bail!(
401 "local backend configuration value `cpu` cannot exceed the virtual CPUs \
402 available to the host ({total})"
403 );
404 }
405 }
406
407 if let Some(memory) = &self.memory {
408 let memory = convert_unit_string(memory).with_context(|| {
409 format!("local backend configuration value `memory` has invalid value `{memory}`")
410 })?;
411
412 if memory == 0 {
413 bail!("local backend configuration value `memory` cannot be zero");
414 }
415
416 let total = SYSTEM.total_memory();
417 if memory > total {
418 bail!(
419 "local backend configuration value `memory` cannot exceed the total memory of \
420 the host ({total} bytes)"
421 );
422 }
423 }
424
425 Ok(())
426 }
427}
428
429const fn cleanup_default() -> bool {
431 true
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
436#[serde(rename_all = "snake_case", deny_unknown_fields)]
437pub struct DockerBackendConfig {
438 #[serde(default = "cleanup_default")]
442 pub cleanup: bool,
443}
444
445impl DockerBackendConfig {
446 pub fn validate(&self) -> Result<()> {
448 Ok(())
449 }
450}
451
452impl Default for DockerBackendConfig {
453 fn default() -> Self {
454 Self { cleanup: true }
455 }
456}
457
458#[derive(Debug, Default, Clone, Serialize, Deserialize)]
460#[serde(rename_all = "snake_case", deny_unknown_fields)]
461pub struct BasicAuthConfig {
462 pub username: Option<String>,
464 pub password: Option<String>,
466}
467
468impl BasicAuthConfig {
469 pub fn validate(&self) -> Result<()> {
471 if self.username.is_none() {
472 bail!("HTTP basic auth configuration value `username` is required");
473 }
474
475 if self.password.is_none() {
476 bail!("HTTP basic auth configuration value `password` is required");
477 }
478
479 Ok(())
480 }
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
485#[serde(rename_all = "snake_case", tag = "type")]
486pub enum TesBackendAuthConfig {
487 Basic(BasicAuthConfig),
489}
490
491impl TesBackendAuthConfig {
492 pub fn validate(&self) -> Result<()> {
494 match self {
495 Self::Basic(auth) => auth.validate(),
496 }
497 }
498}
499
500#[derive(Debug, Default, Clone, Serialize, Deserialize)]
502#[serde(rename_all = "snake_case", deny_unknown_fields)]
503pub struct TesBackendConfig {
504 #[serde(default)]
506 pub url: Option<Url>,
507
508 #[serde(default, skip_serializing_if = "Option::is_none")]
510 pub auth: Option<TesBackendAuthConfig>,
511
512 #[serde(default, skip_serializing_if = "Option::is_none")]
514 pub inputs: Option<Url>,
515
516 #[serde(default, skip_serializing_if = "Option::is_none")]
518 pub outputs: Option<Url>,
519
520 #[serde(default)]
524 pub interval: Option<u64>,
525
526 #[serde(default)]
530 pub max_concurrency: Option<u64>,
531
532 #[serde(default)]
535 pub insecure: bool,
536}
537
538impl TesBackendConfig {
539 pub fn validate(&self) -> Result<()> {
541 match &self.url {
542 Some(url) => {
543 if !self.insecure && url.scheme() != "https" {
544 bail!(
545 "TES backend configuration value `url` has invalid value `{url}`: URL \
546 must use a HTTPS scheme"
547 );
548 }
549 }
550 None => bail!("TES backend configuration value `url` is required"),
551 }
552
553 if let Some(auth) = &self.auth {
554 auth.validate()?;
555 }
556
557 match &self.inputs {
558 Some(url) => {
559 if !is_url(url.as_str()) {
560 bail!(
561 "TES backend storage configuration value `inputs` has invalid value \
562 `{url}`: URL scheme is not supported"
563 );
564 }
565
566 if !url.path().ends_with('/') {
567 bail!(
568 "TES backend storage configuration value `inputs` has invalid value \
569 `{url}`: URL path must end with a slash"
570 );
571 }
572 }
573 None => bail!("TES backend configuration value `inputs` is required"),
574 }
575
576 match &self.outputs {
577 Some(url) => {
578 if !is_url(url.as_str()) {
579 bail!(
580 "TES backend storage configuration value `outputs` has invalid value \
581 `{url}`: URL scheme is not supported"
582 );
583 }
584
585 if !url.path().ends_with('/') {
586 bail!(
587 "TES backend storage configuration value `outputs` has invalid value \
588 `{url}`: URL path must end with a slash"
589 );
590 }
591 }
592 None => bail!("TES backend storage configuration value `outputs` is required"),
593 }
594
595 Ok(())
596 }
597}
598
599#[cfg(test)]
600mod test {
601 use pretty_assertions::assert_eq;
602
603 use super::*;
604
605 #[test]
606 fn test_config_validate() {
607 let mut config = Config::default();
609 config.task.retries = Some(1000000);
610 assert_eq!(
611 config.validate().unwrap_err().to_string(),
612 "configuration value `task.retries` cannot exceed 100"
613 );
614
615 let mut config = Config::default();
617 config.workflow.scatter.concurrency = Some(0);
618 assert_eq!(
619 config.validate().unwrap_err().to_string(),
620 "configuration value `workflow.scatter.concurrency` cannot be zero"
621 );
622
623 let config = Config {
625 backend: BackendConfig::Local(LocalBackendConfig {
626 cpu: Some(0),
627 ..Default::default()
628 }),
629 ..Default::default()
630 };
631 assert_eq!(
632 config.validate().unwrap_err().to_string(),
633 "local backend configuration value `cpu` cannot be zero"
634 );
635 let config = Config {
636 backend: BackendConfig::Local(LocalBackendConfig {
637 cpu: Some(10000000),
638 ..Default::default()
639 }),
640 ..Default::default()
641 };
642 assert!(config.validate().unwrap_err().to_string().starts_with(
643 "local backend configuration value `cpu` cannot exceed the virtual CPUs available to \
644 the host"
645 ));
646
647 let config = Config {
649 backend: BackendConfig::Local(LocalBackendConfig {
650 memory: Some("0 GiB".to_string()),
651 ..Default::default()
652 }),
653 ..Default::default()
654 };
655 assert_eq!(
656 config.validate().unwrap_err().to_string(),
657 "local backend configuration value `memory` cannot be zero"
658 );
659 let config = Config {
660 backend: BackendConfig::Local(LocalBackendConfig {
661 memory: Some("100 meows".to_string()),
662 ..Default::default()
663 }),
664 ..Default::default()
665 };
666 assert_eq!(
667 config.validate().unwrap_err().to_string(),
668 "local backend configuration value `memory` has invalid value `100 meows`"
669 );
670
671 let config = Config {
672 backend: BackendConfig::Local(LocalBackendConfig {
673 memory: Some("1000 TiB".to_string()),
674 ..Default::default()
675 }),
676 ..Default::default()
677 };
678 assert!(config.validate().unwrap_err().to_string().starts_with(
679 "local backend configuration value `memory` cannot exceed the total memory of the host"
680 ));
681
682 let config = Config {
684 backend: BackendConfig::Tes(Default::default()),
685 ..Default::default()
686 };
687 assert_eq!(
688 config.validate().unwrap_err().to_string(),
689 "TES backend configuration value `url` is required"
690 );
691
692 let config = Config {
694 backend: BackendConfig::Tes(
695 TesBackendConfig {
696 url: Some("http://example.com".parse().unwrap()),
697 inputs: Some("http://example.com".parse().unwrap()),
698 outputs: Some("http://example.com".parse().unwrap()),
699 ..Default::default()
700 }
701 .into(),
702 ),
703 ..Default::default()
704 };
705 assert_eq!(
706 config.validate().unwrap_err().to_string(),
707 "TES backend configuration value `url` has invalid value `http://example.com/`: URL \
708 must use a HTTPS scheme"
709 );
710
711 let config = Config {
713 backend: BackendConfig::Tes(
714 TesBackendConfig {
715 url: Some("http://example.com".parse().unwrap()),
716 inputs: Some("http://example.com".parse().unwrap()),
717 outputs: Some("http://example.com".parse().unwrap()),
718 insecure: true,
719 ..Default::default()
720 }
721 .into(),
722 ),
723 ..Default::default()
724 };
725 config.validate().expect("configuration should validate");
726
727 let config = Config {
729 backend: BackendConfig::Tes(Box::new(TesBackendConfig {
730 url: Some(Url::parse("https://example.com").unwrap()),
731 auth: Some(TesBackendAuthConfig::Basic(Default::default())),
732 ..Default::default()
733 })),
734 ..Default::default()
735 };
736 assert_eq!(
737 config.validate().unwrap_err().to_string(),
738 "HTTP basic auth configuration value `username` is required"
739 );
740 let config = Config {
741 backend: BackendConfig::Tes(Box::new(TesBackendConfig {
742 url: Some(Url::parse("https://example.com").unwrap()),
743 auth: Some(TesBackendAuthConfig::Basic(BasicAuthConfig {
744 username: Some("Foo".into()),
745 ..Default::default()
746 })),
747 ..Default::default()
748 })),
749 ..Default::default()
750 };
751 assert_eq!(
752 config.validate().unwrap_err().to_string(),
753 "HTTP basic auth configuration value `password` is required"
754 );
755
756 let mut config = Config::default();
757 config.http.max_concurrent_downloads = Some(0);
758 assert_eq!(
759 config.validate().unwrap_err().to_string(),
760 "configuration value `http.max_concurrent_downloads` cannot be zero"
761 );
762
763 let mut config = Config::default();
764 config.http.max_concurrent_downloads = Some(5);
765 assert!(
766 config.validate().is_ok(),
767 "should pass for valid configuration"
768 );
769
770 let mut config = Config::default();
771 config.http.max_concurrent_downloads = None;
772 assert!(config.validate().is_ok(), "should pass for default (None)");
773 }
774}