1use std::collections::HashMap;
4use std::collections::VecDeque;
5use std::fmt;
6use std::future::Future;
7use std::ops::Add;
8use std::ops::Range;
9use std::ops::Sub;
10use std::path::Path;
11use std::path::PathBuf;
12use std::sync::Arc;
13
14use anyhow::Result;
15use anyhow::anyhow;
16use futures::future::BoxFuture;
17use indexmap::IndexMap;
18use ordered_float::OrderedFloat;
19use tokio::sync::mpsc;
20use tokio::sync::oneshot;
21use tokio::sync::oneshot::Receiver;
22use tokio::task::JoinSet;
23use tokio_util::sync::CancellationToken;
24use tracing::debug;
25
26use crate::Input;
27use crate::Value;
28use crate::http::Transferer;
29use crate::path::EvaluationPath;
30
31mod apptainer;
32mod docker;
33mod local;
34mod lsf_apptainer;
35mod slurm_apptainer;
36mod tes;
37
38pub use apptainer::*;
39pub use docker::*;
40pub use local::*;
41pub use lsf_apptainer::*;
42pub use slurm_apptainer::*;
43pub use tes::*;
44
45pub(crate) const WORK_DIR_NAME: &str = "work";
47
48pub(crate) const COMMAND_FILE_NAME: &str = "command";
50
51pub(crate) const STDOUT_FILE_NAME: &str = "stdout";
53
54pub(crate) const STDERR_FILE_NAME: &str = "stderr";
56
57const INITIAL_EXPECTED_NAMES: usize = 1000;
62
63pub struct TaskExecutionConstraints {
65 pub container: Option<String>,
69 pub cpu: f64,
71 pub memory: i64,
73 pub gpu: Vec<String>,
79 pub fpga: Vec<String>,
85 pub disks: IndexMap<String, i64>,
97}
98
99pub struct TaskSpawnInfo {
101 command: String,
103 inputs: Vec<Input>,
105 requirements: Arc<HashMap<String, Value>>,
107 hints: Arc<HashMap<String, Value>>,
109 env: Arc<IndexMap<String, String>>,
111 transferer: Arc<dyn Transferer>,
113}
114
115impl TaskSpawnInfo {
116 pub fn new(
118 command: String,
119 inputs: Vec<Input>,
120 requirements: Arc<HashMap<String, Value>>,
121 hints: Arc<HashMap<String, Value>>,
122 env: Arc<IndexMap<String, String>>,
123 transferer: Arc<dyn Transferer>,
124 ) -> Self {
125 Self {
126 command,
127 inputs,
128 requirements,
129 hints,
130 env,
131 transferer,
132 }
133 }
134}
135
136impl fmt::Debug for TaskSpawnInfo {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_struct("TaskSpawnInfo")
139 .field("command", &self.command)
140 .field("inputs", &self.inputs)
141 .field("requirements", &self.requirements)
142 .field("hints", &self.hints)
143 .field("env", &self.env)
144 .field("transferer", &"<transferer>")
145 .finish()
146 }
147}
148
149#[derive(Debug)]
151pub struct TaskSpawnRequest {
152 id: String,
154 info: TaskSpawnInfo,
156 attempt: u64,
158 attempt_dir: PathBuf,
160 task_eval_root: PathBuf,
162 temp_dir: PathBuf,
164}
165
166impl TaskSpawnRequest {
167 pub fn new(
169 id: String,
170 info: TaskSpawnInfo,
171 attempt: u64,
172 attempt_dir: PathBuf,
173 task_eval_root: PathBuf,
174 temp_dir: PathBuf,
175 ) -> Self {
176 Self {
177 id,
178 info,
179 attempt,
180 attempt_dir,
181 task_eval_root,
182 temp_dir,
183 }
184 }
185
186 pub fn id(&self) -> &str {
188 &self.id
189 }
190
191 pub fn command(&self) -> &str {
193 &self.info.command
194 }
195
196 pub fn inputs(&self) -> &[Input] {
198 &self.info.inputs
199 }
200
201 pub fn requirements(&self) -> &HashMap<String, Value> {
203 &self.info.requirements
204 }
205
206 pub fn hints(&self) -> &HashMap<String, Value> {
208 &self.info.hints
209 }
210
211 pub fn env(&self) -> &IndexMap<String, String> {
213 &self.info.env
214 }
215
216 pub fn transferer(&self) -> &Arc<dyn Transferer> {
218 &self.info.transferer
219 }
220
221 pub fn attempt(&self) -> u64 {
225 self.attempt
226 }
227
228 pub fn attempt_dir(&self) -> &Path {
230 &self.attempt_dir
231 }
232
233 pub fn task_eval_root_dir(&self) -> &Path {
235 &self.task_eval_root
236 }
237
238 pub fn temp_dir(&self) -> &Path {
240 &self.temp_dir
241 }
242
243 pub fn wdl_command_host_path(&self) -> PathBuf {
246 self.attempt_dir.join(COMMAND_FILE_NAME)
247 }
248
249 pub fn wdl_work_dir_host_path(&self) -> PathBuf {
251 self.attempt_dir.join(WORK_DIR_NAME)
252 }
253
254 pub fn wdl_stdout_host_path(&self) -> PathBuf {
257 self.attempt_dir.join(STDOUT_FILE_NAME)
258 }
259
260 pub fn wdl_stderr_host_path(&self) -> PathBuf {
263 self.attempt_dir.join(STDERR_FILE_NAME)
264 }
265}
266
267#[derive(Debug)]
269pub struct TaskExecutionResult {
270 pub exit_code: i32,
272 pub work_dir: EvaluationPath,
274 pub stdout: Value,
276 pub stderr: Value,
278}
279
280pub trait TaskExecutionBackend: Send + Sync {
282 fn max_concurrency(&self) -> u64;
284
285 fn constraints(
290 &self,
291 requirements: &HashMap<String, Value>,
292 hints: &HashMap<String, Value>,
293 ) -> Result<TaskExecutionConstraints>;
294
295 fn guest_inputs_dir(&self) -> Option<&'static str>;
301
302 fn needs_local_inputs(&self) -> bool;
307
308 fn spawn(
312 &self,
313 request: TaskSpawnRequest,
314 token: CancellationToken,
315 ) -> Result<Receiver<Result<TaskExecutionResult>>>;
316
317 fn cleanup<'a>(
321 &'a self,
322 work_dir: &'a EvaluationPath,
323 token: CancellationToken,
324 ) -> Option<BoxFuture<'a, ()>> {
325 let _ = work_dir;
326 let _ = token;
327 None
328 }
329}
330
331trait TaskManagerRequest: Send + Sync + 'static {
333 fn cpu(&self) -> f64;
335
336 fn memory(&self) -> u64;
338
339 fn run(self) -> impl Future<Output = Result<TaskExecutionResult>> + Send;
341}
342
343struct TaskManagerResponse {
345 cpu: f64,
347 memory: u64,
349 result: Result<TaskExecutionResult>,
351 tx: oneshot::Sender<Result<TaskExecutionResult>>,
353}
354
355struct TaskManagerState<Req> {
357 cpu: OrderedFloat<f64>,
359 memory: u64,
361 spawned: JoinSet<TaskManagerResponse>,
363 parked: VecDeque<(Req, oneshot::Sender<Result<TaskExecutionResult>>)>,
365}
366
367impl<Req> TaskManagerState<Req> {
368 fn new(cpu: u64, memory: u64) -> Self {
370 Self {
371 cpu: OrderedFloat(cpu as f64),
372 memory,
373 spawned: Default::default(),
374 parked: Default::default(),
375 }
376 }
377
378 fn unlimited(&self) -> bool {
380 self.cpu == u64::MAX as f64 && self.memory == u64::MAX
381 }
382}
383
384#[derive(Debug)]
386struct TaskManager<Req> {
387 tx: mpsc::UnboundedSender<(Req, oneshot::Sender<Result<TaskExecutionResult>>)>,
389}
390
391impl<Req> TaskManager<Req>
392where
393 Req: TaskManagerRequest,
394{
395 fn new(cpu: u64, max_cpu: u64, memory: u64, max_memory: u64) -> Self {
398 let (tx, rx) = mpsc::unbounded_channel();
399
400 tokio::spawn(async move {
401 Self::run_request_queue(rx, cpu, max_cpu, memory, max_memory).await;
402 });
403
404 Self { tx }
405 }
406
407 fn new_unlimited(max_cpu: u64, max_memory: u64) -> Self {
410 Self::new(u64::MAX, max_cpu, u64::MAX, max_memory)
411 }
412
413 fn send(&self, request: Req, completed: oneshot::Sender<Result<TaskExecutionResult>>) {
415 self.tx.send((request, completed)).ok();
416 }
417
418 async fn run_request_queue(
420 mut rx: mpsc::UnboundedReceiver<(Req, oneshot::Sender<Result<TaskExecutionResult>>)>,
421 cpu: u64,
422 max_cpu: u64,
423 memory: u64,
424 max_memory: u64,
425 ) {
426 let mut state = TaskManagerState::new(cpu, memory);
427
428 loop {
429 if state.spawned.is_empty() {
431 assert!(
432 state.parked.is_empty(),
433 "there can't be any parked requests if there are no spawned tasks"
434 );
435 match rx.recv().await {
436 Some((req, completed)) => {
437 Self::handle_spawn_request(&mut state, max_cpu, max_memory, req, completed);
438 continue;
439 }
440 None => break,
441 }
442 }
443
444 tokio::select! {
446 request = rx.recv() => {
447 match request {
448 Some((req, completed)) => {
449 Self::handle_spawn_request(&mut state, max_cpu, max_memory, req, completed);
450 }
451 None => break,
452 }
453 }
454 Some(Ok(response)) = state.spawned.join_next() => {
455 if !state.unlimited() {
456 state.cpu += response.cpu;
457 state.memory += response.memory;
458 }
459
460 response.tx.send(response.result).ok();
461 Self::spawn_parked_tasks(&mut state, max_cpu, max_memory);
462 }
463 }
464 }
465 }
466
467 fn handle_spawn_request(
470 state: &mut TaskManagerState<Req>,
471 max_cpu: u64,
472 max_memory: u64,
473 request: Req,
474 completed: oneshot::Sender<Result<TaskExecutionResult>>,
475 ) {
476 let cpu = request.cpu();
478 if cpu > max_cpu as f64 {
479 completed
480 .send(Err(anyhow!(
481 "requested task CPU count of {cpu} exceeds the maximum CPU count of {max_cpu}",
482 )))
483 .ok();
484 return;
485 }
486
487 let memory = request.memory();
489 if memory > max_memory {
490 completed
491 .send(Err(anyhow!(
492 "requested task memory of {memory} byte{s} exceeds the maximum memory of \
493 {max_memory}",
494 s = if memory == 1 { "" } else { "s" }
495 )))
496 .ok();
497 return;
498 }
499
500 if !state.unlimited() {
501 if cpu > state.cpu.into() || memory > state.memory {
505 debug!(
506 "parking task due to insufficient resources: task reserves {cpu} CPU(s) and \
507 {memory} bytes of memory but there are only {cpu_remaining} CPU(s) and \
508 {memory_remaining} bytes of memory available",
509 cpu_remaining = state.cpu,
510 memory_remaining = state.memory
511 );
512 state.parked.push_back((request, completed));
513 return;
514 }
515
516 state.cpu -= cpu;
518 state.memory -= memory;
519 debug!(
520 "spawning task with {cpu} CPUs and {memory} bytes of memory remaining",
521 cpu = state.cpu,
522 memory = state.memory
523 );
524 }
525
526 state.spawned.spawn(async move {
527 TaskManagerResponse {
528 cpu: request.cpu(),
529 memory: request.memory(),
530 result: request.run().await,
531 tx: completed,
532 }
533 });
534 }
535
536 fn spawn_parked_tasks(state: &mut TaskManagerState<Req>, max_cpu: u64, max_memory: u64) {
538 if state.parked.is_empty() {
539 return;
540 }
541
542 debug!(
543 "attempting to unpark tasks with {cpu} CPUs and {memory} bytes of memory available",
544 cpu = state.cpu,
545 memory = state.memory,
546 );
547
548 loop {
560 let cpu_by_memory_len = {
561 let range =
564 fit_longest_range(state.parked.make_contiguous(), state.cpu, |(r, ..)| {
565 OrderedFloat(r.cpu())
566 });
567
568 fit_longest_range(
571 &mut state.parked.make_contiguous()[range],
572 state.memory,
573 |(r, ..)| r.memory(),
574 )
575 .len()
576 };
577
578 let memory_by_cpu =
581 fit_longest_range(state.parked.make_contiguous(), state.memory, |(r, ..)| {
582 r.memory()
583 });
584
585 let memory_by_cpu = fit_longest_range(
588 &mut state.parked.make_contiguous()[memory_by_cpu],
589 state.cpu,
590 |(r, ..)| OrderedFloat(r.cpu()),
591 );
592
593 if cpu_by_memory_len == 0 && memory_by_cpu.is_empty() {
595 break;
596 }
597
598 let range = if memory_by_cpu.len() >= cpu_by_memory_len {
601 memory_by_cpu
602 } else {
603 let range =
606 fit_longest_range(state.parked.make_contiguous(), state.cpu, |(r, ..)| {
607 OrderedFloat(r.cpu())
608 });
609
610 fit_longest_range(
611 &mut state.parked.make_contiguous()[range],
612 state.memory,
613 |(r, ..)| r.memory(),
614 )
615 };
616
617 debug!("unparking {len} task(s)", len = range.len());
618
619 assert_eq!(
620 range.start, 0,
621 "expected the fit tasks to be at the front of the queue"
622 );
623 for _ in range {
624 let (request, completed) = state.parked.pop_front().unwrap();
625
626 debug!(
627 "unparking task with reservation of {cpu} CPU(s) and {memory} bytes of memory",
628 cpu = request.cpu(),
629 memory = request.memory(),
630 );
631
632 Self::handle_spawn_request(state, max_cpu, max_memory, request, completed);
633 }
634 }
635 }
636}
637
638fn fit_longest_range<T, F, W>(slice: &mut [T], total_weight: W, mut weight_fn: F) -> Range<usize>
675where
676 F: FnMut(&T) -> W,
677 W: Ord + Add<Output = W> + Sub<Output = W> + Default,
678{
679 fn partition<T, F, W>(
686 slice: &mut [T],
687 weight_fn: &mut F,
688 mut low: usize,
689 high: usize,
690 ) -> (usize, W, W)
691 where
692 F: FnMut(&T) -> W,
693 W: Ord + Add<Output = W> + Sub<Output = W> + Default,
694 {
695 assert!(low < high);
696
697 slice.swap(high, rand::random_range(low..high));
699
700 let pivot_weight = weight_fn(&slice[high]);
701 let mut sum_weight = W::default();
702 let range = low..=high;
703 for i in range {
704 let weight = weight_fn(&slice[i]);
705 if weight < pivot_weight {
707 slice.swap(i, low);
708 low += 1;
709 sum_weight = sum_weight.add(weight);
710 }
711 }
712
713 slice.swap(low, high);
714 (low, pivot_weight, sum_weight)
715 }
716
717 fn recurse_fit_maximal_range<T, F, W>(
718 slice: &mut [T],
719 mut remaining_weight: W,
720 weight_fn: &mut F,
721 low: usize,
722 high: usize,
723 end: &mut usize,
724 ) where
725 F: FnMut(&T) -> W,
726 W: Ord + Add<Output = W> + Sub<Output = W> + Default,
727 {
728 if low == high {
729 let weight = weight_fn(&slice[low]);
730 if weight <= remaining_weight {
731 *end += 1;
732 }
733
734 return;
735 }
736
737 if low < high {
738 let (pivot, pivot_weight, sum) = partition(slice, weight_fn, low, high);
739 if sum <= remaining_weight {
740 *end += pivot - low;
742 remaining_weight = remaining_weight.sub(sum);
743
744 if pivot_weight <= remaining_weight {
746 *end += 1;
747 remaining_weight = remaining_weight.sub(pivot_weight);
748 }
749
750 recurse_fit_maximal_range(slice, remaining_weight, weight_fn, pivot + 1, high, end);
752 } else if pivot > 0 {
753 recurse_fit_maximal_range(slice, remaining_weight, weight_fn, low, pivot - 1, end);
756 }
757 }
758 }
759
760 assert!(
761 total_weight >= W::default(),
762 "total weight cannot be negative"
763 );
764
765 if slice.is_empty() {
766 return 0..0;
767 }
768
769 let mut end = 0;
770 recurse_fit_maximal_range(
771 slice,
772 total_weight,
773 &mut weight_fn,
774 0,
775 slice.len() - 1, &mut end,
777 );
778
779 0..end
780}
781
782#[cfg(test)]
783mod test {
784 use super::*;
785
786 #[test]
787 fn fit_empty_slice() {
788 let r = fit_longest_range(&mut [], 100, |i| *i);
789 assert!(r.is_empty());
790 }
791
792 #[test]
793 #[should_panic(expected = "total weight cannot be negative")]
794 fn fit_negative_panic() {
795 fit_longest_range(&mut [0], -1, |i| *i);
796 }
797
798 #[test]
799 fn no_fit() {
800 let r = fit_longest_range(&mut [100, 101, 102], 99, |i| *i);
801 assert!(r.is_empty());
802 }
803
804 #[test]
805 fn fit_all() {
806 let r = fit_longest_range(&mut [1, 2, 3, 4, 5], 15, |i| *i);
807 assert_eq!(r.len(), 5);
808
809 let r = fit_longest_range(&mut [5, 4, 3, 2, 1], 20, |i| *i);
810 assert_eq!(r.len(), 5);
811 }
812
813 #[test]
814 fn fit_some() {
815 let s = &mut [8, 2, 2, 3, 2, 1, 2, 4, 1];
816 let r = fit_longest_range(s, 10, |i| *i);
817 assert_eq!(r.len(), 6);
818 assert_eq!(s[r.start..r.end].iter().copied().sum::<i32>(), 10);
819 assert!(s[r.end..].contains(&8));
820 assert!(s[r.end..].contains(&4));
821 assert!(s[r.end..].contains(&3));
822 }
823
824 #[test]
825 fn unlimited_state() {
826 let manager_state = TaskManagerState::<()>::new(u64::MAX, u64::MAX);
827 assert!(manager_state.unlimited());
828 }
829}