Skip to main content

smol_workflow_engine/environment/
local.rs

1//! Local in-process execution environment implementation.
2
3use super::{
4    path_to_environment_path, AgentExecutionEnvironment, EnvironmentPath, ExecEvent, ExecEventSink,
5    ExecOutput, ExecRequest, SpawnOutput,
6};
7use anyhow::{anyhow, Context};
8use std::path::PathBuf;
9use std::process::Stdio;
10use std::sync::{Arc, Mutex};
11use tempfile::TempDir;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
13use tokio::process::{Child, ChildStdin, Command};
14use tokio::sync::mpsc;
15use tokio::task::JoinHandle;
16
17/// Local, in-process environment implementation.
18#[derive(Debug, Clone)]
19pub struct LocalExecutionEnvironment {
20    cwd: Option<EnvironmentPath>,
21    state: Arc<LocalEnvironmentState>,
22}
23
24#[derive(Debug, Default)]
25struct LocalEnvironmentState {
26    temp_dirs: Mutex<Vec<TempDir>>,
27    spawned: Mutex<Vec<JoinHandle<()>>>,
28}
29
30impl Drop for LocalEnvironmentState {
31    fn drop(&mut self) {
32        if let Ok(mut tasks) = self.spawned.lock() {
33            for task in tasks.drain(..) {
34                task.abort();
35            }
36        }
37        // TempDir cleanup happens through Drop after this state is dropped.
38        // Aborting spawned-process tasks drops their kill_on_drop children.
39    }
40}
41
42impl LocalExecutionEnvironment {
43    pub fn new(cwd: Option<PathBuf>) -> anyhow::Result<Self> {
44        let cwd = cwd.map(path_to_environment_path).transpose()?;
45        Ok(Self {
46            cwd,
47            state: Arc::new(LocalEnvironmentState::default()),
48        })
49    }
50
51    pub fn with_cwd(cwd: impl Into<PathBuf>) -> anyhow::Result<Self> {
52        Self::new(Some(cwd.into()))
53    }
54
55    fn resolve_path(&self, path: &EnvironmentPath) -> PathBuf {
56        let path = PathBuf::from(path.as_str());
57        if path.is_absolute() {
58            path
59        } else if let Some(cwd) = &self.cwd {
60            PathBuf::from(cwd.as_str()).join(path)
61        } else {
62            path
63        }
64    }
65
66    fn request_cwd(&self, cwd: Option<&EnvironmentPath>) -> Option<PathBuf> {
67        cwd.map(|path| self.resolve_path(path))
68            .or_else(|| self.cwd.as_ref().map(|path| PathBuf::from(path.as_str())))
69    }
70}
71
72#[async_trait::async_trait]
73impl AgentExecutionEnvironment for LocalExecutionEnvironment {
74    fn cwd(&self) -> Option<&EnvironmentPath> {
75        self.cwd.as_ref()
76    }
77
78    async fn create_dir_all(&self, path: &EnvironmentPath) -> anyhow::Result<()> {
79        tokio::fs::create_dir_all(self.resolve_path(path))
80            .await
81            .with_context(|| format!("failed to create directory `{}`", path.as_str()))
82    }
83
84    async fn write_file(&self, path: &EnvironmentPath, content: &[u8]) -> anyhow::Result<()> {
85        tokio::fs::write(self.resolve_path(path), content)
86            .await
87            .with_context(|| format!("failed to write file `{}`", path.as_str()))
88    }
89
90    async fn read_file(&self, path: &EnvironmentPath) -> anyhow::Result<Vec<u8>> {
91        tokio::fs::read(self.resolve_path(path))
92            .await
93            .with_context(|| format!("failed to read file `{}`", path.as_str()))
94    }
95
96    async fn remove(&self, path: &EnvironmentPath) -> anyhow::Result<()> {
97        let resolved = self.resolve_path(path);
98        let metadata = match tokio::fs::symlink_metadata(&resolved).await {
99            Ok(metadata) => metadata,
100            Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(()),
101            Err(error) => {
102                return Err(error)
103                    .with_context(|| format!("failed to inspect path `{}`", path.as_str()))
104            }
105        };
106
107        if metadata.is_dir() {
108            tokio::fs::remove_dir_all(&resolved)
109                .await
110                .with_context(|| format!("failed to remove directory `{}`", path.as_str()))
111        } else {
112            tokio::fs::remove_file(&resolved)
113                .await
114                .with_context(|| format!("failed to remove file `{}`", path.as_str()))
115        }
116    }
117
118    async fn create_temp_dir(&self, prefix: &str) -> anyhow::Result<EnvironmentPath> {
119        let temp_dir = tempfile::Builder::new()
120            .prefix(prefix)
121            .tempdir()
122            .with_context(|| format!("failed to create temp directory with prefix `{prefix}`"))?;
123        let path = path_to_environment_path(temp_dir.path())?;
124        self.state
125            .temp_dirs
126            .lock()
127            .map_err(|_| anyhow!("local environment temp-dir lock poisoned"))?
128            .push(temp_dir);
129        Ok(path)
130    }
131
132    async fn exec(
133        &self,
134        request: ExecRequest,
135        sink: &mut dyn ExecEventSink,
136    ) -> anyhow::Result<ExecOutput> {
137        let (command, args) = split_argv(&request.argv)?;
138        let mut command_builder = Command::new(command);
139        command_builder.args(args);
140        if let Some(cwd) = self.request_cwd(request.cwd.as_ref()) {
141            command_builder.current_dir(cwd);
142        }
143        command_builder.envs(&request.env);
144        command_builder
145            .stdout(Stdio::piped())
146            .stderr(Stdio::piped());
147        if request.stdin.is_some() {
148            command_builder.stdin(Stdio::piped());
149        } else {
150            command_builder.stdin(Stdio::null());
151        }
152
153        let mut child = command_builder
154            .kill_on_drop(true)
155            .spawn()
156            .with_context(|| format!("failed to spawn command `{}`", request.argv[0]))?;
157        let process_id = child.id().map(|id| id.to_string());
158        sink.event(ExecEvent::Started { process_id }).await?;
159
160        let stdin_task = spawn_stdin_writer(child.stdin.take(), request.stdin);
161        let stdout = child.stdout.take();
162        let stderr = child.stderr.take();
163        let (event_tx, mut event_rx) = mpsc::channel::<PipeEvent>(32);
164        spawn_pipe_reader(stdout, PipeKind::Stdout, event_tx.clone());
165        spawn_pipe_reader(stderr, PipeKind::Stderr, event_tx.clone());
166        drop(event_tx);
167
168        let wait = child.wait();
169        tokio::pin!(wait);
170        let mut stdout_acc = Vec::new();
171        let mut stderr_acc = Vec::new();
172        let mut exit_code = None;
173        let mut pipes_open = true;
174
175        while exit_code.is_none() || pipes_open {
176            tokio::select! {
177                status = &mut wait, if exit_code.is_none() => {
178                    let status = status.context("failed to wait for command")?;
179                    exit_code = Some(status.code().unwrap_or(-1));
180                }
181                event = event_rx.recv(), if pipes_open => {
182                    match event {
183                        Some(PipeEvent::Stdout(chunk)) => {
184                            stdout_acc.extend_from_slice(&chunk);
185                            sink.event(ExecEvent::Stdout { chunk }).await?;
186                        }
187                        Some(PipeEvent::Stderr(chunk)) => {
188                            stderr_acc.extend_from_slice(&chunk);
189                            sink.event(ExecEvent::Stderr { chunk }).await?;
190                        }
191                        None => pipes_open = false,
192                    }
193                }
194            }
195        }
196
197        await_stdin_writer(stdin_task).await?;
198        let exit_code = exit_code.unwrap_or(-1);
199        sink.event(ExecEvent::Exited { exit_code }).await?;
200        Ok(ExecOutput {
201            exit_code,
202            stdout: stdout_acc,
203            stderr: stderr_acc,
204        })
205    }
206
207    async fn spawn(
208        &self,
209        request: ExecRequest,
210        sink: Option<Box<dyn ExecEventSink>>,
211    ) -> anyhow::Result<SpawnOutput> {
212        let (command, args) = split_argv(&request.argv)?;
213        let mut command_builder = Command::new(command);
214        command_builder.args(args);
215        if let Some(cwd) = self.request_cwd(request.cwd.as_ref()) {
216            command_builder.current_dir(cwd);
217        }
218        command_builder.envs(&request.env);
219        command_builder.kill_on_drop(true);
220        if request.stdin.is_some() {
221            command_builder.stdin(Stdio::piped());
222        } else {
223            command_builder.stdin(Stdio::null());
224        }
225        command_builder
226            .stdout(Stdio::piped())
227            .stderr(Stdio::piped());
228
229        let mut child = command_builder
230            .spawn()
231            .with_context(|| format!("failed to spawn command `{}`", request.argv[0]))?;
232        let stdin_task = spawn_stdin_writer(child.stdin.take(), request.stdin);
233        self.track_spawned_child(child, sink, stdin_task).await
234    }
235}
236
237impl LocalExecutionEnvironment {
238    async fn track_spawned_child(
239        &self,
240        mut child: Child,
241        mut sink: Option<Box<dyn ExecEventSink>>,
242        stdin_task: Option<JoinHandle<anyhow::Result<()>>>,
243    ) -> anyhow::Result<SpawnOutput> {
244        let process_id = child.id().map(|id| id.to_string());
245        if let Some(sink) = sink.as_mut() {
246            sink.event(ExecEvent::Started {
247                process_id: process_id.clone(),
248            })
249            .await?;
250        }
251
252        let stdout = child.stdout.take();
253        let stderr = child.stderr.take();
254        let task = tokio::spawn(async move {
255            let (event_tx, mut event_rx) = mpsc::channel::<PipeEvent>(32);
256            spawn_pipe_reader(stdout, PipeKind::Stdout, event_tx.clone());
257            spawn_pipe_reader(stderr, PipeKind::Stderr, event_tx.clone());
258            drop(event_tx);
259
260            let wait = child.wait();
261            tokio::pin!(wait);
262            let mut exit_code = None;
263            let mut pipes_open = true;
264
265            while exit_code.is_none() || pipes_open {
266                tokio::select! {
267                    status = &mut wait, if exit_code.is_none() => {
268                        exit_code = status.ok().map(|status| status.code().unwrap_or(-1));
269                    }
270                    event = event_rx.recv(), if pipes_open => {
271                        match event {
272                            Some(PipeEvent::Stdout(chunk)) => {
273                                let failed = if let Some(sink_ref) = sink.as_mut() {
274                                    sink_ref.event(ExecEvent::Stdout { chunk }).await.is_err()
275                                } else {
276                                    false
277                                };
278                                if failed {
279                                    sink = None;
280                                }
281                            }
282                            Some(PipeEvent::Stderr(chunk)) => {
283                                let failed = if let Some(sink_ref) = sink.as_mut() {
284                                    sink_ref.event(ExecEvent::Stderr { chunk }).await.is_err()
285                                } else {
286                                    false
287                                };
288                                if failed {
289                                    sink = None;
290                                }
291                            }
292                            None => pipes_open = false,
293                        }
294                    }
295                }
296            }
297
298            let _ = await_stdin_writer(stdin_task).await;
299            if let Some(sink) = sink.as_mut() {
300                let _ = sink
301                    .event(ExecEvent::Exited {
302                        exit_code: exit_code.unwrap_or(-1),
303                    })
304                    .await;
305            }
306        });
307
308        self.state
309            .spawned
310            .lock()
311            .map_err(|_| anyhow!("local environment spawned-process lock poisoned"))?
312            .push(task);
313        Ok(SpawnOutput { process_id })
314    }
315}
316
317#[derive(Debug)]
318enum PipeEvent {
319    Stdout(Vec<u8>),
320    Stderr(Vec<u8>),
321}
322
323#[derive(Debug, Clone, Copy)]
324enum PipeKind {
325    Stdout,
326    Stderr,
327}
328
329fn spawn_stdin_writer(
330    child_stdin: Option<ChildStdin>,
331    stdin: Option<Vec<u8>>,
332) -> Option<JoinHandle<anyhow::Result<()>>> {
333    match (child_stdin, stdin) {
334        (Some(mut child_stdin), Some(stdin)) => Some(tokio::spawn(async move {
335            child_stdin
336                .write_all(&stdin)
337                .await
338                .context("failed to write command stdin")?;
339            Ok(())
340        })),
341        (None, Some(_)) => Some(tokio::spawn(async {
342            Err(anyhow!("failed to open command stdin"))
343        })),
344        _ => None,
345    }
346}
347
348async fn await_stdin_writer(task: Option<JoinHandle<anyhow::Result<()>>>) -> anyhow::Result<()> {
349    if let Some(task) = task {
350        task.await
351            .context("stdin writer task failed to complete")??;
352    }
353    Ok(())
354}
355
356fn spawn_pipe_reader<R>(reader: Option<R>, kind: PipeKind, event_tx: mpsc::Sender<PipeEvent>)
357where
358    R: AsyncRead + Unpin + Send + 'static,
359{
360    if let Some(reader) = reader {
361        tokio::spawn(async move {
362            read_pipe(reader, kind, event_tx).await;
363        });
364    }
365}
366
367async fn read_pipe<R>(mut reader: R, kind: PipeKind, event_tx: mpsc::Sender<PipeEvent>)
368where
369    R: AsyncRead + Unpin,
370{
371    let mut buffer = vec![0u8; 8192];
372    loop {
373        match reader.read(&mut buffer).await {
374            Ok(0) | Err(_) => break,
375            Ok(n) => {
376                let chunk = buffer[..n].to_vec();
377                let event = match kind {
378                    PipeKind::Stdout => PipeEvent::Stdout(chunk),
379                    PipeKind::Stderr => PipeEvent::Stderr(chunk),
380                };
381                if event_tx.send(event).await.is_err() {
382                    break;
383                }
384            }
385        }
386    }
387}
388
389fn split_argv(argv: &[String]) -> anyhow::Result<(&str, &[String])> {
390    let Some((command, args)) = argv.split_first() else {
391        return Err(anyhow!("ExecRequest.argv must not be empty"));
392    };
393    Ok((command.as_str(), args))
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use std::sync::Mutex as StdMutex;
400    use tokio::time::{sleep, timeout, Duration, Instant};
401
402    #[derive(Default)]
403    struct RecordingSink {
404        events: Vec<ExecEvent>,
405    }
406
407    #[async_trait::async_trait]
408    impl ExecEventSink for RecordingSink {
409        async fn event(&mut self, event: ExecEvent) -> anyhow::Result<()> {
410            self.events.push(event);
411            Ok(())
412        }
413    }
414
415    #[derive(Clone, Default)]
416    struct SharedRecordingSink {
417        events: Arc<StdMutex<Vec<ExecEvent>>>,
418    }
419
420    #[async_trait::async_trait]
421    impl ExecEventSink for SharedRecordingSink {
422        async fn event(&mut self, event: ExecEvent) -> anyhow::Result<()> {
423            self.events.lock().unwrap().push(event);
424            Ok(())
425        }
426    }
427
428    #[tokio::test]
429    async fn local_environment_file_io_uses_relative_cwd() {
430        let temp = tempfile::tempdir().unwrap();
431        let env = LocalExecutionEnvironment::with_cwd(temp.path()).unwrap();
432        let path = EnvironmentPath::from("nested/data.bin");
433
434        env.create_dir_all(&EnvironmentPath::from("nested"))
435            .await
436            .unwrap();
437        env.write_file(&path, b"hello\0world").await.unwrap();
438        assert_eq!(env.read_file(&path).await.unwrap(), b"hello\0world");
439        env.remove(&EnvironmentPath::from("nested")).await.unwrap();
440        assert!(!temp.path().join("nested").exists());
441        env.remove(&EnvironmentPath::from("missing")).await.unwrap();
442    }
443
444    #[tokio::test]
445    async fn local_environment_exec_streams_and_accumulates_bytes() {
446        let env = LocalExecutionEnvironment::new(None).unwrap();
447        let mut sink = RecordingSink::default();
448        let output = env
449            .exec(
450                ExecRequest {
451                    argv: vec!["sh".into(), "-c".into(), "cat; printf err >&2".into()],
452                    stdin: Some(b"out".to_vec()),
453                    ..Default::default()
454                },
455                &mut sink,
456            )
457            .await
458            .unwrap();
459
460        assert_eq!(output.exit_code, 0);
461        assert_eq!(output.stdout, b"out");
462        assert_eq!(output.stderr, b"err");
463        assert!(sink
464            .events
465            .iter()
466            .any(|event| matches!(event, ExecEvent::Started { .. })));
467        assert!(sink
468            .events
469            .iter()
470            .any(|event| matches!(event, ExecEvent::Stdout { chunk } if chunk == b"out")));
471        assert!(sink
472            .events
473            .iter()
474            .any(|event| matches!(event, ExecEvent::Stderr { chunk } if chunk == b"err")));
475        assert!(sink
476            .events
477            .iter()
478            .any(|event| matches!(event, ExecEvent::Exited { exit_code: 0 })));
479    }
480
481    #[tokio::test]
482    async fn local_environment_exec_reads_output_while_writing_large_stdin() {
483        let env = LocalExecutionEnvironment::new(None).unwrap();
484        let mut sink = RecordingSink::default();
485        let stdin = vec![b'x'; 2 * 1024 * 1024];
486        let output = timeout(
487            Duration::from_secs(5),
488            env.exec(
489                ExecRequest {
490                    argv: vec![
491                        "sh".into(),
492                        "-c".into(),
493                        "printf ready; cat >/dev/null".into(),
494                    ],
495                    stdin: Some(stdin),
496                    ..Default::default()
497                },
498                &mut sink,
499            ),
500        )
501        .await
502        .expect("exec should not deadlock")
503        .unwrap();
504
505        assert_eq!(output.exit_code, 0);
506        assert_eq!(output.stdout, b"ready");
507    }
508
509    #[tokio::test]
510    async fn local_environment_spawn_reaps_and_emits_exit() {
511        let env = LocalExecutionEnvironment::new(None).unwrap();
512        let sink = SharedRecordingSink::default();
513        let events = Arc::clone(&sink.events);
514        let output = env
515            .spawn(
516                ExecRequest {
517                    argv: vec![
518                        "sh".into(),
519                        "-c".into(),
520                        "printf spawned; printf err >&2".into(),
521                    ],
522                    ..Default::default()
523                },
524                Some(Box::new(sink)),
525            )
526            .await
527            .unwrap();
528
529        assert!(output.process_id.is_some());
530        let started = Instant::now();
531        loop {
532            let snapshot = events.lock().unwrap().clone();
533            if snapshot
534                .iter()
535                .any(|event| matches!(event, ExecEvent::Exited { exit_code: 0 }))
536            {
537                assert!(snapshot
538                    .iter()
539                    .any(|event| matches!(event, ExecEvent::Started { .. })));
540                assert!(snapshot.iter().any(
541                    |event| matches!(event, ExecEvent::Stdout { chunk } if chunk == b"spawned")
542                ));
543                assert!(snapshot
544                    .iter()
545                    .any(|event| matches!(event, ExecEvent::Stderr { chunk } if chunk == b"err")));
546                break;
547            }
548            assert!(started.elapsed() < Duration::from_secs(5));
549            sleep(Duration::from_millis(10)).await;
550        }
551    }
552
553    #[tokio::test]
554    async fn local_environment_create_temp_dir_cleans_up_on_drop() {
555        let path = {
556            let env = LocalExecutionEnvironment::new(None).unwrap();
557            let path = env.create_temp_dir("smol-wf-test-").await.unwrap();
558            let pathbuf = PathBuf::from(path.as_str());
559            assert!(pathbuf.exists());
560            pathbuf
561        };
562        assert!(!path.exists());
563    }
564}