tcrm_task/tasks/async_tokio/direct/
start.rs

1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17    pub async fn start_direct(
18        &mut self,
19        event_tx: mpsc::Sender<TaskEvent>,
20    ) -> Result<u32, TaskError> {
21        self.update_state(TaskState::Initiating).await;
22
23        self.config.validate()?;
24
25        let mut cmd = Command::new(&self.config.command);
26        let mut cmd = cmd.kill_on_drop(true);
27
28        setup_command(&mut cmd, &self.config);
29        let mut child = cmd.spawn()?;
30        let child_id = match child.id() {
31            Some(id) => id,
32            None => {
33                #[cfg(feature = "tracing")]
34                tracing::error!("Failed to get process id");
35                return Err(TaskError::IO(std::io::Error::new(
36                    std::io::ErrorKind::Other,
37                    "Failed to get process id",
38                )));
39            }
40        };
41        self.process_id = Some(child_id);
42        let mut task_handles = vec![];
43        self.update_state(TaskState::Running).await;
44        if let Err(_) = event_tx
45            .send(TaskEvent::Started {
46                task_name: self.task_name.clone(),
47            })
48            .await
49        {
50            #[cfg(feature = "tracing")]
51            tracing::warn!("Event channel closed while sending TaskEvent::Started");
52        }
53
54        let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
55        let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
56        let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
57
58        // Spawn stdout and stderr watchers
59        let handles = spawn_output_watchers(
60            self.task_name.clone(),
61            event_tx.clone(),
62            &mut child,
63            handle_terminator_rx.clone(),
64        );
65        task_handles.extend(handles);
66
67        // Spawn stdin watcher if configured
68        if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
69            let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
70            task_handles.push(handle);
71        }
72
73        // Spawn child wait watcher
74        *self.terminate_tx.lock().await = Some(terminate_tx);
75
76        let handle = spawn_wait_watcher(
77            self.task_name.clone(),
78            self.state.clone(),
79            child,
80            terminate_rx,
81            handle_terminator_tx.clone(),
82            result_tx,
83        );
84        task_handles.push(handle);
85
86        // Spawn timeout watcher if configured
87        if let Some(timeout_ms) = self.config.timeout_ms {
88            let handle =
89                spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
90            task_handles.push(handle);
91        }
92
93        // Spawn result watcher
94        let _handle = spawn_result_watcher(
95            self.task_name.clone(),
96            self.state.clone(),
97            self.finished_at.clone(),
98            event_tx,
99            result_rx,
100            task_handles,
101        );
102
103        Ok(child_id)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use tokio::sync::mpsc;
110
111    use crate::tasks::{
112        async_tokio::spawner::TaskSpawner,
113        config::{StreamSource, TaskConfig},
114        error::TaskError,
115        event::{TaskEvent, TaskEventStopReason},
116        state::TaskTerminateReason,
117    };
118    #[tokio::test]
119    async fn start_direct_fn_echo_command() {
120        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
121        #[cfg(windows)]
122        let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
123        #[cfg(unix)]
124        let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
125
126        let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
127
128        let result = spawner.start_direct(tx).await;
129        assert!(result.is_ok());
130
131        let mut started = false;
132        let mut stopped = false;
133        while let Some(event) = rx.recv().await {
134            match event {
135                TaskEvent::Started { task_name } => {
136                    assert_eq!(task_name, "echo_task");
137                    started = true;
138                }
139                TaskEvent::Output {
140                    task_name,
141                    line,
142                    src,
143                } => {
144                    assert_eq!(task_name, "echo_task");
145                    assert_eq!(line, "hello");
146                    assert_eq!(src, StreamSource::Stdout);
147                }
148                TaskEvent::Stopped {
149                    task_name,
150                    exit_code,
151                    reason: _,
152                } => {
153                    assert_eq!(task_name, "echo_task");
154                    assert_eq!(exit_code, Some(0));
155                    stopped = true;
156                }
157                _ => {}
158            }
159        }
160
161        assert!(started);
162        assert!(stopped);
163    }
164    #[tokio::test]
165    async fn start_direct_timeout_terminated_task() {
166        #[cfg(windows)]
167        let config = TaskConfig::new("powershell")
168            .args(["-Command", "sleep 2"])
169            .timeout_ms(1);
170        #[cfg(unix)]
171        let config = TaskConfig::new("bash")
172            .args(["-c", "sleep 2"])
173            .timeout_ms(1);
174
175        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
176        let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
177
178        let result = spawner.start_direct(tx).await;
179        assert!(result.is_ok());
180
181        let mut started = false;
182        let mut stopped = false;
183        while let Some(event) = rx.recv().await {
184            match event {
185                TaskEvent::Started { task_name } => {
186                    assert_eq!(task_name, "sleep_with_timeout_task");
187                    started = true;
188                }
189
190                TaskEvent::Stopped {
191                    task_name,
192                    exit_code,
193                    reason,
194                } => {
195                    assert_eq!(task_name, "sleep_with_timeout_task");
196                    assert_eq!(exit_code, None);
197                    assert_eq!(
198                        reason,
199                        TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
200                    );
201                    stopped = true;
202                }
203                _ => {}
204            }
205        }
206
207        assert!(started);
208        assert!(stopped);
209    }
210
211    #[tokio::test]
212    async fn start_direct_fn_invalid_empty_command() {
213        let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
214        let config = TaskConfig::new(""); // invalid: empty command
215        let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
216
217        let result = spawner.start_direct(tx).await;
218        assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
219    }
220
221    #[tokio::test]
222    async fn start_direct_fn_stdin_valid() {
223        // Channel for receiving task events
224        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
225        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
226
227        #[cfg(windows)]
228        let config = TaskConfig::new("powershell")
229            .args(["-Command", "$line = Read-Host; Write-Output $line"])
230            .enable_stdin(true);
231        #[cfg(unix)]
232        let config = TaskConfig::new("bash")
233            .args(["-c", "read line; echo $line"])
234            .enable_stdin(true);
235
236        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
237
238        // Spawn the task
239        let result = spawner.start_direct(tx).await;
240        assert!(result.is_ok());
241
242        // Send input via stdin if enabled
243        stdin_tx.send("hello world".to_string()).await.unwrap();
244
245        let mut started = false;
246        let mut output_ok = false;
247        let mut stopped = false;
248
249        while let Some(event) = rx.recv().await {
250            match event {
251                TaskEvent::Started { task_name } => {
252                    assert_eq!(task_name, "stdin_task");
253                    started = true;
254                }
255                TaskEvent::Output {
256                    task_name,
257                    line,
258                    src,
259                } => {
260                    assert_eq!(task_name, "stdin_task");
261                    assert_eq!(line, "hello world");
262                    assert_eq!(src, StreamSource::Stdout);
263                    output_ok = true;
264                }
265                TaskEvent::Stopped {
266                    task_name,
267                    exit_code,
268                    ..
269                } => {
270                    assert_eq!(task_name, "stdin_task");
271                    assert_eq!(exit_code, Some(0));
272                    stopped = true;
273                }
274                _ => {}
275            }
276        }
277
278        assert!(started);
279        assert!(output_ok);
280        assert!(stopped);
281    }
282
283    #[tokio::test]
284    async fn start_direct_fn_stdin_ignore() {
285        // Channel for receiving task events
286        let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
287        let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
288
289        #[cfg(windows)]
290        let config = TaskConfig::new("powershell")
291            .args(["-Command", "$line = Read-Host; Write-Output $line"]);
292        #[cfg(unix)]
293        let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
294
295        // Note: stdin is not enabled in config, so stdin should be ignored
296        let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
297
298        // Spawn the task
299        let result = spawner.start_direct(tx).await;
300        assert!(result.is_ok());
301
302        // Send input, but it should be ignored (receiver will be dropped, so this should error)
303        let send_result = stdin_tx.send("hello world".to_string()).await;
304        assert!(
305            send_result.is_err(),
306            "Sending to stdin_tx should error because receiver is dropped"
307        );
308
309        let mut started = false;
310        let mut output_found = false;
311        let mut stopped = false;
312
313        while let Some(event) = rx.recv().await {
314            match event {
315                TaskEvent::Started { task_name } => {
316                    assert_eq!(task_name, "stdin_task");
317                    started = true;
318                }
319                TaskEvent::Output { .. } => {
320                    // Should NOT receive output from stdin
321                    output_found = true;
322                }
323                TaskEvent::Stopped {
324                    task_name,
325                    exit_code,
326                    ..
327                } => {
328                    assert_eq!(task_name, "stdin_task");
329                    assert_eq!(exit_code, Some(0));
330                    stopped = true;
331                }
332                _ => {}
333            }
334        }
335
336        assert!(started);
337        assert!(
338            !output_found,
339            "Should not receive output from stdin when not enabled"
340        );
341        assert!(stopped);
342    }
343}