1use std::collections::HashSet;
8
9use tcrm_task::tasks::{
10 event::{TaskEvent, TaskEventStopReason, TaskTerminateReason},
11 state::TaskState,
12};
13use tokio::sync::mpsc::{self, Sender};
14
15use crate::monitor::{
16 error::{ControlCommandError, SendStdinErrorReason, TaskMonitorError},
17 event::{TaskMonitorControlCommand, TaskMonitorControlEvent, TaskMonitorEvent},
18 tasks::TaskMonitor,
19};
20
21impl TaskMonitor {
22 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
87 pub async fn execute_all_direct(&mut self, event_tx: Option<Sender<TaskEvent>>) {
88 let (task_event_tx, mut task_event_rx) = mpsc::channel::<TaskEvent>(1024);
89 self.start_independent_tasks_direct(&task_event_tx).await;
90
91 let mut active_tasks: HashSet<String> = self.tasks_spawner.keys().cloned().collect();
93
94 while let Some(event) = task_event_rx.recv().await {
95 if let Some(ref tx) = event_tx
96 && let Err(_e) = tx.send(event.clone()).await
97 {
98 #[cfg(feature = "tracing")]
99 tracing::warn!(event = ?event, "Failed to forward event");
100 }
101 match event {
102 TaskEvent::Started { .. } | TaskEvent::Output { .. } => {}
103 TaskEvent::Ready { task_name } => {
104 self.start_ready_dependents_direct(
105 &mut active_tasks,
106 &task_name,
107 None,
108 &task_event_tx.clone(),
109 )
110 .await;
111 }
112 TaskEvent::Stopped {
113 task_name,
114 exit_code: _,
115 reason,
116 } => {
117 active_tasks.remove(&task_name);
118
119 self.terminate_dependencies_if_all_dependent_finished(&task_name)
120 .await;
121
122 self.start_ready_dependents_direct(
123 &mut active_tasks,
124 &task_name,
125 Some(reason),
126 &task_event_tx.clone(),
127 )
128 .await;
129 }
130 TaskEvent::Error { task_name, error } => {
131 active_tasks.remove(&task_name);
132
133 self.terminate_dependencies_if_all_dependent_finished(&task_name)
134 .await;
135
136 self.start_ready_dependents_direct(
137 &mut active_tasks,
138 &task_name,
139 Some(TaskEventStopReason::Error(error.to_string())),
140 &task_event_tx.clone(),
141 )
142 .await;
143 }
144 }
145
146 if active_tasks.is_empty() {
148 #[cfg(feature = "tracing")]
149 tracing::debug!("All tasks completed");
150 break;
151 }
152 }
153 }
154
155 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
231 pub async fn execute_all_direct_with_control(
232 &mut self,
233 event_tx: Option<Sender<TaskMonitorEvent>>,
234 mut control_rx: mpsc::Receiver<TaskMonitorControlCommand>,
235 ) {
236 let total_tasks = self.tasks_spawner.len();
237
238 if let Some(ref tx) = event_tx
240 && tx
241 .send(TaskMonitorEvent::Started { total_tasks })
242 .await
243 .is_err()
244 {
245 #[cfg(feature = "tracing")]
246 tracing::warn!("Event channel closed while sending ExecutionStarted");
247 }
248 let (task_event_tx, mut task_event_rx) = mpsc::channel::<TaskEvent>(1024);
249 self.start_independent_tasks_direct(&task_event_tx).await;
250
251 let mut active_tasks: HashSet<String> = self.tasks_spawner.keys().cloned().collect();
253 let mut completed_tasks = 0;
254 let mut failed_tasks = 0;
255
256 loop {
257 tokio::select! {
258 event = task_event_rx.recv() => {
260 let should_break = self.handle_task_event(event,&mut completed_tasks,&mut failed_tasks, &mut active_tasks, &task_event_tx, &event_tx).await;
261 if should_break {
262 break;
263 }
264 }
265 control = control_rx.recv() => {
267 let should_break = self.handle_control_event(control, &event_tx).await;
268 if should_break {
269 break;
270 }
271
272
273 }
274 }
275 }
276
277 if let Some(ref tx) = event_tx
279 && tx
280 .send(TaskMonitorEvent::Completed {
281 completed_tasks,
282 failed_tasks,
283 })
284 .await
285 .is_err()
286 {
287 #[cfg(feature = "tracing")]
288 tracing::warn!("Event channel closed while sending ExecutionCompleted");
289 }
290 }
291
292 async fn handle_task_event(
294 &mut self,
295 event: Option<TaskEvent>,
296 completed_tasks: &mut usize,
297 failed_tasks: &mut usize,
298 active_tasks: &mut HashSet<String>,
299 task_event_tx: &Sender<TaskEvent>,
300 event_tx: &Option<Sender<TaskMonitorEvent>>,
301 ) -> bool {
302 let event = match event {
303 Some(e) => e,
304 None => {
305 #[cfg(feature = "tracing")]
306 tracing::debug!("Task event channel closed");
307 return true;
308 }
309 };
310
311 if let Some(tx) = event_tx
313 && let Err(_e) = tx.send(TaskMonitorEvent::Task(event.clone())).await
314 {
315 #[cfg(feature = "tracing")]
316 tracing::warn!(event = ?event, "Failed to forward task event");
317 }
318 match event {
319 TaskEvent::Started { .. } | TaskEvent::Output { .. } => {}
320 TaskEvent::Ready { task_name } => {
321 self.start_ready_dependents_direct(
322 active_tasks,
323 &task_name,
324 None,
325 &task_event_tx.clone(),
326 )
327 .await;
328 }
329 TaskEvent::Stopped {
330 task_name,
331 exit_code: _,
332 reason,
333 } => {
334 active_tasks.remove(&task_name);
335 if let TaskEventStopReason::Error(_) = reason {
336 *failed_tasks += 1;
337 } else {
338 *completed_tasks += 1;
339 }
340
341 self.terminate_dependencies_if_all_dependent_finished(&task_name)
342 .await;
343
344 self.start_ready_dependents_direct(
345 active_tasks,
346 &task_name,
347 Some(reason),
348 &task_event_tx.clone(),
349 )
350 .await;
351 }
352 TaskEvent::Error { task_name, error } => {
353 active_tasks.remove(&task_name);
354 *failed_tasks += 1;
355
356 self.terminate_dependencies_if_all_dependent_finished(&task_name)
357 .await;
358
359 self.start_ready_dependents_direct(
360 active_tasks,
361 &task_name,
362 Some(TaskEventStopReason::Error(error.to_string())),
363 &task_event_tx.clone(),
364 )
365 .await;
366 }
367 }
368
369 if active_tasks.is_empty() {
371 #[cfg(feature = "tracing")]
372 tracing::debug!(active_tasks = active_tasks.len(), "Execution loop ending");
373 return true;
374 }
375 false
376 }
377
378 async fn handle_control_event(
379 &mut self,
380 control: Option<TaskMonitorControlCommand>,
381 event_tx: &Option<Sender<TaskMonitorEvent>>,
382 ) -> bool {
383 let control = match control {
384 Some(c) => c,
385 None => {
386 #[cfg(feature = "tracing")]
387 tracing::debug!("Control channel closed");
388 return false;
389 }
390 };
391 let control_event = TaskMonitorControlEvent::ControlReceived {
393 control: control.clone(),
394 };
395 if let Some(tx) = event_tx
396 && tx
397 .send(TaskMonitorEvent::Control(control_event))
398 .await
399 .is_err()
400 {
401 #[cfg(feature = "tracing")]
402 tracing::warn!("Event channel closed while sending ControlReceived");
403 }
404 match control {
405 TaskMonitorControlCommand::TerminateAllTasks => {
406 #[cfg(feature = "tracing")]
407 tracing::debug!("Received TerminateAllTasks signal, terminating all tasks");
408
409 self.terminate_all_tasks(TaskTerminateReason::UserRequested)
410 .await;
411 }
412 TaskMonitorControlCommand::TerminateTask { ref task_name } => {
413 #[cfg(feature = "tracing")]
414 tracing::debug!(task_name = %task_name, "Terminating specific task");
415
416 self.terminate_task(&task_name, TaskTerminateReason::UserRequested)
417 .await;
418 }
419 TaskMonitorControlCommand::SendStdin {
420 ref task_name,
421 ref input,
422 } => {
423 #[cfg(feature = "tracing")]
424 tracing::debug!(task_name = %task_name, input_len = input.len(), "Sending stdin to task");
425
426 match self.send_stdin_to_task(&task_name, &input).await {
427 Ok(()) => {}
428 Err(e) => {
429 #[cfg(feature = "tracing")]
430 tracing::warn!(task_name = %task_name, error = %e, "Failed to send stdin to task");
431 let control_command = ControlCommandError::SendStdin {
433 task_name: task_name.clone(),
434 input: input.clone(),
435 reason: e,
436 };
437 if let Some(tx) = event_tx
438 && tx
439 .send(TaskMonitorEvent::Error(TaskMonitorError::ControlError(
440 control_command,
441 )))
442 .await
443 .is_err()
444 {
445 #[cfg(feature = "tracing")]
446 tracing::warn!(
447 "Event channel closed while sending TaskMonitorError::ControlError"
448 );
449 }
450 return false;
451 }
452 }
453 }
454 }
455
456 let control_event = TaskMonitorControlEvent::ControlProcessed { control };
457 if let Some(tx) = event_tx
458 && tx
459 .send(TaskMonitorEvent::Control(control_event))
460 .await
461 .is_err()
462 {
463 #[cfg(feature = "tracing")]
464 tracing::warn!("Event channel closed while sending ControlProcessed");
465 }
466 false
467 }
468 async fn terminate_all_tasks(&mut self, reason: TaskTerminateReason) {
473 for (task_name, spawner) in &mut self.tasks_spawner {
474 let state = spawner.get_state().await;
475 if !matches!(state, TaskState::Running | TaskState::Ready) {
476 continue;
477 }
478 #[allow(clippy::used_underscore_binding)]
479 if let Err(_e) = spawner.send_terminate_signal(reason.clone()).await {
480 #[cfg(feature = "tracing")]
481 tracing::warn!(
482 task_name = %task_name,
483 error = %_e,
484 "Failed to terminate task"
485 );
486
487 #[cfg(not(feature = "tracing"))]
489 let _ = task_name;
490 }
491 }
492 }
493
494 async fn terminate_task(&mut self, task_name: &str, reason: TaskTerminateReason) {
499 let spawner = match self.tasks_spawner.get_mut(task_name) {
500 Some(s) => s,
501 None => {
502 #[cfg(feature = "tracing")]
503 tracing::warn!(task_name = %task_name, "Task not found");
504 return;
505 }
506 };
507 let state = spawner.get_state().await;
508 if matches!(state, TaskState::Running | TaskState::Ready) {
509 #[allow(clippy::used_underscore_binding)]
510 if let Err(_e) = spawner.send_terminate_signal(reason).await {
511 #[cfg(feature = "tracing")]
512 tracing::warn!(
513 task_name = %task_name,
514 error = %_e,
515 "Failed to terminate task"
516 );
517 }
518 } else {
519 #[cfg(feature = "tracing")]
520 tracing::warn!(
521 task_name = %task_name,
522 state = ?state,
523 "Task is not in a state that can be terminated"
524 );
525 }
526 }
527
528 async fn send_stdin_to_task(
534 &mut self,
535 task_name: &str,
536 input: &str,
537 ) -> Result<(), SendStdinErrorReason> {
538 let Some(task_spec) = self.tasks.get(task_name) else {
540 #[cfg(feature = "tracing")]
541 tracing::warn!(
542 task_name = %task_name,
543 "Task not found"
544 );
545 return Err(SendStdinErrorReason::TaskNotFound);
546 };
547
548 let has_stdin_enabled = task_spec.config.enable_stdin.unwrap_or(false);
550
551 if !has_stdin_enabled {
552 #[cfg(feature = "tracing")]
553 tracing::warn!(
554 task_name = %task_name,
555 "Task does not have stdin enabled in configuration"
556 );
557 return Err(SendStdinErrorReason::StdinNotEnabled);
558 }
559
560 let Some(stdin_sender) = self.stdin_senders.get(task_name) else {
562 #[cfg(feature = "tracing")]
563 tracing::warn!(
564 task_name = %task_name,
565 "Task does not have a stdin sender (stdin might not be enabled)"
566 );
567 return Err(SendStdinErrorReason::TaskNotFound);
568 };
569
570 let Some(spawner) = self.tasks_spawner.get(task_name) else {
572 #[cfg(feature = "tracing")]
573 tracing::warn!(task_name = %task_name, "Task spawner not found");
574 return Err(SendStdinErrorReason::TaskNotFound);
575 };
576
577 let state = spawner.get_state().await;
579 if !matches!(state, TaskState::Running | TaskState::Ready) {
580 #[cfg(feature = "tracing")]
581 tracing::warn!(
582 task_name = %task_name,
583 state = ?state,
584 "Task is not in a state that can receive stdin input"
585 );
586 return Err(SendStdinErrorReason::TaskNotActive);
587 }
588
589 match stdin_sender.send(input.to_string()).await {
591 Ok(()) => {
592 #[cfg(feature = "tracing")]
593 tracing::info!(
594 task_name = %task_name,
595 input_len = input.len(),
596 "Successfully sent stdin to task: '{}'",
597 input.trim()
598 );
599 Ok(())
600 }
601 #[allow(clippy::used_underscore_binding)]
602 Err(_e) => {
603 #[cfg(feature = "tracing")]
604 tracing::error!(
605 task_name = %task_name,
606 error = %_e,
607 "Failed to send stdin to task"
608 );
609 Err(SendStdinErrorReason::ChannelClosed)
610 }
611 }
612 }
613
614 async fn start_task_direct(&mut self, name: &str, tx: &mpsc::Sender<TaskEvent>) {
619 let Some(spawner) = self.tasks_spawner.get_mut(name) else {
620 return;
621 };
622 let state = spawner.get_state().await;
623 if state != TaskState::Pending {
624 #[cfg(feature = "tracing")]
625 tracing::warn!(
626 task_name = name,
627 state = ?state,
628 "Task is not in Pending state",
629 );
630 return;
631 }
632 let _id = spawner.start_direct(tx.clone()).await;
633 }
634
635 async fn start_independent_tasks_direct(&mut self, tx: &mpsc::Sender<TaskEvent>) {
640 let independent_tasks: Vec<String> = self
642 .tasks_spawner
643 .keys()
644 .filter(|name| !self.dependencies.contains_key(*name))
645 .cloned()
646 .collect();
647 for name in independent_tasks {
649 self.start_task_direct(&name, &tx).await;
650 }
651 }
652
653 async fn start_ready_dependents_direct(
659 &mut self,
660 active_tasks: &mut HashSet<String>,
661 parent_task: &str,
662 parent_task_stop_reason: Option<TaskEventStopReason>,
663 tx: &mpsc::Sender<TaskEvent>,
664 ) {
665 let dependents = match self.dependents.get(parent_task) {
666 Some(d) => d.clone(),
667 None => return,
668 };
669 for task_name in dependents {
670 let Some(dependencies) = self.dependencies.get(&task_name) else {
671 #[cfg(feature = "tracing")]
672 tracing::error!(
673 task_name,
674 "Task has no dependencies, unexpected behavior, it should not be started by this function"
675 );
676 break;
677 };
678 let mut all_dependencies_ready = true;
679 for dependency in dependencies {
680 let state = if let Some(c) = self.tasks_spawner.get(dependency) {
681 c.get_state().await
682 } else {
683 #[cfg(feature = "tracing")]
684 tracing::error!(task_name, "Failed to get task spawner, unexpected behavior");
685 all_dependencies_ready = false;
686 break;
687 };
688 let not_ready = !matches!(state, TaskState::Ready | TaskState::Finished);
689 if not_ready {
690 all_dependencies_ready = false;
691 break;
692 }
693 }
694 if !all_dependencies_ready {
695 continue;
696 }
697
698 let ignore_dependencies_error = match self.tasks.get(&task_name) {
699 Some(config) => config.ignore_dependencies_error.unwrap_or_default(),
700 None => false,
701 };
702
703 let should_start = match &parent_task_stop_reason {
704 Some(TaskEventStopReason::Terminated(_) | TaskEventStopReason::Error(_)) => {
705 ignore_dependencies_error
706 }
707 Some(TaskEventStopReason::Finished) | None => true,
708 };
709 if should_start {
710 self.start_task_direct(&task_name.clone(), tx).await;
711 } else {
712 active_tasks.remove(&task_name);
714 if let Some(child_dependents) = self.dependents.get(&task_name) {
715 for child in child_dependents {
716 active_tasks.remove(child);
717 }
718 }
719 }
720 }
721 }
722}
723
724#[cfg(test)]
725mod tests {
726
727 #[tokio::test]
728 async fn test_terminate_task_sends_user_requested_stop_event() {
729 use crate::monitor::event::TaskMonitorControlCommand;
730 use tcrm_task::tasks::event::{TaskEventStopReason, TaskTerminateReason};
731
732 let mut tasks = HashMap::new();
733 #[cfg(windows)]
735 let cmd = TaskConfig::new("ping").args(["127.0.0.1", "-n", "10"]);
736 #[cfg(not(windows))]
737 let cmd = TaskConfig::new("sh").args(["-c", "sleep 10"]);
738 tasks.insert(
739 "long_task".to_string(),
740 TaskSpec::new(cmd).shell(TaskShell::Auto),
741 );
742
743 let mut monitor = TaskMonitor::new(tasks).unwrap();
744 let (event_tx, mut event_rx) = mpsc::channel(100);
745 let (control_tx, control_rx) = mpsc::channel(10);
746
747 let monitor_handle = tokio::spawn(async move {
749 monitor
750 .execute_all_direct_with_control(Some(event_tx), control_rx)
751 .await;
752 });
753
754 let mut started = false;
756 while let Some(event) = event_rx.recv().await {
757 if let crate::monitor::event::TaskMonitorEvent::Task(TaskEvent::Started { task_name }) =
758 &event
759 {
760 if task_name == "long_task" {
761 started = true;
762 break;
763 }
764 }
765 }
766 assert!(started, "Task should have started");
767
768 control_tx
770 .send(TaskMonitorControlCommand::TerminateTask {
771 task_name: "long_task".to_string(),
772 })
773 .await
774 .unwrap();
775
776 let mut found = false;
778 for _ in 0..10 {
779 if let Some(event) = event_rx.recv().await {
780 if let crate::monitor::event::TaskMonitorEvent::Task(TaskEvent::Stopped {
781 task_name,
782 reason,
783 ..
784 }) = event
785 {
786 if task_name == "long_task" {
787 if let TaskEventStopReason::Terminated(TaskTerminateReason::UserRequested) =
788 reason
789 {
790 found = true;
791 break;
792 }
793 }
794 }
795 } else {
796 break;
797 }
798 }
799 assert!(
800 found,
801 "Should receive Stopped event with Terminated(UserRequested) reason"
802 );
803
804 let _ = monitor_handle.await;
807 }
808 use std::{collections::HashMap, time::Duration};
809
810 use tcrm_task::tasks::{
811 config::{StreamSource, TaskConfig},
812 error::TaskError,
813 event::TaskEvent,
814 };
815 use tokio::{sync::mpsc, time::timeout};
816
817 use crate::monitor::{
818 config::{TaskShell, TaskSpec},
819 error::SendStdinErrorReason,
820 tasks::TaskMonitor,
821 };
822
823 async fn collect_events(
828 event_rx: &mut mpsc::Receiver<TaskEvent>,
829 max_events: usize,
830 ) -> Vec<TaskEvent> {
831 let mut events = Vec::new();
832 let mut event_count = 0;
833
834 while event_count < max_events {
835 match timeout(Duration::from_secs(5), event_rx.recv()).await {
836 Ok(Some(event)) => {
837 events.push(event);
838 event_count += 1;
839 }
840 Ok(None) => break, Err(_) => break, }
843 }
844 events
845 }
846
847 #[tokio::test]
848 async fn test_execute_all_simple_chain() {
849 let mut tasks = HashMap::new();
850
851 tasks.insert(
852 "task1".to_string(),
853 TaskSpec::new(TaskConfig::new("echo").args(["hello"])).shell(TaskShell::Auto),
854 );
855
856 tasks.insert(
857 "task2".to_string(),
858 TaskSpec::new(TaskConfig::new("echo").args(["world"]))
859 .dependencies(["task1"])
860 .shell(TaskShell::Auto),
861 );
862
863 tasks.insert(
864 "task3".to_string(),
865 TaskSpec::new(TaskConfig::new("echo").args(["!"]))
866 .dependencies(["task2"])
867 .shell(TaskShell::Auto),
868 );
869
870 let mut monitor = TaskMonitor::new(tasks).unwrap();
871 let (event_tx, mut event_rx) = mpsc::channel(1024);
872
873 let execute_handle =
875 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
876
877 let events = collect_events(&mut event_rx, 10).await;
878 let result = timeout(Duration::from_secs(10), execute_handle).await;
880 assert!(result.is_ok());
881
882 let started_tasks: Vec<_> = events
884 .iter()
885 .filter_map(|e| match e {
886 TaskEvent::Started { task_name } => Some(task_name.clone()),
887 _ => None,
888 })
889 .collect();
890
891 assert!(!started_tasks.is_empty());
892
893 let task1_idx = started_tasks.iter().position(|x| x == "task1").unwrap();
895 let task2_idx = started_tasks.iter().position(|x| x == "task2").unwrap();
896 let task3_idx = started_tasks.iter().position(|x| x == "task3").unwrap();
897
898 assert!(task1_idx < task2_idx);
899 assert!(task2_idx < task3_idx);
900 }
901
902 #[tokio::test]
903 async fn test_execute_all_independent_tasks() {
904 let mut tasks = HashMap::new();
905
906 tasks.insert(
907 "independent1".to_string(),
908 TaskSpec::new(TaskConfig::new("echo").args(["task1"])).shell(TaskShell::Auto),
909 );
910 tasks.insert(
911 "independent2".to_string(),
912 TaskSpec::new(TaskConfig::new("echo").args(["task2"])).shell(TaskShell::Auto),
913 );
914 tasks.insert(
915 "independent3".to_string(),
916 TaskSpec::new(TaskConfig::new("echo").args(["task3"])).shell(TaskShell::Auto),
917 );
918
919 let mut monitor = TaskMonitor::new(tasks).unwrap();
920 let (event_tx, mut event_rx) = mpsc::channel(1024);
921
922 let execute_handle =
923 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
924
925 let events = collect_events(&mut event_rx, 15).await;
926
927 let result = timeout(Duration::from_secs(10), execute_handle).await;
928 assert!(result.is_ok());
929
930 let started_tasks: Vec<_> = events
932 .iter()
933 .filter_map(|e| match e {
934 TaskEvent::Started { task_name } => Some(task_name.clone()),
935 _ => None,
936 })
937 .collect();
938
939 let stopped_tasks: Vec<_> = events
940 .iter()
941 .filter_map(|e| match e {
942 TaskEvent::Stopped { task_name, .. } => Some(task_name.clone()),
943 _ => None,
944 })
945 .collect();
946
947 assert_eq!(started_tasks.len(), 3);
949 assert_eq!(stopped_tasks.len(), 3);
950 }
951
952 #[tokio::test]
953 async fn test_task_with_ready_indicator() {
954 let mut tasks = HashMap::new();
958
959 tasks.insert(
961 "server".to_string(),
962 TaskSpec::new(
963 TaskConfig::new("echo ready!; sleep 10")
964 .ready_indicator("ready!".to_string())
965 .ready_indicator_source(StreamSource::Stdout),
966 )
967 .terminate_after_dependents(true)
968 .shell(TaskShell::Auto),
969 );
970
971 tasks.insert(
973 "client".to_string(),
974 TaskSpec::new(TaskConfig::new("echo client-started"))
975 .dependencies(["server"])
976 .shell(TaskShell::Auto),
977 );
978
979 let mut monitor = TaskMonitor::new(tasks).unwrap();
980 let (event_tx, mut event_rx) = mpsc::channel(1024);
981
982 let execute_handle =
983 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
984
985 let events = collect_events(&mut event_rx, 10).await;
986
987 println!("Collected events: {:?}", events);
988 let server_ready = events
990 .iter()
991 .find(|e| matches!(e, TaskEvent::Ready { task_name } if task_name == "server"));
992 assert!(server_ready.is_some());
993
994 let client_started = events
996 .iter()
997 .find(|e| matches!(e, TaskEvent::Started { task_name } if task_name == "client"));
998 assert!(client_started.is_some());
999
1000 let server_ready_idx = events
1002 .iter()
1003 .position(|e| matches!(e, TaskEvent::Ready { task_name } if task_name == "server"))
1004 .unwrap();
1005 let client_start_idx = events
1006 .iter()
1007 .position(|e| matches!(e, TaskEvent::Started { task_name } if task_name == "client"))
1008 .unwrap();
1009 assert!(server_ready_idx < client_start_idx);
1010
1011 let _ = timeout(Duration::from_secs(10), execute_handle).await;
1013 }
1014
1015 #[tokio::test]
1016 async fn test_task_error_handling() {
1017 let mut tasks = HashMap::new();
1018
1019 tasks.insert(
1021 "failing_task".to_string(),
1022 TaskSpec::new(TaskConfig::new("exit").args(["1"])),
1023 );
1024
1025 tasks.insert(
1027 "dependent_task".to_string(),
1028 TaskSpec::new(TaskConfig::new("echo").args(["should-not-run"]))
1029 .dependencies(["failing_task"]),
1030 );
1031
1032 tasks.insert(
1034 "resilient_task".to_string(),
1035 TaskSpec::new(TaskConfig::new("echo").args(["should-run"]))
1036 .shell(TaskShell::Auto)
1037 .dependencies(["failing_task"])
1038 .ignore_dependencies_error(true),
1039 );
1040
1041 let mut monitor = TaskMonitor::new(tasks).unwrap();
1042 let (event_tx, mut event_rx) = mpsc::channel(1024);
1043 let execute_handle =
1044 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
1045
1046 let events = collect_events(&mut event_rx, 15).await;
1047
1048 let failing_task_stopped = events.iter().find(|e| {
1050 matches!(e,
1051 TaskEvent::Error {
1052 task_name,
1053 error: TaskError::IO(_),
1054 } if task_name == "failing_task"
1055 )
1056 });
1057 assert!(failing_task_stopped.is_some());
1058
1059 let dependent_started = events.iter().find(|e| {
1061 matches!(e,
1062 TaskEvent::Started {
1063 task_name
1064 } if task_name == "dependent_task"
1065 )
1066 });
1067 assert!(dependent_started.is_none());
1068
1069 let resilient_started = events.iter().find(|e| {
1071 matches!(e,
1072 TaskEvent::Started {
1073 task_name
1074 } if task_name == "resilient_task"
1075 )
1076 });
1077 assert!(resilient_started.is_some());
1078
1079 let _ = timeout(Duration::from_secs(10), execute_handle).await;
1080 }
1081
1082 #[tokio::test]
1083 async fn test_channel_overflow_resilience() {
1084 let mut tasks = HashMap::new();
1086
1087 for i in 0..10 {
1089 tasks.insert(
1090 format!("task_{}", i),
1091 TaskSpec::new(TaskConfig::new("echo").args([&format!("output from task {}", i)])),
1092 );
1093 }
1094
1095 let mut monitor = TaskMonitor::new(tasks).unwrap();
1096
1097 let (event_tx, mut event_rx) = mpsc::channel(2);
1099 let execute_handle =
1100 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
1101
1102 let mut event_count = 0;
1104 let timeout_duration = Duration::from_millis(200);
1105
1106 while let Ok(Some(_event)) = timeout(timeout_duration, event_rx.recv()).await {
1107 event_count += 1;
1108 if event_count > 50 {
1109 break; }
1111 tokio::time::sleep(Duration::from_millis(2)).await; }
1113
1114 let result = timeout(Duration::from_secs(10), execute_handle).await;
1116 assert!(
1117 result.is_ok(),
1118 "Execution should complete despite channel pressure"
1119 );
1120
1121 assert!(
1123 event_count >= 10,
1124 "Should receive at least 10 events, got {}",
1125 event_count
1126 );
1127 }
1128
1129 #[tokio::test]
1130 async fn test_concurrent_task_state_consistency() {
1131 let mut tasks = HashMap::new();
1133
1134 tasks.insert(
1136 "first".to_string(),
1137 TaskSpec::new(TaskConfig::new("echo").args(["first"])),
1138 );
1139
1140 for i in 1..10 {
1141 tasks.insert(
1142 format!("task_{}", i),
1143 TaskSpec::new(TaskConfig::new("echo").args([&format!("task {}", i)]))
1144 .dependencies([&format!("task_{}", i - 1)]),
1145 );
1146 }
1147 tasks.insert(
1148 "task_0".to_string(),
1149 TaskSpec::new(TaskConfig::new("echo").args(["task 0"])).dependencies(["first"]),
1150 );
1151
1152 let mut monitor = TaskMonitor::new(tasks).unwrap();
1153 let (event_tx, mut event_rx) = mpsc::channel(1024);
1154
1155 let execute_handle =
1156 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
1157
1158 let events = collect_events(&mut event_rx, 50).await;
1159
1160 let mut start_times = HashMap::new();
1162 for (idx, event) in events.iter().enumerate() {
1163 if let TaskEvent::Started { task_name } = event {
1164 start_times.insert(task_name.clone(), idx);
1165 }
1166 }
1167
1168 if let (Some(&first_start), Some(&task_0_start)) =
1170 (start_times.get("first"), start_times.get("task_0"))
1171 {
1172 assert!(
1173 first_start < task_0_start,
1174 "Dependencies should start before dependents"
1175 );
1176 }
1177
1178 for i in 1..9 {
1179 let task_name = format!("task_{}", i);
1180 let prev_task_name = format!("task_{}", i - 1);
1181 if let (Some(¤t), Some(&previous)) = (
1182 start_times.get(&task_name),
1183 start_times.get(&prev_task_name),
1184 ) {
1185 assert!(
1186 previous < current,
1187 "Task {} should start after task {}",
1188 i,
1189 i - 1
1190 );
1191 }
1192 }
1193
1194 let _ = timeout(Duration::from_secs(10), execute_handle).await;
1195 }
1196
1197 #[tokio::test]
1198 async fn test_resource_cleanup_on_early_termination() {
1199 let mut tasks = HashMap::new();
1201
1202 tasks.insert(
1204 "long_task_1".to_string(),
1205 TaskSpec::new(TaskConfig::new("ping").args(["127.0.0.1", "-n", "100"])),
1206 );
1207 tasks.insert(
1208 "long_task_2".to_string(),
1209 TaskSpec::new(TaskConfig::new("ping").args(["127.0.0.1", "-n", "100"])),
1210 );
1211
1212 let mut monitor = TaskMonitor::new(tasks).unwrap();
1213 let (event_tx, mut event_rx) = mpsc::channel(1024);
1214
1215 let execute_handle =
1216 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
1217
1218 let mut started_count = 0;
1220 while let Ok(Some(event)) = timeout(Duration::from_secs(2), event_rx.recv()).await {
1221 if matches!(event, TaskEvent::Started { .. }) {
1222 started_count += 1;
1223 if started_count >= 2 {
1224 break;
1225 }
1226 }
1227 }
1228
1229 execute_handle.abort();
1231
1232 tokio::time::sleep(Duration::from_millis(100)).await;
1234
1235 assert!(
1237 started_count >= 2,
1238 "Both long-running tasks should have started"
1239 );
1240 }
1241
1242 #[tokio::test]
1243 async fn test_invalid_command_error_propagation() {
1244 let mut tasks = HashMap::new();
1246
1247 tasks.insert(
1248 "invalid_command".to_string(),
1249 TaskSpec::new(TaskConfig::new("definitely_not_a_real_command_12345")),
1250 );
1251
1252 tasks.insert(
1253 "dependent_task".to_string(),
1254 TaskSpec::new(TaskConfig::new("echo").args(["should not run"]))
1255 .dependencies(["invalid_command"]),
1256 );
1257
1258 let mut monitor = TaskMonitor::new(tasks).unwrap();
1259 let (event_tx, mut event_rx) = mpsc::channel(1024);
1260
1261 let execute_handle =
1262 tokio::spawn(async move { monitor.execute_all_direct(Some(event_tx)).await });
1263
1264 let events = collect_events(&mut event_rx, 10).await;
1265
1266 let error_event = events.iter().find(
1268 |e| matches!(e, TaskEvent::Error { task_name, .. } if task_name == "invalid_command"),
1269 );
1270 assert!(
1271 error_event.is_some(),
1272 "Invalid command should generate an error event"
1273 );
1274
1275 let dependent_started = events.iter().find(
1277 |e| matches!(e, TaskEvent::Started { task_name } if task_name == "dependent_task"),
1278 );
1279 assert!(
1280 dependent_started.is_none(),
1281 "Dependent task should not start after dependency error"
1282 );
1283
1284 let _ = timeout(Duration::from_secs(5), execute_handle).await;
1285 }
1286
1287 #[tokio::test]
1288 async fn test_stdin_functionality_with_control() {
1289 let mut tasks = HashMap::new();
1291
1292 #[cfg(windows)]
1294 let stdin_task_config = TaskConfig::new("powershell")
1295 .args(["-Command", "'ready'; $host.UI.ReadLine()"])
1296 .enable_stdin(true)
1297 .ready_indicator("ready")
1298 .timeout_ms(5000);
1299 #[cfg(not(windows))]
1300 let stdin_task_config = TaskConfig::new("sh")
1301 .args(["-c", "echo 'ready'; read input; echo $input"])
1302 .enable_stdin(true)
1303 .ready_indicator("ready")
1304 .timeout_ms(5000);
1305
1306 tasks.insert(
1307 "stdin_task".to_string(),
1308 TaskSpec::new(stdin_task_config).shell(TaskShell::Auto),
1309 );
1310
1311 tasks.insert(
1313 "no_stdin_task".to_string(),
1314 TaskSpec::new(TaskConfig::new("echo").args(["No stdin task"])).shell(TaskShell::Auto),
1315 );
1316
1317 let mut monitor = TaskMonitor::new(tasks).unwrap();
1318
1319 assert!(monitor.stdin_senders.contains_key("stdin_task"));
1321 assert!(!monitor.stdin_senders.contains_key("no_stdin_task"));
1322 assert_eq!(monitor.stdin_senders.len(), 1);
1323
1324 let result = monitor
1326 .send_stdin_to_task("stdin_task", "Hello stdin!")
1327 .await;
1328 assert!(
1330 result.is_err(),
1331 "Sending stdin to non-running task should fail"
1332 );
1333 if let Err(e) = result {
1334 println!("Expected error: {}", e);
1335 }
1336
1337 let result = monitor
1339 .send_stdin_to_task("nonexistent_task", "Should be ignored")
1340 .await;
1341 assert!(
1342 result.is_err(),
1343 "Sending stdin to nonexistent task should fail"
1344 );
1345
1346 let result = monitor
1348 .send_stdin_to_task("no_stdin_task", "Should be rejected")
1349 .await;
1350 assert!(
1351 result.is_err(),
1352 "Sending stdin to task without stdin enabled should fail"
1353 );
1354 }
1355
1356 #[tokio::test]
1357 async fn test_stdin_channel_creation() {
1358 let mut tasks = HashMap::new();
1360
1361 tasks.insert(
1363 "with_stdin".to_string(),
1364 TaskSpec::new(TaskConfig::new("echo").args(["test"]).enable_stdin(true))
1365 .shell(TaskShell::Auto),
1366 );
1367
1368 tasks.insert(
1370 "without_stdin".to_string(),
1371 TaskSpec::new(TaskConfig::new("echo").args(["test"]).enable_stdin(false))
1372 .shell(TaskShell::Auto),
1373 );
1374
1375 tasks.insert(
1377 "default_stdin".to_string(),
1378 TaskSpec::new(TaskConfig::new("echo").args(["test"])).shell(TaskShell::Auto),
1379 );
1380
1381 let monitor = TaskMonitor::new(tasks).unwrap();
1382
1383 assert!(monitor.stdin_senders.contains_key("with_stdin"));
1385 assert!(!monitor.stdin_senders.contains_key("without_stdin"));
1386 assert!(!monitor.stdin_senders.contains_key("default_stdin"));
1387 assert_eq!(monitor.stdin_senders.len(), 1);
1388
1389 assert_eq!(monitor.tasks_spawner.len(), 3);
1391 }
1392
1393 #[tokio::test]
1394 async fn test_stdin_validation_and_error_handling() {
1395 let mut tasks = HashMap::new();
1397
1398 #[cfg(windows)]
1399 let stdin_task_config = TaskConfig::new("powershell")
1400 .args(["-Command", "'ready'; $host.UI.ReadLine()"])
1401 .enable_stdin(true)
1402 .ready_indicator("ready")
1403 .timeout_ms(5000);
1404 #[cfg(not(windows))]
1405 let stdin_task_config = TaskConfig::new("sh")
1406 .args(["-c", "echo 'ready'; read input; echo $input"])
1407 .enable_stdin(true)
1408 .ready_indicator("ready")
1409 .timeout_ms(5000);
1410
1411 tasks.insert(
1412 "stdin_enabled".to_string(),
1413 TaskSpec::new(stdin_task_config).shell(TaskShell::Auto),
1414 );
1415
1416 let mut monitor = TaskMonitor::new(tasks).unwrap();
1417
1418 assert!(monitor.stdin_senders.contains_key("stdin_enabled"));
1420 assert_eq!(monitor.stdin_senders.len(), 1);
1421
1422 let result = monitor
1424 .send_stdin_to_task("stdin_enabled", "Valid input")
1425 .await;
1426 assert!(
1428 result.is_err(),
1429 "Sending stdin to non-running task should fail"
1430 );
1431
1432 let result = monitor
1433 .send_stdin_to_task("nonexistent_task", "Should be ignored")
1434 .await;
1435 assert!(
1436 result.is_err(),
1437 "Sending stdin to nonexistent task should fail"
1438 );
1439 }
1440
1441 #[tokio::test]
1442 async fn test_stdin_error_types() {
1443 let mut tasks = HashMap::new();
1445
1446 #[cfg(windows)]
1447 let stdin_task_config = TaskConfig::new("powershell")
1448 .args(["-Command", "'ready'; $host.UI.ReadLine()"])
1449 .enable_stdin(true)
1450 .ready_indicator("ready")
1451 .timeout_ms(5000);
1452 #[cfg(not(windows))]
1453 let stdin_task_config = TaskConfig::new("sh")
1454 .args(["-c", "echo 'ready'; read input; echo $input"])
1455 .enable_stdin(true)
1456 .ready_indicator("ready")
1457 .timeout_ms(5000);
1458
1459 tasks.insert(
1460 "stdin_task".to_string(),
1461 TaskSpec::new(stdin_task_config).shell(TaskShell::Auto),
1462 );
1463
1464 tasks.insert(
1465 "no_stdin_task".to_string(),
1466 TaskSpec::new(TaskConfig::new("echo").args(["test"])).shell(TaskShell::Auto),
1467 );
1468
1469 let mut monitor = TaskMonitor::new(tasks).unwrap();
1470
1471 let result = monitor.send_stdin_to_task("stdin_task", "input").await;
1473 assert!(result.is_err());
1474 assert_eq!(result, Err(SendStdinErrorReason::TaskNotActive));
1475
1476 let result = monitor.send_stdin_to_task("no_stdin_task", "input").await;
1478 assert!(result.is_err());
1479 assert_eq!(result, Err(SendStdinErrorReason::StdinNotEnabled));
1480
1481 let result = monitor.send_stdin_to_task("nonexistent", "input").await;
1483 assert!(result.is_err());
1484 assert_eq!(result, Err(SendStdinErrorReason::TaskNotFound));
1485 }
1486
1487 #[tokio::test]
1488 async fn test_multiple_stdin_tasks_concurrent() {
1489 let mut tasks = HashMap::new();
1491
1492 for i in 1..=3 {
1493 #[cfg(windows)]
1494 let stdin_task_config = TaskConfig::new("powershell")
1495 .args(["-Command", "'ready'; $host.UI.ReadLine()"])
1496 .enable_stdin(true)
1497 .ready_indicator("ready")
1498 .timeout_ms(5000);
1499 #[cfg(not(windows))]
1500 let stdin_task_config = TaskConfig::new("sh")
1501 .args(["-c", "echo 'ready'; read input; echo $input"])
1502 .enable_stdin(true)
1503 .ready_indicator("ready")
1504 .timeout_ms(5000);
1505
1506 tasks.insert(
1507 format!("stdin_task_{}", i),
1508 TaskSpec::new(stdin_task_config).shell(TaskShell::Auto),
1509 );
1510 }
1511
1512 let mut monitor = TaskMonitor::new(tasks).unwrap();
1513
1514 assert_eq!(monitor.stdin_senders.len(), 3);
1516 for i in 1..=3 {
1517 assert!(
1518 monitor
1519 .stdin_senders
1520 .contains_key(&format!("stdin_task_{}", i))
1521 );
1522 }
1523
1524 for i in 1..=3 {
1526 let result = monitor
1527 .send_stdin_to_task(
1528 &format!("stdin_task_{}", i),
1529 &format!("Input for task {}", i),
1530 )
1531 .await;
1532 assert!(
1533 result.is_err(),
1534 "Sending stdin to non-running stdin_task_{} should fail",
1535 i
1536 );
1537 }
1538 }
1539}