1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason, TaskTerminateReason};
13use crate::tasks::state::TaskState;
14
15impl TaskSpawner {
16 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
190 #[allow(clippy::too_many_lines)]
191 pub async fn start_direct(
192 &mut self,
193 event_tx: mpsc::Sender<TaskEvent>,
194 ) -> Result<u32, TaskError> {
195 self.update_state(TaskState::Initiating).await;
196
197 match self.config.validate() {
198 Ok(()) => {}
199 Err(e) => {
200 #[cfg(feature = "tracing")]
201 tracing::error!(error = %e, "Invalid task configuration");
202
203 self.update_state(TaskState::Finished).await;
204 let error_event = TaskEvent::Error {
205 task_name: self.task_name.clone(),
206 error: e.clone(),
207 };
208
209 if (event_tx.send(error_event).await).is_err() {
210 #[cfg(feature = "tracing")]
211 tracing::warn!("Event channel closed while sending TaskEvent::Error");
212 }
213 return Err(e);
214 }
215 }
216
217 let mut cmd = Command::new(&self.config.command);
218 let cmd = cmd.kill_on_drop(true);
219
220 setup_command(cmd, &self.config);
221 let mut child = match cmd.spawn() {
222 Ok(c) => c,
223 Err(e) => {
224 #[cfg(feature = "tracing")]
225 tracing::error!(error = %e, "Failed to spawn child process");
226
227 self.update_state(TaskState::Finished).await;
228 let error_event = TaskEvent::Error {
229 task_name: self.task_name.clone(),
230 error: TaskError::IO(e.to_string()),
231 };
232
233 if (event_tx.send(error_event).await).is_err() {
234 #[cfg(feature = "tracing")]
235 tracing::warn!("Event channel closed while sending TaskEvent::Error");
236 }
237
238 return Err(TaskError::IO(e.to_string()));
239 }
240 };
241 let Some(child_id) = child.id() else {
242 let msg = "Failed to get process id";
243
244 #[cfg(feature = "tracing")]
245 tracing::error!(msg);
246
247 self.update_state(TaskState::Finished).await;
248 let error_event = TaskEvent::Error {
249 task_name: self.task_name.clone(),
250 error: TaskError::Handle(msg.to_string()),
251 };
252
253 if (event_tx.send(error_event).await).is_err() {
254 #[cfg(feature = "tracing")]
255 tracing::warn!("Event channel closed while sending TaskEvent::Error");
256 }
257
258 return Err(TaskError::Handle(msg.to_string()));
259 };
260 *self.process_id.write().await = Some(child_id);
261 let mut task_handles = vec![];
262 self.update_state(TaskState::Running).await;
263 if (event_tx
264 .send(TaskEvent::Started {
265 task_name: self.task_name.clone(),
266 })
267 .await)
268 .is_err()
269 {
270 #[cfg(feature = "tracing")]
271 tracing::warn!("Event channel closed while sending TaskEvent::Started");
272 }
273
274 let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
275 let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
276 let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
277
278 let handles = spawn_output_watchers(
280 self.task_name.clone(),
281 self.state.clone(),
282 event_tx.clone(),
283 &mut child,
284 handle_terminator_rx.clone(),
285 self.config.ready_indicator.clone(),
286 self.config.ready_indicator_source.clone(),
287 );
288 task_handles.extend(handles);
289
290 if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
292 let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
293 task_handles.push(handle);
294 }
295
296 *self.terminate_tx.lock().await = Some(terminate_tx);
298
299 let handle = spawn_wait_watcher(
300 self.task_name.clone(),
301 self.state.clone(),
302 child,
303 terminate_rx,
304 handle_terminator_tx.clone(),
305 result_tx,
306 self.process_id.clone(),
307 );
308 task_handles.push(handle);
309
310 if let Some(timeout_ms) = self.config.timeout_ms {
312 let handle =
313 spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
314 task_handles.push(handle);
315 }
316
317 let _handle = spawn_result_watcher(
319 self.task_name.clone(),
320 self.state.clone(),
321 self.finished_at.clone(),
322 event_tx,
323 result_rx,
324 task_handles,
325 );
326
327 Ok(child_id)
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 #[tokio::test]
334 async fn start_direct_ready_indicator_source_stdout() {
335 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
336 #[cfg(windows)]
337 let config = TaskConfig::new("powershell")
338 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
339 .ready_indicator("READY_INDICATOR".to_string())
340 .ready_indicator_source(StreamSource::Stdout);
341 #[cfg(unix)]
342 let config = TaskConfig::new("bash")
343 .args(["-c", "echo READY_INDICATOR"])
344 .ready_indicator("READY_INDICATOR".to_string())
345 .ready_indicator_source(StreamSource::Stdout);
346
347 let mut spawner = TaskSpawner::new("ready_stdout_task".to_string(), config);
348 let result = spawner.start_direct(tx).await;
349 assert!(result.is_ok());
350
351 let mut ready_event = false;
352 while let Some(event) = rx.recv().await {
353 if let TaskEvent::Ready { task_name } = event {
354 assert_eq!(task_name, "ready_stdout_task");
355 ready_event = true;
356 }
357 }
358 assert!(
359 ready_event,
360 "Should emit Ready event when indicator is in stdout"
361 );
362 }
363
364 #[tokio::test]
365 async fn start_direct_ready_indicator_source_stderr() {
366 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
367 #[cfg(windows)]
368 let config = TaskConfig::new("powershell")
369 .args(["-Command", "Write-Error 'READY_INDICATOR'"])
370 .ready_indicator("READY_INDICATOR".to_string())
371 .ready_indicator_source(StreamSource::Stderr);
372 #[cfg(unix)]
373 let config = TaskConfig::new("bash")
374 .args(["-c", "echo READY_INDICATOR 1>&2"])
375 .ready_indicator("READY_INDICATOR".to_string())
376 .ready_indicator_source(StreamSource::Stderr);
377
378 let mut spawner = TaskSpawner::new("ready_stderr_task".to_string(), config);
379 let result = spawner.start_direct(tx).await;
380 assert!(result.is_ok());
381
382 let mut ready_event = false;
383 while let Some(event) = rx.recv().await {
384 if let TaskEvent::Ready { task_name } = event {
385 assert_eq!(task_name, "ready_stderr_task");
386 ready_event = true;
387 }
388 }
389 assert!(
390 ready_event,
391 "Should emit Ready event when indicator is in stderr"
392 );
393 }
394
395 #[tokio::test]
396 async fn start_direct_ready_indicator_source_mismatch() {
397 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
398 #[cfg(windows)]
399 let config = TaskConfig::new("powershell")
400 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
401 .ready_indicator("READY_INDICATOR".to_string())
402 .ready_indicator_source(StreamSource::Stderr);
403 #[cfg(unix)]
404 let config = TaskConfig::new("bash")
405 .args(["-c", "echo READY_INDICATOR"])
406 .ready_indicator("READY_INDICATOR".to_string())
407 .ready_indicator_source(StreamSource::Stderr);
408
409 let mut spawner = TaskSpawner::new("ready_mismatch_task".to_string(), config);
410 let result = spawner.start_direct(tx).await;
411 assert!(result.is_ok());
412
413 let mut ready_event = false;
414 while let Some(event) = rx.recv().await {
415 if let TaskEvent::Ready { .. } = event {
416 ready_event = true;
417 }
418 }
419 assert!(
420 !ready_event,
421 "Should NOT emit Ready event if indicator is in wrong stream"
422 );
423 }
424 use tokio::sync::mpsc;
425
426 use crate::tasks::{
427 async_tokio::spawner::TaskSpawner,
428 config::{StreamSource, TaskConfig},
429 error::TaskError,
430 event::{TaskEvent, TaskEventStopReason, TaskTerminateReason},
431 };
432 #[tokio::test]
433 async fn start_direct_fn_echo_command() {
434 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
435 #[cfg(windows)]
436 let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
437 #[cfg(unix)]
438 let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
439
440 let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
441
442 let result = spawner.start_direct(tx).await;
443 assert!(result.is_ok());
444
445 let mut started = false;
446 let mut stopped = false;
447 while let Some(event) = rx.recv().await {
448 match event {
449 TaskEvent::Started { task_name } => {
450 assert_eq!(task_name, "echo_task");
451 started = true;
452 }
453 TaskEvent::Output {
454 task_name,
455 line,
456 src,
457 } => {
458 assert_eq!(task_name, "echo_task");
459 assert_eq!(line, "hello");
460 assert_eq!(src, StreamSource::Stdout);
461 }
462 TaskEvent::Stopped {
463 task_name,
464 exit_code,
465 reason: _,
466 } => {
467 assert_eq!(task_name, "echo_task");
468 assert_eq!(exit_code, Some(0));
469 stopped = true;
470 }
471 _ => {}
472 }
473 }
474
475 assert!(started);
476 assert!(stopped);
477 }
478 #[tokio::test]
479 async fn start_direct_timeout_terminated_task() {
480 #[cfg(windows)]
481 let config = TaskConfig::new("powershell")
482 .args(["-Command", "sleep 2"])
483 .timeout_ms(1);
484 #[cfg(unix)]
485 let config = TaskConfig::new("bash")
486 .args(["-c", "sleep 2"])
487 .timeout_ms(1);
488
489 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
490 let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
491
492 let result = spawner.start_direct(tx).await;
493 assert!(result.is_ok());
494
495 let mut started = false;
496 let mut stopped = false;
497 while let Some(event) = rx.recv().await {
498 match event {
499 TaskEvent::Started { task_name } => {
500 assert_eq!(task_name, "sleep_with_timeout_task");
501 started = true;
502 }
503
504 TaskEvent::Stopped {
505 task_name,
506 exit_code,
507 reason,
508 } => {
509 assert_eq!(task_name, "sleep_with_timeout_task");
510 assert_eq!(exit_code, None);
511 assert_eq!(
512 reason,
513 TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
514 );
515 stopped = true;
516 }
517 _ => {}
518 }
519 }
520
521 assert!(started);
522 assert!(stopped);
523 }
524
525 #[tokio::test]
526 async fn start_direct_fn_invalid_empty_command() {
527 let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
528 let config = TaskConfig::new(""); let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
530
531 let result = spawner.start_direct(tx).await;
532 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
533
534 let state = spawner.get_state().await;
536 assert_eq!(
537 state,
538 crate::tasks::state::TaskState::Finished,
539 "TaskState should be Finished after error, not Initiating"
540 );
541 }
542
543 #[tokio::test]
544 async fn start_direct_fn_stdin_valid() {
545 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
547 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
548
549 #[cfg(windows)]
550 let config = TaskConfig::new("powershell")
551 .args(["-Command", "$line = Read-Host; Write-Output $line"])
552 .enable_stdin(true);
553 #[cfg(unix)]
554 let config = TaskConfig::new("bash")
555 .args(["-c", "read line; echo $line"])
556 .enable_stdin(true);
557
558 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
559
560 let result = spawner.start_direct(tx).await;
562 assert!(result.is_ok());
563
564 stdin_tx.send("hello world".to_string()).await.unwrap();
566
567 let mut started = false;
568 let mut output_ok = false;
569 let mut stopped = false;
570
571 while let Some(event) = rx.recv().await {
572 match event {
573 TaskEvent::Started { task_name } => {
574 assert_eq!(task_name, "stdin_task");
575 started = true;
576 }
577 TaskEvent::Output {
578 task_name,
579 line,
580 src,
581 } => {
582 assert_eq!(task_name, "stdin_task");
583 assert_eq!(line, "hello world");
584 assert_eq!(src, StreamSource::Stdout);
585 output_ok = true;
586 }
587 TaskEvent::Stopped {
588 task_name,
589 exit_code,
590 ..
591 } => {
592 assert_eq!(task_name, "stdin_task");
593 assert_eq!(exit_code, Some(0));
594 stopped = true;
595 }
596 _ => {}
597 }
598 }
599
600 assert!(started);
601 assert!(output_ok);
602 assert!(stopped);
603 }
604
605 #[tokio::test]
606 async fn start_direct_fn_stdin_ignore() {
607 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
609 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
610
611 #[cfg(windows)]
612 let config = TaskConfig::new("powershell")
613 .args(["-Command", "$line = Read-Host; Write-Output $line"]);
614 #[cfg(unix)]
615 let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
616
617 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
619
620 let result = spawner.start_direct(tx).await;
622 assert!(result.is_ok());
623
624 let send_result = stdin_tx.send("hello world".to_string()).await;
626 assert!(
627 send_result.is_err(),
628 "Sending to stdin_tx should error because receiver is dropped"
629 );
630
631 let mut started = false;
632 let mut output_found = false;
633 let mut stopped = false;
634
635 while let Some(event) = rx.recv().await {
636 match event {
637 TaskEvent::Started { task_name } => {
638 assert_eq!(task_name, "stdin_task");
639 started = true;
640 }
641 TaskEvent::Output { .. } => {
642 output_found = true;
644 }
645 TaskEvent::Stopped {
646 task_name,
647 exit_code,
648 ..
649 } => {
650 assert_eq!(task_name, "stdin_task");
651 assert_eq!(exit_code, Some(0));
652 stopped = true;
653 }
654 _ => {}
655 }
656 }
657
658 assert!(started);
659 assert!(
660 !output_found,
661 "Should not receive output from stdin when not enabled"
662 );
663 assert!(stopped);
664 }
665
666 #[tokio::test]
668 async fn start_direct_command_not_found() {
669 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
670 let config = TaskConfig::new("non_existent_command");
671 let mut spawner = TaskSpawner::new("error_task".to_string(), config);
672
673 let result = spawner.start_direct(tx).await;
674 assert!(matches!(result, Err(TaskError::IO(_))));
675
676 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
677 assert_eq!(task_name, "error_task");
678 assert!(matches!(error, TaskError::IO(_)));
679 if let TaskError::IO(msg) = error {
680 #[cfg(windows)]
681 assert!(msg.contains("not found") || msg.contains("cannot find"));
682 #[cfg(unix)]
683 assert!(msg.contains("No such file or directory"));
684 }
685 } else {
686 panic!("Expected TaskEvent::Error");
687 }
688 }
689
690 #[tokio::test]
691 async fn start_direct_invalid_working_directory() {
692 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
693 let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
694
695 let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
696
697 let result = spawner.start_direct(tx).await;
698 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
699
700 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
701 assert_eq!(task_name, "working_dir_task");
702 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
703 } else {
704 panic!("Expected TaskEvent::Error");
705 }
706 }
707
708 #[tokio::test]
709 async fn start_direct_zero_timeout() {
710 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
711 #[cfg(windows)]
712 let config = TaskConfig::new("powershell")
713 .args(["-Command", "Start-Sleep -Seconds 1"])
714 .timeout_ms(0);
715 #[cfg(unix)]
716 let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
717
718 let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
719
720 let result = spawner.start_direct(tx).await;
722 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
723
724 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
726 assert_eq!(task_name, "timeout_task");
727 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
728 } else {
729 panic!("Expected TaskEvent::Error with InvalidConfiguration");
730 }
731 }
732
733 #[tokio::test]
734 async fn process_id_is_none_after_task_stopped() {
735 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
736 #[cfg(windows)]
737 let config = TaskConfig::new("powershell").args(["-Command", "echo done"]);
738 #[cfg(unix)]
739 let config = TaskConfig::new("bash").args(["-c", "echo done"]);
740
741 let mut spawner = TaskSpawner::new("pid_test_task".to_string(), config);
742 let result = spawner.start_direct(tx).await;
743 assert!(result.is_ok());
744
745 let mut stopped = false;
746 while let Some(event) = rx.recv().await {
747 if let TaskEvent::Stopped { task_name, .. } = event {
748 assert_eq!(task_name, "pid_test_task");
749 stopped = true;
750 break;
751 }
752 }
753 assert!(stopped, "Task should emit Stopped event");
754 let pid = spawner.get_process_id().await;
756 assert!(
757 pid.is_none(),
758 "process_id should be None after task is stopped"
759 );
760 }
761
762 #[tokio::test]
763 async fn process_id_is_some_while_task_running() {
764 use std::time::Duration;
765 use tokio::time::sleep;
766 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
767 #[cfg(windows)]
768 let config = TaskConfig::new("powershell").args(["-Command", "Start-Sleep -Seconds 2"]);
769 #[cfg(unix)]
770 let config = TaskConfig::new("sleep").args(["2"]);
771
772 let mut spawner = TaskSpawner::new("pid_running_task".to_string(), config);
773 let result = spawner.start_direct(tx).await;
774 assert!(result.is_ok());
775
776 sleep(Duration::from_millis(500)).await;
778 let pid = spawner.get_process_id().await;
779 assert!(
780 pid.is_some(),
781 "process_id should be Some while task is running"
782 );
783
784 while let Some(event) = rx.recv().await {
786 if let TaskEvent::Stopped { .. } = event {
787 break;
788 }
789 }
790 }
791}