Skip to main content

stakpak_shared/
task_manager.rs

1use crate::helper::generate_simple_id;
2use crate::remote_connection::{RemoteConnectionInfo, RemoteConnectionManager};
3use chrono::{DateTime, Utc};
4use std::{collections::HashMap, process::Stdio, sync::Arc, time::Duration};
5use tokio::{
6    io::{AsyncBufReadExt, BufReader},
7    process::Command,
8    sync::{broadcast, mpsc, oneshot},
9    time::timeout,
10};
11
12const START_TASK_WAIT_TIME: Duration = Duration::from_millis(300);
13
14/// Kill a process and its entire process group.
15///
16/// Uses process group kill (`kill -9 -{pid}`) on Unix and `taskkill /F /T` on
17/// Windows to ensure child processes spawned by shells (node, vite, esbuild, etc.)
18/// are also terminated.
19///
20/// This is safe to call even if the process has already exited.
21fn terminate_process_group(process_id: u32) {
22    #[cfg(unix)]
23    {
24        use std::process::Command;
25        // First check if the process exists
26        let check_result = Command::new("kill")
27            .arg("-0") // Signal 0 just checks if process exists
28            .arg(process_id.to_string())
29            .output();
30
31        // Only kill if the process actually exists
32        if check_result
33            .map(|output| output.status.success())
34            .unwrap_or(false)
35        {
36            // Kill the entire process group using negative PID
37            // Since we spawn with .process_group(0), the shell becomes the process group leader
38            // Using -{pid} kills all processes in that group (shell + children like node/vite/esbuild)
39            let _ = Command::new("kill")
40                .arg("-9")
41                .arg(format!("-{}", process_id))
42                .output();
43
44            // Also try to kill the individual process in case it's not a group leader
45            let _ = Command::new("kill")
46                .arg("-9")
47                .arg(process_id.to_string())
48                .output();
49        }
50    }
51
52    #[cfg(windows)]
53    {
54        use std::process::Command;
55        // On Windows, use taskkill with /T flag to kill the process tree
56        let check_result = Command::new("tasklist")
57            .arg("/FI")
58            .arg(format!("PID eq {}", process_id))
59            .arg("/FO")
60            .arg("CSV")
61            .output();
62
63        // Only kill if the process actually exists
64        if let Ok(output) = check_result {
65            let output_str = String::from_utf8_lossy(&output.stdout);
66            if output_str.lines().count() > 1 {
67                // More than just header line - use /T to kill process tree
68                let _ = Command::new("taskkill")
69                    .arg("/F")
70                    .arg("/T") // Kill process tree
71                    .arg("/PID")
72                    .arg(process_id.to_string())
73                    .output();
74            }
75        }
76    }
77}
78
79pub type TaskId = String;
80
81#[derive(Debug, Clone, PartialEq, serde::Serialize)]
82pub enum TaskStatus {
83    Pending,
84    Running,
85    Completed,
86    Failed,
87    Cancelled,
88    TimedOut,
89    Paused,
90}
91
92#[derive(Debug, Clone)]
93pub struct Task {
94    pub id: TaskId,
95    pub status: TaskStatus,
96    pub command: String,
97    pub description: Option<String>,
98    pub remote_connection: Option<RemoteConnectionInfo>,
99    pub output: Option<String>,
100    pub error: Option<String>,
101    pub start_time: DateTime<Utc>,
102    pub duration: Option<Duration>,
103    pub timeout: Option<Duration>,
104    pub pause_info: Option<PauseInfo>,
105}
106
107pub struct TaskEntry {
108    pub task: Task,
109    pub handle: tokio::task::JoinHandle<()>,
110    pub process_id: Option<u32>,
111    pub cancel_tx: Option<oneshot::Sender<()>>,
112}
113
114#[derive(Debug, Clone, serde::Serialize)]
115pub struct TaskInfo {
116    pub id: TaskId,
117    pub status: TaskStatus,
118    pub command: String,
119    pub description: Option<String>,
120    pub output: Option<String>,
121    pub start_time: DateTime<Utc>,
122    pub duration: Option<Duration>,
123    pub pause_info: Option<PauseInfo>,
124}
125
126impl From<&Task> for TaskInfo {
127    fn from(task: &Task) -> Self {
128        let duration = if matches!(task.status, TaskStatus::Running) {
129            // For running tasks, calculate duration from start time to now
130            Some(
131                Utc::now()
132                    .signed_duration_since(task.start_time)
133                    .to_std()
134                    .unwrap_or_default(),
135            )
136        } else {
137            // For completed/failed/cancelled tasks, use the stored duration
138            task.duration
139        };
140
141        TaskInfo {
142            id: task.id.clone(),
143            status: task.status.clone(),
144            command: task.command.clone(),
145            description: task.description.clone(),
146            output: task.output.clone(),
147            start_time: task.start_time,
148            duration,
149            pause_info: task.pause_info.clone(),
150        }
151    }
152}
153
154pub struct TaskCompletion {
155    pub output: String,
156    pub error: Option<String>,
157    pub final_status: TaskStatus,
158}
159
160#[derive(Debug, Clone, serde::Serialize)]
161pub struct PauseInfo {
162    pub checkpoint_id: Option<String>,
163    pub raw_output: Option<String>,
164}
165
166#[derive(Debug, thiserror::Error)]
167pub enum TaskError {
168    #[error("Task not found: {0}")]
169    TaskNotFound(TaskId),
170    #[error("Task already running: {0}")]
171    TaskAlreadyRunning(TaskId),
172    #[error("Manager shutdown")]
173    ManagerShutdown,
174    #[error("Command execution failed: {0}")]
175    ExecutionFailed(String),
176    #[error("Task timeout")]
177    TaskTimeout,
178    #[error("Task cancelled")]
179    TaskCancelled,
180    #[error("Task failed on start: {0}")]
181    TaskFailedOnStart(String),
182    #[error("Task not paused: {0}")]
183    TaskNotPaused(TaskId),
184}
185
186pub enum TaskMessage {
187    Start {
188        id: Option<TaskId>,
189        command: String,
190        description: Option<String>,
191        remote_connection: Option<RemoteConnectionInfo>,
192        timeout: Option<Duration>,
193        response_tx: oneshot::Sender<Result<TaskId, TaskError>>,
194    },
195    Cancel {
196        id: TaskId,
197        response_tx: oneshot::Sender<Result<(), TaskError>>,
198    },
199    GetStatus {
200        id: TaskId,
201        response_tx: oneshot::Sender<Option<TaskStatus>>,
202    },
203    GetTaskDetails {
204        id: TaskId,
205        response_tx: oneshot::Sender<Option<TaskInfo>>,
206    },
207    GetAllTasks {
208        response_tx: oneshot::Sender<Vec<TaskInfo>>,
209    },
210    Shutdown {
211        response_tx: oneshot::Sender<()>,
212    },
213    TaskUpdate {
214        id: TaskId,
215        completion: TaskCompletion,
216    },
217    PartialUpdate {
218        id: TaskId,
219        output: String,
220    },
221    Resume {
222        id: TaskId,
223        command: String,
224        response_tx: oneshot::Sender<Result<(), TaskError>>,
225    },
226}
227
228pub struct TaskManager {
229    tasks: HashMap<TaskId, TaskEntry>,
230    tx: mpsc::UnboundedSender<TaskMessage>,
231    rx: mpsc::UnboundedReceiver<TaskMessage>,
232    shutdown_tx: broadcast::Sender<()>,
233    shutdown_rx: broadcast::Receiver<()>,
234}
235
236impl Default for TaskManager {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl TaskManager {
243    pub fn new() -> Self {
244        let (tx, rx) = mpsc::unbounded_channel();
245        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
246
247        Self {
248            tasks: HashMap::new(),
249            tx,
250            rx,
251            shutdown_tx,
252            shutdown_rx,
253        }
254    }
255
256    pub fn handle(&self) -> Arc<TaskManagerHandle> {
257        Arc::new(TaskManagerHandle {
258            tx: self.tx.clone(),
259            shutdown_tx: self.shutdown_tx.clone(),
260        })
261    }
262
263    pub async fn run(mut self) {
264        loop {
265            tokio::select! {
266                msg = self.rx.recv() => {
267                    match msg {
268                        Some(msg) => {
269                            if self.handle_message(msg).await {
270                                break;
271                            }
272                        }
273                        None => {
274                            // All senders (TaskManagerHandles) have been dropped.
275                            // Clean up all running tasks and child processes.
276                            self.shutdown_all_tasks().await;
277                            break;
278                        }
279                    }
280                }
281                _ = self.shutdown_rx.recv() => {
282                    self.shutdown_all_tasks().await;
283                    break;
284                }
285            }
286        }
287    }
288
289    async fn handle_message(&mut self, msg: TaskMessage) -> bool {
290        match msg {
291            TaskMessage::Start {
292                id,
293                command,
294                description,
295                remote_connection,
296                timeout,
297                response_tx,
298            } => {
299                let task_id = id.unwrap_or_else(|| generate_simple_id(6));
300                let result = self
301                    .start_task(
302                        task_id.clone(),
303                        command,
304                        description,
305                        timeout,
306                        remote_connection,
307                    )
308                    .await;
309                let _ = response_tx.send(result.map(|_| task_id.clone()));
310                false
311            }
312            TaskMessage::Cancel { id, response_tx } => {
313                let result = self.cancel_task(&id).await;
314                let _ = response_tx.send(result);
315                false
316            }
317            TaskMessage::GetStatus { id, response_tx } => {
318                let status = self.tasks.get(&id).map(|entry| entry.task.status.clone());
319                let _ = response_tx.send(status);
320                false
321            }
322            TaskMessage::GetTaskDetails { id, response_tx } => {
323                let task_info = self.tasks.get(&id).map(|entry| TaskInfo::from(&entry.task));
324                let _ = response_tx.send(task_info);
325                false
326            }
327            TaskMessage::GetAllTasks { response_tx } => {
328                let mut tasks: Vec<TaskInfo> = self
329                    .tasks
330                    .values()
331                    .map(|entry| TaskInfo::from(&entry.task))
332                    .collect();
333                tasks.sort_by(|a, b| b.start_time.cmp(&a.start_time));
334                let _ = response_tx.send(tasks);
335                false
336            }
337            TaskMessage::TaskUpdate { id, completion } => {
338                if let Some(entry) = self.tasks.get_mut(&id) {
339                    entry.task.status = completion.final_status.clone();
340                    entry.task.output = Some(completion.output.clone());
341                    entry.task.error = completion.error;
342                    entry.task.duration = Some(
343                        Utc::now()
344                            .signed_duration_since(entry.task.start_time)
345                            .to_std()
346                            .unwrap_or_default(),
347                    );
348
349                    // Extract checkpoint info for paused and completed tasks
350                    if matches!(
351                        completion.final_status,
352                        TaskStatus::Paused | TaskStatus::Completed
353                    ) {
354                        let checkpoint_id =
355                            serde_json::from_str::<serde_json::Value>(&completion.output)
356                                .ok()
357                                .and_then(|v| {
358                                    v.get("checkpoint_id")
359                                        .and_then(|c| c.as_str())
360                                        .map(|s| s.to_string())
361                                });
362                        entry.task.pause_info = Some(PauseInfo {
363                            checkpoint_id,
364                            raw_output: Some(completion.output),
365                        });
366                    }
367
368                    // Keep completed tasks in the list so they can be viewed with get_all_tasks
369                    // TODO: Consider implementing a cleanup mechanism for old completed tasks
370                    // if matches!(entry.task.status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::TimedOut) {
371                    //     self.tasks.remove(&id);
372                    // }
373                }
374                false
375            }
376            TaskMessage::PartialUpdate { id, output } => {
377                if let Some(entry) = self.tasks.get_mut(&id) {
378                    match &entry.task.output {
379                        Some(existing) => {
380                            entry.task.output = Some(format!("{}{}", existing, output));
381                        }
382                        None => {
383                            entry.task.output = Some(output);
384                        }
385                    }
386                }
387                false
388            }
389            TaskMessage::Resume {
390                id,
391                command,
392                response_tx,
393            } => {
394                let result = self.resume_task(id, command).await;
395                let _ = response_tx.send(result);
396                false
397            }
398            TaskMessage::Shutdown { response_tx } => {
399                self.shutdown_all_tasks().await;
400                let _ = response_tx.send(());
401                true
402            }
403        }
404    }
405
406    async fn start_task(
407        &mut self,
408        id: TaskId,
409        command: String,
410        description: Option<String>,
411        timeout: Option<Duration>,
412        remote_connection: Option<RemoteConnectionInfo>,
413    ) -> Result<(), TaskError> {
414        if self.tasks.contains_key(&id) {
415            return Err(TaskError::TaskAlreadyRunning(id));
416        }
417
418        let task = Task {
419            id: id.clone(),
420            status: TaskStatus::Running,
421            command: command.clone(),
422            description,
423            remote_connection: remote_connection.clone(),
424            output: None,
425            error: None,
426            start_time: Utc::now(),
427            duration: None,
428            timeout,
429            pause_info: None,
430        };
431
432        let (cancel_tx, cancel_rx) = oneshot::channel();
433        let (process_tx, process_rx) = oneshot::channel();
434        let task_tx: mpsc::UnboundedSender<TaskMessage> = self.tx.clone();
435
436        let is_remote_task = remote_connection.is_some();
437
438        // Spawn task immediately - SSH connection happens inside the task
439        let handle = tokio::spawn(Self::execute_task(
440            id.clone(),
441            command,
442            remote_connection,
443            timeout,
444            cancel_rx,
445            process_tx,
446            task_tx,
447        ));
448
449        let entry = TaskEntry {
450            task,
451            handle,
452            process_id: None,
453            cancel_tx: Some(cancel_tx),
454        };
455
456        self.tasks.insert(id.clone(), entry);
457
458        // Wait for the process ID for local tasks only
459        if !is_remote_task {
460            // Local task - wait for process ID for proper cleanup
461            if let Ok(process_id) = process_rx.await
462                && let Some(entry) = self.tasks.get_mut(&id)
463            {
464                entry.process_id = Some(process_id);
465            }
466        }
467        // Remote tasks don't have local process IDs, so we skip waiting
468
469        Ok(())
470    }
471
472    async fn resume_task(&mut self, id: TaskId, command: String) -> Result<(), TaskError> {
473        // Verify the task exists and is in a resumable state
474        if let Some(entry) = self.tasks.get(&id) {
475            if !matches!(
476                entry.task.status,
477                TaskStatus::Paused | TaskStatus::Completed
478            ) {
479                return Err(TaskError::TaskNotPaused(id));
480            }
481        } else {
482            return Err(TaskError::TaskNotFound(id));
483        }
484
485        // Update the task to Running and start a new execution
486        let entry = self.tasks.get_mut(&id).unwrap();
487        entry.task.status = TaskStatus::Running;
488        entry.task.command = command.clone();
489        entry.task.pause_info = None;
490        entry.task.output = None;
491        entry.task.error = None;
492
493        let (cancel_tx, cancel_rx) = oneshot::channel();
494        let (process_tx, process_rx) = oneshot::channel();
495        let task_tx = self.tx.clone();
496
497        let remote_connection = entry.task.remote_connection.clone();
498        let timeout = entry.task.timeout;
499
500        let handle = tokio::spawn(Self::execute_task(
501            id.clone(),
502            command,
503            remote_connection.clone(),
504            timeout,
505            cancel_rx,
506            process_tx,
507            task_tx,
508        ));
509
510        entry.handle = handle;
511        entry.cancel_tx = Some(cancel_tx);
512        entry.process_id = None;
513
514        // Wait for process ID for local tasks
515        if remote_connection.is_none()
516            && let Ok(process_id) = process_rx.await
517            && let Some(entry) = self.tasks.get_mut(&id)
518        {
519            entry.process_id = Some(process_id);
520        }
521
522        Ok(())
523    }
524
525    async fn cancel_task(&mut self, id: &TaskId) -> Result<(), TaskError> {
526        if let Some(mut entry) = self.tasks.remove(id) {
527            entry.task.status = TaskStatus::Cancelled;
528
529            if let Some(cancel_tx) = entry.cancel_tx.take() {
530                let _ = cancel_tx.send(());
531            }
532
533            if let Some(process_id) = entry.process_id {
534                terminate_process_group(process_id);
535            }
536
537            entry.handle.abort();
538            Ok(())
539        } else {
540            Err(TaskError::TaskNotFound(id.clone()))
541        }
542    }
543
544    async fn execute_task(
545        id: TaskId,
546        command: String,
547        remote_connection: Option<RemoteConnectionInfo>,
548        task_timeout: Option<Duration>,
549        mut cancel_rx: oneshot::Receiver<()>,
550        process_tx: oneshot::Sender<u32>,
551        task_tx: mpsc::UnboundedSender<TaskMessage>,
552    ) {
553        let completion = if let Some(remote_info) = remote_connection {
554            // Remote execution
555            Self::execute_remote_task(
556                id.clone(),
557                command,
558                remote_info,
559                task_timeout,
560                &mut cancel_rx,
561                &task_tx,
562            )
563            .await
564        } else {
565            // Local execution (existing logic)
566            Self::execute_local_task(
567                id.clone(),
568                command,
569                task_timeout,
570                &mut cancel_rx,
571                process_tx,
572                &task_tx,
573            )
574            .await
575        };
576
577        // Send task completion back to manager
578        let _ = task_tx.send(TaskMessage::TaskUpdate {
579            id: id.clone(),
580            completion,
581        });
582    }
583
584    async fn execute_local_task(
585        id: TaskId,
586        command: String,
587        task_timeout: Option<Duration>,
588        cancel_rx: &mut oneshot::Receiver<()>,
589        process_tx: oneshot::Sender<u32>,
590        task_tx: &mpsc::UnboundedSender<TaskMessage>,
591    ) -> TaskCompletion {
592        let mut cmd = Command::new("sh");
593        cmd.arg("-c")
594            .arg(&command)
595            .stdin(Stdio::null())
596            .stdout(Stdio::piped())
597            .stderr(Stdio::piped());
598        #[cfg(unix)]
599        {
600            cmd.env("DEBIAN_FRONTEND", "noninteractive")
601                .env("SUDO_ASKPASS", "/bin/false")
602                .process_group(0);
603        }
604        #[cfg(windows)]
605        {
606            // On Windows, create a new process group
607            cmd.creation_flags(0x00000200); // CREATE_NEW_PROCESS_GROUP
608        }
609
610        let mut child = match cmd.spawn() {
611            Ok(child) => child,
612            Err(err) => {
613                return TaskCompletion {
614                    output: String::new(),
615                    error: Some(format!("Failed to spawn command: {}", err)),
616                    final_status: TaskStatus::Failed,
617                };
618            }
619        };
620
621        // Send the process ID back to the manager for tracking
622        if let Some(process_id) = child.id() {
623            let _ = process_tx.send(process_id);
624        }
625
626        // Take stdout and stderr for streaming
627        let stdout = child.stdout.take().unwrap();
628        let stderr = child.stderr.take().unwrap();
629
630        let stdout_reader = BufReader::new(stdout);
631        let stderr_reader = BufReader::new(stderr);
632
633        let mut stdout_lines = stdout_reader.lines();
634        let mut stderr_lines = stderr_reader.lines();
635
636        // Helper function to stream output and handle cancellation
637        let stream_output = async {
638            let mut final_output = String::new();
639            let mut final_error: Option<String> = None;
640
641            loop {
642                tokio::select! {
643                    line = stdout_lines.next_line() => {
644                        match line {
645                            Ok(Some(line)) => {
646                                let output_line = format!("{}\n", line);
647                                final_output.push_str(&output_line);
648                                let _ = task_tx.send(TaskMessage::PartialUpdate {
649                                    id: id.clone(),
650                                    output: output_line,
651                                });
652                            }
653                            Ok(None) => {
654                                // stdout stream ended
655                            }
656                            Err(err) => {
657                                final_error = Some(format!("Error reading stdout: {}", err));
658                                break;
659                            }
660                        }
661                    }
662                    line = stderr_lines.next_line() => {
663                        match line {
664                            Ok(Some(line)) => {
665                                let output_line = format!("{}\n", line);
666                                final_output.push_str(&output_line);
667                                let _ = task_tx.send(TaskMessage::PartialUpdate {
668                                    id: id.clone(),
669                                    output: output_line,
670                                });
671                            }
672                            Ok(None) => {
673                                // stderr stream ended
674                            }
675                            Err(err) => {
676                                final_error = Some(format!("Error reading stderr: {}", err));
677                                break;
678                            }
679                        }
680                    }
681                    status = child.wait() => {
682                        match status {
683                            Ok(exit_status) => {
684                                if final_output.is_empty() {
685                                    final_output = "No output".to_string();
686                                }
687
688                                let completion = if exit_status.success() {
689                                    TaskCompletion {
690                                        output: final_output,
691                                        error: final_error,
692                                        final_status: TaskStatus::Completed,
693                                    }
694                                } else if exit_status.code() == Some(10) {
695                                    TaskCompletion {
696                                        output: final_output,
697                                        error: None,
698                                        final_status: TaskStatus::Paused,
699                                    }
700                                } else {
701                                    TaskCompletion {
702                                        output: final_output,
703                                        error: final_error.or_else(|| Some(format!("Command failed with exit code: {:?}", exit_status.code()))),
704                                        final_status: TaskStatus::Failed,
705                                    }
706                                };
707                                return completion;
708                            }
709                            Err(err) => {
710                                return TaskCompletion {
711                                    output: final_output,
712                                    error: Some(err.to_string()),
713                                    final_status: TaskStatus::Failed,
714                                };
715                            }
716                        }
717                    }
718                    _ = &mut *cancel_rx => {
719                        return TaskCompletion {
720                            output: final_output,
721                            error: Some("Tool call was cancelled and don't try to run it again".to_string()),
722                            final_status: TaskStatus::Cancelled,
723                        };
724                    }
725                }
726            }
727
728            TaskCompletion {
729                output: final_output,
730                error: final_error,
731                final_status: TaskStatus::Failed,
732            }
733        };
734
735        // Execute with timeout if provided
736        if let Some(timeout_duration) = task_timeout {
737            match timeout(timeout_duration, stream_output).await {
738                Ok(result) => result,
739                Err(_) => TaskCompletion {
740                    output: String::new(),
741                    error: Some("Task timed out".to_string()),
742                    final_status: TaskStatus::TimedOut,
743                },
744            }
745        } else {
746            stream_output.await
747        }
748    }
749
750    async fn execute_remote_task(
751        id: TaskId,
752        command: String,
753        remote_info: RemoteConnectionInfo,
754        task_timeout: Option<Duration>,
755        cancel_rx: &mut oneshot::Receiver<()>,
756        task_tx: &mpsc::UnboundedSender<TaskMessage>,
757    ) -> TaskCompletion {
758        // Use RemoteConnectionManager to get a connection
759        let connection_manager = RemoteConnectionManager::new();
760        let connection = match connection_manager.get_connection(&remote_info).await {
761            Ok(conn) => conn,
762            Err(e) => {
763                return TaskCompletion {
764                    output: String::new(),
765                    error: Some(format!("Failed to establish remote connection: {}", e)),
766                    final_status: TaskStatus::Failed,
767                };
768            }
769        };
770
771        // Create progress callback for streaming updates
772        let task_tx_clone = task_tx.clone();
773        let id_clone = id.clone();
774        let progress_callback = move |output: String| {
775            if !output.trim().is_empty() {
776                let _ = task_tx_clone.send(TaskMessage::PartialUpdate {
777                    id: id_clone.clone(),
778                    output,
779                });
780            }
781        };
782
783        // Use unified execution with proper cancellation and timeout
784        let options = crate::remote_connection::CommandOptions {
785            timeout: task_timeout,
786            with_progress: false,
787            simple: false,
788        };
789
790        match connection
791            .execute_command_unified(&command, options, cancel_rx, Some(progress_callback), None)
792            .await
793        {
794            Ok((output, exit_code)) => TaskCompletion {
795                output,
796                error: if exit_code != 0 {
797                    Some(format!("Command exited with code {}", exit_code))
798                } else {
799                    None
800                },
801                final_status: TaskStatus::Completed,
802            },
803            Err(e) => {
804                let error_msg = e.to_string();
805                let status = if error_msg.contains("timed out") {
806                    TaskStatus::TimedOut
807                } else if error_msg.contains("cancelled") {
808                    TaskStatus::Cancelled
809                } else {
810                    TaskStatus::Failed
811                };
812
813                TaskCompletion {
814                    output: String::new(),
815                    error: Some(if error_msg.contains("cancelled") {
816                        "Tool call was cancelled and don't try to run it again".to_string()
817                    } else {
818                        format!("Remote command failed: {}", error_msg)
819                    }),
820                    final_status: status,
821                }
822            }
823        }
824    }
825
826    async fn shutdown_all_tasks(&mut self) {
827        for (_id, mut entry) in self.tasks.drain() {
828            if let Some(cancel_tx) = entry.cancel_tx.take() {
829                let _ = cancel_tx.send(());
830            }
831
832            if let Some(process_id) = entry.process_id {
833                terminate_process_group(process_id);
834            }
835
836            entry.handle.abort();
837        }
838    }
839}
840
841pub struct TaskManagerHandle {
842    tx: mpsc::UnboundedSender<TaskMessage>,
843    shutdown_tx: broadcast::Sender<()>,
844}
845
846impl Drop for TaskManagerHandle {
847    fn drop(&mut self) {
848        // Signal the TaskManager to shut down all tasks and kill child processes.
849        // This fires on the broadcast channel that TaskManager::run() listens on,
850        // triggering shutdown_all_tasks() which kills every process group.
851        //
852        // This is a last-resort safety net — callers should prefer calling
853        // handle.shutdown().await for a clean async shutdown. But if the handle
854        // is dropped without that (e.g., panic, std::process::exit, unexpected
855        // scope exit), this ensures child processes don't leak.
856        let _ = self.shutdown_tx.send(());
857    }
858}
859
860impl TaskManagerHandle {
861    pub async fn start_task(
862        &self,
863        command: String,
864        description: Option<String>,
865        timeout: Option<Duration>,
866        remote_connection: Option<RemoteConnectionInfo>,
867    ) -> Result<TaskInfo, TaskError> {
868        let (response_tx, response_rx) = oneshot::channel();
869
870        self.tx
871            .send(TaskMessage::Start {
872                id: None,
873                command: command.clone(),
874                description,
875                remote_connection: remote_connection.clone(),
876                timeout,
877                response_tx,
878            })
879            .map_err(|_| TaskError::ManagerShutdown)?;
880
881        let task_id = response_rx
882            .await
883            .map_err(|_| TaskError::ManagerShutdown)??;
884
885        // Wait for the task to start and get its status
886        tokio::time::sleep(START_TASK_WAIT_TIME).await;
887
888        let task_info = self
889            .get_task_details(task_id.clone())
890            .await
891            .map_err(|_| TaskError::ManagerShutdown)?
892            .ok_or_else(|| TaskError::TaskNotFound(task_id.clone()))?;
893
894        // If the task failed or was cancelled during start, return an error
895        if matches!(task_info.status, TaskStatus::Failed | TaskStatus::Cancelled) {
896            return Err(TaskError::TaskFailedOnStart(
897                task_info
898                    .output
899                    .unwrap_or_else(|| "Unknown reason".to_string()),
900            ));
901        }
902
903        // Return the task info with updated status
904        Ok(task_info)
905    }
906
907    pub async fn cancel_task(&self, id: TaskId) -> Result<TaskInfo, TaskError> {
908        // Get the task info before cancelling
909        let task_info = self
910            .get_all_tasks()
911            .await?
912            .into_iter()
913            .find(|task| task.id == id)
914            .ok_or_else(|| TaskError::TaskNotFound(id.clone()))?;
915
916        let (response_tx, response_rx) = oneshot::channel();
917
918        self.tx
919            .send(TaskMessage::Cancel { id, response_tx })
920            .map_err(|_| TaskError::ManagerShutdown)?;
921
922        response_rx
923            .await
924            .map_err(|_| TaskError::ManagerShutdown)??;
925
926        // Return the task info with updated status
927        Ok(TaskInfo {
928            status: TaskStatus::Cancelled,
929            duration: Some(
930                Utc::now()
931                    .signed_duration_since(task_info.start_time)
932                    .to_std()
933                    .unwrap_or_default(),
934            ),
935            ..task_info
936        })
937    }
938
939    pub async fn resume_task(&self, id: TaskId, command: String) -> Result<TaskInfo, TaskError> {
940        let (response_tx, response_rx) = oneshot::channel();
941
942        self.tx
943            .send(TaskMessage::Resume {
944                id: id.clone(),
945                command,
946                response_tx,
947            })
948            .map_err(|_| TaskError::ManagerShutdown)?;
949
950        response_rx
951            .await
952            .map_err(|_| TaskError::ManagerShutdown)??;
953
954        // Wait for the task to start
955        tokio::time::sleep(START_TASK_WAIT_TIME).await;
956
957        let task_info = self
958            .get_task_details(id.clone())
959            .await
960            .map_err(|_| TaskError::ManagerShutdown)?
961            .ok_or(TaskError::TaskNotFound(id))?;
962
963        Ok(task_info)
964    }
965
966    pub async fn get_task_status(&self, id: TaskId) -> Result<Option<TaskStatus>, TaskError> {
967        let (response_tx, response_rx) = oneshot::channel();
968
969        self.tx
970            .send(TaskMessage::GetStatus { id, response_tx })
971            .map_err(|_| TaskError::ManagerShutdown)?;
972
973        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
974    }
975
976    pub async fn get_task_details(&self, id: TaskId) -> Result<Option<TaskInfo>, TaskError> {
977        let (response_tx, response_rx) = oneshot::channel();
978
979        self.tx
980            .send(TaskMessage::GetTaskDetails { id, response_tx })
981            .map_err(|_| TaskError::ManagerShutdown)?;
982
983        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
984    }
985
986    pub async fn get_all_tasks(&self) -> Result<Vec<TaskInfo>, TaskError> {
987        let (response_tx, response_rx) = oneshot::channel();
988
989        self.tx
990            .send(TaskMessage::GetAllTasks { response_tx })
991            .map_err(|_| TaskError::ManagerShutdown)?;
992
993        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
994    }
995
996    pub async fn shutdown(&self) -> Result<(), TaskError> {
997        let (response_tx, response_rx) = oneshot::channel();
998
999        self.tx
1000            .send(TaskMessage::Shutdown { response_tx })
1001            .map_err(|_| TaskError::ManagerShutdown)?;
1002
1003        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1004    }
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009    use super::*;
1010    use tokio::time::{Duration, sleep};
1011
1012    #[tokio::test]
1013    async fn test_task_manager_shutdown() {
1014        let task_manager = TaskManager::new();
1015        let handle = task_manager.handle();
1016
1017        // Spawn the task manager
1018        let manager_handle = tokio::spawn(async move {
1019            task_manager.run().await;
1020        });
1021
1022        // Start a background task
1023        let task_info = handle
1024            .start_task("sleep 5".to_string(), None, None, None)
1025            .await
1026            .expect("Failed to start task");
1027
1028        // Verify task is running
1029        let status = handle
1030            .get_task_status(task_info.id.clone())
1031            .await
1032            .expect("Failed to get task status");
1033        assert_eq!(status, Some(TaskStatus::Running));
1034
1035        // Shutdown the task manager
1036        handle
1037            .shutdown()
1038            .await
1039            .expect("Failed to shutdown task manager");
1040
1041        // Wait a bit for the shutdown to complete
1042        sleep(Duration::from_millis(100)).await;
1043
1044        // Verify the manager task has completed
1045        assert!(manager_handle.is_finished());
1046    }
1047
1048    #[tokio::test]
1049    async fn test_task_manager_cancels_tasks_on_shutdown() {
1050        let task_manager = TaskManager::new();
1051        let handle = task_manager.handle();
1052
1053        // Spawn the task manager
1054        let manager_handle = tokio::spawn(async move {
1055            task_manager.run().await;
1056        });
1057
1058        // Start a long-running background task
1059        let task_info = handle
1060            .start_task("sleep 10".to_string(), None, None, None)
1061            .await
1062            .expect("Failed to start task");
1063
1064        // Verify task is running
1065        let status = handle
1066            .get_task_status(task_info.id.clone())
1067            .await
1068            .expect("Failed to get task status");
1069        assert_eq!(status, Some(TaskStatus::Running));
1070
1071        // Shutdown the task manager
1072        handle
1073            .shutdown()
1074            .await
1075            .expect("Failed to shutdown task manager");
1076
1077        // Wait a bit for the shutdown to complete
1078        sleep(Duration::from_millis(100)).await;
1079
1080        // Verify the manager task has completed
1081        assert!(manager_handle.is_finished());
1082    }
1083
1084    #[tokio::test]
1085    async fn test_task_manager_start_and_complete_task() {
1086        let task_manager = TaskManager::new();
1087        let handle = task_manager.handle();
1088
1089        // Spawn the task manager
1090        let _manager_handle = tokio::spawn(async move {
1091            task_manager.run().await;
1092        });
1093
1094        // Start a simple task
1095        let task_info = handle
1096            .start_task("echo 'Hello, World!'".to_string(), None, None, None)
1097            .await
1098            .expect("Failed to start task");
1099
1100        // Wait for the task to complete
1101        sleep(Duration::from_millis(500)).await;
1102
1103        // Get task status
1104        let status = handle
1105            .get_task_status(task_info.id.clone())
1106            .await
1107            .expect("Failed to get task status");
1108        assert_eq!(status, Some(TaskStatus::Completed));
1109
1110        // Get all tasks
1111        let tasks = handle
1112            .get_all_tasks()
1113            .await
1114            .expect("Failed to get all tasks");
1115        assert_eq!(tasks.len(), 1);
1116        assert_eq!(tasks[0].status, TaskStatus::Completed);
1117
1118        // Shutdown the task manager
1119        handle
1120            .shutdown()
1121            .await
1122            .expect("Failed to shutdown task manager");
1123    }
1124
1125    #[tokio::test]
1126    async fn test_task_manager_detects_immediate_failure() {
1127        let task_manager = TaskManager::new();
1128        let handle = task_manager.handle();
1129
1130        // Spawn the task manager
1131        let _manager_handle = tokio::spawn(async move {
1132            task_manager.run().await;
1133        });
1134
1135        // Start a task that will fail immediately
1136        let result = handle
1137            .start_task("nonexistent_command_12345".to_string(), None, None, None)
1138            .await;
1139
1140        // Should get a TaskFailedOnStart error
1141        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1142
1143        // Shutdown the task manager
1144        handle
1145            .shutdown()
1146            .await
1147            .expect("Failed to shutdown task manager");
1148    }
1149
1150    #[tokio::test]
1151    async fn test_task_manager_handle_drop_triggers_shutdown() {
1152        let task_manager = TaskManager::new();
1153        let handle = task_manager.handle();
1154
1155        let manager_handle = tokio::spawn(async move {
1156            task_manager.run().await;
1157        });
1158
1159        // Start a long-running task
1160        let _task_info = handle
1161            .start_task("sleep 30".to_string(), None, None, None)
1162            .await
1163            .expect("Failed to start task");
1164
1165        // Drop the handle WITHOUT calling shutdown()
1166        drop(handle);
1167
1168        // The Drop impl sends on the broadcast shutdown channel,
1169        // which causes TaskManager::run() to call shutdown_all_tasks() and exit.
1170        // Give it a moment to process.
1171        sleep(Duration::from_millis(500)).await;
1172
1173        assert!(
1174            manager_handle.is_finished(),
1175            "TaskManager::run() should have exited after handle was dropped"
1176        );
1177    }
1178
1179    #[tokio::test]
1180    async fn test_task_manager_handle_drop_kills_child_processes() {
1181        let task_manager = TaskManager::new();
1182        let handle = task_manager.handle();
1183
1184        let _manager_handle = tokio::spawn(async move {
1185            task_manager.run().await;
1186        });
1187
1188        // Start a task that writes a marker file while running
1189        let marker = format!("/tmp/stakpak_test_drop_{}", std::process::id());
1190        let task_info = handle
1191            .start_task(format!("touch {} && sleep 30", marker), None, None, None)
1192            .await
1193            .expect("Failed to start task");
1194
1195        // Verify task is running
1196        let status = handle
1197            .get_task_status(task_info.id.clone())
1198            .await
1199            .expect("Failed to get status");
1200        assert_eq!(status, Some(TaskStatus::Running));
1201
1202        // Drop handle without explicit shutdown — Drop should kill the process
1203        drop(handle);
1204        sleep(Duration::from_millis(500)).await;
1205
1206        // Clean up marker file
1207        let _ = std::fs::remove_file(&marker);
1208    }
1209
1210    #[tokio::test]
1211    async fn test_task_manager_detects_immediate_exit_code_failure() {
1212        let task_manager = TaskManager::new();
1213        let handle = task_manager.handle();
1214
1215        // Spawn the task manager
1216        let _manager_handle = tokio::spawn(async move {
1217            task_manager.run().await;
1218        });
1219
1220        // Start a task that will exit with non-zero code immediately
1221        let result = handle
1222            .start_task("exit 1".to_string(), None, None, None)
1223            .await;
1224
1225        // Should get a TaskFailedOnStart error
1226        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1227
1228        // Shutdown the task manager
1229        handle
1230            .shutdown()
1231            .await
1232            .expect("Failed to shutdown task manager");
1233    }
1234}