1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
190 #[allow(clippy::too_many_lines)]
191 pub async fn start_direct(
192 &mut self,
193 event_tx: mpsc::Sender<TaskEvent>,
194 ) -> Result<u32, TaskError> {
195 self.update_state(TaskState::Initiating).await;
196
197 match self.config.validate() {
198 Ok(()) => {}
199 Err(e) => {
200 #[cfg(feature = "tracing")]
201 tracing::error!(error = %e, "Invalid task configuration");
202
203 self.update_state(TaskState::Finished).await;
204 let error_event = TaskEvent::Error {
205 task_name: self.task_name.clone(),
206 error: e.clone(),
207 };
208
209 if (event_tx.send(error_event).await).is_err() {
210 #[cfg(feature = "tracing")]
211 tracing::warn!("Event channel closed while sending TaskEvent::Error");
212 }
213 return Err(e);
214 }
215 }
216
217 let mut cmd = Command::new(&self.config.command);
218 let cmd = cmd.kill_on_drop(true);
219
220 setup_command(cmd, &self.config);
221 let mut child = match cmd.spawn() {
222 Ok(c) => c,
223 Err(e) => {
224 #[cfg(feature = "tracing")]
225 tracing::error!(error = %e, "Failed to spawn child process");
226
227 self.update_state(TaskState::Finished).await;
228 let error_event = TaskEvent::Error {
229 task_name: self.task_name.clone(),
230 error: TaskError::IO(e.to_string()),
231 };
232
233 if (event_tx.send(error_event).await).is_err() {
234 #[cfg(feature = "tracing")]
235 tracing::warn!("Event channel closed while sending TaskEvent::Error");
236 }
237
238 return Err(TaskError::IO(e.to_string()));
239 }
240 };
241 let Some(child_id) = child.id() else {
242 let msg = "Failed to get process id";
243
244 #[cfg(feature = "tracing")]
245 tracing::error!(msg);
246
247 self.update_state(TaskState::Finished).await;
248 let error_event = TaskEvent::Error {
249 task_name: self.task_name.clone(),
250 error: TaskError::IO(msg.to_string()),
251 };
252
253 if (event_tx.send(error_event).await).is_err() {
254 #[cfg(feature = "tracing")]
255 tracing::warn!("Event channel closed while sending TaskEvent::Error");
256 }
257
258 return Err(TaskError::IO(msg.to_string()));
259 };
260 *self.process_id.write().await = Some(child_id);
261 let mut task_handles = vec![];
262 self.update_state(TaskState::Running).await;
263 if (event_tx
264 .send(TaskEvent::Started {
265 task_name: self.task_name.clone(),
266 })
267 .await)
268 .is_err()
269 {
270 #[cfg(feature = "tracing")]
271 tracing::warn!("Event channel closed while sending TaskEvent::Started");
272 }
273
274 let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
275 let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
276 let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
277
278 let handles = spawn_output_watchers(
280 self.task_name.clone(),
281 self.state.clone(),
282 event_tx.clone(),
283 &mut child,
284 handle_terminator_rx.clone(),
285 self.config.ready_indicator.clone(),
286 self.config.ready_indicator_source.clone(),
287 );
288 task_handles.extend(handles);
289
290 if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
292 let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
293 task_handles.push(handle);
294 }
295
296 *self.terminate_tx.lock().await = Some(terminate_tx);
298
299 let handle = spawn_wait_watcher(
300 self.task_name.clone(),
301 self.state.clone(),
302 child,
303 terminate_rx,
304 handle_terminator_tx.clone(),
305 result_tx,
306 self.process_id.clone(),
307 );
308 task_handles.push(handle);
309
310 if let Some(timeout_ms) = self.config.timeout_ms {
312 let handle =
313 spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
314 task_handles.push(handle);
315 }
316
317 let _handle = spawn_result_watcher(
319 self.task_name.clone(),
320 self.state.clone(),
321 self.finished_at.clone(),
322 event_tx,
323 result_rx,
324 task_handles,
325 );
326
327 Ok(child_id)
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 #[tokio::test]
334 async fn start_direct_ready_indicator_source_stdout() {
335 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
336 #[cfg(windows)]
337 let config = TaskConfig::new("powershell")
338 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
339 .ready_indicator("READY_INDICATOR".to_string())
340 .ready_indicator_source(StreamSource::Stdout);
341 #[cfg(unix)]
342 let config = TaskConfig::new("bash")
343 .args(["-c", "echo READY_INDICATOR"])
344 .ready_indicator("READY_INDICATOR".to_string())
345 .ready_indicator_source(StreamSource::Stdout);
346
347 let mut spawner = TaskSpawner::new("ready_stdout_task".to_string(), config);
348 let result = spawner.start_direct(tx).await;
349 assert!(result.is_ok());
350
351 let mut ready_event = false;
352 while let Some(event) = rx.recv().await {
353 if let TaskEvent::Ready { task_name } = event {
354 assert_eq!(task_name, "ready_stdout_task");
355 ready_event = true;
356 }
357 }
358 assert!(
359 ready_event,
360 "Should emit Ready event when indicator is in stdout"
361 );
362 }
363
364 #[tokio::test]
365 async fn start_direct_ready_indicator_source_stderr() {
366 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
367 #[cfg(windows)]
368 let config = TaskConfig::new("powershell")
369 .args(["-Command", "Write-Error 'READY_INDICATOR'"])
370 .ready_indicator("READY_INDICATOR".to_string())
371 .ready_indicator_source(StreamSource::Stderr);
372 #[cfg(unix)]
373 let config = TaskConfig::new("bash")
374 .args(["-c", "echo READY_INDICATOR 1>&2"])
375 .ready_indicator("READY_INDICATOR".to_string())
376 .ready_indicator_source(StreamSource::Stderr);
377
378 let mut spawner = TaskSpawner::new("ready_stderr_task".to_string(), config);
379 let result = spawner.start_direct(tx).await;
380 assert!(result.is_ok());
381
382 let mut ready_event = false;
383 while let Some(event) = rx.recv().await {
384 if let TaskEvent::Ready { task_name } = event {
385 assert_eq!(task_name, "ready_stderr_task");
386 ready_event = true;
387 }
388 }
389 assert!(
390 ready_event,
391 "Should emit Ready event when indicator is in stderr"
392 );
393 }
394
395 #[tokio::test]
396 async fn start_direct_ready_indicator_source_mismatch() {
397 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
398 #[cfg(windows)]
399 let config = TaskConfig::new("powershell")
400 .args(["-Command", "Write-Output 'READY_INDICATOR'"])
401 .ready_indicator("READY_INDICATOR".to_string())
402 .ready_indicator_source(StreamSource::Stderr);
403 #[cfg(unix)]
404 let config = TaskConfig::new("bash")
405 .args(["-c", "echo READY_INDICATOR"])
406 .ready_indicator("READY_INDICATOR".to_string())
407 .ready_indicator_source(StreamSource::Stderr);
408
409 let mut spawner = TaskSpawner::new("ready_mismatch_task".to_string(), config);
410 let result = spawner.start_direct(tx).await;
411 assert!(result.is_ok());
412
413 let mut ready_event = false;
414 while let Some(event) = rx.recv().await {
415 if let TaskEvent::Ready { .. } = event {
416 ready_event = true;
417 }
418 }
419 assert!(
420 !ready_event,
421 "Should NOT emit Ready event if indicator is in wrong stream"
422 );
423 }
424 use tokio::sync::mpsc;
425
426 use crate::tasks::{
427 async_tokio::spawner::TaskSpawner,
428 config::{StreamSource, TaskConfig},
429 error::TaskError,
430 event::{TaskEvent, TaskEventStopReason},
431 state::TaskTerminateReason,
432 };
433 #[tokio::test]
434 async fn start_direct_fn_echo_command() {
435 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
436 #[cfg(windows)]
437 let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
438 #[cfg(unix)]
439 let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
440
441 let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
442
443 let result = spawner.start_direct(tx).await;
444 assert!(result.is_ok());
445
446 let mut started = false;
447 let mut stopped = false;
448 while let Some(event) = rx.recv().await {
449 match event {
450 TaskEvent::Started { task_name } => {
451 assert_eq!(task_name, "echo_task");
452 started = true;
453 }
454 TaskEvent::Output {
455 task_name,
456 line,
457 src,
458 } => {
459 assert_eq!(task_name, "echo_task");
460 assert_eq!(line, "hello");
461 assert_eq!(src, StreamSource::Stdout);
462 }
463 TaskEvent::Stopped {
464 task_name,
465 exit_code,
466 reason: _,
467 } => {
468 assert_eq!(task_name, "echo_task");
469 assert_eq!(exit_code, Some(0));
470 stopped = true;
471 }
472 _ => {}
473 }
474 }
475
476 assert!(started);
477 assert!(stopped);
478 }
479 #[tokio::test]
480 async fn start_direct_timeout_terminated_task() {
481 #[cfg(windows)]
482 let config = TaskConfig::new("powershell")
483 .args(["-Command", "sleep 2"])
484 .timeout_ms(1);
485 #[cfg(unix)]
486 let config = TaskConfig::new("bash")
487 .args(["-c", "sleep 2"])
488 .timeout_ms(1);
489
490 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
491 let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
492
493 let result = spawner.start_direct(tx).await;
494 assert!(result.is_ok());
495
496 let mut started = false;
497 let mut stopped = false;
498 while let Some(event) = rx.recv().await {
499 match event {
500 TaskEvent::Started { task_name } => {
501 assert_eq!(task_name, "sleep_with_timeout_task");
502 started = true;
503 }
504
505 TaskEvent::Stopped {
506 task_name,
507 exit_code,
508 reason,
509 } => {
510 assert_eq!(task_name, "sleep_with_timeout_task");
511 assert_eq!(exit_code, None);
512 assert_eq!(
513 reason,
514 TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
515 );
516 stopped = true;
517 }
518 _ => {}
519 }
520 }
521
522 assert!(started);
523 assert!(stopped);
524 }
525
526 #[tokio::test]
527 async fn start_direct_fn_invalid_empty_command() {
528 let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
529 let config = TaskConfig::new(""); let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
531
532 let result = spawner.start_direct(tx).await;
533 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
534
535 let state = spawner.get_state().await;
537 assert_eq!(
538 state,
539 crate::tasks::state::TaskState::Finished,
540 "TaskState should be Finished after error, not Initiating"
541 );
542 }
543
544 #[tokio::test]
545 async fn start_direct_fn_stdin_valid() {
546 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
548 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
549
550 #[cfg(windows)]
551 let config = TaskConfig::new("powershell")
552 .args(["-Command", "$line = Read-Host; Write-Output $line"])
553 .enable_stdin(true);
554 #[cfg(unix)]
555 let config = TaskConfig::new("bash")
556 .args(["-c", "read line; echo $line"])
557 .enable_stdin(true);
558
559 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
560
561 let result = spawner.start_direct(tx).await;
563 assert!(result.is_ok());
564
565 stdin_tx.send("hello world".to_string()).await.unwrap();
567
568 let mut started = false;
569 let mut output_ok = false;
570 let mut stopped = false;
571
572 while let Some(event) = rx.recv().await {
573 match event {
574 TaskEvent::Started { task_name } => {
575 assert_eq!(task_name, "stdin_task");
576 started = true;
577 }
578 TaskEvent::Output {
579 task_name,
580 line,
581 src,
582 } => {
583 assert_eq!(task_name, "stdin_task");
584 assert_eq!(line, "hello world");
585 assert_eq!(src, StreamSource::Stdout);
586 output_ok = true;
587 }
588 TaskEvent::Stopped {
589 task_name,
590 exit_code,
591 ..
592 } => {
593 assert_eq!(task_name, "stdin_task");
594 assert_eq!(exit_code, Some(0));
595 stopped = true;
596 }
597 _ => {}
598 }
599 }
600
601 assert!(started);
602 assert!(output_ok);
603 assert!(stopped);
604 }
605
606 #[tokio::test]
607 async fn start_direct_fn_stdin_ignore() {
608 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
610 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
611
612 #[cfg(windows)]
613 let config = TaskConfig::new("powershell")
614 .args(["-Command", "$line = Read-Host; Write-Output $line"]);
615 #[cfg(unix)]
616 let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
617
618 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
620
621 let result = spawner.start_direct(tx).await;
623 assert!(result.is_ok());
624
625 let send_result = stdin_tx.send("hello world".to_string()).await;
627 assert!(
628 send_result.is_err(),
629 "Sending to stdin_tx should error because receiver is dropped"
630 );
631
632 let mut started = false;
633 let mut output_found = false;
634 let mut stopped = false;
635
636 while let Some(event) = rx.recv().await {
637 match event {
638 TaskEvent::Started { task_name } => {
639 assert_eq!(task_name, "stdin_task");
640 started = true;
641 }
642 TaskEvent::Output { .. } => {
643 output_found = true;
645 }
646 TaskEvent::Stopped {
647 task_name,
648 exit_code,
649 ..
650 } => {
651 assert_eq!(task_name, "stdin_task");
652 assert_eq!(exit_code, Some(0));
653 stopped = true;
654 }
655 _ => {}
656 }
657 }
658
659 assert!(started);
660 assert!(
661 !output_found,
662 "Should not receive output from stdin when not enabled"
663 );
664 assert!(stopped);
665 }
666
667 #[tokio::test]
669 async fn start_direct_command_not_found() {
670 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
671 let config = TaskConfig::new("non_existent_command");
672 let mut spawner = TaskSpawner::new("error_task".to_string(), config);
673
674 let result = spawner.start_direct(tx).await;
675 assert!(matches!(result, Err(TaskError::IO(_))));
676
677 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
678 assert_eq!(task_name, "error_task");
679 assert!(matches!(error, TaskError::IO(_)));
680 if let TaskError::IO(msg) = error {
681 #[cfg(windows)]
682 assert!(msg.contains("not found") || msg.contains("cannot find"));
683 #[cfg(unix)]
684 assert!(msg.contains("No such file or directory"));
685 }
686 } else {
687 panic!("Expected TaskEvent::Error");
688 }
689 }
690
691 #[tokio::test]
692 async fn start_direct_invalid_working_directory() {
693 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
694 let config = TaskConfig::new("echo").working_dir("/non/existent/directory");
695
696 let mut spawner = TaskSpawner::new("working_dir_task".to_string(), config);
697
698 let result = spawner.start_direct(tx).await;
699 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
700
701 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
702 assert_eq!(task_name, "working_dir_task");
703 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
704 } else {
705 panic!("Expected TaskEvent::Error");
706 }
707 }
708
709 #[tokio::test]
710 async fn start_direct_zero_timeout() {
711 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
712 #[cfg(windows)]
713 let config = TaskConfig::new("powershell")
714 .args(["-Command", "Start-Sleep -Seconds 1"])
715 .timeout_ms(0);
716 #[cfg(unix)]
717 let config = TaskConfig::new("sleep").args(["1"]).timeout_ms(0);
718
719 let mut spawner = TaskSpawner::new("timeout_task".to_string(), config);
720
721 let result = spawner.start_direct(tx).await;
723 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
724
725 if let Some(TaskEvent::Error { task_name, error }) = rx.recv().await {
727 assert_eq!(task_name, "timeout_task");
728 assert!(matches!(error, TaskError::InvalidConfiguration(_)));
729 } else {
730 panic!("Expected TaskEvent::Error with InvalidConfiguration");
731 }
732 }
733
734 #[tokio::test]
735 async fn process_id_is_none_after_task_stopped() {
736 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
737 #[cfg(windows)]
738 let config = TaskConfig::new("powershell").args(["-Command", "echo done"]);
739 #[cfg(unix)]
740 let config = TaskConfig::new("bash").args(["-c", "echo done"]);
741
742 let mut spawner = TaskSpawner::new("pid_test_task".to_string(), config);
743 let result = spawner.start_direct(tx).await;
744 assert!(result.is_ok());
745
746 let mut stopped = false;
747 while let Some(event) = rx.recv().await {
748 if let TaskEvent::Stopped { task_name, .. } = event {
749 assert_eq!(task_name, "pid_test_task");
750 stopped = true;
751 break;
752 }
753 }
754 assert!(stopped, "Task should emit Stopped event");
755 let pid = spawner.get_process_id().await;
757 assert!(
758 pid.is_none(),
759 "process_id should be None after task is stopped"
760 );
761 }
762
763 #[tokio::test]
764 async fn process_id_is_some_while_task_running() {
765 use std::time::Duration;
766 use tokio::time::sleep;
767 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
768 #[cfg(windows)]
769 let config = TaskConfig::new("powershell").args(["-Command", "Start-Sleep -Seconds 2"]);
770 #[cfg(unix)]
771 let config = TaskConfig::new("sleep").args(["2"]);
772
773 let mut spawner = TaskSpawner::new("pid_running_task".to_string(), config);
774 let result = spawner.start_direct(tx).await;
775 assert!(result.is_ok());
776
777 sleep(Duration::from_millis(500)).await;
779 let pid = spawner.get_process_id().await;
780 assert!(
781 pid.is_some(),
782 "process_id should be Some while task is running"
783 );
784
785 while let Some(event) = rx.recv().await {
787 if let TaskEvent::Stopped { .. } = event {
788 break;
789 }
790 }
791 }
792}