wdl_engine/backend/
tes.rs

1//! Implementation of the TES backend.
2
3use std::borrow::Cow;
4use std::collections::HashMap;
5use std::fs;
6use std::sync::Arc;
7use std::sync::Mutex;
8
9use anyhow::Context;
10use anyhow::Result;
11use anyhow::bail;
12use cloud_copy::UrlExt;
13use crankshaft::config::backend;
14use crankshaft::config::backend::tes::http::HttpAuthConfig;
15use crankshaft::engine::Task;
16use crankshaft::engine::service::name::GeneratorIterator;
17use crankshaft::engine::service::name::UniqueAlphanumeric;
18use crankshaft::engine::service::runner::Backend;
19use crankshaft::engine::service::runner::backend::TaskRunError;
20use crankshaft::engine::service::runner::backend::tes;
21use crankshaft::engine::task::Execution;
22use crankshaft::engine::task::Input;
23use crankshaft::engine::task::Output;
24use crankshaft::engine::task::Resources;
25use crankshaft::engine::task::input::Contents;
26use crankshaft::engine::task::input::Type as InputType;
27use crankshaft::engine::task::output::Type as OutputType;
28use crankshaft::events::Event;
29use nonempty::NonEmpty;
30use secrecy::ExposeSecret;
31use tokio::sync::broadcast;
32use tokio::sync::oneshot;
33use tokio::sync::oneshot::Receiver;
34use tokio::task::JoinSet;
35use tokio_util::sync::CancellationToken;
36use tracing::info;
37use wdl_ast::v1::TASK_REQUIREMENT_DISKS;
38
39use super::TaskExecutionBackend;
40use super::TaskExecutionConstraints;
41use super::TaskExecutionResult;
42use super::TaskManager;
43use super::TaskManagerRequest;
44use super::TaskSpawnRequest;
45use crate::COMMAND_FILE_NAME;
46use crate::ONE_GIBIBYTE;
47use crate::PrimitiveValue;
48use crate::STDERR_FILE_NAME;
49use crate::STDOUT_FILE_NAME;
50use crate::Value;
51use crate::WORK_DIR_NAME;
52use crate::backend::INITIAL_EXPECTED_NAMES;
53use crate::config::Config;
54use crate::config::DEFAULT_TASK_SHELL;
55use crate::config::TesBackendAuthConfig;
56use crate::config::TesBackendConfig;
57use crate::hash::UrlDigestExt;
58use crate::hash::calculate_path_digest;
59use crate::path::EvaluationPath;
60use crate::v1::DEFAULT_TASK_REQUIREMENT_DISKS;
61use crate::v1::container;
62use crate::v1::cpu;
63use crate::v1::disks;
64use crate::v1::max_cpu;
65use crate::v1::max_memory;
66use crate::v1::memory;
67use crate::v1::preemptible;
68
69/// The root guest path for inputs.
70const GUEST_INPUTS_DIR: &str = "/mnt/task/inputs/";
71
72/// The guest working directory.
73const GUEST_WORK_DIR: &str = "/mnt/task/work";
74
75/// The guest path for the command file.
76const GUEST_COMMAND_PATH: &str = "/mnt/task/command";
77
78/// The path to the container's stdout.
79const GUEST_STDOUT_PATH: &str = "/mnt/task/stdout";
80
81/// The path to the container's stderr.
82const GUEST_STDERR_PATH: &str = "/mnt/task/stderr";
83
84/// The default poll interval, in seconds, for the TES backend.
85const DEFAULT_TES_INTERVAL: u64 = 1;
86
87/// Represents a TES task request.
88///
89/// This request contains the requested cpu and memory reservations for the task
90/// as well as the result receiver channel.
91#[derive(Debug)]
92struct TesTaskRequest {
93    /// The engine configuration.
94    config: Arc<Config>,
95    /// The backend configuration.
96    backend_config: Arc<TesBackendConfig>,
97    /// The inner task spawn request.
98    inner: TaskSpawnRequest,
99    /// The Crankshaft TES backend to use.
100    backend: Arc<tes::Backend>,
101    /// The name of the task.
102    name: String,
103    /// The requested container for the task.
104    container: String,
105    /// The requested CPU reservation for the task.
106    cpu: f64,
107    /// The requested memory reservation for the task, in bytes.
108    memory: u64,
109    /// The requested maximum CPU limit for the task.
110    max_cpu: Option<f64>,
111    /// The requested maximum memory limit for the task, in bytes.
112    max_memory: Option<u64>,
113    /// The number of preemptible task retries to do before using a
114    /// non-preemptible task.
115    ///
116    /// If this value is 0, no preemptible tasks are requested from the TES
117    /// server.
118    preemptible: i64,
119    /// The cancellation token for the request.
120    token: CancellationToken,
121}
122
123impl TesTaskRequest {
124    /// Gets the TES disk resource for the request.
125    fn disk_resource(&self) -> Result<f64> {
126        let disks = disks(self.inner.requirements(), self.inner.hints())?;
127        if disks.len() > 1 {
128            bail!(
129                "TES backend does not support more than one disk specification for the \
130                 `{TASK_REQUIREMENT_DISKS}` task requirement"
131            );
132        }
133
134        if let Some(mount_point) = disks.keys().next()
135            && *mount_point != "/"
136        {
137            bail!(
138                "TES backend does not support a disk mount point other than `/` for the \
139                 `{TASK_REQUIREMENT_DISKS}` task requirement"
140            );
141        }
142
143        Ok(disks
144            .values()
145            .next()
146            .map(|d| d.size as f64)
147            .unwrap_or(DEFAULT_TASK_REQUIREMENT_DISKS))
148    }
149}
150
151impl TaskManagerRequest for TesTaskRequest {
152    fn cpu(&self) -> f64 {
153        self.cpu
154    }
155
156    fn memory(&self) -> u64 {
157        self.memory
158    }
159
160    async fn run(self) -> Result<TaskExecutionResult> {
161        // Create the attempt directory
162        let attempt_dir = self.inner.attempt_dir();
163        fs::create_dir_all(attempt_dir).with_context(|| {
164            format!(
165                "failed to create directory `{path}`",
166                path = attempt_dir.display()
167            )
168        })?;
169
170        // Write the evaluated command to disk
171        // This is done even for remote execution so that a copy exists locally
172        let command_path = attempt_dir.join(COMMAND_FILE_NAME);
173        fs::write(&command_path, self.inner.command()).with_context(|| {
174            format!(
175                "failed to write command contents to `{path}`",
176                path = command_path.display()
177            )
178        })?;
179
180        // SAFETY: currently `inputs` is required by configuration validation, so it
181        // should always unwrap
182        let inputs_url = Arc::new(
183            self.backend_config
184                .inputs
185                .clone()
186                .expect("should have inputs URL"),
187        );
188
189        // Start with the command file as an input
190        let mut inputs = vec![
191            Input::builder()
192                .path(GUEST_COMMAND_PATH)
193                .contents(Contents::Path(command_path.to_path_buf()))
194                .ty(InputType::File)
195                .read_only(true)
196                .build(),
197        ];
198
199        // Spawn upload tasks for inputs available locally, and apply authentication to
200        // the URLs for remote inputs.
201        let mut uploads = JoinSet::new();
202        for (i, input) in self.inner.inputs().iter().enumerate() {
203            match input.path() {
204                EvaluationPath::Local(path) => {
205                    // Input is local, spawn an upload of it
206                    let path = path.to_path_buf();
207                    let transferer = self.inner.transferer().clone();
208                    let inputs_url = inputs_url.clone();
209                    uploads.spawn(async move {
210                        let url = inputs_url.join_digest(
211                            calculate_path_digest(&path).await.with_context(|| {
212                                format!(
213                                    "failed to calculate digest of `{path}`",
214                                    path = path.display()
215                                )
216                            })?,
217                        );
218                        transferer
219                            .upload(&path, &url)
220                            .await
221                            .with_context(|| {
222                                format!(
223                                    "failed to upload `{path}` to `{url}`",
224                                    path = path.display(),
225                                    url = url.display()
226                                )
227                            })
228                            .map(|_| (i, url))
229                    });
230                }
231                EvaluationPath::Remote(url) => {
232                    // Input is already remote, add it to the Crankshaft inputs list
233                    let url = match self.inner.transferer().apply_auth(url)? {
234                        Cow::Borrowed(_) => url.clone(),
235                        Cow::Owned(url) => url,
236                    };
237                    inputs.push(
238                        Input::builder()
239                            .path(
240                                input
241                                    .guest_path()
242                                    .expect("input should have guest path")
243                                    .as_str(),
244                            )
245                            .contents(Contents::Url(url))
246                            .ty(input.kind())
247                            .read_only(true)
248                            .build(),
249                    );
250                }
251            }
252        }
253
254        // Wait for any uploads to complete
255        while let Some(result) = uploads.join_next().await {
256            let (i, url) = result.context("upload task")??;
257            let input = &self.inner.inputs()[i];
258            let url = match self.inner.transferer().apply_auth(&url)? {
259                Cow::Borrowed(_) => url,
260                Cow::Owned(url) => url,
261            };
262
263            inputs.push(
264                Input::builder()
265                    .path(
266                        input
267                            .guest_path()
268                            .expect("input should have guest path")
269                            .as_str(),
270                    )
271                    .contents(Contents::Url(url))
272                    .ty(input.kind())
273                    .read_only(true)
274                    .build(),
275            );
276        }
277
278        let output_dir = format!(
279            "{name}-{timestamp}/",
280            name = self.name,
281            timestamp = chrono::Utc::now().format("%Y%m%d-%H%M%S")
282        );
283
284        // SAFETY: currently `outputs` is required by configuration validation, so it
285        // should always unwrap
286        let outputs_url = self
287            .backend_config
288            .outputs
289            .as_ref()
290            .expect("should have outputs URL")
291            .join(&output_dir)
292            .expect("should join");
293
294        let work_dir_url = outputs_url.join(WORK_DIR_NAME).expect("should join");
295        let stdout_url = outputs_url.join(STDOUT_FILE_NAME).expect("should join");
296        let stderr_url = outputs_url.join(STDERR_FILE_NAME).expect("should join");
297
298        let mut work_dir_url = match self.inner.transferer().apply_auth(&work_dir_url)? {
299            Cow::Borrowed(_) => work_dir_url,
300            Cow::Owned(url) => url,
301        };
302
303        let stdout_url = match self.inner.transferer().apply_auth(&stdout_url)? {
304            Cow::Borrowed(_) => stdout_url,
305            Cow::Owned(url) => url,
306        };
307
308        let stderr_url = match self.inner.transferer().apply_auth(&stderr_url)? {
309            Cow::Borrowed(_) => stderr_url,
310            Cow::Owned(url) => url,
311        };
312
313        // The TES backend will output three things: the working directory contents,
314        // stdout, and stderr.
315        let outputs = vec![
316            Output::builder()
317                .path(GUEST_WORK_DIR)
318                .url(work_dir_url.clone())
319                .ty(OutputType::Directory)
320                .build(),
321            Output::builder()
322                .path(GUEST_STDOUT_PATH)
323                .url(stdout_url.clone())
324                .ty(OutputType::File)
325                .build(),
326            Output::builder()
327                .path(GUEST_STDERR_PATH)
328                .url(stderr_url.clone())
329                .ty(OutputType::File)
330                .build(),
331        ];
332
333        let mut preemptible = self.preemptible;
334        loop {
335            let task = Task::builder()
336                .name(&self.name)
337                .executions(NonEmpty::new(
338                    Execution::builder()
339                        .image(&self.container)
340                        .program(
341                            self.config
342                                .task
343                                .shell
344                                .as_deref()
345                                .unwrap_or(DEFAULT_TASK_SHELL),
346                        )
347                        .args([GUEST_COMMAND_PATH.to_string()])
348                        .work_dir(GUEST_WORK_DIR)
349                        .env(self.inner.env().clone())
350                        .stdout(GUEST_STDOUT_PATH)
351                        .stderr(GUEST_STDERR_PATH)
352                        .build(),
353                ))
354                .inputs(inputs.clone())
355                .outputs(outputs.clone())
356                .resources(
357                    Resources::builder()
358                        .cpu(self.cpu)
359                        .maybe_cpu_limit(self.max_cpu)
360                        .ram(self.memory as f64 / ONE_GIBIBYTE)
361                        .disk(self.disk_resource()?)
362                        .maybe_ram_limit(self.max_memory.map(|m| m as f64 / ONE_GIBIBYTE))
363                        .preemptible(preemptible > 0)
364                        .build(),
365                )
366                .build();
367
368            let statuses = match self.backend.run(task, self.token.clone())?.await {
369                Ok(statuses) => statuses,
370                Err(TaskRunError::Preempted) if preemptible > 0 => {
371                    // Decrement the preemptible count and retry
372                    preemptible -= 1;
373                    continue;
374                }
375                Err(e) => {
376                    return Err(e.into());
377                }
378            };
379
380            assert_eq!(statuses.len(), 1, "there should only be one output");
381            let status = statuses.first();
382
383            // Push an empty path segment so that future joins of the work directory URL
384            // treat it as a directory
385            work_dir_url.path_segments_mut().unwrap().push("");
386
387            return Ok(TaskExecutionResult {
388                exit_code: status.code().expect("should have exit code"),
389                work_dir: EvaluationPath::Remote(work_dir_url),
390                stdout: PrimitiveValue::new_file(stdout_url).into(),
391                stderr: PrimitiveValue::new_file(stderr_url).into(),
392            });
393        }
394    }
395}
396
397/// Represents the Task Execution Service (TES) backend.
398pub struct TesBackend {
399    /// The engine configuration.
400    config: Arc<Config>,
401    /// The backend configuration.
402    backend_config: Arc<TesBackendConfig>,
403    /// The underlying Crankshaft backend.
404    inner: Arc<tes::Backend>,
405    /// The maximum CPUs for any of one node.
406    max_cpu: u64,
407    /// The maximum memory for any of one node.
408    max_memory: u64,
409    /// The task manager for the backend.
410    manager: TaskManager<TesTaskRequest>,
411    /// The name generator for tasks.
412    names: Arc<Mutex<GeneratorIterator<UniqueAlphanumeric>>>,
413}
414
415impl TesBackend {
416    /// Constructs a new TES task execution backend with the given
417    /// configuration.
418    ///
419    /// The provided configuration is expected to have already been validated.
420    pub async fn new(
421        config: Arc<Config>,
422        backend_config: &TesBackendConfig,
423        events: Option<broadcast::Sender<Event>>,
424    ) -> Result<Self> {
425        info!("initializing TES backend");
426
427        // There's no way to ask the TES service for its limits, so use the maximums
428        // allowed
429        let max_cpu = u64::MAX;
430        let max_memory = u64::MAX;
431        let manager = TaskManager::new_unlimited(max_cpu, max_memory);
432
433        let mut http = backend::tes::http::Config::default();
434        match &backend_config.auth {
435            Some(TesBackendAuthConfig::Basic(config)) => {
436                http.auth = Some(HttpAuthConfig::Basic {
437                    username: config.username.clone(),
438                    password: config.password.inner().expose_secret().to_string(),
439                });
440            }
441            Some(TesBackendAuthConfig::Bearer(config)) => {
442                http.auth = Some(HttpAuthConfig::Bearer {
443                    token: config.token.inner().expose_secret().to_string(),
444                });
445            }
446            None => {}
447        }
448
449        http.retries = backend_config.retries;
450        http.max_concurrency = backend_config.max_concurrency.map(|c| c as usize);
451
452        let names = Arc::new(Mutex::new(GeneratorIterator::new(
453            UniqueAlphanumeric::default_with_expected_generations(INITIAL_EXPECTED_NAMES),
454            INITIAL_EXPECTED_NAMES,
455        )));
456
457        let backend = tes::Backend::initialize(
458            backend::tes::Config::builder()
459                .url(backend_config.url.clone().expect("should have URL"))
460                .http(http)
461                .interval(backend_config.interval.unwrap_or(DEFAULT_TES_INTERVAL))
462                .build(),
463            names.clone(),
464            events,
465        )
466        .await;
467
468        Ok(Self {
469            config,
470            backend_config: Arc::new(backend_config.clone()),
471            inner: Arc::new(backend),
472            max_cpu,
473            max_memory,
474            manager,
475            names,
476        })
477    }
478}
479
480impl TaskExecutionBackend for TesBackend {
481    fn max_concurrency(&self) -> u64 {
482        // The TES backend doesn't limit the number of tasks that can be queued at a
483        // time
484        u64::MAX
485    }
486
487    fn constraints(
488        &self,
489        requirements: &HashMap<String, Value>,
490        hints: &HashMap<String, Value>,
491    ) -> Result<TaskExecutionConstraints> {
492        let container = container(requirements, self.config.task.container.as_deref());
493
494        let cpu = cpu(requirements);
495        if (self.max_cpu as f64) < cpu {
496            bail!(
497                "task requires at least {cpu} CPU{s}, but the execution backend has a maximum of \
498                 {max_cpu}",
499                s = if cpu == 1.0 { "" } else { "s" },
500                max_cpu = self.max_cpu,
501            );
502        }
503
504        let memory = memory(requirements)?;
505        if self.max_memory < memory as u64 {
506            // Display the error in GiB, as it is the most common unit for memory
507            let memory = memory as f64 / ONE_GIBIBYTE;
508            let max_memory = self.max_memory as f64 / ONE_GIBIBYTE;
509
510            bail!(
511                "task requires at least {memory} GiB of memory, but the execution backend has a \
512                 maximum of {max_memory} GiB",
513            );
514        }
515
516        // TODO: only parse the disks requirement once
517        let disks = disks(requirements, hints)?
518            .into_iter()
519            .map(|(mp, disk)| (mp.to_string(), disk.size))
520            .collect();
521
522        Ok(TaskExecutionConstraints {
523            container: Some(container.into_owned()),
524            cpu,
525            memory,
526            gpu: Default::default(),
527            fpga: Default::default(),
528            disks,
529        })
530    }
531
532    fn guest_inputs_dir(&self) -> Option<&'static str> {
533        Some(GUEST_INPUTS_DIR)
534    }
535
536    fn needs_local_inputs(&self) -> bool {
537        false
538    }
539
540    fn spawn(
541        &self,
542        request: TaskSpawnRequest,
543        token: CancellationToken,
544    ) -> Result<Receiver<Result<TaskExecutionResult>>> {
545        let (completed_tx, completed_rx) = oneshot::channel();
546
547        let requirements = request.requirements();
548        let hints = request.hints();
549
550        let container = container(requirements, self.config.task.container.as_deref()).into_owned();
551        let cpu = cpu(requirements);
552        let memory = memory(requirements)? as u64;
553        let max_cpu = max_cpu(hints);
554        let max_memory = max_memory(hints)?.map(|i| i as u64);
555        let preemptible = preemptible(hints);
556
557        let name = format!(
558            "{id}-{generated}",
559            id = request.id(),
560            generated = self
561                .names
562                .lock()
563                .expect("generator should always acquire")
564                .next()
565                .expect("generator should never be exhausted")
566        );
567        self.manager.send(
568            TesTaskRequest {
569                config: self.config.clone(),
570                backend_config: self.backend_config.clone(),
571                inner: request,
572                backend: self.inner.clone(),
573                name,
574                container,
575                cpu,
576                memory,
577                max_cpu,
578                max_memory,
579                token,
580                preemptible,
581            },
582            completed_tx,
583        );
584
585        Ok(completed_rx)
586    }
587}