1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17 pub async fn start_direct(
18 &mut self,
19 event_tx: mpsc::Sender<TaskEvent>,
20 ) -> Result<u32, TaskError> {
21 self.update_state(TaskState::Initiating).await;
22
23 match self.config.validate() {
24 Ok(_) => {}
25 Err(e) => {
26 #[cfg(feature = "tracing")]
27 tracing::error!(error = %e, "Invalid task configuration");
28
29 self.update_state(TaskState::Finished).await;
30 let error_event = TaskEvent::Error {
31 task_name: self.task_name.clone(),
32 error: e.clone(),
33 };
34
35 if let Err(_) = event_tx.send(error_event).await {
36 #[cfg(feature = "tracing")]
37 tracing::warn!("Event channel closed while sending TaskEvent::Error");
38 };
39 return Err(e);
40 }
41 }
42
43 let mut cmd = Command::new(&self.config.command);
44 let mut cmd = cmd.kill_on_drop(true);
45
46 setup_command(&mut cmd, &self.config);
47 let mut child = match cmd.spawn() {
48 Ok(c) => c,
49 Err(e) => {
50 #[cfg(feature = "tracing")]
51 tracing::error!(error = %e, "Failed to spawn child process");
52
53 self.update_state(TaskState::Finished).await;
54 let error_event = TaskEvent::Error {
55 task_name: self.task_name.clone(),
56 error: TaskError::IO(e.to_string()),
57 };
58
59 if let Err(_) = event_tx.send(error_event).await {
60 #[cfg(feature = "tracing")]
61 tracing::warn!("Event channel closed while sending TaskEvent::Error");
62 };
63
64 return Err(TaskError::IO(e.to_string()));
65 }
66 };
67 let child_id = match child.id() {
68 Some(id) => id,
69 None => {
70 let msg = "Failed to get process id";
71
72 #[cfg(feature = "tracing")]
73 tracing::error!(msg);
74
75 self.update_state(TaskState::Finished).await;
76 let error_event = TaskEvent::Error {
77 task_name: self.task_name.clone(),
78 error: TaskError::IO(msg.to_string()),
79 };
80
81 if let Err(_) = event_tx.send(error_event).await {
82 #[cfg(feature = "tracing")]
83 tracing::warn!("Event channel closed while sending TaskEvent::Error");
84 };
85
86 return Err(TaskError::IO(msg.to_string()));
87 }
88 };
89 *self.process_id.write().await = Some(child_id);
90 let mut task_handles = vec![];
91 self.update_state(TaskState::Running).await;
92 if let Err(_) = event_tx
93 .send(TaskEvent::Started {
94 task_name: self.task_name.clone(),
95 })
96 .await
97 {
98 #[cfg(feature = "tracing")]
99 tracing::warn!("Event channel closed while sending TaskEvent::Started");
100 }
101
102 let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
103 let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
104 let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
105
106 let handles = spawn_output_watchers(
108 self.task_name.clone(),
109 self.state.clone(),
110 event_tx.clone(),
111 &mut child,
112 handle_terminator_rx.clone(),
113 self.config.ready_indicator.clone(),
114 self.config.ready_indicator_source.clone(),
115 );
116 task_handles.extend(handles);
117
118 if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
120 let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
121 task_handles.push(handle);
122 }
123
124 *self.terminate_tx.lock().await = Some(terminate_tx);
126
127 let handle = spawn_wait_watcher(
128 self.task_name.clone(),
129 self.state.clone(),
130 child,
131 terminate_rx,
132 handle_terminator_tx.clone(),
133 result_tx,
134 self.process_id.clone(),
135 );
136 task_handles.push(handle);
137
138 if let Some(timeout_ms) = self.config.timeout_ms {
140 let handle =
141 spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
142 task_handles.push(handle);
143 }
144
145 let _handle = spawn_result_watcher(
147 self.task_name.clone(),
148 self.state.clone(),
149 self.finished_at.clone(),
150 event_tx,
151 result_rx,
152 task_handles,
153 );
154
155 Ok(child_id)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 #[tokio::test]
162 async fn start_direct_ready_indicator_source_stdout() {
163 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
164 #[cfg(windows)]
165 let config = TaskConfig::new("powershell")
166 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
167 .ready_indicator("READY_INDICATOR".to_string())
168 .ready_indicator_source(StreamSource::Stdout);
169 #[cfg(unix)]
170 let config = TaskConfig::new("bash")
171 .args(["-c", "echo READY_INDICATOR"])
172 .ready_indicator("READY_INDICATOR".to_string())
173 .ready_indicator_source(StreamSource::Stdout);
174
175 let mut spawner = TaskSpawner::new("ready_stdout_task".to_string(), config);
176 let result = spawner.start_direct(tx).await;
177 assert!(result.is_ok());
178
179 let mut ready_event = false;
180 while let Some(event) = rx.recv().await {
181 if let TaskEvent::Ready { task_name } = event {
182 assert_eq!(task_name, "ready_stdout_task");
183 ready_event = true;
184 }
185 }
186 assert!(
187 ready_event,
188 "Should emit Ready event when indicator is in stdout"
189 );
190 }
191
192 #[tokio::test]
193 async fn start_direct_ready_indicator_source_stderr() {
194 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
195 #[cfg(windows)]
196 let config = TaskConfig::new("powershell")
197 .args(["-Command", "Write-Error 'READY_INDICATOR'"])
198 .ready_indicator("READY_INDICATOR".to_string())
199 .ready_indicator_source(StreamSource::Stderr);
200 #[cfg(unix)]
201 let config = TaskConfig::new("bash")
202 .args(["-c", "echo READY_INDICATOR 1>&2"])
203 .ready_indicator("READY_INDICATOR".to_string())
204 .ready_indicator_source(StreamSource::Stderr);
205
206 let mut spawner = TaskSpawner::new("ready_stderr_task".to_string(), config);
207 let result = spawner.start_direct(tx).await;
208 assert!(result.is_ok());
209
210 let mut ready_event = false;
211 while let Some(event) = rx.recv().await {
212 if let TaskEvent::Ready { task_name } = event {
213 assert_eq!(task_name, "ready_stderr_task");
214 ready_event = true;
215 }
216 }
217 assert!(
218 ready_event,
219 "Should emit Ready event when indicator is in stderr"
220 );
221 }
222
223 #[tokio::test]
224 async fn start_direct_ready_indicator_source_mismatch() {
225 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
226 #[cfg(windows)]
227 let config = TaskConfig::new("powershell")
228 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
229 .ready_indicator("READY_INDICATOR".to_string())
230 .ready_indicator_source(StreamSource::Stderr);
231 #[cfg(unix)]
232 let config = TaskConfig::new("bash")
233 .args(["-c", "echo READY_INDICATOR"])
234 .ready_indicator("READY_INDICATOR".to_string())
235 .ready_indicator_source(StreamSource::Stderr);
236
237 let mut spawner = TaskSpawner::new("ready_mismatch_task".to_string(), config);
238 let result = spawner.start_direct(tx).await;
239 assert!(result.is_ok());
240
241 let mut ready_event = false;
242 while let Some(event) = rx.recv().await {
243 if let TaskEvent::Ready { .. } = event {
244 ready_event = true;
245 }
246 }
247 assert!(
248 !ready_event,
249 "Should NOT emit Ready event if indicator is in wrong stream"
250 );
251 }
252 use tokio::sync::mpsc;
253
254 use crate::tasks::{
255 async_tokio::spawner::TaskSpawner,
256 config::{StreamSource, TaskConfig},
257 error::TaskError,
258 event::{TaskEvent, TaskEventStopReason},
259 state::TaskTerminateReason,
260 };
261 #[tokio::test]
262 async fn start_direct_fn_echo_command() {
263 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
264 #[cfg(windows)]
265 let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
266 #[cfg(unix)]
267 let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
268
269 let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
270
271 let result = spawner.start_direct(tx).await;
272 assert!(result.is_ok());
273
274 let mut started = false;
275 let mut stopped = false;
276 while let Some(event) = rx.recv().await {
277 match event {
278 TaskEvent::Started { task_name } => {
279 assert_eq!(task_name, "echo_task");
280 started = true;
281 }
282 TaskEvent::Output {
283 task_name,
284 line,
285 src,
286 } => {
287 assert_eq!(task_name, "echo_task");
288 assert_eq!(line, "hello");
289 assert_eq!(src, StreamSource::Stdout);
290 }
291 TaskEvent::Stopped {
292 task_name,
293 exit_code,
294 reason: _,
295 } => {
296 assert_eq!(task_name, "echo_task");
297 assert_eq!(exit_code, Some(0));
298 stopped = true;
299 }
300 _ => {}
301 }
302 }
303
304 assert!(started);
305 assert!(stopped);
306 }
307 #[tokio::test]
308 async fn start_direct_timeout_terminated_task() {
309 #[cfg(windows)]
310 let config = TaskConfig::new("powershell")
311 .args(["-Command", "sleep 2"])
312 .timeout_ms(1);
313 #[cfg(unix)]
314 let config = TaskConfig::new("bash")
315 .args(["-c", "sleep 2"])
316 .timeout_ms(1);
317
318 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
319 let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
320
321 let result = spawner.start_direct(tx).await;
322 assert!(result.is_ok());
323
324 let mut started = false;
325 let mut stopped = false;
326 while let Some(event) = rx.recv().await {
327 match event {
328 TaskEvent::Started { task_name } => {
329 assert_eq!(task_name, "sleep_with_timeout_task");
330 started = true;
331 }
332
333 TaskEvent::Stopped {
334 task_name,
335 exit_code,
336 reason,
337 } => {
338 assert_eq!(task_name, "sleep_with_timeout_task");
339 assert_eq!(exit_code, None);
340 assert_eq!(
341 reason,
342 TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
343 );
344 stopped = true;
345 }
346 _ => {}
347 }
348 }
349
350 assert!(started);
351 assert!(stopped);
352 }
353
354 #[tokio::test]
355 async fn start_direct_fn_invalid_empty_command() {
356 let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
357 let config = TaskConfig::new(""); let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
359
360 let result = spawner.start_direct(tx).await;
361 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
362
363 let state = spawner.get_state().await;
365 assert_eq!(
366 state,
367 crate::tasks::state::TaskState::Finished,
368 "TaskState should be Finished after error, not Initiating"
369 );
370 }
371
372 #[tokio::test]
373 async fn start_direct_fn_stdin_valid() {
374 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
376 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
377
378 #[cfg(windows)]
379 let config = TaskConfig::new("powershell")
380 .args(["-Command", "$line = Read-Host; Write-Output $line"])
381 .enable_stdin(true);
382 #[cfg(unix)]
383 let config = TaskConfig::new("bash")
384 .args(["-c", "read line; echo $line"])
385 .enable_stdin(true);
386
387 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
388
389 let result = spawner.start_direct(tx).await;
391 assert!(result.is_ok());
392
393 stdin_tx.send("hello world".to_string()).await.unwrap();
395
396 let mut started = false;
397 let mut output_ok = false;
398 let mut stopped = false;
399
400 while let Some(event) = rx.recv().await {
401 match event {
402 TaskEvent::Started { task_name } => {
403 assert_eq!(task_name, "stdin_task");
404 started = true;
405 }
406 TaskEvent::Output {
407 task_name,
408 line,
409 src,
410 } => {
411 assert_eq!(task_name, "stdin_task");
412 assert_eq!(line, "hello world");
413 assert_eq!(src, StreamSource::Stdout);
414 output_ok = true;
415 }
416 TaskEvent::Stopped {
417 task_name,
418 exit_code,
419 ..
420 } => {
421 assert_eq!(task_name, "stdin_task");
422 assert_eq!(exit_code, Some(0));
423 stopped = true;
424 }
425 _ => {}
426 }
427 }
428
429 assert!(started);
430 assert!(output_ok);
431 assert!(stopped);
432 }
433
434 #[tokio::test]
435 async fn start_direct_fn_stdin_ignore() {
436 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
438 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
439
440 #[cfg(windows)]
441 let config = TaskConfig::new("powershell")
442 .args(["-Command", "$line = Read-Host; Write-Output $line"]);
443 #[cfg(unix)]
444 let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
445
446 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
448
449 let result = spawner.start_direct(tx).await;
451 assert!(result.is_ok());
452
453 let send_result = stdin_tx.send("hello world".to_string()).await;
455 assert!(
456 send_result.is_err(),
457 "Sending to stdin_tx should error because receiver is dropped"
458 );
459
460 let mut started = false;
461 let mut output_found = false;
462 let mut stopped = false;
463
464 while let Some(event) = rx.recv().await {
465 match event {
466 TaskEvent::Started { task_name } => {
467 assert_eq!(task_name, "stdin_task");
468 started = true;
469 }
470 TaskEvent::Output { .. } => {
471 output_found = true;
473 }
474 TaskEvent::Stopped {
475 task_name,
476 exit_code,
477 ..
478 } => {
479 assert_eq!(task_name, "stdin_task");
480 assert_eq!(exit_code, Some(0));
481 stopped = true;
482 }
483 _ => {}
484 }
485 }
486
487 assert!(started);
488 assert!(
489 !output_found,
490 "Should not receive output from stdin when not enabled"
491 );
492 assert!(stopped);
493 }
494
495 #[tokio::test]
497 async fn start_direct_command_not_found() {
498 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
499 let config = TaskConfig::new("non_existent_command");
500 let mut spawner = TaskSpawner::new("error_task".to_string(), config);
501
502 let result = spawner.start_direct(tx).await;
503 assert!(matches!(result, Err(TaskError::IO(_))));
504
505 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
506 assert_eq!(task_name, "error_task");
507 assert!(matches!(error, TaskError::IO(_)));
508 if let TaskError::IO(msg) = error {
509 #[cfg(windows)]
510 assert!(msg.contains("not found") || msg.contains("cannot find"));
511 #[cfg(unix)]
512 assert!(msg.contains("No such file or directory"));
513 }
514 } else {
515 panic!("Expected TaskEvent::Error");
516 }
517 }
518
519 #[tokio::test]
520 async fn start_direct_invalid_working_directory() {
521 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
522 let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
523
524 let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
525
526 let result = spawner.start_direct(tx).await;
527 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
528
529 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
530 assert_eq!(task_name, "working_dir_task");
531 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
532 } else {
533 panic!("Expected TaskEvent::Error");
534 }
535 }
536
537 #[tokio::test]
538 async fn start_direct_zero_timeout() {
539 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
540 #[cfg(windows)]
541 let config = TaskConfig::new("powershell")
542 .args(["-Command", "Start-Sleep -Seconds 1"])
543 .timeout_ms(0);
544 #[cfg(unix)]
545 let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
546
547 let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
548
549 let result = spawner.start_direct(tx).await;
551 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
552
553 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
555 assert_eq!(task_name, "timeout_task");
556 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
557 } else {
558 panic!("Expected TaskEvent::Error with InvalidConfiguration");
559 }
560 }
561
562 #[tokio::test]
563 async fn process_id_is_none_after_task_stopped() {
564 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
565 #[cfg(windows)]
566 let config = TaskConfig::new("powershell").args(["-Command", "echo done"]);
567 #[cfg(unix)]
568 let config = TaskConfig::new("bash").args(["-c", "echo done"]);
569
570 let mut spawner = TaskSpawner::new("pid_test_task".to_string(), config);
571 let result = spawner.start_direct(tx).await;
572 assert!(result.is_ok());
573
574 let mut stopped = false;
575 while let Some(event) = rx.recv().await {
576 if let TaskEvent::Stopped { task_name, .. } = event {
577 assert_eq!(task_name, "pid_test_task");
578 stopped = true;
579 break;
580 }
581 }
582 assert!(stopped, "Task should emit Stopped event");
583 let pid = spawner.get_process_id().await;
585 assert!(
586 pid.is_none(),
587 "process_id should be None after task is stopped"
588 );
589 }
590
591 #[tokio::test]
592 async fn process_id_is_some_while_task_running() {
593 use std::time::Duration;
594 use tokio::time::sleep;
595 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
596 #[cfg(windows)]
597 let config = TaskConfig::new("powershell").args(["-Command", "Start-Sleep -Seconds 2"]);
598 #[cfg(unix)]
599 let config = TaskConfig::new("sleep").args(["2"]);
600
601 let mut spawner = TaskSpawner::new("pid_running_task".to_string(), config);
602 let result = spawner.start_direct(tx).await;
603 assert!(result.is_ok());
604
605 sleep(Duration::from_millis(500)).await;
607 let pid = spawner.get_process_id().await;
608 assert!(
609 pid.is_some(),
610 "process_id should be Some while task is running"
611 );
612
613 while let Some(event) = rx.recv().await {
615 if let TaskEvent::Stopped { .. } = event {
616 break;
617 }
618 }
619 }
620}