wdl_engine/backend/
local.rs

1//! Implementation of the local backend.
2
3use std::collections::HashMap;
4use std::ffi::OsStr;
5use std::fs;
6use std::fs::File;
7use std::path::Path;
8use std::process::Stdio;
9use std::sync::Arc;
10
11use anyhow::Context;
12use anyhow::Result;
13use anyhow::bail;
14use futures::FutureExt;
15use futures::future::BoxFuture;
16use tokio::process::Command;
17use tokio::select;
18use tokio::sync::oneshot;
19use tokio::task::JoinSet;
20use tokio_util::sync::CancellationToken;
21use tracing::info;
22
23use super::TaskExecutionBackend;
24use super::TaskExecutionConstraints;
25use super::TaskExecutionEvents;
26use super::TaskManager;
27use super::TaskManagerRequest;
28use super::TaskSpawnRequest;
29use crate::COMMAND_FILE_NAME;
30use crate::Input;
31use crate::ONE_GIBIBYTE;
32use crate::PrimitiveValue;
33use crate::STDERR_FILE_NAME;
34use crate::STDOUT_FILE_NAME;
35use crate::SYSTEM;
36use crate::TaskExecutionResult;
37use crate::Value;
38use crate::WORK_DIR_NAME;
39use crate::config::Config;
40use crate::config::DEFAULT_TASK_SHELL;
41use crate::config::LocalBackendConfig;
42use crate::convert_unit_string;
43use crate::http::Downloader;
44use crate::http::HttpDownloader;
45use crate::http::Location;
46use crate::path::EvaluationPath;
47use crate::v1::cpu;
48use crate::v1::memory;
49
50/// Represents a local task request.
51///
52/// This request contains the requested cpu and memory reservations for the task
53/// as well as the result receiver channel.
54#[derive(Debug)]
55struct LocalTaskRequest {
56    /// The engine configuration.
57    config: Arc<Config>,
58    /// The inner task spawn request.
59    inner: TaskSpawnRequest,
60    /// The requested CPU reservation for the task.
61    ///
62    /// Note that CPU isn't actually reserved for the task process.
63    cpu: f64,
64    /// The requested memory reservation for the task.
65    ///
66    /// Note that memory isn't actually reserved for the task process.
67    memory: u64,
68    /// The cancellation token for the request.
69    token: CancellationToken,
70}
71
72impl TaskManagerRequest for LocalTaskRequest {
73    fn cpu(&self) -> f64 {
74        self.cpu
75    }
76
77    fn memory(&self) -> u64 {
78        self.memory
79    }
80
81    async fn run(self, spawned: oneshot::Sender<()>) -> Result<TaskExecutionResult> {
82        // Create the working directory
83        let work_dir = self.inner.attempt_dir().join(WORK_DIR_NAME);
84        fs::create_dir_all(&work_dir).with_context(|| {
85            format!(
86                "failed to create directory `{path}`",
87                path = work_dir.display()
88            )
89        })?;
90
91        // Write the evaluated command to disk
92        let command_path = self.inner.attempt_dir().join(COMMAND_FILE_NAME);
93        fs::write(&command_path, self.inner.command()).with_context(|| {
94            format!(
95                "failed to write command contents to `{path}`",
96                path = command_path.display()
97            )
98        })?;
99
100        // Create a file for the stdout
101        let stdout_path = self.inner.attempt_dir().join(STDOUT_FILE_NAME);
102        let stdout = File::create(&stdout_path).with_context(|| {
103            format!(
104                "failed to create stdout file `{path}`",
105                path = stdout_path.display()
106            )
107        })?;
108
109        // Create a file for the stderr
110        let stderr_path = self.inner.attempt_dir().join(STDERR_FILE_NAME);
111        let stderr = File::create(&stderr_path).with_context(|| {
112            format!(
113                "failed to create stderr file `{path}`",
114                path = stderr_path.display()
115            )
116        })?;
117
118        let mut command = Command::new(
119            self.config
120                .task
121                .shell
122                .as_deref()
123                .unwrap_or(DEFAULT_TASK_SHELL),
124        );
125        command
126            .current_dir(&work_dir)
127            .arg("-C")
128            .arg(command_path)
129            .stdin(Stdio::null())
130            .stdout(stdout)
131            .stderr(stderr)
132            .envs(
133                self.inner
134                    .env()
135                    .iter()
136                    .map(|(k, v)| (OsStr::new(k), OsStr::new(v))),
137            )
138            .kill_on_drop(true);
139
140        // Set the PATH variable for the child on Windows to get consistent PATH
141        // searching. See: https://github.com/rust-lang/rust/issues/122660
142        #[cfg(windows)]
143        if let Ok(path) = std::env::var("PATH") {
144            command.env("PATH", path);
145        }
146
147        let mut child = command.spawn().context("failed to spawn `bash`")?;
148
149        // Notify that the process has spawned
150        spawned.send(()).ok();
151
152        let id = child.id().expect("should have id");
153        info!("spawned local `bash` process {id} for task execution");
154
155        select! {
156            // Poll the cancellation token before the child future
157            biased;
158
159            _ = self.token.cancelled() => {
160                bail!("task was cancelled");
161            }
162            status = child.wait() => {
163                let status = status.with_context(|| {
164                    format!("failed to wait for termination of task child process {id}")
165                })?;
166
167                #[cfg(unix)]
168                {
169                    use std::os::unix::process::ExitStatusExt;
170                    if let Some(signal) = status.signal() {
171                        tracing::warn!("task process {id} has terminated with signal {signal}");
172
173                        bail!(
174                            "task child process {id} has terminated with signal {signal}; see stderr file \
175                            `{path}` for more details",
176                            path = stderr_path.display()
177                        );
178                    }
179                }
180
181                let exit_code = status.code().expect("process should have exited");
182                info!("task process {id} has terminated with status code {exit_code}");
183                Ok(TaskExecutionResult {
184                    inputs: self.inner.info.inputs,
185                    exit_code,
186                    work_dir: EvaluationPath::Local(work_dir),
187                    stdout: PrimitiveValue::new_file(stdout_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
188                    stderr: PrimitiveValue::new_file(stderr_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
189                })
190            }
191        }
192    }
193}
194
195/// Represents a task execution backend that locally executes tasks.
196///
197/// <div class="warning">
198/// Warning: the local task execution backend spawns processes on the host
199/// directly without the use of a container; only use this backend on trusted
200/// WDL. </div>
201pub struct LocalBackend {
202    /// The engine configuration.
203    config: Arc<Config>,
204    /// The total CPU of the host.
205    cpu: u64,
206    /// The total memory of the host.
207    memory: u64,
208    /// The underlying task manager.
209    manager: TaskManager<LocalTaskRequest>,
210}
211
212impl LocalBackend {
213    /// Constructs a new local task execution backend with the given
214    /// configuration.
215    ///
216    /// The provided configuration is expected to have already been validated.
217    pub fn new(config: Arc<Config>, backend_config: &LocalBackendConfig) -> Result<Self> {
218        info!("initializing local backend");
219
220        let cpu = backend_config
221            .cpu
222            .unwrap_or_else(|| SYSTEM.cpus().len() as u64);
223        let memory = backend_config
224            .memory
225            .as_ref()
226            .map(|s| convert_unit_string(s).expect("value should be valid"))
227            .unwrap_or_else(|| SYSTEM.total_memory());
228        let manager = TaskManager::new(cpu, cpu, memory, memory);
229
230        Ok(Self {
231            config,
232            cpu,
233            memory,
234            manager,
235        })
236    }
237}
238
239impl TaskExecutionBackend for LocalBackend {
240    fn max_concurrency(&self) -> u64 {
241        self.cpu
242    }
243
244    fn constraints(
245        &self,
246        requirements: &HashMap<String, Value>,
247        _: &HashMap<String, Value>,
248    ) -> Result<TaskExecutionConstraints> {
249        let cpu = cpu(requirements);
250        if (self.cpu as f64) < cpu {
251            bail!(
252                "task requires at least {cpu} CPU{s}, but the host only has {total_cpu} available",
253                s = if cpu == 1.0 { "" } else { "s" },
254                total_cpu = self.cpu,
255            );
256        }
257
258        let memory = memory(requirements)?;
259        if self.memory < memory as u64 {
260            // Display the error in GiB, as it is the most common unit for memory
261            let memory = memory as f64 / ONE_GIBIBYTE;
262            let total_memory = self.memory as f64 / ONE_GIBIBYTE;
263
264            bail!(
265                "task requires at least {memory} GiB of memory, but the host only has \
266                 {total_memory} GiB available",
267            );
268        }
269
270        Ok(TaskExecutionConstraints {
271            container: None,
272            cpu,
273            memory,
274            gpu: Default::default(),
275            fpga: Default::default(),
276            disks: Default::default(),
277        })
278    }
279
280    fn guest_work_dir(&self) -> Option<&Path> {
281        // Local execution does not use a container
282        None
283    }
284
285    fn localize_inputs<'a, 'b, 'c, 'd>(
286        &'a self,
287        downloader: &'b HttpDownloader,
288        inputs: &'c mut [Input],
289    ) -> BoxFuture<'d, Result<()>>
290    where
291        'a: 'd,
292        'b: 'd,
293        'c: 'd,
294        Self: 'd,
295    {
296        async move {
297            let mut downloads = JoinSet::new();
298
299            for (idx, input) in inputs.iter_mut().enumerate() {
300                match input.path() {
301                    EvaluationPath::Local(path) => {
302                        let location = Location::Path(path.clone().into());
303                        let guest_path = location
304                            .to_str()
305                            .with_context(|| {
306                                format!("path `{path}` is not UTF-8", path = path.display())
307                            })?
308                            .to_string();
309                        input.set_location(location.into_owned());
310                        input.set_guest_path(guest_path);
311                    }
312                    EvaluationPath::Remote(url) => {
313                        let downloader = downloader.clone();
314                        let url = url.clone();
315                        downloads.spawn(async move {
316                            let location_result = downloader.download(&url).await;
317
318                            match location_result {
319                                Ok(location) => Ok((idx, location.into_owned())),
320                                Err(e) => bail!("failed to localize `{url}`: {e:?}"),
321                            }
322                        });
323                    }
324                }
325            }
326
327            while let Some(result) = downloads.join_next().await {
328                match result {
329                    Ok(Ok((idx, location))) => {
330                        let guest_path = location
331                            .to_str()
332                            .with_context(|| {
333                                format!(
334                                    "downloaded path `{path}` is not UTF-8",
335                                    path = location.display()
336                                )
337                            })?
338                            .to_string();
339
340                        let input = inputs.get_mut(idx).expect("index should be valid");
341                        input.set_location(location);
342                        input.set_guest_path(guest_path);
343                    }
344                    Ok(Err(e)) => {
345                        // Futures are aborted when the `JoinSet` is dropped.
346                        bail!(e);
347                    }
348                    Err(e) => {
349                        // Futures are aborted when the `JoinSet` is dropped.
350                        bail!("download task failed: {e}");
351                    }
352                }
353            }
354
355            Ok(())
356        }
357        .boxed()
358    }
359
360    fn spawn(
361        &self,
362        request: TaskSpawnRequest,
363        token: CancellationToken,
364    ) -> Result<TaskExecutionEvents> {
365        let (spawned_tx, spawned_rx) = oneshot::channel();
366        let (completed_tx, completed_rx) = oneshot::channel();
367
368        let requirements = request.requirements();
369        let cpu = cpu(requirements);
370        let memory = memory(requirements)? as u64;
371
372        self.manager.send(
373            LocalTaskRequest {
374                config: self.config.clone(),
375                inner: request,
376                cpu,
377                memory,
378                token,
379            },
380            spawned_tx,
381            completed_tx,
382        );
383
384        Ok(TaskExecutionEvents {
385            spawned: spawned_rx,
386            completed: completed_rx,
387        })
388    }
389}