tcrm_task/tasks/async_tokio/
spawner.rs

1use std::sync::Arc;
2use std::time::Duration;
3use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
4use tokio::task::JoinHandle;
5use tokio::time::{Instant, timeout};
6
7use crate::tasks::error::TaskError;
8use crate::tasks::state::TaskTerminateReason;
9use crate::tasks::{config::TaskConfig, state::TaskState};
10
11#[derive(Debug, Clone)]
12pub struct TaskInfo {
13    pub name: String,
14    pub state: TaskState,
15    pub uptime: Duration,
16    pub created_at: Instant,
17    pub finished_at: Option<Instant>,
18}
19/// Spawns and manages the lifecycle of a task
20#[derive(Debug)]
21pub struct TaskSpawner {
22    pub(crate) config: TaskConfig,
23    pub(crate) task_name: String,
24    pub(crate) state: Arc<RwLock<TaskState>>,
25    pub(crate) terminate_tx: Arc<Mutex<Option<oneshot::Sender<TaskTerminateReason>>>>,
26    pub(crate) process_id: Arc<RwLock<Option<u32>>>,
27    pub(crate) created_at: Instant,
28    pub(crate) finished_at: Arc<RwLock<Option<Instant>>>,
29    pub(crate) stdin_rx: Option<mpsc::Receiver<String>>,
30}
31
32impl TaskSpawner {
33    /// Create a new task spawner for the given task name and configuration
34    pub fn new(task_name: String, config: TaskConfig) -> Self {
35        Self {
36            task_name,
37            config,
38            state: Arc::new(RwLock::new(TaskState::Pending)),
39            terminate_tx: Arc::new(Mutex::new(None)),
40            process_id: Arc::new(RwLock::new(None)),
41            created_at: Instant::now(),
42            finished_at: Arc::new(RwLock::new(None)),
43            stdin_rx: None,
44        }
45    }
46
47    /// Set the stdin receiver for the task, enabling asynchronous input
48    ///
49    /// Has no effect if `enable_stdin` is false in the configuration
50    pub fn set_stdin(mut self, stdin_rx: mpsc::Receiver<String>) -> Self {
51        if self.config.enable_stdin.unwrap_or_default() {
52            self.stdin_rx = Some(stdin_rx);
53        }
54        self
55    }
56
57    /// Get the current state of the task
58    pub async fn get_state(&self) -> TaskState {
59        self.state.read().await.clone()
60    }
61
62    /// Check if the task is currently running
63    pub async fn is_running(&self) -> bool {
64        let state = self.state.read().await.clone();
65        state == TaskState::Running
66    }
67    /// Check if the task is currently ready
68    pub async fn is_ready(&self) -> bool {
69        let state = self.state.read().await.clone();
70        state == TaskState::Ready
71    }
72
73    /// Get the uptime of the task since creation
74    pub fn uptime(&self) -> Duration {
75        self.created_at.elapsed()
76    }
77
78    /// Get information about the task, including name, state, and uptime
79    pub async fn get_task_info(&self) -> TaskInfo {
80        TaskInfo {
81            name: self.task_name.clone(),
82            state: self.get_state().await,
83            uptime: self.uptime(),
84            created_at: self.created_at,
85            finished_at: self.finished_at.read().await.clone(),
86        }
87    }
88
89    /// Get the process ID of the running task (if any)
90    pub async fn get_process_id(&self) -> Option<u32> {
91        self.process_id.read().await.clone()
92    }
93
94    /// Update the state of the task to Ready
95    pub async fn update_state_to_ready(&self) {
96        let mut state = self.state.write().await;
97        *state = TaskState::Ready;
98    }
99
100    /// Update the state of the task
101    pub(crate) async fn update_state(&self, new_state: TaskState) {
102        let mut state = self.state.write().await;
103        *state = new_state;
104    }
105
106    /// Send a termination signal to the running task
107    ///
108    /// Returns an error if the signal cannot be sent
109    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
110    pub async fn send_terminate_signal(
111        &self,
112        reason: TaskTerminateReason,
113    ) -> Result<(), TaskError> {
114        if let Some(tx) = self.terminate_tx.lock().await.take() {
115            if tx.send(reason.clone()).is_err() {
116                let msg = "Terminate channel closed while sending signal";
117                #[cfg(feature = "tracing")]
118                tracing::warn!(terminate_reason=?reason, msg);
119                return Err(TaskError::Channel(msg.to_string()));
120            }
121        } else {
122            let msg = "Terminate signal already sent or channel missing";
123            #[cfg(feature = "tracing")]
124            tracing::warn!(msg);
125            return Err(TaskError::Channel(msg.to_string()));
126        }
127
128        Ok(())
129    }
130}
131
132/// Waits for all spawned task handles to complete, with a timeout
133///
134/// Returns an error if any handle fails or times out
135pub(crate) async fn join_all_handles(
136    task_handles: &mut Vec<JoinHandle<()>>,
137) -> Result<(), TaskError> {
138    if task_handles.is_empty() {
139        return Ok(());
140    }
141
142    let handles = std::mem::take(task_handles);
143    let mut errors = Vec::new();
144
145    for (_index, mut handle) in handles.into_iter().enumerate() {
146        match timeout(Duration::from_secs(5), &mut handle).await {
147            Ok(Ok(())) => {}
148            Ok(Err(join_err)) => {
149                let err_msg = format!("Handle [{}] join failed: {:?}", handle.id(), join_err);
150
151                errors.push(err_msg);
152            }
153            Err(_) => {
154                let err_msg = format!("Handle [{}] join timeout, aborting", handle.id());
155                handle.abort(); // ensure it’s killed
156                errors.push(err_msg);
157            }
158        }
159    }
160
161    if !errors.is_empty() {
162        return Err(TaskError::Handle(format!(
163            "Multiple task handles join failures: {}",
164            errors.join("; ")
165        )));
166    }
167
168    Ok(())
169}
170#[cfg(test)]
171mod tests {
172    use std::time::Duration;
173
174    use tokio::time::sleep;
175
176    use crate::tasks::{
177        async_tokio::spawner::{TaskInfo, TaskSpawner},
178        config::TaskConfig,
179        state::TaskState,
180    };
181
182    #[tokio::test]
183    async fn task_spawner_is_running_returns_true_when_state_running() {
184        let config = TaskConfig::new("echo");
185        let spawner = TaskSpawner::new("running_task".to_string(), config);
186        assert!(
187            !spawner.is_running().await,
188            "Should not be running initially"
189        );
190        spawner.update_state(TaskState::Running).await;
191        assert!(spawner.is_running().await, "Should be running after update");
192    }
193
194    #[tokio::test]
195    async fn task_spawner_update_state_to_ready_sets_ready() {
196        let config = TaskConfig::new("echo");
197        let spawner = TaskSpawner::new("ready_method_task".to_string(), config);
198        assert!(!spawner.is_ready().await, "Should not be ready initially");
199        spawner.update_state_to_ready().await;
200        assert!(
201            spawner.is_ready().await,
202            "Should be ready after update_state_to_ready()"
203        );
204        let state = spawner.get_state().await;
205        assert_eq!(
206            state,
207            TaskState::Ready,
208            "State should be Ready after update_state_to_ready()"
209        );
210    }
211
212    #[tokio::test]
213    async fn task_spawner_is_ready_returns_true_when_state_ready() {
214        let config = TaskConfig::new("echo");
215        let spawner = TaskSpawner::new("ready_task".to_string(), config);
216        assert!(!spawner.is_ready().await, "Should not be ready initially");
217        spawner.update_state(TaskState::Ready).await;
218        assert!(spawner.is_ready().await, "Should be ready after update");
219    }
220
221    #[tokio::test]
222    async fn task_spawner_initial_state_is_pending() {
223        let config = TaskConfig::new("echo");
224        let spawner = TaskSpawner::new("pending_task".to_string(), config);
225        let state = spawner.get_state().await;
226        assert_eq!(state, TaskState::Pending, "Initial state should be Pending");
227    }
228
229    #[tokio::test]
230    async fn task_spawner_update_state_changes_state() {
231        let config = TaskConfig::new("echo");
232        let spawner = TaskSpawner::new("update_task".to_string(), config);
233        spawner.update_state(TaskState::Running).await;
234        let state = spawner.get_state().await;
235        assert_eq!(
236            state,
237            TaskState::Running,
238            "State should be Running after update"
239        );
240    }
241
242    #[tokio::test]
243    async fn task_spawner_uptime_increases_over_time() {
244        let config = TaskConfig::new("echo");
245        let spawner = TaskSpawner::new("uptime_task".to_string(), config);
246        let uptime1 = spawner.uptime();
247        sleep(Duration::from_millis(20)).await;
248        let uptime2 = spawner.uptime();
249        assert!(uptime2 > uptime1, "Uptime should increase after sleep");
250    }
251
252    #[tokio::test]
253    async fn task_spawner_get_task_info_returns_correct_info() {
254        let config = TaskConfig::new("echo");
255        let spawner = TaskSpawner::new("info_task".to_string(), config);
256        let info: TaskInfo = spawner.get_task_info().await;
257        assert_eq!(info.name, "info_task");
258        assert_eq!(info.state, TaskState::Pending);
259        assert!(info.uptime >= Duration::ZERO);
260    }
261}