1use std::collections::HashMap;
4use std::ffi::OsStr;
5use std::fs;
6use std::fs::File;
7use std::path::Path;
8use std::process::Stdio;
9use std::sync::Arc;
10
11use anyhow::Context;
12use anyhow::Result;
13use anyhow::bail;
14use futures::FutureExt;
15use futures::future::BoxFuture;
16use tokio::process::Command;
17use tokio::select;
18use tokio::sync::oneshot;
19use tokio::task::JoinSet;
20use tokio_util::sync::CancellationToken;
21use tracing::info;
22use tracing::warn;
23
24use super::TaskExecutionBackend;
25use super::TaskExecutionConstraints;
26use super::TaskExecutionEvents;
27use super::TaskManager;
28use super::TaskManagerRequest;
29use super::TaskSpawnRequest;
30use crate::COMMAND_FILE_NAME;
31use crate::Input;
32use crate::ONE_GIBIBYTE;
33use crate::PrimitiveValue;
34use crate::STDERR_FILE_NAME;
35use crate::STDOUT_FILE_NAME;
36use crate::SYSTEM;
37use crate::TaskExecutionResult;
38use crate::Value;
39use crate::WORK_DIR_NAME;
40use crate::config::Config;
41use crate::config::DEFAULT_TASK_SHELL;
42use crate::config::LocalBackendConfig;
43use crate::config::TaskResourceLimitBehavior;
44use crate::convert_unit_string;
45use crate::http::Downloader;
46use crate::http::HttpDownloader;
47use crate::http::Location;
48use crate::path::EvaluationPath;
49use crate::v1::cpu;
50use crate::v1::memory;
51
52#[derive(Debug)]
57struct LocalTaskRequest {
58 config: Arc<Config>,
60 inner: TaskSpawnRequest,
62 cpu: f64,
66 memory: u64,
70 token: CancellationToken,
72}
73
74impl TaskManagerRequest for LocalTaskRequest {
75 fn cpu(&self) -> f64 {
76 self.cpu
77 }
78
79 fn memory(&self) -> u64 {
80 self.memory
81 }
82
83 async fn run(self, spawned: oneshot::Sender<()>) -> Result<TaskExecutionResult> {
84 let work_dir = self.inner.attempt_dir().join(WORK_DIR_NAME);
86 fs::create_dir_all(&work_dir).with_context(|| {
87 format!(
88 "failed to create directory `{path}`",
89 path = work_dir.display()
90 )
91 })?;
92
93 let command_path = self.inner.attempt_dir().join(COMMAND_FILE_NAME);
95 fs::write(&command_path, self.inner.command()).with_context(|| {
96 format!(
97 "failed to write command contents to `{path}`",
98 path = command_path.display()
99 )
100 })?;
101
102 let stdout_path = self.inner.attempt_dir().join(STDOUT_FILE_NAME);
104 let stdout = File::create(&stdout_path).with_context(|| {
105 format!(
106 "failed to create stdout file `{path}`",
107 path = stdout_path.display()
108 )
109 })?;
110
111 let stderr_path = self.inner.attempt_dir().join(STDERR_FILE_NAME);
113 let stderr = File::create(&stderr_path).with_context(|| {
114 format!(
115 "failed to create stderr file `{path}`",
116 path = stderr_path.display()
117 )
118 })?;
119
120 let mut command = Command::new(
121 self.config
122 .task
123 .shell
124 .as_deref()
125 .unwrap_or(DEFAULT_TASK_SHELL),
126 );
127 command
128 .current_dir(&work_dir)
129 .arg(command_path)
130 .stdin(Stdio::null())
131 .stdout(stdout)
132 .stderr(stderr)
133 .envs(
134 self.inner
135 .env()
136 .iter()
137 .map(|(k, v)| (OsStr::new(k), OsStr::new(v))),
138 )
139 .kill_on_drop(true);
140
141 #[cfg(windows)]
144 if let Ok(path) = std::env::var("PATH") {
145 command.env("PATH", path);
146 }
147
148 let mut child = command.spawn().context("failed to spawn shell")?;
149
150 spawned.send(()).ok();
152
153 let id = child.id().expect("should have id");
154 info!("spawned local shell process {id} for task execution");
155
156 select! {
157 biased;
159
160 _ = self.token.cancelled() => {
161 bail!("task was cancelled");
162 }
163 status = child.wait() => {
164 let status = status.with_context(|| {
165 format!("failed to wait for termination of task child process {id}")
166 })?;
167
168 #[cfg(unix)]
169 {
170 use std::os::unix::process::ExitStatusExt;
171 if let Some(signal) = status.signal() {
172 tracing::warn!("task process {id} has terminated with signal {signal}");
173
174 bail!(
175 "task child process {id} has terminated with signal {signal}; see stderr file \
176 `{path}` for more details",
177 path = stderr_path.display()
178 );
179 }
180 }
181
182 let exit_code = status.code().expect("process should have exited");
183 info!("task process {id} has terminated with status code {exit_code}");
184 Ok(TaskExecutionResult {
185 inputs: self.inner.info.inputs,
186 exit_code,
187 work_dir: EvaluationPath::Local(work_dir),
188 stdout: PrimitiveValue::new_file(stdout_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
189 stderr: PrimitiveValue::new_file(stderr_path.into_os_string().into_string().expect("path should be UTF-8")).into(),
190 })
191 }
192 }
193 }
194}
195
196pub struct LocalBackend {
203 config: Arc<Config>,
205 cpu: u64,
207 memory: u64,
209 manager: TaskManager<LocalTaskRequest>,
211}
212
213impl LocalBackend {
214 pub fn new(config: Arc<Config>, backend_config: &LocalBackendConfig) -> Result<Self> {
219 info!("initializing local backend");
220
221 let cpu = backend_config
222 .cpu
223 .unwrap_or_else(|| SYSTEM.cpus().len() as u64);
224 let memory = backend_config
225 .memory
226 .as_ref()
227 .map(|s| convert_unit_string(s).expect("value should be valid"))
228 .unwrap_or_else(|| SYSTEM.total_memory());
229 let manager = TaskManager::new(cpu, cpu, memory, memory);
230
231 Ok(Self {
232 config,
233 cpu,
234 memory,
235 manager,
236 })
237 }
238}
239
240impl TaskExecutionBackend for LocalBackend {
241 fn max_concurrency(&self) -> u64 {
242 self.cpu
243 }
244
245 fn constraints(
246 &self,
247 requirements: &HashMap<String, Value>,
248 _: &HashMap<String, Value>,
249 ) -> Result<TaskExecutionConstraints> {
250 let mut cpu = cpu(requirements);
251 if (self.cpu as f64) < cpu {
252 let env_specific = if self.config.suppress_env_specific_output {
253 String::new()
254 } else {
255 format!(
256 ", but the host only has {total_cpu} available",
257 total_cpu = self.cpu
258 )
259 };
260 match self.config.task.cpu_limit_behavior {
261 TaskResourceLimitBehavior::TryWithMax => {
262 warn!(
263 "task requires at least {cpu} CPU{s}{env_specific}",
264 s = if cpu == 1.0 { "" } else { "s" },
265 );
266 cpu = self.cpu as f64;
268 }
269 TaskResourceLimitBehavior::Deny => {
270 bail!(
271 "task requires at least {cpu} CPU{s}{env_specific}",
272 s = if cpu == 1.0 { "" } else { "s" },
273 );
274 }
275 }
276 }
277
278 let mut memory = memory(requirements)?;
279 if self.memory < memory as u64 {
280 let env_specific = if self.config.suppress_env_specific_output {
281 String::new()
282 } else {
283 format!(
284 ", but the host only has {total_memory} GiB available",
285 total_memory = self.memory as f64 / ONE_GIBIBYTE,
286 )
287 };
288 match self.config.task.memory_limit_behavior {
289 TaskResourceLimitBehavior::TryWithMax => {
290 warn!(
291 "task requires at least {memory} GiB of memory{env_specific}",
292 memory = memory as f64 / ONE_GIBIBYTE,
294 );
295 memory = self.memory.try_into().unwrap_or(i64::MAX);
297 }
298 TaskResourceLimitBehavior::Deny => {
299 bail!(
300 "task requires at least {memory} GiB of memory{env_specific}",
301 memory = memory as f64 / ONE_GIBIBYTE,
303 );
304 }
305 }
306 }
307
308 Ok(TaskExecutionConstraints {
309 container: None,
310 cpu,
311 memory,
312 gpu: Default::default(),
313 fpga: Default::default(),
314 disks: Default::default(),
315 })
316 }
317
318 fn guest_work_dir(&self) -> Option<&Path> {
319 None
321 }
322
323 fn localize_inputs<'a, 'b, 'c, 'd>(
324 &'a self,
325 downloader: &'b HttpDownloader,
326 inputs: &'c mut [Input],
327 ) -> BoxFuture<'d, Result<()>>
328 where
329 'a: 'd,
330 'b: 'd,
331 'c: 'd,
332 Self: 'd,
333 {
334 async move {
335 let mut downloads = JoinSet::new();
336
337 for (idx, input) in inputs.iter_mut().enumerate() {
338 match input.path() {
339 EvaluationPath::Local(path) => {
340 let location = Location::Path(path.clone().into());
341 let guest_path = location
342 .to_str()
343 .with_context(|| {
344 format!("path `{path}` is not UTF-8", path = path.display())
345 })?
346 .to_string();
347 input.set_location(location.into_owned());
348 input.set_guest_path(guest_path);
349 }
350 EvaluationPath::Remote(url) => {
351 let downloader = downloader.clone();
352 let url = url.clone();
353 downloads.spawn(async move {
354 let location_result = downloader.download(&url).await;
355
356 match location_result {
357 Ok(location) => Ok((idx, location.into_owned())),
358 Err(e) => bail!("failed to localize `{url}`: {e:?}"),
359 }
360 });
361 }
362 }
363 }
364
365 while let Some(result) = downloads.join_next().await {
366 match result {
367 Ok(Ok((idx, location))) => {
368 let guest_path = location
369 .to_str()
370 .with_context(|| {
371 format!(
372 "downloaded path `{path}` is not UTF-8",
373 path = location.display()
374 )
375 })?
376 .to_string();
377
378 let input = inputs.get_mut(idx).expect("index should be valid");
379 input.set_location(location);
380 input.set_guest_path(guest_path);
381 }
382 Ok(Err(e)) => {
383 bail!(e);
385 }
386 Err(e) => {
387 bail!("download task failed: {e}");
389 }
390 }
391 }
392
393 Ok(())
394 }
395 .boxed()
396 }
397
398 fn spawn(
399 &self,
400 request: TaskSpawnRequest,
401 token: CancellationToken,
402 ) -> Result<TaskExecutionEvents> {
403 let (spawned_tx, spawned_rx) = oneshot::channel();
404 let (completed_tx, completed_rx) = oneshot::channel();
405
406 let requirements = request.requirements();
407 let mut cpu = cpu(requirements);
408 if let TaskResourceLimitBehavior::TryWithMax = self.config.task.cpu_limit_behavior {
409 cpu = std::cmp::min(cpu.ceil() as u64, self.cpu) as f64;
410 }
411 let mut memory = memory(requirements)? as u64;
412 if let TaskResourceLimitBehavior::TryWithMax = self.config.task.memory_limit_behavior {
413 memory = std::cmp::min(memory, self.memory);
414 }
415
416 self.manager.send(
417 LocalTaskRequest {
418 config: self.config.clone(),
419 inner: request,
420 cpu,
421 memory,
422 token,
423 },
424 spawned_tx,
425 completed_tx,
426 );
427
428 Ok(TaskExecutionEvents {
429 spawned: spawned_rx,
430 completed: completed_rx,
431 })
432 }
433}