1use crate::helper::generate_simple_id;
2use crate::remote_connection::{RemoteConnectionInfo, RemoteConnectionManager};
3use chrono::{DateTime, Utc};
4use std::{collections::HashMap, process::Stdio, sync::Arc, time::Duration};
5use tokio::{
6 io::{AsyncBufReadExt, BufReader},
7 process::Command,
8 sync::{broadcast, mpsc, oneshot},
9 time::timeout,
10};
11
12const START_TASK_WAIT_TIME: Duration = Duration::from_millis(300);
13
14fn terminate_process_group(process_id: u32) {
22 #[cfg(unix)]
23 {
24 use std::process::Command;
25 let check_result = Command::new("kill")
27 .arg("-0") .arg(process_id.to_string())
29 .output();
30
31 if check_result
33 .map(|output| output.status.success())
34 .unwrap_or(false)
35 {
36 let _ = Command::new("kill")
40 .arg("-9")
41 .arg(format!("-{}", process_id))
42 .output();
43
44 let _ = Command::new("kill")
46 .arg("-9")
47 .arg(process_id.to_string())
48 .output();
49 }
50 }
51
52 #[cfg(windows)]
53 {
54 use std::process::Command;
55 let check_result = Command::new("tasklist")
57 .arg("/FI")
58 .arg(format!("PID eq {}", process_id))
59 .arg("/FO")
60 .arg("CSV")
61 .output();
62
63 if let Ok(output) = check_result {
65 let output_str = String::from_utf8_lossy(&output.stdout);
66 if output_str.lines().count() > 1 {
67 let _ = Command::new("taskkill")
69 .arg("/F")
70 .arg("/T") .arg("/PID")
72 .arg(process_id.to_string())
73 .output();
74 }
75 }
76 }
77}
78
79pub type TaskId = String;
80
81#[derive(Debug, Clone, PartialEq, serde::Serialize)]
82pub enum TaskStatus {
83 Pending,
84 Running,
85 Completed,
86 Failed,
87 Cancelled,
88 TimedOut,
89 Paused,
90}
91
92#[derive(Debug, Clone)]
93pub struct Task {
94 pub id: TaskId,
95 pub status: TaskStatus,
96 pub command: String,
97 pub description: Option<String>,
98 pub remote_connection: Option<RemoteConnectionInfo>,
99 pub output: Option<String>,
100 pub error: Option<String>,
101 pub start_time: DateTime<Utc>,
102 pub duration: Option<Duration>,
103 pub timeout: Option<Duration>,
104 pub pause_info: Option<PauseInfo>,
105}
106
107pub struct TaskEntry {
108 pub task: Task,
109 pub handle: tokio::task::JoinHandle<()>,
110 pub process_id: Option<u32>,
111 pub cancel_tx: Option<oneshot::Sender<()>>,
112}
113
114#[derive(Debug, Clone, serde::Serialize)]
115pub struct TaskInfo {
116 pub id: TaskId,
117 pub status: TaskStatus,
118 pub command: String,
119 pub description: Option<String>,
120 pub output: Option<String>,
121 pub start_time: DateTime<Utc>,
122 pub duration: Option<Duration>,
123 pub pause_info: Option<PauseInfo>,
124}
125
126impl From<&Task> for TaskInfo {
127 fn from(task: &Task) -> Self {
128 let duration = if matches!(task.status, TaskStatus::Running) {
129 Some(
131 Utc::now()
132 .signed_duration_since(task.start_time)
133 .to_std()
134 .unwrap_or_default(),
135 )
136 } else {
137 task.duration
139 };
140
141 TaskInfo {
142 id: task.id.clone(),
143 status: task.status.clone(),
144 command: task.command.clone(),
145 description: task.description.clone(),
146 output: task.output.clone(),
147 start_time: task.start_time,
148 duration,
149 pause_info: task.pause_info.clone(),
150 }
151 }
152}
153
154pub struct TaskCompletion {
155 pub output: String,
156 pub error: Option<String>,
157 pub final_status: TaskStatus,
158}
159
160#[derive(Debug, Clone, serde::Serialize)]
161pub struct PauseInfo {
162 pub checkpoint_id: Option<String>,
163 pub raw_output: Option<String>,
164}
165
166#[derive(Debug, thiserror::Error)]
167pub enum TaskError {
168 #[error("Task not found: {0}")]
169 TaskNotFound(TaskId),
170 #[error("Task already running: {0}")]
171 TaskAlreadyRunning(TaskId),
172 #[error("Manager shutdown")]
173 ManagerShutdown,
174 #[error("Command execution failed: {0}")]
175 ExecutionFailed(String),
176 #[error("Task timeout")]
177 TaskTimeout,
178 #[error("Task cancelled")]
179 TaskCancelled,
180 #[error("Task failed on start: {0}")]
181 TaskFailedOnStart(String),
182 #[error("Task not paused: {0}")]
183 TaskNotPaused(TaskId),
184}
185
186pub enum TaskMessage {
187 Start {
188 id: Option<TaskId>,
189 command: String,
190 description: Option<String>,
191 remote_connection: Option<RemoteConnectionInfo>,
192 timeout: Option<Duration>,
193 response_tx: oneshot::Sender<Result<TaskId, TaskError>>,
194 },
195 Cancel {
196 id: TaskId,
197 response_tx: oneshot::Sender<Result<(), TaskError>>,
198 },
199 GetStatus {
200 id: TaskId,
201 response_tx: oneshot::Sender<Option<TaskStatus>>,
202 },
203 GetTaskDetails {
204 id: TaskId,
205 response_tx: oneshot::Sender<Option<TaskInfo>>,
206 },
207 GetAllTasks {
208 response_tx: oneshot::Sender<Vec<TaskInfo>>,
209 },
210 Shutdown {
211 response_tx: oneshot::Sender<()>,
212 },
213 TaskUpdate {
214 id: TaskId,
215 completion: TaskCompletion,
216 },
217 PartialUpdate {
218 id: TaskId,
219 output: String,
220 },
221 Resume {
222 id: TaskId,
223 command: String,
224 response_tx: oneshot::Sender<Result<(), TaskError>>,
225 },
226}
227
228pub struct TaskManager {
229 tasks: HashMap<TaskId, TaskEntry>,
230 tx: mpsc::UnboundedSender<TaskMessage>,
231 rx: mpsc::UnboundedReceiver<TaskMessage>,
232 shutdown_tx: broadcast::Sender<()>,
233 shutdown_rx: broadcast::Receiver<()>,
234}
235
236impl Default for TaskManager {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242impl TaskManager {
243 pub fn new() -> Self {
244 let (tx, rx) = mpsc::unbounded_channel();
245 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
246
247 Self {
248 tasks: HashMap::new(),
249 tx,
250 rx,
251 shutdown_tx,
252 shutdown_rx,
253 }
254 }
255
256 pub fn handle(&self) -> Arc<TaskManagerHandle> {
257 Arc::new(TaskManagerHandle {
258 tx: self.tx.clone(),
259 shutdown_tx: self.shutdown_tx.clone(),
260 })
261 }
262
263 pub async fn run(mut self) {
264 loop {
265 tokio::select! {
266 msg = self.rx.recv() => {
267 match msg {
268 Some(msg) => {
269 if self.handle_message(msg).await {
270 break;
271 }
272 }
273 None => {
274 self.shutdown_all_tasks().await;
277 break;
278 }
279 }
280 }
281 _ = self.shutdown_rx.recv() => {
282 self.shutdown_all_tasks().await;
283 break;
284 }
285 }
286 }
287 }
288
289 async fn handle_message(&mut self, msg: TaskMessage) -> bool {
290 match msg {
291 TaskMessage::Start {
292 id,
293 command,
294 description,
295 remote_connection,
296 timeout,
297 response_tx,
298 } => {
299 let task_id = id.unwrap_or_else(|| generate_simple_id(6));
300 let result = self
301 .start_task(
302 task_id.clone(),
303 command,
304 description,
305 timeout,
306 remote_connection,
307 )
308 .await;
309 let _ = response_tx.send(result.map(|_| task_id.clone()));
310 false
311 }
312 TaskMessage::Cancel { id, response_tx } => {
313 let result = self.cancel_task(&id).await;
314 let _ = response_tx.send(result);
315 false
316 }
317 TaskMessage::GetStatus { id, response_tx } => {
318 let status = self.tasks.get(&id).map(|entry| entry.task.status.clone());
319 let _ = response_tx.send(status);
320 false
321 }
322 TaskMessage::GetTaskDetails { id, response_tx } => {
323 let task_info = self.tasks.get(&id).map(|entry| TaskInfo::from(&entry.task));
324 let _ = response_tx.send(task_info);
325 false
326 }
327 TaskMessage::GetAllTasks { response_tx } => {
328 let mut tasks: Vec<TaskInfo> = self
329 .tasks
330 .values()
331 .map(|entry| TaskInfo::from(&entry.task))
332 .collect();
333 tasks.sort_by(|a, b| b.start_time.cmp(&a.start_time));
334 let _ = response_tx.send(tasks);
335 false
336 }
337 TaskMessage::TaskUpdate { id, completion } => {
338 if let Some(entry) = self.tasks.get_mut(&id) {
339 entry.task.status = completion.final_status.clone();
340 entry.task.output = Some(completion.output.clone());
341 entry.task.error = completion.error;
342 entry.task.duration = Some(
343 Utc::now()
344 .signed_duration_since(entry.task.start_time)
345 .to_std()
346 .unwrap_or_default(),
347 );
348
349 if matches!(
351 completion.final_status,
352 TaskStatus::Paused | TaskStatus::Completed
353 ) {
354 let checkpoint_id =
355 serde_json::from_str::<serde_json::Value>(&completion.output)
356 .ok()
357 .and_then(|v| {
358 v.get("checkpoint_id")
359 .and_then(|c| c.as_str())
360 .map(|s| s.to_string())
361 });
362 entry.task.pause_info = Some(PauseInfo {
363 checkpoint_id,
364 raw_output: Some(completion.output),
365 });
366 }
367
368 }
374 false
375 }
376 TaskMessage::PartialUpdate { id, output } => {
377 if let Some(entry) = self.tasks.get_mut(&id) {
378 match &entry.task.output {
379 Some(existing) => {
380 entry.task.output = Some(format!("{}{}", existing, output));
381 }
382 None => {
383 entry.task.output = Some(output);
384 }
385 }
386 }
387 false
388 }
389 TaskMessage::Resume {
390 id,
391 command,
392 response_tx,
393 } => {
394 let result = self.resume_task(id, command).await;
395 let _ = response_tx.send(result);
396 false
397 }
398 TaskMessage::Shutdown { response_tx } => {
399 self.shutdown_all_tasks().await;
400 let _ = response_tx.send(());
401 true
402 }
403 }
404 }
405
406 async fn start_task(
407 &mut self,
408 id: TaskId,
409 command: String,
410 description: Option<String>,
411 timeout: Option<Duration>,
412 remote_connection: Option<RemoteConnectionInfo>,
413 ) -> Result<(), TaskError> {
414 if self.tasks.contains_key(&id) {
415 return Err(TaskError::TaskAlreadyRunning(id));
416 }
417
418 let task = Task {
419 id: id.clone(),
420 status: TaskStatus::Running,
421 command: command.clone(),
422 description,
423 remote_connection: remote_connection.clone(),
424 output: None,
425 error: None,
426 start_time: Utc::now(),
427 duration: None,
428 timeout,
429 pause_info: None,
430 };
431
432 let (cancel_tx, cancel_rx) = oneshot::channel();
433 let (process_tx, process_rx) = oneshot::channel();
434 let task_tx: mpsc::UnboundedSender<TaskMessage> = self.tx.clone();
435
436 let is_remote_task = remote_connection.is_some();
437
438 let handle = tokio::spawn(Self::execute_task(
440 id.clone(),
441 command,
442 remote_connection,
443 timeout,
444 cancel_rx,
445 process_tx,
446 task_tx,
447 ));
448
449 let entry = TaskEntry {
450 task,
451 handle,
452 process_id: None,
453 cancel_tx: Some(cancel_tx),
454 };
455
456 self.tasks.insert(id.clone(), entry);
457
458 if !is_remote_task {
460 if let Ok(process_id) = process_rx.await
462 && let Some(entry) = self.tasks.get_mut(&id)
463 {
464 entry.process_id = Some(process_id);
465 }
466 }
467 Ok(())
470 }
471
472 async fn resume_task(&mut self, id: TaskId, command: String) -> Result<(), TaskError> {
473 if let Some(entry) = self.tasks.get(&id) {
475 if !matches!(
476 entry.task.status,
477 TaskStatus::Paused | TaskStatus::Completed
478 ) {
479 return Err(TaskError::TaskNotPaused(id));
480 }
481 } else {
482 return Err(TaskError::TaskNotFound(id));
483 }
484
485 let entry = self.tasks.get_mut(&id).unwrap();
487 entry.task.status = TaskStatus::Running;
488 entry.task.command = command.clone();
489 entry.task.pause_info = None;
490 entry.task.output = None;
491 entry.task.error = None;
492
493 let (cancel_tx, cancel_rx) = oneshot::channel();
494 let (process_tx, process_rx) = oneshot::channel();
495 let task_tx = self.tx.clone();
496
497 let remote_connection = entry.task.remote_connection.clone();
498 let timeout = entry.task.timeout;
499
500 let handle = tokio::spawn(Self::execute_task(
501 id.clone(),
502 command,
503 remote_connection.clone(),
504 timeout,
505 cancel_rx,
506 process_tx,
507 task_tx,
508 ));
509
510 entry.handle = handle;
511 entry.cancel_tx = Some(cancel_tx);
512 entry.process_id = None;
513
514 if remote_connection.is_none()
516 && let Ok(process_id) = process_rx.await
517 && let Some(entry) = self.tasks.get_mut(&id)
518 {
519 entry.process_id = Some(process_id);
520 }
521
522 Ok(())
523 }
524
525 async fn cancel_task(&mut self, id: &TaskId) -> Result<(), TaskError> {
526 if let Some(mut entry) = self.tasks.remove(id) {
527 entry.task.status = TaskStatus::Cancelled;
528
529 if let Some(cancel_tx) = entry.cancel_tx.take() {
530 let _ = cancel_tx.send(());
531 }
532
533 if let Some(process_id) = entry.process_id {
534 terminate_process_group(process_id);
535 }
536
537 entry.handle.abort();
538 Ok(())
539 } else {
540 Err(TaskError::TaskNotFound(id.clone()))
541 }
542 }
543
544 async fn execute_task(
545 id: TaskId,
546 command: String,
547 remote_connection: Option<RemoteConnectionInfo>,
548 task_timeout: Option<Duration>,
549 mut cancel_rx: oneshot::Receiver<()>,
550 process_tx: oneshot::Sender<u32>,
551 task_tx: mpsc::UnboundedSender<TaskMessage>,
552 ) {
553 let completion = if let Some(remote_info) = remote_connection {
554 Self::execute_remote_task(
556 id.clone(),
557 command,
558 remote_info,
559 task_timeout,
560 &mut cancel_rx,
561 &task_tx,
562 )
563 .await
564 } else {
565 Self::execute_local_task(
567 id.clone(),
568 command,
569 task_timeout,
570 &mut cancel_rx,
571 process_tx,
572 &task_tx,
573 )
574 .await
575 };
576
577 let _ = task_tx.send(TaskMessage::TaskUpdate {
579 id: id.clone(),
580 completion,
581 });
582 }
583
584 async fn execute_local_task(
585 id: TaskId,
586 command: String,
587 task_timeout: Option<Duration>,
588 cancel_rx: &mut oneshot::Receiver<()>,
589 process_tx: oneshot::Sender<u32>,
590 task_tx: &mpsc::UnboundedSender<TaskMessage>,
591 ) -> TaskCompletion {
592 let mut cmd = Command::new("sh");
593 cmd.arg("-c")
594 .arg(&command)
595 .stdin(Stdio::null())
596 .stdout(Stdio::piped())
597 .stderr(Stdio::piped());
598 #[cfg(unix)]
599 {
600 cmd.env("DEBIAN_FRONTEND", "noninteractive")
601 .env("SUDO_ASKPASS", "/bin/false")
602 .process_group(0);
603 }
604 #[cfg(windows)]
605 {
606 cmd.creation_flags(0x00000200); }
609
610 let mut child = match cmd.spawn() {
611 Ok(child) => child,
612 Err(err) => {
613 return TaskCompletion {
614 output: String::new(),
615 error: Some(format!("Failed to spawn command: {}", err)),
616 final_status: TaskStatus::Failed,
617 };
618 }
619 };
620
621 if let Some(process_id) = child.id() {
623 let _ = process_tx.send(process_id);
624 }
625
626 let stdout = child.stdout.take().unwrap();
628 let stderr = child.stderr.take().unwrap();
629
630 let stdout_reader = BufReader::new(stdout);
631 let stderr_reader = BufReader::new(stderr);
632
633 let mut stdout_lines = stdout_reader.lines();
634 let mut stderr_lines = stderr_reader.lines();
635
636 let stream_output = async {
638 let mut final_output = String::new();
639 let mut final_error: Option<String> = None;
640
641 loop {
642 tokio::select! {
643 line = stdout_lines.next_line() => {
644 match line {
645 Ok(Some(line)) => {
646 let output_line = format!("{}\n", line);
647 final_output.push_str(&output_line);
648 let _ = task_tx.send(TaskMessage::PartialUpdate {
649 id: id.clone(),
650 output: output_line,
651 });
652 }
653 Ok(None) => {
654 }
656 Err(err) => {
657 final_error = Some(format!("Error reading stdout: {}", err));
658 break;
659 }
660 }
661 }
662 line = stderr_lines.next_line() => {
663 match line {
664 Ok(Some(line)) => {
665 let output_line = format!("{}\n", line);
666 final_output.push_str(&output_line);
667 let _ = task_tx.send(TaskMessage::PartialUpdate {
668 id: id.clone(),
669 output: output_line,
670 });
671 }
672 Ok(None) => {
673 }
675 Err(err) => {
676 final_error = Some(format!("Error reading stderr: {}", err));
677 break;
678 }
679 }
680 }
681 status = child.wait() => {
682 match status {
683 Ok(exit_status) => {
684 if final_output.is_empty() {
685 final_output = "No output".to_string();
686 }
687
688 let completion = if exit_status.success() {
689 TaskCompletion {
690 output: final_output,
691 error: final_error,
692 final_status: TaskStatus::Completed,
693 }
694 } else if exit_status.code() == Some(10) {
695 TaskCompletion {
696 output: final_output,
697 error: None,
698 final_status: TaskStatus::Paused,
699 }
700 } else {
701 TaskCompletion {
702 output: final_output,
703 error: final_error.or_else(|| Some(format!("Command failed with exit code: {:?}", exit_status.code()))),
704 final_status: TaskStatus::Failed,
705 }
706 };
707 return completion;
708 }
709 Err(err) => {
710 return TaskCompletion {
711 output: final_output,
712 error: Some(err.to_string()),
713 final_status: TaskStatus::Failed,
714 };
715 }
716 }
717 }
718 _ = &mut *cancel_rx => {
719 return TaskCompletion {
720 output: final_output,
721 error: Some("Tool call was cancelled and don't try to run it again".to_string()),
722 final_status: TaskStatus::Cancelled,
723 };
724 }
725 }
726 }
727
728 TaskCompletion {
729 output: final_output,
730 error: final_error,
731 final_status: TaskStatus::Failed,
732 }
733 };
734
735 if let Some(timeout_duration) = task_timeout {
737 match timeout(timeout_duration, stream_output).await {
738 Ok(result) => result,
739 Err(_) => TaskCompletion {
740 output: String::new(),
741 error: Some("Task timed out".to_string()),
742 final_status: TaskStatus::TimedOut,
743 },
744 }
745 } else {
746 stream_output.await
747 }
748 }
749
750 async fn execute_remote_task(
751 id: TaskId,
752 command: String,
753 remote_info: RemoteConnectionInfo,
754 task_timeout: Option<Duration>,
755 cancel_rx: &mut oneshot::Receiver<()>,
756 task_tx: &mpsc::UnboundedSender<TaskMessage>,
757 ) -> TaskCompletion {
758 let connection_manager = RemoteConnectionManager::new();
760 let connection = match connection_manager.get_connection(&remote_info).await {
761 Ok(conn) => conn,
762 Err(e) => {
763 return TaskCompletion {
764 output: String::new(),
765 error: Some(format!("Failed to establish remote connection: {}", e)),
766 final_status: TaskStatus::Failed,
767 };
768 }
769 };
770
771 let task_tx_clone = task_tx.clone();
773 let id_clone = id.clone();
774 let progress_callback = move |output: String| {
775 if !output.trim().is_empty() {
776 let _ = task_tx_clone.send(TaskMessage::PartialUpdate {
777 id: id_clone.clone(),
778 output,
779 });
780 }
781 };
782
783 let options = crate::remote_connection::CommandOptions {
785 timeout: task_timeout,
786 with_progress: false,
787 simple: false,
788 };
789
790 match connection
791 .execute_command_unified(&command, options, cancel_rx, Some(progress_callback), None)
792 .await
793 {
794 Ok((output, exit_code)) => TaskCompletion {
795 output,
796 error: if exit_code != 0 {
797 Some(format!("Command exited with code {}", exit_code))
798 } else {
799 None
800 },
801 final_status: TaskStatus::Completed,
802 },
803 Err(e) => {
804 let error_msg = e.to_string();
805 let status = if error_msg.contains("timed out") {
806 TaskStatus::TimedOut
807 } else if error_msg.contains("cancelled") {
808 TaskStatus::Cancelled
809 } else {
810 TaskStatus::Failed
811 };
812
813 TaskCompletion {
814 output: String::new(),
815 error: Some(if error_msg.contains("cancelled") {
816 "Tool call was cancelled and don't try to run it again".to_string()
817 } else {
818 format!("Remote command failed: {}", error_msg)
819 }),
820 final_status: status,
821 }
822 }
823 }
824 }
825
826 async fn shutdown_all_tasks(&mut self) {
827 for (_id, mut entry) in self.tasks.drain() {
828 if let Some(cancel_tx) = entry.cancel_tx.take() {
829 let _ = cancel_tx.send(());
830 }
831
832 if let Some(process_id) = entry.process_id {
833 terminate_process_group(process_id);
834 }
835
836 entry.handle.abort();
837 }
838 }
839}
840
841pub struct TaskManagerHandle {
842 tx: mpsc::UnboundedSender<TaskMessage>,
843 shutdown_tx: broadcast::Sender<()>,
844}
845
846impl Drop for TaskManagerHandle {
847 fn drop(&mut self) {
848 let _ = self.shutdown_tx.send(());
857 }
858}
859
860impl TaskManagerHandle {
861 pub async fn start_task(
862 &self,
863 command: String,
864 description: Option<String>,
865 timeout: Option<Duration>,
866 remote_connection: Option<RemoteConnectionInfo>,
867 ) -> Result<TaskInfo, TaskError> {
868 let (response_tx, response_rx) = oneshot::channel();
869
870 self.tx
871 .send(TaskMessage::Start {
872 id: None,
873 command: command.clone(),
874 description,
875 remote_connection: remote_connection.clone(),
876 timeout,
877 response_tx,
878 })
879 .map_err(|_| TaskError::ManagerShutdown)?;
880
881 let task_id = response_rx
882 .await
883 .map_err(|_| TaskError::ManagerShutdown)??;
884
885 tokio::time::sleep(START_TASK_WAIT_TIME).await;
887
888 let task_info = self
889 .get_task_details(task_id.clone())
890 .await
891 .map_err(|_| TaskError::ManagerShutdown)?
892 .ok_or_else(|| TaskError::TaskNotFound(task_id.clone()))?;
893
894 if matches!(task_info.status, TaskStatus::Failed | TaskStatus::Cancelled) {
896 return Err(TaskError::TaskFailedOnStart(
897 task_info
898 .output
899 .unwrap_or_else(|| "Unknown reason".to_string()),
900 ));
901 }
902
903 Ok(task_info)
905 }
906
907 pub async fn cancel_task(&self, id: TaskId) -> Result<TaskInfo, TaskError> {
908 let task_info = self
910 .get_all_tasks()
911 .await?
912 .into_iter()
913 .find(|task| task.id == id)
914 .ok_or_else(|| TaskError::TaskNotFound(id.clone()))?;
915
916 let (response_tx, response_rx) = oneshot::channel();
917
918 self.tx
919 .send(TaskMessage::Cancel { id, response_tx })
920 .map_err(|_| TaskError::ManagerShutdown)?;
921
922 response_rx
923 .await
924 .map_err(|_| TaskError::ManagerShutdown)??;
925
926 Ok(TaskInfo {
928 status: TaskStatus::Cancelled,
929 duration: Some(
930 Utc::now()
931 .signed_duration_since(task_info.start_time)
932 .to_std()
933 .unwrap_or_default(),
934 ),
935 ..task_info
936 })
937 }
938
939 pub async fn resume_task(&self, id: TaskId, command: String) -> Result<TaskInfo, TaskError> {
940 let (response_tx, response_rx) = oneshot::channel();
941
942 self.tx
943 .send(TaskMessage::Resume {
944 id: id.clone(),
945 command,
946 response_tx,
947 })
948 .map_err(|_| TaskError::ManagerShutdown)?;
949
950 response_rx
951 .await
952 .map_err(|_| TaskError::ManagerShutdown)??;
953
954 tokio::time::sleep(START_TASK_WAIT_TIME).await;
956
957 let task_info = self
958 .get_task_details(id.clone())
959 .await
960 .map_err(|_| TaskError::ManagerShutdown)?
961 .ok_or(TaskError::TaskNotFound(id))?;
962
963 Ok(task_info)
964 }
965
966 pub async fn get_task_status(&self, id: TaskId) -> Result<Option<TaskStatus>, TaskError> {
967 let (response_tx, response_rx) = oneshot::channel();
968
969 self.tx
970 .send(TaskMessage::GetStatus { id, response_tx })
971 .map_err(|_| TaskError::ManagerShutdown)?;
972
973 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
974 }
975
976 pub async fn get_task_details(&self, id: TaskId) -> Result<Option<TaskInfo>, TaskError> {
977 let (response_tx, response_rx) = oneshot::channel();
978
979 self.tx
980 .send(TaskMessage::GetTaskDetails { id, response_tx })
981 .map_err(|_| TaskError::ManagerShutdown)?;
982
983 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
984 }
985
986 pub async fn get_all_tasks(&self) -> Result<Vec<TaskInfo>, TaskError> {
987 let (response_tx, response_rx) = oneshot::channel();
988
989 self.tx
990 .send(TaskMessage::GetAllTasks { response_tx })
991 .map_err(|_| TaskError::ManagerShutdown)?;
992
993 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
994 }
995
996 pub async fn shutdown(&self) -> Result<(), TaskError> {
997 let (response_tx, response_rx) = oneshot::channel();
998
999 self.tx
1000 .send(TaskMessage::Shutdown { response_tx })
1001 .map_err(|_| TaskError::ManagerShutdown)?;
1002
1003 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1004 }
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010 use tokio::time::{Duration, sleep};
1011
1012 #[tokio::test]
1013 async fn test_task_manager_shutdown() {
1014 let task_manager = TaskManager::new();
1015 let handle = task_manager.handle();
1016
1017 let manager_handle = tokio::spawn(async move {
1019 task_manager.run().await;
1020 });
1021
1022 let task_info = handle
1024 .start_task("sleep 5".to_string(), None, None, None)
1025 .await
1026 .expect("Failed to start task");
1027
1028 let status = handle
1030 .get_task_status(task_info.id.clone())
1031 .await
1032 .expect("Failed to get task status");
1033 assert_eq!(status, Some(TaskStatus::Running));
1034
1035 handle
1037 .shutdown()
1038 .await
1039 .expect("Failed to shutdown task manager");
1040
1041 sleep(Duration::from_millis(100)).await;
1043
1044 assert!(manager_handle.is_finished());
1046 }
1047
1048 #[tokio::test]
1049 async fn test_task_manager_cancels_tasks_on_shutdown() {
1050 let task_manager = TaskManager::new();
1051 let handle = task_manager.handle();
1052
1053 let manager_handle = tokio::spawn(async move {
1055 task_manager.run().await;
1056 });
1057
1058 let task_info = handle
1060 .start_task("sleep 10".to_string(), None, None, None)
1061 .await
1062 .expect("Failed to start task");
1063
1064 let status = handle
1066 .get_task_status(task_info.id.clone())
1067 .await
1068 .expect("Failed to get task status");
1069 assert_eq!(status, Some(TaskStatus::Running));
1070
1071 handle
1073 .shutdown()
1074 .await
1075 .expect("Failed to shutdown task manager");
1076
1077 sleep(Duration::from_millis(100)).await;
1079
1080 assert!(manager_handle.is_finished());
1082 }
1083
1084 #[tokio::test]
1085 async fn test_task_manager_start_and_complete_task() {
1086 let task_manager = TaskManager::new();
1087 let handle = task_manager.handle();
1088
1089 let _manager_handle = tokio::spawn(async move {
1091 task_manager.run().await;
1092 });
1093
1094 let task_info = handle
1096 .start_task("echo 'Hello, World!'".to_string(), None, None, None)
1097 .await
1098 .expect("Failed to start task");
1099
1100 sleep(Duration::from_millis(500)).await;
1102
1103 let status = handle
1105 .get_task_status(task_info.id.clone())
1106 .await
1107 .expect("Failed to get task status");
1108 assert_eq!(status, Some(TaskStatus::Completed));
1109
1110 let tasks = handle
1112 .get_all_tasks()
1113 .await
1114 .expect("Failed to get all tasks");
1115 assert_eq!(tasks.len(), 1);
1116 assert_eq!(tasks[0].status, TaskStatus::Completed);
1117
1118 handle
1120 .shutdown()
1121 .await
1122 .expect("Failed to shutdown task manager");
1123 }
1124
1125 #[tokio::test]
1126 async fn test_task_manager_detects_immediate_failure() {
1127 let task_manager = TaskManager::new();
1128 let handle = task_manager.handle();
1129
1130 let _manager_handle = tokio::spawn(async move {
1132 task_manager.run().await;
1133 });
1134
1135 let result = handle
1137 .start_task("nonexistent_command_12345".to_string(), None, None, None)
1138 .await;
1139
1140 assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1142
1143 handle
1145 .shutdown()
1146 .await
1147 .expect("Failed to shutdown task manager");
1148 }
1149
1150 #[tokio::test]
1151 async fn test_task_manager_handle_drop_triggers_shutdown() {
1152 let task_manager = TaskManager::new();
1153 let handle = task_manager.handle();
1154
1155 let manager_handle = tokio::spawn(async move {
1156 task_manager.run().await;
1157 });
1158
1159 let _task_info = handle
1161 .start_task("sleep 30".to_string(), None, None, None)
1162 .await
1163 .expect("Failed to start task");
1164
1165 drop(handle);
1167
1168 sleep(Duration::from_millis(500)).await;
1172
1173 assert!(
1174 manager_handle.is_finished(),
1175 "TaskManager::run() should have exited after handle was dropped"
1176 );
1177 }
1178
1179 #[tokio::test]
1180 async fn test_task_manager_handle_drop_kills_child_processes() {
1181 let task_manager = TaskManager::new();
1182 let handle = task_manager.handle();
1183
1184 let _manager_handle = tokio::spawn(async move {
1185 task_manager.run().await;
1186 });
1187
1188 let marker = format!("/tmp/stakpak_test_drop_{}", std::process::id());
1190 let task_info = handle
1191 .start_task(format!("touch {} && sleep 30", marker), None, None, None)
1192 .await
1193 .expect("Failed to start task");
1194
1195 let status = handle
1197 .get_task_status(task_info.id.clone())
1198 .await
1199 .expect("Failed to get status");
1200 assert_eq!(status, Some(TaskStatus::Running));
1201
1202 drop(handle);
1204 sleep(Duration::from_millis(500)).await;
1205
1206 let _ = std::fs::remove_file(&marker);
1208 }
1209
1210 #[tokio::test]
1211 async fn test_task_manager_detects_immediate_exit_code_failure() {
1212 let task_manager = TaskManager::new();
1213 let handle = task_manager.handle();
1214
1215 let _manager_handle = tokio::spawn(async move {
1217 task_manager.run().await;
1218 });
1219
1220 let result = handle
1222 .start_task("exit 1".to_string(), None, None, None)
1223 .await;
1224
1225 assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1227
1228 handle
1230 .shutdown()
1231 .await
1232 .expect("Failed to shutdown task manager");
1233 }
1234}