tcrm_task/tasks/async_tokio/direct/
start.rs1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17 pub async fn start_direct(
18 &mut self,
19 event_tx: mpsc::Sender<TaskEvent>,
20 ) -> Result<u32, TaskError> {
21 self.update_state(TaskState::Initiating).await;
22
23 match self.config.validate() {
24 Ok(_) => {}
25 Err(e) => {
26 #[cfg(feature = "tracing")]
27 tracing::error!(error = %e, "Invalid task configuration");
28
29 let error_event = TaskEvent::Error {
30 task_name: self.task_name.clone(),
31 error: e.clone(),
32 };
33
34 if let Err(_) = event_tx.send(error_event).await {
35 #[cfg(feature = "tracing")]
36 tracing::warn!("Event channel closed while sending TaskEvent::Error");
37 };
38
39 return Err(e);
40 }
41 }
42
43 let mut cmd = Command::new(&self.config.command);
44 let mut cmd = cmd.kill_on_drop(true);
45
46 setup_command(&mut cmd, &self.config);
47 let mut child = match cmd.spawn() {
48 Ok(c) => c,
49 Err(e) => {
50 #[cfg(feature = "tracing")]
51 tracing::error!(error = %e, "Failed to spawn child process");
52
53 let error_event = TaskEvent::Error {
54 task_name: self.task_name.clone(),
55 error: TaskError::IO(e.to_string()),
56 };
57
58 if let Err(_) = event_tx.send(error_event).await {
59 #[cfg(feature = "tracing")]
60 tracing::warn!("Event channel closed while sending TaskEvent::Error");
61 };
62
63 return Err(TaskError::IO(e.to_string()));
64 }
65 };
66 let child_id = match child.id() {
67 Some(id) => id,
68 None => {
69 let msg = "Failed to get process id";
70
71 #[cfg(feature = "tracing")]
72 tracing::error!(msg);
73
74 let error_event = TaskEvent::Error {
75 task_name: self.task_name.clone(),
76 error: TaskError::IO(msg.to_string()),
77 };
78
79 if let Err(_) = event_tx.send(error_event).await {
80 #[cfg(feature = "tracing")]
81 tracing::warn!("Event channel closed while sending TaskEvent::Error");
82 };
83
84 return Err(TaskError::IO(msg.to_string()));
85 }
86 };
87 self.process_id = Some(child_id);
88 let mut task_handles = vec![];
89 self.update_state(TaskState::Running).await;
90 if let Err(_) = event_tx
91 .send(TaskEvent::Started {
92 task_name: self.task_name.clone(),
93 })
94 .await
95 {
96 #[cfg(feature = "tracing")]
97 tracing::warn!("Event channel closed while sending TaskEvent::Started");
98 }
99
100 let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
101 let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
102 let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
103
104 let handles = spawn_output_watchers(
106 self.task_name.clone(),
107 event_tx.clone(),
108 &mut child,
109 handle_terminator_rx.clone(),
110 );
111 task_handles.extend(handles);
112
113 if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
115 let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
116 task_handles.push(handle);
117 }
118
119 *self.terminate_tx.lock().await = Some(terminate_tx);
121
122 let handle = spawn_wait_watcher(
123 self.task_name.clone(),
124 self.state.clone(),
125 child,
126 terminate_rx,
127 handle_terminator_tx.clone(),
128 result_tx,
129 );
130 task_handles.push(handle);
131
132 if let Some(timeout_ms) = self.config.timeout_ms {
134 let handle =
135 spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
136 task_handles.push(handle);
137 }
138
139 let _handle = spawn_result_watcher(
141 self.task_name.clone(),
142 self.state.clone(),
143 self.finished_at.clone(),
144 event_tx,
145 result_rx,
146 task_handles,
147 );
148
149 Ok(child_id)
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use tokio::sync::mpsc;
156
157 use crate::tasks::{
158 async_tokio::spawner::TaskSpawner,
159 config::{StreamSource, TaskConfig},
160 error::TaskError,
161 event::{TaskEvent, TaskEventStopReason},
162 state::TaskTerminateReason,
163 };
164 #[tokio::test]
165 async fn start_direct_fn_echo_command() {
166 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
167 #[cfg(windows)]
168 let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
169 #[cfg(unix)]
170 let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
171
172 let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
173
174 let result = spawner.start_direct(tx).await;
175 assert!(result.is_ok());
176
177 let mut started = false;
178 let mut stopped = false;
179 while let Some(event) = rx.recv().await {
180 match event {
181 TaskEvent::Started { task_name } => {
182 assert_eq!(task_name, "echo_task");
183 started = true;
184 }
185 TaskEvent::Output {
186 task_name,
187 line,
188 src,
189 } => {
190 assert_eq!(task_name, "echo_task");
191 assert_eq!(line, "hello");
192 assert_eq!(src, StreamSource::Stdout);
193 }
194 TaskEvent::Stopped {
195 task_name,
196 exit_code,
197 reason: _,
198 } => {
199 assert_eq!(task_name, "echo_task");
200 assert_eq!(exit_code, Some(0));
201 stopped = true;
202 }
203 _ => {}
204 }
205 }
206
207 assert!(started);
208 assert!(stopped);
209 }
210 #[tokio::test]
211 async fn start_direct_timeout_terminated_task() {
212 #[cfg(windows)]
213 let config = TaskConfig::new("powershell")
214 .args(["-Command", "sleep 2"])
215 .timeout_ms(1);
216 #[cfg(unix)]
217 let config = TaskConfig::new("bash")
218 .args(["-c", "sleep 2"])
219 .timeout_ms(1);
220
221 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
222 let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
223
224 let result = spawner.start_direct(tx).await;
225 assert!(result.is_ok());
226
227 let mut started = false;
228 let mut stopped = false;
229 while let Some(event) = rx.recv().await {
230 match event {
231 TaskEvent::Started { task_name } => {
232 assert_eq!(task_name, "sleep_with_timeout_task");
233 started = true;
234 }
235
236 TaskEvent::Stopped {
237 task_name,
238 exit_code,
239 reason,
240 } => {
241 assert_eq!(task_name, "sleep_with_timeout_task");
242 assert_eq!(exit_code, None);
243 assert_eq!(
244 reason,
245 TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
246 );
247 stopped = true;
248 }
249 _ => {}
250 }
251 }
252
253 assert!(started);
254 assert!(stopped);
255 }
256
257 #[tokio::test]
258 async fn start_direct_fn_invalid_empty_command() {
259 let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
260 let config = TaskConfig::new(""); let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
262
263 let result = spawner.start_direct(tx).await;
264 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
265 }
266
267 #[tokio::test]
268 async fn start_direct_fn_stdin_valid() {
269 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
271 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
272
273 #[cfg(windows)]
274 let config = TaskConfig::new("powershell")
275 .args(["-Command", "$line = Read-Host; Write-Output $line"])
276 .enable_stdin(true);
277 #[cfg(unix)]
278 let config = TaskConfig::new("bash")
279 .args(["-c", "read line; echo $line"])
280 .enable_stdin(true);
281
282 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
283
284 let result = spawner.start_direct(tx).await;
286 assert!(result.is_ok());
287
288 stdin_tx.send("hello world".to_string()).await.unwrap();
290
291 let mut started = false;
292 let mut output_ok = false;
293 let mut stopped = false;
294
295 while let Some(event) = rx.recv().await {
296 match event {
297 TaskEvent::Started { task_name } => {
298 assert_eq!(task_name, "stdin_task");
299 started = true;
300 }
301 TaskEvent::Output {
302 task_name,
303 line,
304 src,
305 } => {
306 assert_eq!(task_name, "stdin_task");
307 assert_eq!(line, "hello world");
308 assert_eq!(src, StreamSource::Stdout);
309 output_ok = true;
310 }
311 TaskEvent::Stopped {
312 task_name,
313 exit_code,
314 ..
315 } => {
316 assert_eq!(task_name, "stdin_task");
317 assert_eq!(exit_code, Some(0));
318 stopped = true;
319 }
320 _ => {}
321 }
322 }
323
324 assert!(started);
325 assert!(output_ok);
326 assert!(stopped);
327 }
328
329 #[tokio::test]
330 async fn start_direct_fn_stdin_ignore() {
331 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
333 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
334
335 #[cfg(windows)]
336 let config = TaskConfig::new("powershell")
337 .args(["-Command", "$line = Read-Host; Write-Output $line"]);
338 #[cfg(unix)]
339 let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
340
341 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
343
344 let result = spawner.start_direct(tx).await;
346 assert!(result.is_ok());
347
348 let send_result = stdin_tx.send("hello world".to_string()).await;
350 assert!(
351 send_result.is_err(),
352 "Sending to stdin_tx should error because receiver is dropped"
353 );
354
355 let mut started = false;
356 let mut output_found = false;
357 let mut stopped = false;
358
359 while let Some(event) = rx.recv().await {
360 match event {
361 TaskEvent::Started { task_name } => {
362 assert_eq!(task_name, "stdin_task");
363 started = true;
364 }
365 TaskEvent::Output { .. } => {
366 output_found = true;
368 }
369 TaskEvent::Stopped {
370 task_name,
371 exit_code,
372 ..
373 } => {
374 assert_eq!(task_name, "stdin_task");
375 assert_eq!(exit_code, Some(0));
376 stopped = true;
377 }
378 _ => {}
379 }
380 }
381
382 assert!(started);
383 assert!(
384 !output_found,
385 "Should not receive output from stdin when not enabled"
386 );
387 assert!(stopped);
388 }
389
390 #[tokio::test]
392 async fn start_direct_command_not_found() {
393 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
394 let config = TaskConfig::new("non_existent_command");
395 let mut spawner = TaskSpawner::new("error_task".to_string(), config);
396
397 let result = spawner.start_direct(tx).await;
398 assert!(matches!(result, Err(TaskError::IO(_))));
399
400 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
401 assert_eq!(task_name, "error_task");
402 assert!(matches!(error, TaskError::IO(_)));
403 if let TaskError::IO(msg) = error {
404 #[cfg(windows)]
405 assert!(msg.contains("not found") || msg.contains("cannot find"));
406 #[cfg(unix)]
407 assert!(msg.contains("No such file or directory"));
408 }
409 } else {
410 panic!("Expected TaskEvent::Error");
411 }
412 }
413
414 #[tokio::test]
415 async fn start_direct_invalid_working_directory() {
416 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
417 let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
418
419 let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
420
421 let result = spawner.start_direct(tx).await;
422 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
423
424 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
425 assert_eq!(task_name, "working_dir_task");
426 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
427 } else {
428 panic!("Expected TaskEvent::Error");
429 }
430 }
431
432 #[tokio::test]
433 async fn start_direct_zero_timeout() {
434 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
435 #[cfg(windows)]
436 let config = TaskConfig::new("powershell")
437 .args(["-Command", "Start-Sleep -Seconds 1"])
438 .timeout_ms(0);
439 #[cfg(unix)]
440 let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
441
442 let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
443
444 let result = spawner.start_direct(tx).await;
446 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
447
448 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
450 assert_eq!(task_name, "timeout_task");
451 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
452 } else {
453 panic!("Expected TaskEvent::Error with InvalidConfiguration");
454 }
455 }
456}