tcrm_task/tasks/async_tokio/direct/
start.rs1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17 pub async fn start_direct(
18 &mut self,
19 event_tx: mpsc::Sender<TaskEvent>,
20 ) -> Result<u32, TaskError> {
21 self.update_state(TaskState::Initiating).await;
22
23 match self.config.validate() {
24 Ok(_) => {}
25 Err(e) => {
26 #[cfg(feature = "tracing")]
27 tracing::error!(error = %e, "Invalid task configuration");
28
29 let error_event = TaskEvent::Error {
30 task_name: self.task_name.clone(),
31 error: e.clone(),
32 };
33
34 if let Err(_) = event_tx.send(error_event).await {
35 #[cfg(feature = "tracing")]
36 tracing::warn!("Event channel closed while sending TaskEvent::Error");
37 };
38
39 return Err(e);
40 }
41 }
42
43 let mut cmd = Command::new(&self.config.command);
44 let mut cmd = cmd.kill_on_drop(true);
45
46 setup_command(&mut cmd, &self.config);
47 let mut child = match cmd.spawn() {
48 Ok(c) => c,
49 Err(e) => {
50 #[cfg(feature = "tracing")]
51 tracing::error!(error = %e, "Failed to spawn child process");
52
53 let error_event = TaskEvent::Error {
54 task_name: self.task_name.clone(),
55 error: TaskError::IO(e.to_string()),
56 };
57
58 if let Err(_) = event_tx.send(error_event).await {
59 #[cfg(feature = "tracing")]
60 tracing::warn!("Event channel closed while sending TaskEvent::Error");
61 };
62
63 return Err(TaskError::IO(e.to_string()));
64 }
65 };
66 let child_id = match child.id() {
67 Some(id) => id,
68 None => {
69 let msg = "Failed to get process id";
70
71 #[cfg(feature = "tracing")]
72 tracing::error!(msg);
73
74 let error_event = TaskEvent::Error {
75 task_name: self.task_name.clone(),
76 error: TaskError::IO(msg.to_string()),
77 };
78
79 if let Err(_) = event_tx.send(error_event).await {
80 #[cfg(feature = "tracing")]
81 tracing::warn!("Event channel closed while sending TaskEvent::Error");
82 };
83
84 return Err(TaskError::IO(msg.to_string()));
85 }
86 };
87 *self.process_id.write().await = Some(child_id);
88 let mut task_handles = vec![];
89 self.update_state(TaskState::Running).await;
90 if let Err(_) = event_tx
91 .send(TaskEvent::Started {
92 task_name: self.task_name.clone(),
93 })
94 .await
95 {
96 #[cfg(feature = "tracing")]
97 tracing::warn!("Event channel closed while sending TaskEvent::Started");
98 }
99
100 let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
101 let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
102 let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
103
104 let handles = spawn_output_watchers(
106 self.task_name.clone(),
107 event_tx.clone(),
108 &mut child,
109 handle_terminator_rx.clone(),
110 );
111 task_handles.extend(handles);
112
113 if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
115 let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
116 task_handles.push(handle);
117 }
118
119 *self.terminate_tx.lock().await = Some(terminate_tx);
121
122 let handle = spawn_wait_watcher(
123 self.task_name.clone(),
124 self.state.clone(),
125 child,
126 terminate_rx,
127 handle_terminator_tx.clone(),
128 result_tx,
129 self.process_id.clone(),
130 );
131 task_handles.push(handle);
132
133 if let Some(timeout_ms) = self.config.timeout_ms {
135 let handle =
136 spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
137 task_handles.push(handle);
138 }
139
140 let _handle = spawn_result_watcher(
142 self.task_name.clone(),
143 self.state.clone(),
144 self.finished_at.clone(),
145 event_tx,
146 result_rx,
147 task_handles,
148 );
149
150 Ok(child_id)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use tokio::sync::mpsc;
157
158 use crate::tasks::{
159 async_tokio::spawner::TaskSpawner,
160 config::{StreamSource, TaskConfig},
161 error::TaskError,
162 event::{TaskEvent, TaskEventStopReason},
163 state::TaskTerminateReason,
164 };
165 #[tokio::test]
166 async fn start_direct_fn_echo_command() {
167 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
168 #[cfg(windows)]
169 let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
170 #[cfg(unix)]
171 let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
172
173 let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
174
175 let result = spawner.start_direct(tx).await;
176 assert!(result.is_ok());
177
178 let mut started = false;
179 let mut stopped = false;
180 while let Some(event) = rx.recv().await {
181 match event {
182 TaskEvent::Started { task_name } => {
183 assert_eq!(task_name, "echo_task");
184 started = true;
185 }
186 TaskEvent::Output {
187 task_name,
188 line,
189 src,
190 } => {
191 assert_eq!(task_name, "echo_task");
192 assert_eq!(line, "hello");
193 assert_eq!(src, StreamSource::Stdout);
194 }
195 TaskEvent::Stopped {
196 task_name,
197 exit_code,
198 reason: _,
199 } => {
200 assert_eq!(task_name, "echo_task");
201 assert_eq!(exit_code, Some(0));
202 stopped = true;
203 }
204 _ => {}
205 }
206 }
207
208 assert!(started);
209 assert!(stopped);
210 }
211 #[tokio::test]
212 async fn start_direct_timeout_terminated_task() {
213 #[cfg(windows)]
214 let config = TaskConfig::new("powershell")
215 .args(["-Command", "sleep 2"])
216 .timeout_ms(1);
217 #[cfg(unix)]
218 let config = TaskConfig::new("bash")
219 .args(["-c", "sleep 2"])
220 .timeout_ms(1);
221
222 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
223 let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
224
225 let result = spawner.start_direct(tx).await;
226 assert!(result.is_ok());
227
228 let mut started = false;
229 let mut stopped = false;
230 while let Some(event) = rx.recv().await {
231 match event {
232 TaskEvent::Started { task_name } => {
233 assert_eq!(task_name, "sleep_with_timeout_task");
234 started = true;
235 }
236
237 TaskEvent::Stopped {
238 task_name,
239 exit_code,
240 reason,
241 } => {
242 assert_eq!(task_name, "sleep_with_timeout_task");
243 assert_eq!(exit_code, None);
244 assert_eq!(
245 reason,
246 TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
247 );
248 stopped = true;
249 }
250 _ => {}
251 }
252 }
253
254 assert!(started);
255 assert!(stopped);
256 }
257
258 #[tokio::test]
259 async fn start_direct_fn_invalid_empty_command() {
260 let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
261 let config = TaskConfig::new(""); let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
263
264 let result = spawner.start_direct(tx).await;
265 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
266 }
267
268 #[tokio::test]
269 async fn start_direct_fn_stdin_valid() {
270 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
272 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
273
274 #[cfg(windows)]
275 let config = TaskConfig::new("powershell")
276 .args(["-Command", "$line = Read-Host; Write-Output $line"])
277 .enable_stdin(true);
278 #[cfg(unix)]
279 let config = TaskConfig::new("bash")
280 .args(["-c", "read line; echo $line"])
281 .enable_stdin(true);
282
283 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
284
285 let result = spawner.start_direct(tx).await;
287 assert!(result.is_ok());
288
289 stdin_tx.send("hello world".to_string()).await.unwrap();
291
292 let mut started = false;
293 let mut output_ok = false;
294 let mut stopped = false;
295
296 while let Some(event) = rx.recv().await {
297 match event {
298 TaskEvent::Started { task_name } => {
299 assert_eq!(task_name, "stdin_task");
300 started = true;
301 }
302 TaskEvent::Output {
303 task_name,
304 line,
305 src,
306 } => {
307 assert_eq!(task_name, "stdin_task");
308 assert_eq!(line, "hello world");
309 assert_eq!(src, StreamSource::Stdout);
310 output_ok = true;
311 }
312 TaskEvent::Stopped {
313 task_name,
314 exit_code,
315 ..
316 } => {
317 assert_eq!(task_name, "stdin_task");
318 assert_eq!(exit_code, Some(0));
319 stopped = true;
320 }
321 _ => {}
322 }
323 }
324
325 assert!(started);
326 assert!(output_ok);
327 assert!(stopped);
328 }
329
330 #[tokio::test]
331 async fn start_direct_fn_stdin_ignore() {
332 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
334 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
335
336 #[cfg(windows)]
337 let config = TaskConfig::new("powershell")
338 .args(["-Command", "$line = Read-Host; Write-Output $line"]);
339 #[cfg(unix)]
340 let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
341
342 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
344
345 let result = spawner.start_direct(tx).await;
347 assert!(result.is_ok());
348
349 let send_result = stdin_tx.send("hello world".to_string()).await;
351 assert!(
352 send_result.is_err(),
353 "Sending to stdin_tx should error because receiver is dropped"
354 );
355
356 let mut started = false;
357 let mut output_found = false;
358 let mut stopped = false;
359
360 while let Some(event) = rx.recv().await {
361 match event {
362 TaskEvent::Started { task_name } => {
363 assert_eq!(task_name, "stdin_task");
364 started = true;
365 }
366 TaskEvent::Output { .. } => {
367 output_found = true;
369 }
370 TaskEvent::Stopped {
371 task_name,
372 exit_code,
373 ..
374 } => {
375 assert_eq!(task_name, "stdin_task");
376 assert_eq!(exit_code, Some(0));
377 stopped = true;
378 }
379 _ => {}
380 }
381 }
382
383 assert!(started);
384 assert!(
385 !output_found,
386 "Should not receive output from stdin when not enabled"
387 );
388 assert!(stopped);
389 }
390
391 #[tokio::test]
393 async fn start_direct_command_not_found() {
394 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
395 let config = TaskConfig::new("non_existent_command");
396 let mut spawner = TaskSpawner::new("error_task".to_string(), config);
397
398 let result = spawner.start_direct(tx).await;
399 assert!(matches!(result, Err(TaskError::IO(_))));
400
401 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
402 assert_eq!(task_name, "error_task");
403 assert!(matches!(error, TaskError::IO(_)));
404 if let TaskError::IO(msg) = error {
405 #[cfg(windows)]
406 assert!(msg.contains("not found") || msg.contains("cannot find"));
407 #[cfg(unix)]
408 assert!(msg.contains("No such file or directory"));
409 }
410 } else {
411 panic!("Expected TaskEvent::Error");
412 }
413 }
414
415 #[tokio::test]
416 async fn start_direct_invalid_working_directory() {
417 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
418 let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
419
420 let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
421
422 let result = spawner.start_direct(tx).await;
423 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
424
425 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
426 assert_eq!(task_name, "working_dir_task");
427 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
428 } else {
429 panic!("Expected TaskEvent::Error");
430 }
431 }
432
433 #[tokio::test]
434 async fn start_direct_zero_timeout() {
435 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
436 #[cfg(windows)]
437 let config = TaskConfig::new("powershell")
438 .args(["-Command", "Start-Sleep -Seconds 1"])
439 .timeout_ms(0);
440 #[cfg(unix)]
441 let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
442
443 let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
444
445 let result = spawner.start_direct(tx).await;
447 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
448
449 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
451 assert_eq!(task_name, "timeout_task");
452 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
453 } else {
454 panic!("Expected TaskEvent::Error with InvalidConfiguration");
455 }
456 }
457
458 #[tokio::test]
459 async fn process_id_is_none_after_task_stopped() {
460 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
461 #[cfg(windows)]
462 let config = TaskConfig::new("powershell").args(["-Command", "echo done"]);
463 #[cfg(unix)]
464 let config = TaskConfig::new("bash").args(["-c", "echo done"]);
465
466 let mut spawner = TaskSpawner::new("pid_test_task".to_string(), config);
467 let result = spawner.start_direct(tx).await;
468 assert!(result.is_ok());
469
470 let mut stopped = false;
471 while let Some(event) = rx.recv().await {
472 if let TaskEvent::Stopped { task_name, .. } = event {
473 assert_eq!(task_name, "pid_test_task");
474 stopped = true;
475 break;
476 }
477 }
478 assert!(stopped, "Task should emit Stopped event");
479 let pid = spawner.get_process_id().await;
481 assert!(
482 pid.is_none(),
483 "process_id should be None after task is stopped"
484 );
485 }
486
487 #[tokio::test]
488 async fn process_id_is_some_while_task_running() {
489 use std::time::Duration;
490 use tokio::time::sleep;
491 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
492 #[cfg(windows)]
493 let config = TaskConfig::new("powershell").args(["-Command", "Start-Sleep -Seconds 2"]);
494 #[cfg(unix)]
495 let config = TaskConfig::new("sleep").args(["2"]);
496
497 let mut spawner = TaskSpawner::new("pid_running_task".to_string(), config);
498 let result = spawner.start_direct(tx).await;
499 assert!(result.is_ok());
500
501 sleep(Duration::from_millis(500)).await;
503 let pid = spawner.get_process_id().await;
504 assert!(
505 pid.is_some(),
506 "process_id should be Some while task is running"
507 );
508
509 while let Some(event) = rx.recv().await {
511 if let TaskEvent::Stopped { .. } = event {
512 break;
513 }
514 }
515 }
516}