tcrm_task/tasks/async_tokio/direct/
start.rs1use tokio::process::Command;
2use tokio::sync::{mpsc, oneshot, watch};
3
4use crate::tasks::async_tokio::direct::command::setup_command;
5use crate::tasks::async_tokio::direct::watchers::input::spawn_stdin_watcher;
6use crate::tasks::async_tokio::direct::watchers::output::spawn_output_watchers;
7use crate::tasks::async_tokio::direct::watchers::result::spawn_result_watcher;
8use crate::tasks::async_tokio::direct::watchers::timeout::spawn_timeout_watcher;
9use crate::tasks::async_tokio::direct::watchers::wait::spawn_wait_watcher;
10use crate::tasks::async_tokio::spawner::TaskSpawner;
11use crate::tasks::error::TaskError;
12use crate::tasks::event::{TaskEvent, TaskEventStopReason};
13use crate::tasks::state::{TaskState, TaskTerminateReason};
14
15impl TaskSpawner {
16 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, event_tx), fields(task_name = %self.task_name)))]
17 pub async fn start_direct(
18 &mut self,
19 event_tx: mpsc::Sender<TaskEvent>,
20 ) -> Result<u32, TaskError> {
21 self.update_state(TaskState::Initiating).await;
22
23 self.config.validate()?;
24
25 let mut cmd = Command::new(&self.config.command);
26 let mut cmd = cmd.kill_on_drop(true);
27
28 setup_command(&mut cmd, &self.config);
29 let mut child = cmd.spawn()?;
30 let child_id = match child.id() {
31 Some(id) => id,
32 None => {
33 #[cfg(feature = "tracing")]
34 tracing::error!("Failed to get process id");
35 return Err(TaskError::IO(std::io::Error::new(
36 std::io::ErrorKind::Other,
37 "Failed to get process id",
38 )));
39 }
40 };
41 self.process_id = Some(child_id);
42 let mut task_handles = vec![];
43 self.update_state(TaskState::Running).await;
44 if let Err(_) = event_tx
45 .send(TaskEvent::Started {
46 task_name: self.task_name.clone(),
47 })
48 .await
49 {
50 #[cfg(feature = "tracing")]
51 tracing::warn!("Event channel closed while sending TaskEvent::Started");
52 }
53
54 let (result_tx, result_rx) = oneshot::channel::<(Option<i32>, TaskEventStopReason)>();
55 let (terminate_tx, terminate_rx) = oneshot::channel::<TaskTerminateReason>();
56 let (handle_terminator_tx, handle_terminator_rx) = watch::channel(false);
57
58 let handles = spawn_output_watchers(
60 self.task_name.clone(),
61 event_tx.clone(),
62 &mut child,
63 handle_terminator_rx.clone(),
64 );
65 task_handles.extend(handles);
66
67 if let Some((stdin, stdin_rx)) = child.stdin.take().zip(self.stdin_rx.take()) {
69 let handle = spawn_stdin_watcher(stdin, stdin_rx, handle_terminator_rx.clone());
70 task_handles.push(handle);
71 }
72
73 *self.terminate_tx.lock().await = Some(terminate_tx);
75
76 let handle = spawn_wait_watcher(
77 self.task_name.clone(),
78 self.state.clone(),
79 child,
80 terminate_rx,
81 handle_terminator_tx.clone(),
82 result_tx,
83 );
84 task_handles.push(handle);
85
86 if let Some(timeout_ms) = self.config.timeout_ms {
88 let handle =
89 spawn_timeout_watcher(self.terminate_tx.clone(), timeout_ms, handle_terminator_rx);
90 task_handles.push(handle);
91 }
92
93 let _handle = spawn_result_watcher(
95 self.task_name.clone(),
96 self.state.clone(),
97 self.finished_at.clone(),
98 event_tx,
99 result_rx,
100 task_handles,
101 );
102
103 Ok(child_id)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use tokio::sync::mpsc;
110
111 use crate::tasks::{
112 async_tokio::spawner::TaskSpawner,
113 config::{StreamSource, TaskConfig},
114 error::TaskError,
115 event::{TaskEvent, TaskEventStopReason},
116 state::TaskTerminateReason,
117 };
118 #[tokio::test]
119 async fn start_direct_fn_echo_command() {
120 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
121 #[cfg(windows)]
122 let config = TaskConfig::new("powershell").args(["-Command", "echo hello"]);
123 #[cfg(unix)]
124 let config = TaskConfig::new("bash").args(["-c", "echo hello"]);
125
126 let mut spawner = TaskSpawner::new("echo_task".to_string(), config);
127
128 let result = spawner.start_direct(tx).await;
129 assert!(result.is_ok());
130
131 let mut started = false;
132 let mut stopped = false;
133 while let Some(event) = rx.recv().await {
134 match event {
135 TaskEvent::Started { task_name } => {
136 assert_eq!(task_name, "echo_task");
137 started = true;
138 }
139 TaskEvent::Output {
140 task_name,
141 line,
142 src,
143 } => {
144 assert_eq!(task_name, "echo_task");
145 assert_eq!(line, "hello");
146 assert_eq!(src, StreamSource::Stdout);
147 }
148 TaskEvent::Stopped {
149 task_name,
150 exit_code,
151 reason: _,
152 } => {
153 assert_eq!(task_name, "echo_task");
154 assert_eq!(exit_code, Some(0));
155 stopped = true;
156 }
157 _ => {}
158 }
159 }
160
161 assert!(started);
162 assert!(stopped);
163 }
164 #[tokio::test]
165 async fn start_direct_timeout_terminated_task() {
166 #[cfg(windows)]
167 let config = TaskConfig::new("powershell")
168 .args(["-Command", "sleep 2"])
169 .timeout_ms(1);
170 #[cfg(unix)]
171 let config = TaskConfig::new("bash")
172 .args(["-c", "sleep 2"])
173 .timeout_ms(1);
174
175 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
176 let mut spawner = TaskSpawner::new("sleep_with_timeout_task".into(), config);
177
178 let result = spawner.start_direct(tx).await;
179 assert!(result.is_ok());
180
181 let mut started = false;
182 let mut stopped = false;
183 while let Some(event) = rx.recv().await {
184 match event {
185 TaskEvent::Started { task_name } => {
186 assert_eq!(task_name, "sleep_with_timeout_task");
187 started = true;
188 }
189
190 TaskEvent::Stopped {
191 task_name,
192 exit_code,
193 reason,
194 } => {
195 assert_eq!(task_name, "sleep_with_timeout_task");
196 assert_eq!(exit_code, None);
197 assert_eq!(
198 reason,
199 TaskEventStopReason::Terminated(TaskTerminateReason::Timeout)
200 );
201 stopped = true;
202 }
203 _ => {}
204 }
205 }
206
207 assert!(started);
208 assert!(stopped);
209 }
210
211 #[tokio::test]
212 async fn start_direct_fn_invalid_empty_command() {
213 let (tx, _rx) = mpsc::channel::<TaskEvent>(1024);
214 let config = TaskConfig::new(""); let mut spawner = TaskSpawner::new("bad_task".to_string(), config);
216
217 let result = spawner.start_direct(tx).await;
218 assert!(matches!(result, Err(TaskError::InvalidConfiguration(_))));
219 }
220
221 #[tokio::test]
222 async fn start_direct_fn_stdin_valid() {
223 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
225 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
226
227 #[cfg(windows)]
228 let config = TaskConfig::new("powershell")
229 .args(["-Command", "$line = Read-Host; Write-Output $line"])
230 .enable_stdin(true);
231 #[cfg(unix)]
232 let config = TaskConfig::new("bash")
233 .args(["-c", "read line; echo $line"])
234 .enable_stdin(true);
235
236 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
237
238 let result = spawner.start_direct(tx).await;
240 assert!(result.is_ok());
241
242 stdin_tx.send("hello world".to_string()).await.unwrap();
244
245 let mut started = false;
246 let mut output_ok = false;
247 let mut stopped = false;
248
249 while let Some(event) = rx.recv().await {
250 match event {
251 TaskEvent::Started { task_name } => {
252 assert_eq!(task_name, "stdin_task");
253 started = true;
254 }
255 TaskEvent::Output {
256 task_name,
257 line,
258 src,
259 } => {
260 assert_eq!(task_name, "stdin_task");
261 assert_eq!(line, "hello world");
262 assert_eq!(src, StreamSource::Stdout);
263 output_ok = true;
264 }
265 TaskEvent::Stopped {
266 task_name,
267 exit_code,
268 ..
269 } => {
270 assert_eq!(task_name, "stdin_task");
271 assert_eq!(exit_code, Some(0));
272 stopped = true;
273 }
274 _ => {}
275 }
276 }
277
278 assert!(started);
279 assert!(output_ok);
280 assert!(stopped);
281 }
282
283 #[tokio::test]
284 async fn start_direct_fn_stdin_ignore() {
285 let (tx, mut rx) = mpsc::channel::<TaskEvent>(1024);
287 let (stdin_tx, stdin_rx) = mpsc::channel::<String>(1024);
288
289 #[cfg(windows)]
290 let config = TaskConfig::new("powershell")
291 .args(["-Command", "$line = Read-Host; Write-Output $line"]);
292 #[cfg(unix)]
293 let config = TaskConfig::new("bash").args(["-c", "read line; echo $line"]);
294
295 let mut spawner = TaskSpawner::new("stdin_task".to_string(), config).set_stdin(stdin_rx);
297
298 let result = spawner.start_direct(tx).await;
300 assert!(result.is_ok());
301
302 let send_result = stdin_tx.send("hello world".to_string()).await;
304 assert!(
305 send_result.is_err(),
306 "Sending to stdin_tx should error because receiver is dropped"
307 );
308
309 let mut started = false;
310 let mut output_found = false;
311 let mut stopped = false;
312
313 while let Some(event) = rx.recv().await {
314 match event {
315 TaskEvent::Started { task_name } => {
316 assert_eq!(task_name, "stdin_task");
317 started = true;
318 }
319 TaskEvent::Output { .. } => {
320 output_found = true;
322 }
323 TaskEvent::Stopped {
324 task_name,
325 exit_code,
326 ..
327 } => {
328 assert_eq!(task_name, "stdin_task");
329 assert_eq!(exit_code, Some(0));
330 stopped = true;
331 }
332 _ => {}
333 }
334 }
335
336 assert!(started);
337 assert!(
338 !output_found,
339 "Should not receive output from stdin when not enabled"
340 );
341 assert!(stopped);
342 }
343}