tcrm_task/tasks/async_tokio/direct/
start.rs

1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17    pub async fn start_direct(
18        &mut self,
19        event_tx: mpsc::Sender<TaskEvent>,
20    ) -> Result<u32, TaskError> {
21        self.update_state(TaskState::Initiating).await;
22
23        match self.config.validate() {
24            Ok(_) => {}
25            Err(e) => {
26                #[cfg(feature = "tracing")]
27                tracing::error!(error = %e, "Invalid task configuration");
28
29                self.update_state(TaskState::Finished).await;
30                let error_event = TaskEvent::Error {
31                    task_name: self.task_name.clone(),
32                    error: e.clone(),
33                };
34
35                if let Err(_) = event_tx.send(error_event).await {
36                    #[cfg(feature = "tracing")]
37                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
38                };
39                return Err(e);
40            }
41        }
42
43        let mut cmd = Command::new(&self.config.command);
44        let mut cmd = cmd.kill_on_drop(true);
45
46        setup_command(&mut cmd, &self.config);
47        let mut child = match cmd.spawn() {
48            Ok(c) => c,
49            Err(e) => {
50                #[cfg(feature = "tracing")]
51                tracing::error!(error = %e, "Failed to spawn child process");
52
53                self.update_state(TaskState::Finished).await;
54                let error_event = TaskEvent::Error {
55                    task_name: self.task_name.clone(),
56                    error: TaskError::IO(e.to_string()),
57                };
58
59                if let Err(_) = event_tx.send(error_event).await {
60                    #[cfg(feature = "tracing")]
61                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
62                };
63
64                return Err(TaskError::IO(e.to_string()));
65            }
66        };
67        let child_id = match child.id() {
68            Some(id) => id,
69            None => {
70                let msg = "Failed to get process id";
71
72                #[cfg(feature = "tracing")]
73                tracing::error!(msg);
74
75                self.update_state(TaskState::Finished).await;
76                let error_event = TaskEvent::Error {
77                    task_name: self.task_name.clone(),
78                    error: TaskError::IO(msg.to_string()),
79                };
80
81                if let Err(_) = event_tx.send(error_event).await {
82                    #[cfg(feature = "tracing")]
83                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
84                };
85
86                return Err(TaskError::IO(msg.to_string()));
87            }
88        };
89        *self.process_id.write().await = Some(child_id);
90        let mut task_handles = vec![];
91        self.update_state(TaskState::Running).await;
92        if let Err(_) = event_tx
93            .send(TaskEvent::Started {
94                task_name: self.task_name.clone(),
95            })
96            .await
97        {
98            #[cfg(feature = "tracing")]
99            tracing::warn!("Event channel closed while sending TaskEvent::Started");
100        }
101
102        let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
103        let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
104        let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
105
106        // Spawn stdout and stderr watchers
107        let handles = spawn_output_watchers(
108            self.task_name.clone(),
109            self.state.clone(),
110            event_tx.clone(),
111            &mut child,
112            handle_terminator_rx.clone(),
113            self.config.ready_indicator.clone(),
114            self.config.ready_indicator_source.clone(),
115        );
116        task_handles.extend(handles);
117
118        // Spawn stdin watcher if configured
119        if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
120            let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
121            task_handles.push(handle);
122        }
123
124        // Spawn child wait watcher
125        *self.terminate_tx.lock().await = Some(terminate_tx);
126
127        let handle = spawn_wait_watcher(
128            self.task_name.clone(),
129            self.state.clone(),
130            child,
131            terminate_rx,
132            handle_terminator_tx.clone(),
133            result_tx,
134            self.process_id.clone(),
135        );
136        task_handles.push(handle);
137
138        // Spawn timeout watcher if configured
139        if let Some(timeout_ms) = self.config.timeout_ms {
140            let handle =
141                spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
142            task_handles.push(handle);
143        }
144
145        // Spawn result watcher
146        let _handle = spawn_result_watcher(
147            self.task_name.clone(),
148            self.state.clone(),
149            self.finished_at.clone(),
150            event_tx,
151            result_rx,
152            task_handles,
153        );
154
155        Ok(child_id)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    #[tokio::test]
162    async fn start_direct_ready_indicator_source_stdout() {
163        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
164        #[cfg(windows)]
165        let config = TaskConfig::new("powershell")
166            .args(["-Command", "Write-Output 'READY_INDICATOR'"])
167            .ready_indicator("READY_INDICATOR".to_string())
168            .ready_indicator_source(StreamSource::Stdout);
169        #[cfg(unix)]
170        let config = TaskConfig::new("bash")
171            .args(["-c", "echo READY_INDICATOR"])
172            .ready_indicator("READY_INDICATOR".to_string())
173            .ready_indicator_source(StreamSource::Stdout);
174
175        let mut spawner = TaskSpawner::new("ready_stdout_task".to_string(), config);
176        let result = spawner.start_direct(tx).await;
177        assert!(result.is_ok());
178
179        let mut ready_event = false;
180        while let Some(event) = rx.recv().await {
181            if let TaskEvent::Ready { task_name } = event {
182                assert_eq!(task_name, "ready_stdout_task");
183                ready_event = true;
184            }
185        }
186        assert!(
187            ready_event,
188            "Should emit Ready event when indicator is in stdout"
189        );
190    }
191
192    #[tokio::test]
193    async fn start_direct_ready_indicator_source_stderr() {
194        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
195        #[cfg(windows)]
196        let config = TaskConfig::new("powershell")
197            .args(["-Command", "Write-Error 'READY_INDICATOR'"])
198            .ready_indicator("READY_INDICATOR".to_string())
199            .ready_indicator_source(StreamSource::Stderr);
200        #[cfg(unix)]
201        let config = TaskConfig::new("bash")
202            .args(["-c", "echo READY_INDICATOR 1>&2"])
203            .ready_indicator("READY_INDICATOR".to_string())
204            .ready_indicator_source(StreamSource::Stderr);
205
206        let mut spawner = TaskSpawner::new("ready_stderr_task".to_string(), config);
207        let result = spawner.start_direct(tx).await;
208        assert!(result.is_ok());
209
210        let mut ready_event = false;
211        while let Some(event) = rx.recv().await {
212            if let TaskEvent::Ready { task_name } = event {
213                assert_eq!(task_name, "ready_stderr_task");
214                ready_event = true;
215            }
216        }
217        assert!(
218            ready_event,
219            "Should emit Ready event when indicator is in stderr"
220        );
221    }
222
223    #[tokio::test]
224    async fn start_direct_ready_indicator_source_mismatch() {
225        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
226        #[cfg(windows)]
227        let config = TaskConfig::new("powershell")
228            .args(["-Command", "Write-Output 'READY_INDICATOR'"])
229            .ready_indicator("READY_INDICATOR".to_string())
230            .ready_indicator_source(StreamSource::Stderr);
231        #[cfg(unix)]
232        let config = TaskConfig::new("bash")
233            .args(["-c", "echo READY_INDICATOR"])
234            .ready_indicator("READY_INDICATOR".to_string())
235            .ready_indicator_source(StreamSource::Stderr);
236
237        let mut spawner = TaskSpawner::new("ready_mismatch_task".to_string(), config);
238        let result = spawner.start_direct(tx).await;
239        assert!(result.is_ok());
240
241        let mut ready_event = false;
242        while let Some(event) = rx.recv().await {
243            if let TaskEvent::Ready { .. } = event {
244                ready_event = true;
245            }
246        }
247        assert!(
248            !ready_event,
249            "Should NOT emit Ready event if indicator is in wrong stream"
250        );
251    }
252    use tokio::sync::mpsc;
253
254    use crate::tasks::{
255        async_tokio::spawner::TaskSpawner,
256        config::{StreamSource, TaskConfig},
257        error::TaskError,
258        event::{TaskEvent, TaskEventStopReason},
259        state::TaskTerminateReason,
260    };
261    #[tokio::test]
262    async fn start_direct_fn_echo_command() {
263        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
264        #[cfg(windows)]
265        let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
266        #[cfg(unix)]
267        let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
268
269        let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
270
271        let result = spawner.start_direct(tx).await;
272        assert!(result.is_ok());
273
274        let mut started = false;
275        let mut stopped = false;
276        while let Some(event) = rx.recv().await {
277            match event {
278                TaskEvent::Started { task_name } => {
279                    assert_eq!(task_name, "echo_task");
280                    started = true;
281                }
282                TaskEvent::Output {
283                    task_name,
284                    line,
285                    src,
286                } => {
287                    assert_eq!(task_name, "echo_task");
288                    assert_eq!(line, "hello");
289                    assert_eq!(src, StreamSource::Stdout);
290                }
291                TaskEvent::Stopped {
292                    task_name,
293                    exit_code,
294                    reason: _,
295                } => {
296                    assert_eq!(task_name, "echo_task");
297                    assert_eq!(exit_code, Some(0));
298                    stopped = true;
299                }
300                _ => {}
301            }
302        }
303
304        assert!(started);
305        assert!(stopped);
306    }
307    #[tokio::test]
308    async fn start_direct_timeout_terminated_task() {
309        #[cfg(windows)]
310        let config = TaskConfig::new("powershell")
311            .args(["-Command", "sleep 2"])
312            .timeout_ms(1);
313        #[cfg(unix)]
314        let config = TaskConfig::new("bash")
315            .args(["-c", "sleep 2"])
316            .timeout_ms(1);
317
318        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
319        let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
320
321        let result = spawner.start_direct(tx).await;
322        assert!(result.is_ok());
323
324        let mut started = false;
325        let mut stopped = false;
326        while let Some(event) = rx.recv().await {
327            match event {
328                TaskEvent::Started { task_name } => {
329                    assert_eq!(task_name, "sleep_with_timeout_task");
330                    started = true;
331                }
332
333                TaskEvent::Stopped {
334                    task_name,
335                    exit_code,
336                    reason,
337                } => {
338                    assert_eq!(task_name, "sleep_with_timeout_task");
339                    assert_eq!(exit_code, None);
340                    assert_eq!(
341                        reason,
342                        TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
343                    );
344                    stopped = true;
345                }
346                _ => {}
347            }
348        }
349
350        assert!(started);
351        assert!(stopped);
352    }
353
354    #[tokio::test]
355    async fn start_direct_fn_invalid_empty_command() {
356        let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
357        let config = TaskConfig::new(""); // invalid: empty command
358        let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
359
360        let result = spawner.start_direct(tx).await;
361        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
362
363        // Ensure TaskState is Finished after error, not stalled at Initiating
364        let state = spawner.get_state().await;
365        assert_eq!(
366            state,
367            crate::tasks::state::TaskState::Finished,
368            "TaskState should be Finished after error, not Initiating"
369        );
370    }
371
372    #[tokio::test]
373    async fn start_direct_fn_stdin_valid() {
374        // Channel for receiving task events
375        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
376        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
377
378        #[cfg(windows)]
379        let config = TaskConfig::new("powershell")
380            .args(["-Command", "$line = Read-Host; Write-Output $line"])
381            .enable_stdin(true);
382        #[cfg(unix)]
383        let config = TaskConfig::new("bash")
384            .args(["-c", "read line; echo $line"])
385            .enable_stdin(true);
386
387        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
388
389        // Spawn the task
390        let result = spawner.start_direct(tx).await;
391        assert!(result.is_ok());
392
393        // Send input via stdin if enabled
394        stdin_tx.send("hello world".to_string()).await.unwrap();
395
396        let mut started = false;
397        let mut output_ok = false;
398        let mut stopped = false;
399
400        while let Some(event) = rx.recv().await {
401            match event {
402                TaskEvent::Started { task_name } => {
403                    assert_eq!(task_name, "stdin_task");
404                    started = true;
405                }
406                TaskEvent::Output {
407                    task_name,
408                    line,
409                    src,
410                } => {
411                    assert_eq!(task_name, "stdin_task");
412                    assert_eq!(line, "hello world");
413                    assert_eq!(src, StreamSource::Stdout);
414                    output_ok = true;
415                }
416                TaskEvent::Stopped {
417                    task_name,
418                    exit_code,
419                    ..
420                } => {
421                    assert_eq!(task_name, "stdin_task");
422                    assert_eq!(exit_code, Some(0));
423                    stopped = true;
424                }
425                _ => {}
426            }
427        }
428
429        assert!(started);
430        assert!(output_ok);
431        assert!(stopped);
432    }
433
434    #[tokio::test]
435    async fn start_direct_fn_stdin_ignore() {
436        // Channel for receiving task events
437        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
438        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
439
440        #[cfg(windows)]
441        let config = TaskConfig::new("powershell")
442            .args(["-Command", "$line = Read-Host; Write-Output $line"]);
443        #[cfg(unix)]
444        let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
445
446        // Note: stdin is not enabled in config, so stdin should be ignored
447        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
448
449        // Spawn the task
450        let result = spawner.start_direct(tx).await;
451        assert!(result.is_ok());
452
453        // Send input, but it should be ignored (receiver will be dropped, so this should error)
454        let send_result = stdin_tx.send("hello world".to_string()).await;
455        assert!(
456            send_result.is_err(),
457            "Sending to stdin_tx should error because receiver is dropped"
458        );
459
460        let mut started = false;
461        let mut output_found = false;
462        let mut stopped = false;
463
464        while let Some(event) = rx.recv().await {
465            match event {
466                TaskEvent::Started { task_name } => {
467                    assert_eq!(task_name, "stdin_task");
468                    started = true;
469                }
470                TaskEvent::Output { .. } => {
471                    // Should NOT receive output from stdin
472                    output_found = true;
473                }
474                TaskEvent::Stopped {
475                    task_name,
476                    exit_code,
477                    ..
478                } => {
479                    assert_eq!(task_name, "stdin_task");
480                    assert_eq!(exit_code, Some(0));
481                    stopped = true;
482                }
483                _ => {}
484            }
485        }
486
487        assert!(started);
488        assert!(
489            !output_found,
490            "Should not receive output from stdin when not enabled"
491        );
492        assert!(stopped);
493    }
494
495    // Error scenario tests
496    #[tokio::test]
497    async fn start_direct_command_not_found() {
498        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
499        let config = TaskConfig::new("non_existent_command");
500        let mut spawner = TaskSpawner::new("error_task".to_string(), config);
501
502        let result = spawner.start_direct(tx).await;
503        assert!(matches!(result, Err(TaskError::IO(_))));
504
505        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
506            assert_eq!(task_name, "error_task");
507            assert!(matches!(error, TaskError::IO(_)));
508            if let TaskError::IO(msg) = error {
509                #[cfg(windows)]
510                assert!(msg.contains("not found") || msg.contains("cannot find"));
511                #[cfg(unix)]
512                assert!(msg.contains("No such file or directory"));
513            }
514        } else {
515            panic!("Expected TaskEvent::Error");
516        }
517    }
518
519    #[tokio::test]
520    async fn start_direct_invalid_working_directory() {
521        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
522        let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
523
524        let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
525
526        let result = spawner.start_direct(tx).await;
527        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
528
529        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
530            assert_eq!(task_name, "working_dir_task");
531            assert!(matches!(error, TaskError::InvalidConfiguration(_)));
532        } else {
533            panic!("Expected TaskEvent::Error");
534        }
535    }
536
537    #[tokio::test]
538    async fn start_direct_zero_timeout() {
539        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
540        #[cfg(windows)]
541        let config = TaskConfig::new("powershell")
542            .args(["-Command", "Start-Sleep -Seconds 1"])
543            .timeout_ms(0);
544        #[cfg(unix)]
545        let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
546
547        let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
548
549        // Zero timeout should be rejected as invalid configuration
550        let result = spawner.start_direct(tx).await;
551        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
552
553        // Should receive an error event
554        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
555            assert_eq!(task_name, "timeout_task");
556            assert!(matches!(error, TaskError::InvalidConfiguration(_)));
557        } else {
558            panic!("Expected TaskEvent::Error with InvalidConfiguration");
559        }
560    }
561
562    #[tokio::test]
563    async fn process_id_is_none_after_task_stopped() {
564        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
565        #[cfg(windows)]
566        let config = TaskConfig::new("powershell").args(["-Command", "echo done"]);
567        #[cfg(unix)]
568        let config = TaskConfig::new("bash").args(["-c", "echo done"]);
569
570        let mut spawner = TaskSpawner::new("pid_test_task".to_string(), config);
571        let result = spawner.start_direct(tx).await;
572        assert!(result.is_ok());
573
574        let mut stopped = false;
575        while let Some(event) = rx.recv().await {
576            if let TaskEvent::Stopped { task_name, .. } = event {
577                assert_eq!(task_name, "pid_test_task");
578                stopped = true;
579                break;
580            }
581        }
582        assert!(stopped, "Task should emit Stopped event");
583        // process_id should be None after stopped
584        let pid = spawner.get_process_id().await;
585        assert!(
586            pid.is_none(),
587            "process_id should be None after task is stopped"
588        );
589    }
590
591    #[tokio::test]
592    async fn process_id_is_some_while_task_running() {
593        use std::time::Duration;
594        use tokio::time::sleep;
595        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
596        #[cfg(windows)]
597        let config = TaskConfig::new("powershell").args(["-Command", "Start-Sleep -Seconds 2"]);
598        #[cfg(unix)]
599        let config = TaskConfig::new("sleep").args(["2"]);
600
601        let mut spawner = TaskSpawner::new("pid_running_task".to_string(), config);
602        let result = spawner.start_direct(tx).await;
603        assert!(result.is_ok());
604
605        // Wait a short time to ensure the process is running
606        sleep(Duration::from_millis(500)).await;
607        let pid = spawner.get_process_id().await;
608        assert!(
609            pid.is_some(),
610            "process_id should be Some while task is running"
611        );
612
613        // Drain events to ensure clean test exit
614        while let Some(event) = rx.recv().await {
615            if let TaskEvent::Stopped { .. } = event {
616                break;
617            }
618        }
619    }
620}