tcrm_task/tasks/async_tokio/direct/
start.rs

1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17    pub async fn start_direct(
18        &mut self,
19        event_tx: mpsc::Sender<TaskEvent>,
20    ) -> Result<u32, TaskError> {
21        self.update_state(TaskState::Initiating).await;
22
23        match self.config.validate() {
24            Ok(_) => {}
25            Err(e) => {
26                #[cfg(feature = "tracing")]
27                tracing::error!(error = %e, "Invalid task configuration");
28
29                let error_event = TaskEvent::Error {
30                    task_name: self.task_name.clone(),
31                    error: e.clone(),
32                };
33
34                if let Err(_) = event_tx.send(error_event).await {
35                    #[cfg(feature = "tracing")]
36                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
37                };
38
39                return Err(e);
40            }
41        }
42
43        let mut cmd = Command::new(&self.config.command);
44        let mut cmd = cmd.kill_on_drop(true);
45
46        setup_command(&mut cmd, &self.config);
47        let mut child = match cmd.spawn() {
48            Ok(c) => c,
49            Err(e) => {
50                #[cfg(feature = "tracing")]
51                tracing::error!(error = %e, "Failed to spawn child process");
52
53                let error_event = TaskEvent::Error {
54                    task_name: self.task_name.clone(),
55                    error: TaskError::IO(e.to_string()),
56                };
57
58                if let Err(_) = event_tx.send(error_event).await {
59                    #[cfg(feature = "tracing")]
60                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
61                };
62
63                return Err(TaskError::IO(e.to_string()));
64            }
65        };
66        let child_id = match child.id() {
67            Some(id) => id,
68            None => {
69                let msg = "Failed to get process id";
70
71                #[cfg(feature = "tracing")]
72                tracing::error!(msg);
73
74                let error_event = TaskEvent::Error {
75                    task_name: self.task_name.clone(),
76                    error: TaskError::IO(msg.to_string()),
77                };
78
79                if let Err(_) = event_tx.send(error_event).await {
80                    #[cfg(feature = "tracing")]
81                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
82                };
83
84                return Err(TaskError::IO(msg.to_string()));
85            }
86        };
87        *self.process_id.write().await = Some(child_id);
88        let mut task_handles = vec![];
89        self.update_state(TaskState::Running).await;
90        if let Err(_) = event_tx
91            .send(TaskEvent::Started {
92                task_name: self.task_name.clone(),
93            })
94            .await
95        {
96            #[cfg(feature = "tracing")]
97            tracing::warn!("Event channel closed while sending TaskEvent::Started");
98        }
99
100        let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
101        let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
102        let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
103
104        // Spawn stdout and stderr watchers
105        let handles = spawn_output_watchers(
106            self.task_name.clone(),
107            event_tx.clone(),
108            &mut child,
109            handle_terminator_rx.clone(),
110        );
111        task_handles.extend(handles);
112
113        // Spawn stdin watcher if configured
114        if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
115            let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
116            task_handles.push(handle);
117        }
118
119        // Spawn child wait watcher
120        *self.terminate_tx.lock().await = Some(terminate_tx);
121
122        let handle = spawn_wait_watcher(
123            self.task_name.clone(),
124            self.state.clone(),
125            child,
126            terminate_rx,
127            handle_terminator_tx.clone(),
128            result_tx,
129            self.process_id.clone(),
130        );
131        task_handles.push(handle);
132
133        // Spawn timeout watcher if configured
134        if let Some(timeout_ms) = self.config.timeout_ms {
135            let handle =
136                spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
137            task_handles.push(handle);
138        }
139
140        // Spawn result watcher
141        let _handle = spawn_result_watcher(
142            self.task_name.clone(),
143            self.state.clone(),
144            self.finished_at.clone(),
145            event_tx,
146            result_rx,
147            task_handles,
148        );
149
150        Ok(child_id)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use tokio::sync::mpsc;
157
158    use crate::tasks::{
159        async_tokio::spawner::TaskSpawner,
160        config::{StreamSource, TaskConfig},
161        error::TaskError,
162        event::{TaskEvent, TaskEventStopReason},
163        state::TaskTerminateReason,
164    };
165    #[tokio::test]
166    async fn start_direct_fn_echo_command() {
167        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
168        #[cfg(windows)]
169        let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
170        #[cfg(unix)]
171        let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
172
173        let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
174
175        let result = spawner.start_direct(tx).await;
176        assert!(result.is_ok());
177
178        let mut started = false;
179        let mut stopped = false;
180        while let Some(event) = rx.recv().await {
181            match event {
182                TaskEvent::Started { task_name } => {
183                    assert_eq!(task_name, "echo_task");
184                    started = true;
185                }
186                TaskEvent::Output {
187                    task_name,
188                    line,
189                    src,
190                } => {
191                    assert_eq!(task_name, "echo_task");
192                    assert_eq!(line, "hello");
193                    assert_eq!(src, StreamSource::Stdout);
194                }
195                TaskEvent::Stopped {
196                    task_name,
197                    exit_code,
198                    reason: _,
199                } => {
200                    assert_eq!(task_name, "echo_task");
201                    assert_eq!(exit_code, Some(0));
202                    stopped = true;
203                }
204                _ => {}
205            }
206        }
207
208        assert!(started);
209        assert!(stopped);
210    }
211    #[tokio::test]
212    async fn start_direct_timeout_terminated_task() {
213        #[cfg(windows)]
214        let config = TaskConfig::new("powershell")
215            .args(["-Command", "sleep 2"])
216            .timeout_ms(1);
217        #[cfg(unix)]
218        let config = TaskConfig::new("bash")
219            .args(["-c", "sleep 2"])
220            .timeout_ms(1);
221
222        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
223        let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
224
225        let result = spawner.start_direct(tx).await;
226        assert!(result.is_ok());
227
228        let mut started = false;
229        let mut stopped = false;
230        while let Some(event) = rx.recv().await {
231            match event {
232                TaskEvent::Started { task_name } => {
233                    assert_eq!(task_name, "sleep_with_timeout_task");
234                    started = true;
235                }
236
237                TaskEvent::Stopped {
238                    task_name,
239                    exit_code,
240                    reason,
241                } => {
242                    assert_eq!(task_name, "sleep_with_timeout_task");
243                    assert_eq!(exit_code, None);
244                    assert_eq!(
245                        reason,
246                        TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
247                    );
248                    stopped = true;
249                }
250                _ => {}
251            }
252        }
253
254        assert!(started);
255        assert!(stopped);
256    }
257
258    #[tokio::test]
259    async fn start_direct_fn_invalid_empty_command() {
260        let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
261        let config = TaskConfig::new(""); // invalid: empty command
262        let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
263
264        let result = spawner.start_direct(tx).await;
265        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
266    }
267
268    #[tokio::test]
269    async fn start_direct_fn_stdin_valid() {
270        // Channel for receiving task events
271        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
272        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
273
274        #[cfg(windows)]
275        let config = TaskConfig::new("powershell")
276            .args(["-Command", "$line = Read-Host; Write-Output $line"])
277            .enable_stdin(true);
278        #[cfg(unix)]
279        let config = TaskConfig::new("bash")
280            .args(["-c", "read line; echo $line"])
281            .enable_stdin(true);
282
283        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
284
285        // Spawn the task
286        let result = spawner.start_direct(tx).await;
287        assert!(result.is_ok());
288
289        // Send input via stdin if enabled
290        stdin_tx.send("hello world".to_string()).await.unwrap();
291
292        let mut started = false;
293        let mut output_ok = false;
294        let mut stopped = false;
295
296        while let Some(event) = rx.recv().await {
297            match event {
298                TaskEvent::Started { task_name } => {
299                    assert_eq!(task_name, "stdin_task");
300                    started = true;
301                }
302                TaskEvent::Output {
303                    task_name,
304                    line,
305                    src,
306                } => {
307                    assert_eq!(task_name, "stdin_task");
308                    assert_eq!(line, "hello world");
309                    assert_eq!(src, StreamSource::Stdout);
310                    output_ok = true;
311                }
312                TaskEvent::Stopped {
313                    task_name,
314                    exit_code,
315                    ..
316                } => {
317                    assert_eq!(task_name, "stdin_task");
318                    assert_eq!(exit_code, Some(0));
319                    stopped = true;
320                }
321                _ => {}
322            }
323        }
324
325        assert!(started);
326        assert!(output_ok);
327        assert!(stopped);
328    }
329
330    #[tokio::test]
331    async fn start_direct_fn_stdin_ignore() {
332        // Channel for receiving task events
333        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
334        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
335
336        #[cfg(windows)]
337        let config = TaskConfig::new("powershell")
338            .args(["-Command", "$line = Read-Host; Write-Output $line"]);
339        #[cfg(unix)]
340        let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
341
342        // Note: stdin is not enabled in config, so stdin should be ignored
343        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
344
345        // Spawn the task
346        let result = spawner.start_direct(tx).await;
347        assert!(result.is_ok());
348
349        // Send input, but it should be ignored (receiver will be dropped, so this should error)
350        let send_result = stdin_tx.send("hello world".to_string()).await;
351        assert!(
352            send_result.is_err(),
353            "Sending to stdin_tx should error because receiver is dropped"
354        );
355
356        let mut started = false;
357        let mut output_found = false;
358        let mut stopped = false;
359
360        while let Some(event) = rx.recv().await {
361            match event {
362                TaskEvent::Started { task_name } => {
363                    assert_eq!(task_name, "stdin_task");
364                    started = true;
365                }
366                TaskEvent::Output { .. } => {
367                    // Should NOT receive output from stdin
368                    output_found = true;
369                }
370                TaskEvent::Stopped {
371                    task_name,
372                    exit_code,
373                    ..
374                } => {
375                    assert_eq!(task_name, "stdin_task");
376                    assert_eq!(exit_code, Some(0));
377                    stopped = true;
378                }
379                _ => {}
380            }
381        }
382
383        assert!(started);
384        assert!(
385            !output_found,
386            "Should not receive output from stdin when not enabled"
387        );
388        assert!(stopped);
389    }
390
391    // Error scenario tests
392    #[tokio::test]
393    async fn start_direct_command_not_found() {
394        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
395        let config = TaskConfig::new("non_existent_command");
396        let mut spawner = TaskSpawner::new("error_task".to_string(), config);
397
398        let result = spawner.start_direct(tx).await;
399        assert!(matches!(result, Err(TaskError::IO(_))));
400
401        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
402            assert_eq!(task_name, "error_task");
403            assert!(matches!(error, TaskError::IO(_)));
404            if let TaskError::IO(msg) = error {
405                #[cfg(windows)]
406                assert!(msg.contains("not found") || msg.contains("cannot find"));
407                #[cfg(unix)]
408                assert!(msg.contains("No such file or directory"));
409            }
410        } else {
411            panic!("Expected TaskEvent::Error");
412        }
413    }
414
415    #[tokio::test]
416    async fn start_direct_invalid_working_directory() {
417        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
418        let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
419
420        let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
421
422        let result = spawner.start_direct(tx).await;
423        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
424
425        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
426            assert_eq!(task_name, "working_dir_task");
427            assert!(matches!(error, TaskError::InvalidConfiguration(_)));
428        } else {
429            panic!("Expected TaskEvent::Error");
430        }
431    }
432
433    #[tokio::test]
434    async fn start_direct_zero_timeout() {
435        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
436        #[cfg(windows)]
437        let config = TaskConfig::new("powershell")
438            .args(["-Command", "Start-Sleep -Seconds 1"])
439            .timeout_ms(0);
440        #[cfg(unix)]
441        let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
442
443        let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
444
445        // Zero timeout should be rejected as invalid configuration
446        let result = spawner.start_direct(tx).await;
447        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
448
449        // Should receive an error event
450        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
451            assert_eq!(task_name, "timeout_task");
452            assert!(matches!(error, TaskError::InvalidConfiguration(_)));
453        } else {
454            panic!("Expected TaskEvent::Error with InvalidConfiguration");
455        }
456    }
457
458    #[tokio::test]
459    async fn process_id_is_none_after_task_stopped() {
460        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
461        #[cfg(windows)]
462        let config = TaskConfig::new("powershell").args(["-Command", "echo done"]);
463        #[cfg(unix)]
464        let config = TaskConfig::new("bash").args(["-c", "echo done"]);
465
466        let mut spawner = TaskSpawner::new("pid_test_task".to_string(), config);
467        let result = spawner.start_direct(tx).await;
468        assert!(result.is_ok());
469
470        let mut stopped = false;
471        while let Some(event) = rx.recv().await {
472            if let TaskEvent::Stopped { task_name, .. } = event {
473                assert_eq!(task_name, "pid_test_task");
474                stopped = true;
475                break;
476            }
477        }
478        assert!(stopped, "Task should emit Stopped event");
479        // process_id should be None after stopped
480        let pid = spawner.get_process_id().await;
481        assert!(
482            pid.is_none(),
483            "process_id should be None after task is stopped"
484        );
485    }
486
487    #[tokio::test]
488    async fn process_id_is_some_while_task_running() {
489        use std::time::Duration;
490        use tokio::time::sleep;
491        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
492        #[cfg(windows)]
493        let config = TaskConfig::new("powershell").args(["-Command", "Start-Sleep -Seconds 2"]);
494        #[cfg(unix)]
495        let config = TaskConfig::new("sleep").args(["2"]);
496
497        let mut spawner = TaskSpawner::new("pid_running_task".to_string(), config);
498        let result = spawner.start_direct(tx).await;
499        assert!(result.is_ok());
500
501        // Wait a short time to ensure the process is running
502        sleep(Duration::from_millis(500)).await;
503        let pid = spawner.get_process_id().await;
504        assert!(
505            pid.is_some(),
506            "process_id should be Some while task is running"
507        );
508
509        // Drain events to ensure clean test exit
510        while let Some(event) = rx.recv().await {
511            if let TaskEvent::Stopped { .. } = event {
512                break;
513            }
514        }
515    }
516}