1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::process_group::ProcessGroup;
11use crate::tasks::async_tokio::spawner::TaskSpawner;
12use crate::tasks::error::TaskError;
13use crate::tasks::event::{TaskEvent, TaskEventStopReason, TaskTerminateReason};
14use crate::tasks::state::TaskState;
15
16impl TaskSpawner {
17 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
187 #[allow(clippy::too_many_lines)]
188 pub async fn start_direct(
189 &mut self,
190 event_tx: mpsc::Sender<TaskEvent>,
191 ) -> Result<u32, TaskError> {
192 self.update_state(TaskState::Initiating).await;
193
194 match self.config.validate() {
195 Ok(()) => {}
196 Err(e) => {
197 #[cfg(feature = "tracing")]
198 tracing::error!(error = %e, "Invalid task configuration");
199
200 self.update_state(TaskState::Finished).await;
201 let error_event = TaskEvent::Error {
202 task_name: self.task_name.clone(),
203 error: e.clone(),
204 };
205
206 if (event_tx.send(error_event).await).is_err() {
207 #[cfg(feature = "tracing")]
208 tracing::warn!("Event channel closed while sending TaskEvent::Error");
209 }
210 return Err(e);
211 }
212 }
213
214 let mut cmd = Command::new(&self.config.command);
215 cmd.kill_on_drop(true);
216
217 setup_command(&mut cmd, &self.config);
218
219 let (mut configured_cmd, process_group) = if self.config.is_process_group_enabled() {
221 match ProcessGroup::create_with_command(cmd) {
222 Ok((cmd, group)) => (cmd, Some(group)),
223 Err(e) => {
224 #[cfg(feature = "tracing")]
225 tracing::error!(error = %e, "Failed to create process group");
226
227 self.update_state(TaskState::Finished).await;
228 let error_event = TaskEvent::Error {
229 task_name: self.task_name.clone(),
230 error: TaskError::Handle(format!("Failed to create process group: {}", e)),
231 };
232
233 if (event_tx.send(error_event).await).is_err() {
234 #[cfg(feature = "tracing")]
235 tracing::warn!("Event channel closed while sending TaskEvent::Error");
236 }
237
238 return Err(TaskError::Handle(format!(
239 "Failed to create process group: {}",
240 e
241 )));
242 }
243 }
244 } else {
245 #[cfg(feature = "tracing")]
246 tracing::debug!("Process group management disabled by configuration");
247 (cmd, None)
248 };
249
250 let mut child = match configured_cmd.spawn() {
251 Ok(c) => c,
252 Err(e) => {
253 #[cfg(feature = "tracing")]
254 tracing::error!(error = %e, "Failed to spawn child process");
255
256 self.update_state(TaskState::Finished).await;
257 let error_event = TaskEvent::Error {
258 task_name: self.task_name.clone(),
259 error: TaskError::IO(e.to_string()),
260 };
261
262 if (event_tx.send(error_event).await).is_err() {
263 #[cfg(feature = "tracing")]
264 tracing::warn!("Event channel closed while sending TaskEvent::Error");
265 }
266
267 return Err(TaskError::IO(e.to_string()));
268 }
269 };
270
271 if let Some(ref pg) = process_group {
273 if let Err(e) = pg.assign_child(&child).await {
274 #[cfg(feature = "tracing")]
275 tracing::error!(error = %e, "Failed to assign child to process group");
276
277 self.update_state(TaskState::Finished).await;
278 let error_event = TaskEvent::Error {
279 task_name: self.task_name.clone(),
280 error: TaskError::Handle(format!(
281 "Failed to assign child to process group: {}",
282 e
283 )),
284 };
285
286 if (event_tx.send(error_event).await).is_err() {
287 #[cfg(feature = "tracing")]
288 tracing::warn!("Event channel closed while sending TaskEvent::Error");
289 }
290
291 return Err(TaskError::Handle(format!(
292 "Failed to assign child to process group: {}",
293 e
294 )));
295 }
296 }
297 let Some(child_id) = child.id() else {
298 let msg = "Failed to get process id";
299
300 #[cfg(feature = "tracing")]
301 tracing::error!(msg);
302
303 self.update_state(TaskState::Finished).await;
304 let error_event = TaskEvent::Error {
305 task_name: self.task_name.clone(),
306 error: TaskError::Handle(msg.to_string()),
307 };
308
309 if (event_tx.send(error_event).await).is_err() {
310 #[cfg(feature = "tracing")]
311 tracing::warn!("Event channel closed while sending TaskEvent::Error");
312 }
313
314 return Err(TaskError::Handle(msg.to_string()));
315 };
316 *self.process_id.write().await = Some(child_id);
317 let mut task_handles = vec![];
318 self.update_state(TaskState::Running).await;
319 if (event_tx
320 .send(TaskEvent::Started {
321 task_name: self.task_name.clone(),
322 })
323 .await)
324 .is_err()
325 {
326 #[cfg(feature = "tracing")]
327 tracing::warn!("Event channel closed while sending TaskEvent::Started");
328 }
329
330 let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
331 let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
332 let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
333
334 let handles = spawn_output_watchers(
336 self.task_name.clone(),
337 self.state.clone(),
338 event_tx.clone(),
339 &mut child,
340 handle_terminator_rx.clone(),
341 self.config.ready_indicator.clone(),
342 self.config.ready_indicator_source.clone(),
343 );
344 task_handles.extend(handles);
345
346 if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
348 let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
349 task_handles.push(handle);
350 }
351
352 *self.terminate_tx.lock().await = Some(terminate_tx);
354
355 let handle = spawn_wait_watcher(
356 self.task_name.clone(),
357 self.state.clone(),
358 child,
359 process_group,
360 terminate_rx,
361 handle_terminator_tx.clone(),
362 result_tx,
363 self.process_id.clone(),
364 );
365 task_handles.push(handle);
366
367 if let Some(timeout_ms) = self.config.timeout_ms {
369 let handle =
370 spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
371 task_handles.push(handle);
372 }
373
374 let _handle = spawn_result_watcher(
376 self.task_name.clone(),
377 self.state.clone(),
378 self.finished_at.clone(),
379 event_tx,
380 result_rx,
381 task_handles,
382 );
383
384 Ok(child_id)
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 #[tokio::test]
391 async fn start_direct_ready_indicator_source_stdout() {
392 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
393 #[cfg(windows)]
394 let config = TaskConfig::new("powershell")
395 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
396 .ready_indicator("READY_INDICATOR".to_string())
397 .ready_indicator_source(StreamSource::Stdout);
398 #[cfg(unix)]
399 let config = TaskConfig::new("bash")
400 .args(["-c", "echo READY_INDICATOR"])
401 .ready_indicator("READY_INDICATOR".to_string())
402 .ready_indicator_source(StreamSource::Stdout);
403
404 let mut spawner = TaskSpawner::new("ready_stdout_task".to_string(), config);
405 let result = spawner.start_direct(tx).await;
406 assert!(result.is_ok());
407
408 let mut ready_event = false;
409 while let Some(event) = rx.recv().await {
410 if let TaskEvent::Ready { task_name } = event {
411 assert_eq!(task_name, "ready_stdout_task");
412 ready_event = true;
413 }
414 }
415 assert!(
416 ready_event,
417 "Should emit Ready event when indicator is in stdout"
418 );
419 }
420
421 #[tokio::test]
422 async fn start_direct_ready_indicator_source_stderr() {
423 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
424 #[cfg(windows)]
425 let config = TaskConfig::new("powershell")
426 .args(["-Command", "Write-Error 'READY_INDICATOR'"])
427 .ready_indicator("READY_INDICATOR".to_string())
428 .ready_indicator_source(StreamSource::Stderr);
429 #[cfg(unix)]
430 let config = TaskConfig::new("bash")
431 .args(["-c", "echo READY_INDICATOR 1>&2"])
432 .ready_indicator("READY_INDICATOR".to_string())
433 .ready_indicator_source(StreamSource::Stderr);
434
435 let mut spawner = TaskSpawner::new("ready_stderr_task".to_string(), config);
436 let result = spawner.start_direct(tx).await;
437 assert!(result.is_ok());
438
439 let mut ready_event = false;
440 while let Some(event) = rx.recv().await {
441 if let TaskEvent::Ready { task_name } = event {
442 assert_eq!(task_name, "ready_stderr_task");
443 ready_event = true;
444 }
445 }
446 assert!(
447 ready_event,
448 "Should emit Ready event when indicator is in stderr"
449 );
450 }
451
452 #[tokio::test]
453 async fn start_direct_ready_indicator_source_mismatch() {
454 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
455 #[cfg(windows)]
456 let config = TaskConfig::new("powershell")
457 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
458 .ready_indicator("READY_INDICATOR".to_string())
459 .ready_indicator_source(StreamSource::Stderr);
460 #[cfg(unix)]
461 let config = TaskConfig::new("bash")
462 .args(["-c", "echo READY_INDICATOR"])
463 .ready_indicator("READY_INDICATOR".to_string())
464 .ready_indicator_source(StreamSource::Stderr);
465
466 let mut spawner = TaskSpawner::new("ready_mismatch_task".to_string(), config);
467 let result = spawner.start_direct(tx).await;
468 assert!(result.is_ok());
469
470 let mut ready_event = false;
471 while let Some(event) = rx.recv().await {
472 if let TaskEvent::Ready { .. } = event {
473 ready_event = true;
474 }
475 }
476 assert!(
477 !ready_event,
478 "Should NOT emit Ready event if indicator is in wrong stream"
479 );
480 }
481 use tokio::sync::mpsc;
482
483 use crate::tasks::{
484 async_tokio::spawner::TaskSpawner,
485 config::{StreamSource, TaskConfig},
486 error::TaskError,
487 event::{TaskEvent, TaskEventStopReason, TaskTerminateReason},
488 };
489 #[tokio::test]
490 async fn start_direct_fn_echo_command() {
491 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
492 #[cfg(windows)]
493 let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
494 #[cfg(unix)]
495 let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
496
497 let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
498
499 let result = spawner.start_direct(tx).await;
500 assert!(result.is_ok());
501
502 let mut started = false;
503 let mut stopped = false;
504 while let Some(event) = rx.recv().await {
505 match event {
506 TaskEvent::Started { task_name } => {
507 assert_eq!(task_name, "echo_task");
508 started = true;
509 }
510 TaskEvent::Output {
511 task_name,
512 line,
513 src,
514 } => {
515 assert_eq!(task_name, "echo_task");
516 assert_eq!(line, "hello");
517 assert_eq!(src, StreamSource::Stdout);
518 }
519 TaskEvent::Stopped {
520 task_name,
521 exit_code,
522 reason: _,
523 } => {
524 assert_eq!(task_name, "echo_task");
525 assert_eq!(exit_code, Some(0));
526 stopped = true;
527 }
528 _ => {}
529 }
530 }
531
532 assert!(started);
533 assert!(stopped);
534 }
535 #[tokio::test]
536 async fn start_direct_timeout_terminated_task() {
537 #[cfg(windows)]
538 let config = TaskConfig::new("powershell")
539 .args(["-Command", "sleep 2"])
540 .timeout_ms(1);
541 #[cfg(unix)]
542 let config = TaskConfig::new("bash")
543 .args(["-c", "sleep 2"])
544 .timeout_ms(1);
545
546 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
547 let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
548
549 let result = spawner.start_direct(tx).await;
550 assert!(result.is_ok());
551
552 let mut started = false;
553 let mut stopped = false;
554 while let Some(event) = rx.recv().await {
555 match event {
556 TaskEvent::Started { task_name } => {
557 assert_eq!(task_name, "sleep_with_timeout_task");
558 started = true;
559 }
560
561 TaskEvent::Stopped {
562 task_name,
563 exit_code,
564 reason,
565 } => {
566 assert_eq!(task_name, "sleep_with_timeout_task");
567 assert_eq!(exit_code, None);
568 assert_eq!(
569 reason,
570 TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
571 );
572 stopped = true;
573 }
574 _ => {}
575 }
576 }
577
578 assert!(started);
579 assert!(stopped);
580 }
581
582 #[tokio::test]
583 async fn start_direct_fn_invalid_empty_command() {
584 let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
585 let config = TaskConfig::new(""); let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
587
588 let result = spawner.start_direct(tx).await;
589 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
590
591 let state = spawner.get_state().await;
593 assert_eq!(
594 state,
595 crate::tasks::state::TaskState::Finished,
596 "TaskState should be Finished after error, not Initiating"
597 );
598 }
599
600 #[tokio::test]
601 async fn start_direct_fn_stdin_valid() {
602 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
604 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
605
606 #[cfg(windows)]
607 let config = TaskConfig::new("powershell")
608 .args(["-Command", "$line = Read-Host; Write-Output $line"])
609 .enable_stdin(true);
610 #[cfg(unix)]
611 let config = TaskConfig::new("bash")
612 .args(["-c", "read line; echo $line"])
613 .enable_stdin(true);
614
615 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
616
617 let result = spawner.start_direct(tx).await;
619 assert!(result.is_ok());
620
621 stdin_tx.send("hello world".to_string()).await.unwrap();
623
624 let mut started = false;
625 let mut output_ok = false;
626 let mut stopped = false;
627
628 while let Some(event) = rx.recv().await {
629 match event {
630 TaskEvent::Started { task_name } => {
631 assert_eq!(task_name, "stdin_task");
632 started = true;
633 }
634 TaskEvent::Output {
635 task_name,
636 line,
637 src,
638 } => {
639 assert_eq!(task_name, "stdin_task");
640 assert_eq!(line, "hello world");
641 assert_eq!(src, StreamSource::Stdout);
642 output_ok = true;
643 }
644 TaskEvent::Stopped {
645 task_name,
646 exit_code,
647 ..
648 } => {
649 assert_eq!(task_name, "stdin_task");
650 assert_eq!(exit_code, Some(0));
651 stopped = true;
652 }
653 _ => {}
654 }
655 }
656
657 assert!(started);
658 assert!(output_ok);
659 assert!(stopped);
660 }
661
662 #[tokio::test]
663 async fn start_direct_fn_stdin_ignore() {
664 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
666 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
667
668 #[cfg(windows)]
669 let config = TaskConfig::new("powershell")
670 .args(["-Command", "$line = Read-Host; Write-Output $line"]);
671 #[cfg(unix)]
672 let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
673
674 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
676
677 let result = spawner.start_direct(tx).await;
679 assert!(result.is_ok());
680
681 let send_result = stdin_tx.send("hello world".to_string()).await;
683 assert!(
684 send_result.is_err(),
685 "Sending to stdin_tx should error because receiver is dropped"
686 );
687
688 let mut started = false;
689 let mut output_found = false;
690 let mut stopped = false;
691
692 while let Some(event) = rx.recv().await {
693 match event {
694 TaskEvent::Started { task_name } => {
695 assert_eq!(task_name, "stdin_task");
696 started = true;
697 }
698 TaskEvent::Output { .. } => {
699 output_found = true;
701 }
702 TaskEvent::Stopped {
703 task_name,
704 exit_code,
705 ..
706 } => {
707 assert_eq!(task_name, "stdin_task");
708 assert_eq!(exit_code, Some(0));
709 stopped = true;
710 }
711 _ => {}
712 }
713 }
714
715 assert!(started);
716 assert!(
717 !output_found,
718 "Should not receive output from stdin when not enabled"
719 );
720 assert!(stopped);
721 }
722
723 #[tokio::test]
725 async fn start_direct_command_not_found() {
726 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
727 let config = TaskConfig::new("non_existent_command");
728 let mut spawner = TaskSpawner::new("error_task".to_string(), config);
729
730 let result = spawner.start_direct(tx).await;
731 assert!(matches!(result, Err(TaskError::IO(_))));
732
733 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
734 assert_eq!(task_name, "error_task");
735 assert!(matches!(error, TaskError::IO(_)));
736 if let TaskError::IO(msg) = error {
737 #[cfg(windows)]
738 assert!(msg.contains("not found") || msg.contains("cannot find"));
739 #[cfg(unix)]
740 assert!(msg.contains("No such file or directory"));
741 }
742 } else {
743 panic!("Expected TaskEvent::Error");
744 }
745 }
746
747 #[tokio::test]
748 async fn start_direct_invalid_working_directory() {
749 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
750 let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
751
752 let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
753
754 let result = spawner.start_direct(tx).await;
755 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
756
757 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
758 assert_eq!(task_name, "working_dir_task");
759 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
760 } else {
761 panic!("Expected TaskEvent::Error");
762 }
763 }
764
765 #[tokio::test]
766 async fn start_direct_zero_timeout() {
767 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
768 #[cfg(windows)]
769 let config = TaskConfig::new("powershell")
770 .args(["-Command", "Start-Sleep -Seconds 1"])
771 .timeout_ms(0);
772 #[cfg(unix)]
773 let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
774
775 let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
776
777 let result = spawner.start_direct(tx).await;
779 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
780
781 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
783 assert_eq!(task_name, "timeout_task");
784 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
785 } else {
786 panic!("Expected TaskEvent::Error with InvalidConfiguration");
787 }
788 }
789
790 #[tokio::test]
791 async fn process_id_is_none_after_task_stopped() {
792 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
793 #[cfg(windows)]
794 let config = TaskConfig::new("powershell").args(["-Command", "echo done"]);
795 #[cfg(unix)]
796 let config = TaskConfig::new("bash").args(["-c", "echo done"]);
797
798 let mut spawner = TaskSpawner::new("pid_test_task".to_string(), config);
799 let result = spawner.start_direct(tx).await;
800 assert!(result.is_ok());
801
802 let mut stopped = false;
803 while let Some(event) = rx.recv().await {
804 if let TaskEvent::Stopped { task_name, .. } = event {
805 assert_eq!(task_name, "pid_test_task");
806 stopped = true;
807 break;
808 }
809 }
810 assert!(stopped, "Task should emit Stopped event");
811 let pid = spawner.get_process_id().await;
813 assert!(
814 pid.is_none(),
815 "process_id should be None after task is stopped"
816 );
817 }
818
819 #[tokio::test]
820 async fn process_id_is_some_while_task_running() {
821 use std::time::Duration;
822 use tokio::time::sleep;
823 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
824 #[cfg(windows)]
825 let config = TaskConfig::new("powershell").args(["-Command", "Start-Sleep -Seconds 2"]);
826 #[cfg(unix)]
827 let config = TaskConfig::new("sleep").args(["2"]);
828
829 let mut spawner = TaskSpawner::new("pid_running_task".to_string(), config);
830 let result = spawner.start_direct(tx).await;
831 assert!(result.is_ok());
832
833 sleep(Duration::from_millis(500)).await;
835 let pid = spawner.get_process_id().await;
836 assert!(
837 pid.is_some(),
838 "process_id should be Some while task is running"
839 );
840
841 while let Some(event) = rx.recv().await {
843 if let TaskEvent::Stopped { .. } = event {
844 break;
845 }
846 }
847 }
848}