wdl_engine/backend/
local.rs

1//! Implementation of the local backend.
2
3use std::collections::HashMap;
4use std::ffi::OsStr;
5use std::fs;
6use std::fs::File;
7use std::path::Path;
8use std::process::Stdio;
9use std::sync::Arc;
10
11use anyhow::Context;
12use anyhow::Result;
13use anyhow::bail;
14use futures::FutureExt;
15use futures::future::BoxFuture;
16use tokio::process::Command;
17use tokio::select;
18use tokio::sync::oneshot;
19use tokio::task::JoinSet;
20use tokio_util::sync::CancellationToken;
21use tracing::info;
22use tracing::warn;
23
24use super::TaskExecutionBackend;
25use super::TaskExecutionConstraints;
26use super::TaskExecutionEvents;
27use super::TaskManager;
28use super::TaskManagerRequest;
29use super::TaskSpawnRequest;
30use crate::COMMAND_FILE_NAME;
31use crate::Input;
32use crate::ONE_GIBIBYTE;
33use crate::PrimitiveValue;
34use crate::STDERR_FILE_NAME;
35use crate::STDOUT_FILE_NAME;
36use crate::SYSTEM;
37use crate::TaskExecutionResult;
38use crate::Value;
39use crate::WORK_DIR_NAME;
40use crate::config::Config;
41use crate::config::DEFAULT_TASK_SHELL;
42use crate::config::LocalBackendConfig;
43use crate::config::TaskResourceLimitBehavior;
44use crate::convert_unit_string;
45use crate::http::Downloader;
46use crate::http::HttpDownloader;
47use crate::http::Location;
48use crate::path::EvaluationPath;
49use crate::v1::cpu;
50use crate::v1::memory;
51
52/// Represents a local task request.
53///
54/// This request contains the requested cpu and memory reservations for the task
55/// as well as the result receiver channel.
56#[derive(Debug)]
57struct LocalTaskRequest {
58    /// The engine configuration.
59    config: Arc<Config>,
60    /// The inner task spawn request.
61    inner: TaskSpawnRequest,
62    /// The requested CPU reservation for the task.
63    ///
64    /// Note that CPU isn't actually reserved for the task process.
65    cpu: f64,
66    /// The requested memory reservation for the task.
67    ///
68    /// Note that memory isn't actually reserved for the task process.
69    memory: u64,
70    /// The cancellation token for the request.
71    token: CancellationToken,
72}
73
74impl TaskManagerRequest for LocalTaskRequest {
75    fn cpu(&self) -> f64 {
76        self.cpu
77    }
78
79    fn memory(&self) -> u64 {
80        self.memory
81    }
82
83    async fn run(self, spawned: oneshot::Sender<()>) -> Result<TaskExecutionResult> {
84        // Create the working directory
85        let work_dir = self.inner.attempt_dir().join(WORK_DIR_NAME);
86        fs::create_dir_all(&work_dir).with_context(|| {
87            format!(
88                "failed to create directory `{path}`",
89                path = work_dir.display()
90            )
91        })?;
92
93        // Write the evaluated command to disk
94        let command_path = self.inner.attempt_dir().join(COMMAND_FILE_NAME);
95        fs::write(&command_path, self.inner.command()).with_context(|| {
96            format!(
97                "failed to write command contents to `{path}`",
98                path = command_path.display()
99            )
100        })?;
101
102        // Create a file for the stdout
103        let stdout_path = self.inner.attempt_dir().join(STDOUT_FILE_NAME);
104        let stdout = File::create(&stdout_path).with_context(|| {
105            format!(
106                "failed to create stdout file `{path}`",
107                path = stdout_path.display()
108            )
109        })?;
110
111        // Create a file for the stderr
112        let stderr_path = self.inner.attempt_dir().join(STDERR_FILE_NAME);
113        let stderr = File::create(&stderr_path).with_context(|| {
114            format!(
115                "failed to create stderr file `{path}`",
116                path = stderr_path.display()
117            )
118        })?;
119
120        let mut command = Command::new(
121            self.config
122                .task
123                .shell
124                .as_deref()
125                .unwrap_or(DEFAULT_TASK_SHELL),
126        );
127        command
128            .current_dir(&work_dir)
129            .arg("-C")
130            .arg(command_path)
131            .stdin(Stdio::null())
132            .stdout(stdout)
133            .stderr(stderr)
134            .envs(
135                self.inner
136                    .env()
137                    .iter()
138                    .map(|(k, v)| (OsStr::new(k), OsStr::new(v))),
139            )
140            .kill_on_drop(true);
141
142        // Set the PATH variable for the child on Windows to get consistent PATH
143        // searching. See: https://github.com/rust-lang/rust/issues/122660
144        #[cfg(windows)]
145        if let Ok(path) = std::env::var("PATH") {
146            command.env("PATH", path);
147        }
148
149        let mut child = command.spawn().context("failed to spawn `bash`")?;
150
151        // Notify that the process has spawned
152        spawned.send(()).ok();
153
154        let id = child.id().expect("should have id");
155        info!("spawned local `bash` process {id} for task execution");
156
157        select! {
158            // Poll the cancellation token before the child future
159            biased;
160
161            _ = self.token.cancelled() => {
162                bail!("task was cancelled");
163            }
164            status = child.wait() => {
165                let status = status.with_context(|| {
166                    format!("failed to wait for termination of task child process {id}")
167                })?;
168
169                #[cfg(unix)]
170                {
171                    use std::os::unix::process::ExitStatusExt;
172                    if let Some(signal) = status.signal() {
173                        tracing::warn!("task process {id} has terminated with signal {signal}");
174
175                        bail!(
176                            "task child process {id} has terminated with signal {signal}; see stderr file \
177                            `{path}` for more details",
178                            path = stderr_path.display()
179                        );
180                    }
181                }
182
183                let exit_code = status.code().expect("process should have exited");
184                info!("task process {id} has terminated with status code {exit_code}");
185                Ok(TaskExecutionResult {
186                    inputs: self.inner.info.inputs,
187                    exit_code,
188                    work_dir: EvaluationPath::Local(work_dir),
189                    stdout: PrimitiveValue::new_file(stdout_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
190                    stderr: PrimitiveValue::new_file(stderr_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
191                })
192            }
193        }
194    }
195}
196
197/// Represents a task execution backend that locally executes tasks.
198///
199/// <div class="warning">
200/// Warning: the local task execution backend spawns processes on the host
201/// directly without the use of a container; only use this backend on trusted
202/// WDL. </div>
203pub struct LocalBackend {
204    /// The engine configuration.
205    config: Arc<Config>,
206    /// The total CPU of the host.
207    cpu: u64,
208    /// The total memory of the host.
209    memory: u64,
210    /// The underlying task manager.
211    manager: TaskManager<LocalTaskRequest>,
212}
213
214impl LocalBackend {
215    /// Constructs a new local task execution backend with the given
216    /// configuration.
217    ///
218    /// The provided configuration is expected to have already been validated.
219    pub fn new(config: Arc<Config>, backend_config: &LocalBackendConfig) -> Result<Self> {
220        info!("initializing local backend");
221
222        let cpu = backend_config
223            .cpu
224            .unwrap_or_else(|| SYSTEM.cpus().len() as u64);
225        let memory = backend_config
226            .memory
227            .as_ref()
228            .map(|s| convert_unit_string(s).expect("value should be valid"))
229            .unwrap_or_else(|| SYSTEM.total_memory());
230        let manager = TaskManager::new(cpu, cpu, memory, memory);
231
232        Ok(Self {
233            config,
234            cpu,
235            memory,
236            manager,
237        })
238    }
239}
240
241impl TaskExecutionBackend for LocalBackend {
242    fn max_concurrency(&self) -> u64 {
243        self.cpu
244    }
245
246    fn constraints(
247        &self,
248        requirements: &HashMap<String, Value>,
249        _: &HashMap<String, Value>,
250    ) -> Result<TaskExecutionConstraints> {
251        let mut cpu = cpu(requirements);
252        if (self.cpu as f64) < cpu {
253            match self.config.task.cpu_limit_behavior {
254                TaskResourceLimitBehavior::TryWithMax => {
255                    warn!(
256                        "task requires at least {cpu} CPU{s}, but the host only has {total_cpu} \
257                         available",
258                        s = if cpu == 1.0 { "" } else { "s" },
259                        total_cpu = self.cpu,
260                    );
261                    // clamp the reported constraint to what's available
262                    cpu = self.cpu as f64;
263                }
264                TaskResourceLimitBehavior::Deny => {
265                    bail!(
266                        "task requires at least {cpu} CPU{s}, but the host only has {total_cpu} \
267                         available",
268                        s = if cpu == 1.0 { "" } else { "s" },
269                        total_cpu = self.cpu,
270                    );
271                }
272            }
273        }
274
275        let mut memory = memory(requirements)?;
276        if self.memory < memory as u64 {
277            match self.config.task.memory_limit_behavior {
278                TaskResourceLimitBehavior::TryWithMax => {
279                    warn!(
280                        "task requires at least {memory} GiB of memory, but the host only has \
281                         {total_memory} GiB available",
282                        // Display the error in GiB, as it is the most common unit for memory
283                        memory = memory as f64 / ONE_GIBIBYTE,
284                        total_memory = self.memory as f64 / ONE_GIBIBYTE,
285                    );
286                    // clamp the reported constraint to what's available
287                    memory = self.memory.try_into().unwrap_or(i64::MAX);
288                }
289                TaskResourceLimitBehavior::Deny => {
290                    bail!(
291                        "task requires at least {memory} GiB of memory, but the host only has \
292                         {total_memory} GiB available",
293                        // Display the error in GiB, as it is the most common unit for memory
294                        memory = memory as f64 / ONE_GIBIBYTE,
295                        total_memory = self.memory as f64 / ONE_GIBIBYTE,
296                    );
297                }
298            }
299        }
300
301        Ok(TaskExecutionConstraints {
302            container: None,
303            cpu,
304            memory,
305            gpu: Default::default(),
306            fpga: Default::default(),
307            disks: Default::default(),
308        })
309    }
310
311    fn guest_work_dir(&self) -> Option<&Path> {
312        // Local execution does not use a container
313        None
314    }
315
316    fn localize_inputs<'a, 'b, 'c, 'd>(
317        &'a self,
318        downloader: &'b HttpDownloader,
319        inputs: &'c mut [Input],
320    ) -> BoxFuture<'d, Result<()>>
321    where
322        'a: 'd,
323        'b: 'd,
324        'c: 'd,
325        Self: 'd,
326    {
327        async move {
328            let mut downloads = JoinSet::new();
329
330            for (idx, input) in inputs.iter_mut().enumerate() {
331                match input.path() {
332                    EvaluationPath::Local(path) => {
333                        let location = Location::Path(path.clone().into());
334                        let guest_path = location
335                            .to_str()
336                            .with_context(|| {
337                                format!("path `{path}` is not UTF-8", path = path.display())
338                            })?
339                            .to_string();
340                        input.set_location(location.into_owned());
341                        input.set_guest_path(guest_path);
342                    }
343                    EvaluationPath::Remote(url) => {
344                        let downloader = downloader.clone();
345                        let url = url.clone();
346                        downloads.spawn(async move {
347                            let location_result = downloader.download(&url).await;
348
349                            match location_result {
350                                Ok(location) => Ok((idx, location.into_owned())),
351                                Err(e) => bail!("failed to localize `{url}`: {e:?}"),
352                            }
353                        });
354                    }
355                }
356            }
357
358            while let Some(result) = downloads.join_next().await {
359                match result {
360                    Ok(Ok((idx, location))) => {
361                        let guest_path = location
362                            .to_str()
363                            .with_context(|| {
364                                format!(
365                                    "downloaded path `{path}` is not UTF-8",
366                                    path = location.display()
367                                )
368                            })?
369                            .to_string();
370
371                        let input = inputs.get_mut(idx).expect("index should be valid");
372                        input.set_location(location);
373                        input.set_guest_path(guest_path);
374                    }
375                    Ok(Err(e)) => {
376                        // Futures are aborted when the `JoinSet` is dropped.
377                        bail!(e);
378                    }
379                    Err(e) => {
380                        // Futures are aborted when the `JoinSet` is dropped.
381                        bail!("download task failed: {e}");
382                    }
383                }
384            }
385
386            Ok(())
387        }
388        .boxed()
389    }
390
391    fn spawn(
392        &self,
393        request: TaskSpawnRequest,
394        token: CancellationToken,
395    ) -> Result<TaskExecutionEvents> {
396        let (spawned_tx, spawned_rx) = oneshot::channel();
397        let (completed_tx, completed_rx) = oneshot::channel();
398
399        let requirements = request.requirements();
400        let mut cpu = cpu(requirements);
401        if let TaskResourceLimitBehavior::TryWithMax = self.config.task.cpu_limit_behavior {
402            cpu = std::cmp::min(cpu.ceil() as u64, self.cpu) as f64;
403        }
404        let mut memory = memory(requirements)? as u64;
405        if let TaskResourceLimitBehavior::TryWithMax = self.config.task.memory_limit_behavior {
406            memory = std::cmp::min(memory, self.memory);
407        }
408
409        self.manager.send(
410            LocalTaskRequest {
411                config: self.config.clone(),
412                inner: request,
413                cpu,
414                memory,
415                token,
416            },
417            spawned_tx,
418            completed_tx,
419        );
420
421        Ok(TaskExecutionEvents {
422            spawned: spawned_rx,
423            completed: completed_rx,
424        })
425    }
426}