tcrm_task/tasks/async_tokio/direct/
start.rs

1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17    pub async fn start_direct(
18        &mut self,
19        event_tx: mpsc::Sender<TaskEvent>,
20    ) -> Result<u32, TaskError> {
21        self.update_state(TaskState::Initiating).await;
22
23        match self.config.validate() {
24            Ok(_) => {}
25            Err(e) => {
26                #[cfg(feature = "tracing")]
27                tracing::error!(error = %e, "Invalid task configuration");
28
29                let error_event = TaskEvent::Error {
30                    task_name: self.task_name.clone(),
31                    error: e.clone(),
32                };
33
34                if let Err(_) = event_tx.send(error_event).await {
35                    #[cfg(feature = "tracing")]
36                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
37                };
38
39                return Err(e);
40            }
41        }
42
43        let mut cmd = Command::new(&self.config.command);
44        let mut cmd = cmd.kill_on_drop(true);
45
46        setup_command(&mut cmd, &self.config);
47        let mut child = match cmd.spawn() {
48            Ok(c) => c,
49            Err(e) => {
50                #[cfg(feature = "tracing")]
51                tracing::error!(error = %e, "Failed to spawn child process");
52
53                let error_event = TaskEvent::Error {
54                    task_name: self.task_name.clone(),
55                    error: TaskError::IO(e.to_string()),
56                };
57
58                if let Err(_) = event_tx.send(error_event).await {
59                    #[cfg(feature = "tracing")]
60                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
61                };
62
63                return Err(TaskError::IO(e.to_string()));
64            }
65        };
66        let child_id = match child.id() {
67            Some(id) => id,
68            None => {
69                let msg = "Failed to get process id";
70
71                #[cfg(feature = "tracing")]
72                tracing::error!(msg);
73
74                let error_event = TaskEvent::Error {
75                    task_name: self.task_name.clone(),
76                    error: TaskError::IO(msg.to_string()),
77                };
78
79                if let Err(_) = event_tx.send(error_event).await {
80                    #[cfg(feature = "tracing")]
81                    tracing::warn!("Event channel closed while sending TaskEvent::Error");
82                };
83
84                return Err(TaskError::IO(msg.to_string()));
85            }
86        };
87        self.process_id = Some(child_id);
88        let mut task_handles = vec![];
89        self.update_state(TaskState::Running).await;
90        if let Err(_) = event_tx
91            .send(TaskEvent::Started {
92                task_name: self.task_name.clone(),
93            })
94            .await
95        {
96            #[cfg(feature = "tracing")]
97            tracing::warn!("Event channel closed while sending TaskEvent::Started");
98        }
99
100        let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
101        let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
102        let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
103
104        // Spawn stdout and stderr watchers
105        let handles = spawn_output_watchers(
106            self.task_name.clone(),
107            event_tx.clone(),
108            &mut child,
109            handle_terminator_rx.clone(),
110        );
111        task_handles.extend(handles);
112
113        // Spawn stdin watcher if configured
114        if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
115            let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
116            task_handles.push(handle);
117        }
118
119        // Spawn child wait watcher
120        *self.terminate_tx.lock().await = Some(terminate_tx);
121
122        let handle = spawn_wait_watcher(
123            self.task_name.clone(),
124            self.state.clone(),
125            child,
126            terminate_rx,
127            handle_terminator_tx.clone(),
128            result_tx,
129        );
130        task_handles.push(handle);
131
132        // Spawn timeout watcher if configured
133        if let Some(timeout_ms) = self.config.timeout_ms {
134            let handle =
135                spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
136            task_handles.push(handle);
137        }
138
139        // Spawn result watcher
140        let _handle = spawn_result_watcher(
141            self.task_name.clone(),
142            self.state.clone(),
143            self.finished_at.clone(),
144            event_tx,
145            result_rx,
146            task_handles,
147        );
148
149        Ok(child_id)
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use tokio::sync::mpsc;
156
157    use crate::tasks::{
158        async_tokio::spawner::TaskSpawner,
159        config::{StreamSource, TaskConfig},
160        error::TaskError,
161        event::{TaskEvent, TaskEventStopReason},
162        state::TaskTerminateReason,
163    };
164    #[tokio::test]
165    async fn start_direct_fn_echo_command() {
166        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
167        #[cfg(windows)]
168        let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
169        #[cfg(unix)]
170        let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
171
172        let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
173
174        let result = spawner.start_direct(tx).await;
175        assert!(result.is_ok());
176
177        let mut started = false;
178        let mut stopped = false;
179        while let Some(event) = rx.recv().await {
180            match event {
181                TaskEvent::Started { task_name } => {
182                    assert_eq!(task_name, "echo_task");
183                    started = true;
184                }
185                TaskEvent::Output {
186                    task_name,
187                    line,
188                    src,
189                } => {
190                    assert_eq!(task_name, "echo_task");
191                    assert_eq!(line, "hello");
192                    assert_eq!(src, StreamSource::Stdout);
193                }
194                TaskEvent::Stopped {
195                    task_name,
196                    exit_code,
197                    reason: _,
198                } => {
199                    assert_eq!(task_name, "echo_task");
200                    assert_eq!(exit_code, Some(0));
201                    stopped = true;
202                }
203                _ => {}
204            }
205        }
206
207        assert!(started);
208        assert!(stopped);
209    }
210    #[tokio::test]
211    async fn start_direct_timeout_terminated_task() {
212        #[cfg(windows)]
213        let config = TaskConfig::new("powershell")
214            .args(["-Command", "sleep 2"])
215            .timeout_ms(1);
216        #[cfg(unix)]
217        let config = TaskConfig::new("bash")
218            .args(["-c", "sleep 2"])
219            .timeout_ms(1);
220
221        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
222        let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
223
224        let result = spawner.start_direct(tx).await;
225        assert!(result.is_ok());
226
227        let mut started = false;
228        let mut stopped = false;
229        while let Some(event) = rx.recv().await {
230            match event {
231                TaskEvent::Started { task_name } => {
232                    assert_eq!(task_name, "sleep_with_timeout_task");
233                    started = true;
234                }
235
236                TaskEvent::Stopped {
237                    task_name,
238                    exit_code,
239                    reason,
240                } => {
241                    assert_eq!(task_name, "sleep_with_timeout_task");
242                    assert_eq!(exit_code, None);
243                    assert_eq!(
244                        reason,
245                        TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
246                    );
247                    stopped = true;
248                }
249                _ => {}
250            }
251        }
252
253        assert!(started);
254        assert!(stopped);
255    }
256
257    #[tokio::test]
258    async fn start_direct_fn_invalid_empty_command() {
259        let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
260        let config = TaskConfig::new(""); // invalid: empty command
261        let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
262
263        let result = spawner.start_direct(tx).await;
264        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
265    }
266
267    #[tokio::test]
268    async fn start_direct_fn_stdin_valid() {
269        // Channel for receiving task events
270        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
271        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
272
273        #[cfg(windows)]
274        let config = TaskConfig::new("powershell")
275            .args(["-Command", "$line = Read-Host; Write-Output $line"])
276            .enable_stdin(true);
277        #[cfg(unix)]
278        let config = TaskConfig::new("bash")
279            .args(["-c", "read line; echo $line"])
280            .enable_stdin(true);
281
282        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
283
284        // Spawn the task
285        let result = spawner.start_direct(tx).await;
286        assert!(result.is_ok());
287
288        // Send input via stdin if enabled
289        stdin_tx.send("hello world".to_string()).await.unwrap();
290
291        let mut started = false;
292        let mut output_ok = false;
293        let mut stopped = false;
294
295        while let Some(event) = rx.recv().await {
296            match event {
297                TaskEvent::Started { task_name } => {
298                    assert_eq!(task_name, "stdin_task");
299                    started = true;
300                }
301                TaskEvent::Output {
302                    task_name,
303                    line,
304                    src,
305                } => {
306                    assert_eq!(task_name, "stdin_task");
307                    assert_eq!(line, "hello world");
308                    assert_eq!(src, StreamSource::Stdout);
309                    output_ok = true;
310                }
311                TaskEvent::Stopped {
312                    task_name,
313                    exit_code,
314                    ..
315                } => {
316                    assert_eq!(task_name, "stdin_task");
317                    assert_eq!(exit_code, Some(0));
318                    stopped = true;
319                }
320                _ => {}
321            }
322        }
323
324        assert!(started);
325        assert!(output_ok);
326        assert!(stopped);
327    }
328
329    #[tokio::test]
330    async fn start_direct_fn_stdin_ignore() {
331        // Channel for receiving task events
332        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
333        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
334
335        #[cfg(windows)]
336        let config = TaskConfig::new("powershell")
337            .args(["-Command", "$line = Read-Host; Write-Output $line"]);
338        #[cfg(unix)]
339        let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
340
341        // Note: stdin is not enabled in config, so stdin should be ignored
342        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
343
344        // Spawn the task
345        let result = spawner.start_direct(tx).await;
346        assert!(result.is_ok());
347
348        // Send input, but it should be ignored (receiver will be dropped, so this should error)
349        let send_result = stdin_tx.send("hello world".to_string()).await;
350        assert!(
351            send_result.is_err(),
352            "Sending to stdin_tx should error because receiver is dropped"
353        );
354
355        let mut started = false;
356        let mut output_found = false;
357        let mut stopped = false;
358
359        while let Some(event) = rx.recv().await {
360            match event {
361                TaskEvent::Started { task_name } => {
362                    assert_eq!(task_name, "stdin_task");
363                    started = true;
364                }
365                TaskEvent::Output { .. } => {
366                    // Should NOT receive output from stdin
367                    output_found = true;
368                }
369                TaskEvent::Stopped {
370                    task_name,
371                    exit_code,
372                    ..
373                } => {
374                    assert_eq!(task_name, "stdin_task");
375                    assert_eq!(exit_code, Some(0));
376                    stopped = true;
377                }
378                _ => {}
379            }
380        }
381
382        assert!(started);
383        assert!(
384            !output_found,
385            "Should not receive output from stdin when not enabled"
386        );
387        assert!(stopped);
388    }
389
390    // Error scenario tests
391    #[tokio::test]
392    async fn start_direct_command_not_found() {
393        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
394        let config = TaskConfig::new("non_existent_command");
395        let mut spawner = TaskSpawner::new("error_task".to_string(), config);
396
397        let result = spawner.start_direct(tx).await;
398        assert!(matches!(result, Err(TaskError::IO(_))));
399
400        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
401            assert_eq!(task_name, "error_task");
402            assert!(matches!(error, TaskError::IO(_)));
403            if let TaskError::IO(msg) = error {
404                #[cfg(windows)]
405                assert!(msg.contains("not found") || msg.contains("cannot find"));
406                #[cfg(unix)]
407                assert!(msg.contains("No such file or directory"));
408            }
409        } else {
410            panic!("Expected TaskEvent::Error");
411        }
412    }
413
414    #[tokio::test]
415    async fn start_direct_invalid_working_directory() {
416        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
417        let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
418
419        let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
420
421        let result = spawner.start_direct(tx).await;
422        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
423
424        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
425            assert_eq!(task_name, "working_dir_task");
426            assert!(matches!(error, TaskError::InvalidConfiguration(_)));
427        } else {
428            panic!("Expected TaskEvent::Error");
429        }
430    }
431
432    #[tokio::test]
433    async fn start_direct_zero_timeout() {
434        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
435        #[cfg(windows)]
436        let config = TaskConfig::new("powershell")
437            .args(["-Command", "Start-Sleep -Seconds 1"])
438            .timeout_ms(0);
439        #[cfg(unix)]
440        let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
441
442        let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
443
444        // Zero timeout should be rejected as invalid configuration
445        let result = spawner.start_direct(tx).await;
446        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
447
448        // Should receive an error event
449        if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
450            assert_eq!(task_name, "timeout_task");
451            assert!(matches!(error, TaskError::InvalidConfiguration(_)));
452        } else {
453            panic!("Expected TaskEvent::Error with InvalidConfiguration");
454        }
455    }
456}