smol_workflow_engine/environment/
local.rs1use super::{
4 path_to_environment_path, AgentExecutionEnvironment, EnvironmentPath, ExecEvent, ExecEventSink,
5 ExecOutput, ExecRequest, SpawnOutput,
6};
7use anyhow::{anyhow, Context};
8use std::path::PathBuf;
9use std::process::Stdio;
10use std::sync::{Arc, Mutex};
11use tempfile::TempDir;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
13use tokio::process::{Child, ChildStdin, Command};
14use tokio::sync::mpsc;
15use tokio::task::JoinHandle;
16
17#[derive(Debug, Clone)]
19pub struct LocalExecutionEnvironment {
20 cwd: Option<EnvironmentPath>,
21 state: Arc<LocalEnvironmentState>,
22}
23
24#[derive(Debug, Default)]
25struct LocalEnvironmentState {
26 temp_dirs: Mutex<Vec<TempDir>>,
27 spawned: Mutex<Vec<JoinHandle<()>>>,
28}
29
30impl Drop for LocalEnvironmentState {
31 fn drop(&mut self) {
32 if let Ok(mut tasks) = self.spawned.lock() {
33 for task in tasks.drain(..) {
34 task.abort();
35 }
36 }
37 }
40}
41
42impl LocalExecutionEnvironment {
43 pub fn new(cwd: Option<PathBuf>) -> anyhow::Result<Self> {
44 let cwd = cwd.map(path_to_environment_path).transpose()?;
45 Ok(Self {
46 cwd,
47 state: Arc::new(LocalEnvironmentState::default()),
48 })
49 }
50
51 pub fn with_cwd(cwd: impl Into<PathBuf>) -> anyhow::Result<Self> {
52 Self::new(Some(cwd.into()))
53 }
54
55 fn resolve_path(&self, path: &EnvironmentPath) -> PathBuf {
56 let path = PathBuf::from(path.as_str());
57 if path.is_absolute() {
58 path
59 } else if let Some(cwd) = &self.cwd {
60 PathBuf::from(cwd.as_str()).join(path)
61 } else {
62 path
63 }
64 }
65
66 fn request_cwd(&self, cwd: Option<&EnvironmentPath>) -> Option<PathBuf> {
67 cwd.map(|path| self.resolve_path(path))
68 .or_else(|| self.cwd.as_ref().map(|path| PathBuf::from(path.as_str())))
69 }
70}
71
72#[async_trait::async_trait]
73impl AgentExecutionEnvironment for LocalExecutionEnvironment {
74 fn cwd(&self) -> Option<&EnvironmentPath> {
75 self.cwd.as_ref()
76 }
77
78 async fn create_dir_all(&self, path: &EnvironmentPath) -> anyhow::Result<()> {
79 tokio::fs::create_dir_all(self.resolve_path(path))
80 .await
81 .with_context(|| format!("failed to create directory `{}`", path.as_str()))
82 }
83
84 async fn write_file(&self, path: &EnvironmentPath, content: &[u8]) -> anyhow::Result<()> {
85 tokio::fs::write(self.resolve_path(path), content)
86 .await
87 .with_context(|| format!("failed to write file `{}`", path.as_str()))
88 }
89
90 async fn read_file(&self, path: &EnvironmentPath) -> anyhow::Result<Vec<u8>> {
91 tokio::fs::read(self.resolve_path(path))
92 .await
93 .with_context(|| format!("failed to read file `{}`", path.as_str()))
94 }
95
96 async fn remove(&self, path: &EnvironmentPath) -> anyhow::Result<()> {
97 let resolved = self.resolve_path(path);
98 let metadata = match tokio::fs::symlink_metadata(&resolved).await {
99 Ok(metadata) => metadata,
100 Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(()),
101 Err(error) => {
102 return Err(error)
103 .with_context(|| format!("failed to inspect path `{}`", path.as_str()))
104 }
105 };
106
107 if metadata.is_dir() {
108 tokio::fs::remove_dir_all(&resolved)
109 .await
110 .with_context(|| format!("failed to remove directory `{}`", path.as_str()))
111 } else {
112 tokio::fs::remove_file(&resolved)
113 .await
114 .with_context(|| format!("failed to remove file `{}`", path.as_str()))
115 }
116 }
117
118 async fn create_temp_dir(&self, prefix: &str) -> anyhow::Result<EnvironmentPath> {
119 let temp_dir = tempfile::Builder::new()
120 .prefix(prefix)
121 .tempdir()
122 .with_context(|| format!("failed to create temp directory with prefix `{prefix}`"))?;
123 let path = path_to_environment_path(temp_dir.path())?;
124 self.state
125 .temp_dirs
126 .lock()
127 .map_err(|_| anyhow!("local environment temp-dir lock poisoned"))?
128 .push(temp_dir);
129 Ok(path)
130 }
131
132 async fn exec(
133 &self,
134 request: ExecRequest,
135 sink: &mut dyn ExecEventSink,
136 ) -> anyhow::Result<ExecOutput> {
137 let (command, args) = split_argv(&request.argv)?;
138 let mut command_builder = Command::new(command);
139 command_builder.args(args);
140 if let Some(cwd) = self.request_cwd(request.cwd.as_ref()) {
141 command_builder.current_dir(cwd);
142 }
143 command_builder.envs(&request.env);
144 command_builder
145 .stdout(Stdio::piped())
146 .stderr(Stdio::piped());
147 if request.stdin.is_some() {
148 command_builder.stdin(Stdio::piped());
149 } else {
150 command_builder.stdin(Stdio::null());
151 }
152
153 let mut child = command_builder
154 .kill_on_drop(true)
155 .spawn()
156 .with_context(|| format!("failed to spawn command `{}`", request.argv[0]))?;
157 let process_id = child.id().map(|id| id.to_string());
158 sink.event(ExecEvent::Started { process_id }).await?;
159
160 let stdin_task = spawn_stdin_writer(child.stdin.take(), request.stdin);
161 let stdout = child.stdout.take();
162 let stderr = child.stderr.take();
163 let (event_tx, mut event_rx) = mpsc::channel::<PipeEvent>(32);
164 spawn_pipe_reader(stdout, PipeKind::Stdout, event_tx.clone());
165 spawn_pipe_reader(stderr, PipeKind::Stderr, event_tx.clone());
166 drop(event_tx);
167
168 let wait = child.wait();
169 tokio::pin!(wait);
170 let mut stdout_acc = Vec::new();
171 let mut stderr_acc = Vec::new();
172 let mut exit_code = None;
173 let mut pipes_open = true;
174
175 while exit_code.is_none() || pipes_open {
176 tokio::select! {
177 status = &mut wait, if exit_code.is_none() => {
178 let status = status.context("failed to wait for command")?;
179 exit_code = Some(status.code().unwrap_or(-1));
180 }
181 event = event_rx.recv(), if pipes_open => {
182 match event {
183 Some(PipeEvent::Stdout(chunk)) => {
184 stdout_acc.extend_from_slice(&chunk);
185 sink.event(ExecEvent::Stdout { chunk }).await?;
186 }
187 Some(PipeEvent::Stderr(chunk)) => {
188 stderr_acc.extend_from_slice(&chunk);
189 sink.event(ExecEvent::Stderr { chunk }).await?;
190 }
191 None => pipes_open = false,
192 }
193 }
194 }
195 }
196
197 await_stdin_writer(stdin_task).await?;
198 let exit_code = exit_code.unwrap_or(-1);
199 sink.event(ExecEvent::Exited { exit_code }).await?;
200 Ok(ExecOutput {
201 exit_code,
202 stdout: stdout_acc,
203 stderr: stderr_acc,
204 })
205 }
206
207 async fn spawn(
208 &self,
209 request: ExecRequest,
210 sink: Option<Box<dyn ExecEventSink>>,
211 ) -> anyhow::Result<SpawnOutput> {
212 let (command, args) = split_argv(&request.argv)?;
213 let mut command_builder = Command::new(command);
214 command_builder.args(args);
215 if let Some(cwd) = self.request_cwd(request.cwd.as_ref()) {
216 command_builder.current_dir(cwd);
217 }
218 command_builder.envs(&request.env);
219 command_builder.kill_on_drop(true);
220 if request.stdin.is_some() {
221 command_builder.stdin(Stdio::piped());
222 } else {
223 command_builder.stdin(Stdio::null());
224 }
225 command_builder
226 .stdout(Stdio::piped())
227 .stderr(Stdio::piped());
228
229 let mut child = command_builder
230 .spawn()
231 .with_context(|| format!("failed to spawn command `{}`", request.argv[0]))?;
232 let stdin_task = spawn_stdin_writer(child.stdin.take(), request.stdin);
233 self.track_spawned_child(child, sink, stdin_task).await
234 }
235}
236
237impl LocalExecutionEnvironment {
238 async fn track_spawned_child(
239 &self,
240 mut child: Child,
241 mut sink: Option<Box<dyn ExecEventSink>>,
242 stdin_task: Option<JoinHandle<anyhow::Result<()>>>,
243 ) -> anyhow::Result<SpawnOutput> {
244 let process_id = child.id().map(|id| id.to_string());
245 if let Some(sink) = sink.as_mut() {
246 sink.event(ExecEvent::Started {
247 process_id: process_id.clone(),
248 })
249 .await?;
250 }
251
252 let stdout = child.stdout.take();
253 let stderr = child.stderr.take();
254 let task = tokio::spawn(async move {
255 let (event_tx, mut event_rx) = mpsc::channel::<PipeEvent>(32);
256 spawn_pipe_reader(stdout, PipeKind::Stdout, event_tx.clone());
257 spawn_pipe_reader(stderr, PipeKind::Stderr, event_tx.clone());
258 drop(event_tx);
259
260 let wait = child.wait();
261 tokio::pin!(wait);
262 let mut exit_code = None;
263 let mut pipes_open = true;
264
265 while exit_code.is_none() || pipes_open {
266 tokio::select! {
267 status = &mut wait, if exit_code.is_none() => {
268 exit_code = status.ok().map(|status| status.code().unwrap_or(-1));
269 }
270 event = event_rx.recv(), if pipes_open => {
271 match event {
272 Some(PipeEvent::Stdout(chunk)) => {
273 let failed = if let Some(sink_ref) = sink.as_mut() {
274 sink_ref.event(ExecEvent::Stdout { chunk }).await.is_err()
275 } else {
276 false
277 };
278 if failed {
279 sink = None;
280 }
281 }
282 Some(PipeEvent::Stderr(chunk)) => {
283 let failed = if let Some(sink_ref) = sink.as_mut() {
284 sink_ref.event(ExecEvent::Stderr { chunk }).await.is_err()
285 } else {
286 false
287 };
288 if failed {
289 sink = None;
290 }
291 }
292 None => pipes_open = false,
293 }
294 }
295 }
296 }
297
298 let _ = await_stdin_writer(stdin_task).await;
299 if let Some(sink) = sink.as_mut() {
300 let _ = sink
301 .event(ExecEvent::Exited {
302 exit_code: exit_code.unwrap_or(-1),
303 })
304 .await;
305 }
306 });
307
308 self.state
309 .spawned
310 .lock()
311 .map_err(|_| anyhow!("local environment spawned-process lock poisoned"))?
312 .push(task);
313 Ok(SpawnOutput { process_id })
314 }
315}
316
317#[derive(Debug)]
318enum PipeEvent {
319 Stdout(Vec<u8>),
320 Stderr(Vec<u8>),
321}
322
323#[derive(Debug, Clone, Copy)]
324enum PipeKind {
325 Stdout,
326 Stderr,
327}
328
329fn spawn_stdin_writer(
330 child_stdin: Option<ChildStdin>,
331 stdin: Option<Vec<u8>>,
332) -> Option<JoinHandle<anyhow::Result<()>>> {
333 match (child_stdin, stdin) {
334 (Some(mut child_stdin), Some(stdin)) => Some(tokio::spawn(async move {
335 child_stdin
336 .write_all(&stdin)
337 .await
338 .context("failed to write command stdin")?;
339 Ok(())
340 })),
341 (None, Some(_)) => Some(tokio::spawn(async {
342 Err(anyhow!("failed to open command stdin"))
343 })),
344 _ => None,
345 }
346}
347
348async fn await_stdin_writer(task: Option<JoinHandle<anyhow::Result<()>>>) -> anyhow::Result<()> {
349 if let Some(task) = task {
350 task.await
351 .context("stdin writer task failed to complete")??;
352 }
353 Ok(())
354}
355
356fn spawn_pipe_reader<R>(reader: Option<R>, kind: PipeKind, event_tx: mpsc::Sender<PipeEvent>)
357where
358 R: AsyncRead + Unpin + Send + 'static,
359{
360 if let Some(reader) = reader {
361 tokio::spawn(async move {
362 read_pipe(reader, kind, event_tx).await;
363 });
364 }
365}
366
367async fn read_pipe<R>(mut reader: R, kind: PipeKind, event_tx: mpsc::Sender<PipeEvent>)
368where
369 R: AsyncRead + Unpin,
370{
371 let mut buffer = vec![0u8; 8192];
372 loop {
373 match reader.read(&mut buffer).await {
374 Ok(0) | Err(_) => break,
375 Ok(n) => {
376 let chunk = buffer[..n].to_vec();
377 let event = match kind {
378 PipeKind::Stdout => PipeEvent::Stdout(chunk),
379 PipeKind::Stderr => PipeEvent::Stderr(chunk),
380 };
381 if event_tx.send(event).await.is_err() {
382 break;
383 }
384 }
385 }
386 }
387}
388
389fn split_argv(argv: &[String]) -> anyhow::Result<(&str, &[String])> {
390 let Some((command, args)) = argv.split_first() else {
391 return Err(anyhow!("ExecRequest.argv must not be empty"));
392 };
393 Ok((command.as_str(), args))
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use std::sync::Mutex as StdMutex;
400 use tokio::time::{sleep, timeout, Duration, Instant};
401
402 #[derive(Default)]
403 struct RecordingSink {
404 events: Vec<ExecEvent>,
405 }
406
407 #[async_trait::async_trait]
408 impl ExecEventSink for RecordingSink {
409 async fn event(&mut self, event: ExecEvent) -> anyhow::Result<()> {
410 self.events.push(event);
411 Ok(())
412 }
413 }
414
415 #[derive(Clone, Default)]
416 struct SharedRecordingSink {
417 events: Arc<StdMutex<Vec<ExecEvent>>>,
418 }
419
420 #[async_trait::async_trait]
421 impl ExecEventSink for SharedRecordingSink {
422 async fn event(&mut self, event: ExecEvent) -> anyhow::Result<()> {
423 self.events.lock().unwrap().push(event);
424 Ok(())
425 }
426 }
427
428 #[tokio::test]
429 async fn local_environment_file_io_uses_relative_cwd() {
430 let temp = tempfile::tempdir().unwrap();
431 let env = LocalExecutionEnvironment::with_cwd(temp.path()).unwrap();
432 let path = EnvironmentPath::from("nested/data.bin");
433
434 env.create_dir_all(&EnvironmentPath::from("nested"))
435 .await
436 .unwrap();
437 env.write_file(&path, b"hello\0world").await.unwrap();
438 assert_eq!(env.read_file(&path).await.unwrap(), b"hello\0world");
439 env.remove(&EnvironmentPath::from("nested")).await.unwrap();
440 assert!(!temp.path().join("nested").exists());
441 env.remove(&EnvironmentPath::from("missing")).await.unwrap();
442 }
443
444 #[tokio::test]
445 async fn local_environment_exec_streams_and_accumulates_bytes() {
446 let env = LocalExecutionEnvironment::new(None).unwrap();
447 let mut sink = RecordingSink::default();
448 let output = env
449 .exec(
450 ExecRequest {
451 argv: vec!["sh".into(), "-c".into(), "cat; printf err >&2".into()],
452 stdin: Some(b"out".to_vec()),
453 ..Default::default()
454 },
455 &mut sink,
456 )
457 .await
458 .unwrap();
459
460 assert_eq!(output.exit_code, 0);
461 assert_eq!(output.stdout, b"out");
462 assert_eq!(output.stderr, b"err");
463 assert!(sink
464 .events
465 .iter()
466 .any(|event| matches!(event, ExecEvent::Started { .. })));
467 assert!(sink
468 .events
469 .iter()
470 .any(|event| matches!(event, ExecEvent::Stdout { chunk } if chunk == b"out")));
471 assert!(sink
472 .events
473 .iter()
474 .any(|event| matches!(event, ExecEvent::Stderr { chunk } if chunk == b"err")));
475 assert!(sink
476 .events
477 .iter()
478 .any(|event| matches!(event, ExecEvent::Exited { exit_code: 0 })));
479 }
480
481 #[tokio::test]
482 async fn local_environment_exec_reads_output_while_writing_large_stdin() {
483 let env = LocalExecutionEnvironment::new(None).unwrap();
484 let mut sink = RecordingSink::default();
485 let stdin = vec![b'x'; 2 * 1024 * 1024];
486 let output = timeout(
487 Duration::from_secs(5),
488 env.exec(
489 ExecRequest {
490 argv: vec![
491 "sh".into(),
492 "-c".into(),
493 "printf ready; cat >/dev/null".into(),
494 ],
495 stdin: Some(stdin),
496 ..Default::default()
497 },
498 &mut sink,
499 ),
500 )
501 .await
502 .expect("exec should not deadlock")
503 .unwrap();
504
505 assert_eq!(output.exit_code, 0);
506 assert_eq!(output.stdout, b"ready");
507 }
508
509 #[tokio::test]
510 async fn local_environment_spawn_reaps_and_emits_exit() {
511 let env = LocalExecutionEnvironment::new(None).unwrap();
512 let sink = SharedRecordingSink::default();
513 let events = Arc::clone(&sink.events);
514 let output = env
515 .spawn(
516 ExecRequest {
517 argv: vec![
518 "sh".into(),
519 "-c".into(),
520 "printf spawned; printf err >&2".into(),
521 ],
522 ..Default::default()
523 },
524 Some(Box::new(sink)),
525 )
526 .await
527 .unwrap();
528
529 assert!(output.process_id.is_some());
530 let started = Instant::now();
531 loop {
532 let snapshot = events.lock().unwrap().clone();
533 if snapshot
534 .iter()
535 .any(|event| matches!(event, ExecEvent::Exited { exit_code: 0 }))
536 {
537 assert!(snapshot
538 .iter()
539 .any(|event| matches!(event, ExecEvent::Started { .. })));
540 assert!(snapshot.iter().any(
541 |event| matches!(event, ExecEvent::Stdout { chunk } if chunk == b"spawned")
542 ));
543 assert!(snapshot
544 .iter()
545 .any(|event| matches!(event, ExecEvent::Stderr { chunk } if chunk == b"err")));
546 break;
547 }
548 assert!(started.elapsed() < Duration::from_secs(5));
549 sleep(Duration::from_millis(10)).await;
550 }
551 }
552
553 #[tokio::test]
554 async fn local_environment_create_temp_dir_cleans_up_on_drop() {
555 let path = {
556 let env = LocalExecutionEnvironment::new(None).unwrap();
557 let path = env.create_temp_dir("smol-wf-test-").await.unwrap();
558 let pathbuf = PathBuf::from(path.as_str());
559 assert!(pathbuf.exists());
560 pathbuf
561 };
562 assert!(!path.exists());
563 }
564}