Skip to main content

opencode/
client.rs

1use std::{
2    collections::BTreeMap,
3    path::PathBuf,
4    pin::Pin,
5    process::Stdio,
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use tokio::{
12    io::{AsyncBufReadExt, AsyncReadExt, BufReader},
13    sync::{mpsc, oneshot},
14};
15
16use crate::{
17    DynOpencodeRunJsonCompletion, DynOpencodeRunJsonEventStream, OpencodeError,
18    OpencodeRunCompletion, OpencodeRunJsonControlHandle, OpencodeRunJsonEvent,
19    OpencodeRunJsonHandle, OpencodeRunJsonParseError, OpencodeRunJsonParser, OpencodeRunRequest,
20    OpencodeTerminationHandle,
21};
22
23const STDERR_CAPTURE_MAX_BYTES: usize = 4096;
24const RUN_FAILED_MESSAGE: &str = "opencode run failed";
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27enum SelectionMode {
28    Last,
29    Id,
30}
31
32#[derive(Clone, Debug)]
33pub struct OpencodeClient {
34    pub(crate) binary: PathBuf,
35    pub(crate) env: BTreeMap<String, String>,
36    pub(crate) timeout: Option<Duration>,
37}
38
39impl OpencodeClient {
40    pub fn builder() -> crate::OpencodeClientBuilder {
41        crate::OpencodeClientBuilder::default()
42    }
43
44    pub async fn run_json(
45        &self,
46        request: OpencodeRunRequest,
47    ) -> Result<OpencodeRunJsonHandle, OpencodeError> {
48        let (events, completion, _termination) = self.spawn_run_json(request).await?;
49        Ok(OpencodeRunJsonHandle { events, completion })
50    }
51
52    pub async fn run_json_control(
53        &self,
54        request: OpencodeRunRequest,
55    ) -> Result<OpencodeRunJsonControlHandle, OpencodeError> {
56        let (events, completion, termination) = self.spawn_run_json(request).await?;
57        Ok(OpencodeRunJsonControlHandle {
58            events,
59            completion,
60            termination,
61        })
62    }
63
64    async fn spawn_run_json(
65        &self,
66        request: OpencodeRunRequest,
67    ) -> Result<
68        (
69            DynOpencodeRunJsonEventStream,
70            DynOpencodeRunJsonCompletion,
71            OpencodeTerminationHandle,
72        ),
73        OpencodeError,
74    > {
75        let selection_mode = selection_mode(&request);
76        let argv = request.argv()?;
77        let mut command = tokio::process::Command::new(&self.binary);
78        command
79            .args(argv)
80            .stdin(Stdio::null())
81            .stdout(Stdio::piped())
82            .stderr(Stdio::piped());
83
84        for (key, value) in &self.env {
85            command.env(key, value);
86        }
87
88        let mut child = command.spawn().map_err(|source| {
89            if source.kind() == std::io::ErrorKind::NotFound {
90                OpencodeError::MissingBinary
91            } else {
92                OpencodeError::Spawn {
93                    binary: self.binary.clone(),
94                    source,
95                }
96            }
97        })?;
98
99        let stdout = child.stdout.take().ok_or(OpencodeError::MissingStdout)?;
100        let stderr_capture = child
101            .stderr
102            .take()
103            .map(|stderr| tokio::spawn(async move { capture_stderr(stderr).await }));
104        let timeout = self.timeout;
105        let termination = OpencodeTerminationHandle::new();
106        let termination_for_runner = termination.clone();
107
108        let (events_tx, events_rx) = mpsc::channel(32);
109        let (completion_tx, completion_rx) = oneshot::channel();
110
111        tokio::spawn(async move {
112            let result = run_opencode_child(
113                child,
114                stdout,
115                stderr_capture,
116                events_tx,
117                timeout,
118                termination_for_runner,
119                selection_mode,
120            )
121            .await;
122            let _ = completion_tx.send(result);
123        });
124
125        let events: DynOpencodeRunJsonEventStream =
126            Box::pin(OpencodeRunJsonEventChannelStream::new(events_rx));
127
128        let completion: DynOpencodeRunJsonCompletion = Box::pin(async move {
129            completion_rx
130                .await
131                .map_err(|_| OpencodeError::Join("run-json task dropped".to_string()))?
132        });
133
134        Ok((events, completion, termination))
135    }
136}
137
138struct OpencodeRunJsonEventChannelStream {
139    rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
140}
141
142impl OpencodeRunJsonEventChannelStream {
143    fn new(rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>) -> Self {
144        Self { rx }
145    }
146}
147
148impl Stream for OpencodeRunJsonEventChannelStream {
149    type Item = Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>;
150
151    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152        self.get_mut().rx.poll_recv(cx)
153    }
154}
155
156async fn run_opencode_child(
157    mut child: tokio::process::Child,
158    stdout: tokio::process::ChildStdout,
159    stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
160    events_tx: mpsc::Sender<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
161    timeout: Option<Duration>,
162    termination: OpencodeTerminationHandle,
163    selection_mode: Option<SelectionMode>,
164) -> Result<OpencodeRunCompletion, OpencodeError> {
165    let mut reader = BufReader::new(stdout);
166    let mut parser = OpencodeRunJsonParser::new();
167    let mut line = String::new();
168    let mut events_open = true;
169    let mut final_text = String::new();
170    let mut saw_finish = false;
171    let mut termination_requested = false;
172    let deadline = timeout.map(|value| Instant::now() + value);
173    let mut exit_status = None;
174
175    loop {
176        if let Some(deadline) = deadline {
177            if Instant::now() >= deadline {
178                match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
179                    Ok(ChildExit::Exited(status)) => {
180                        exit_status = Some(status);
181                        break;
182                    }
183                    Ok(ChildExit::TimedOut) => {
184                        let _ = consume_stderr_capture(stderr_capture).await;
185                        return Err(OpencodeError::Timeout {
186                            timeout: timeout.expect("deadline implies timeout"),
187                        });
188                    }
189                    Err(err) => return Err(err),
190                }
191            }
192        }
193
194        line.clear();
195        let read_result = if let Some(deadline) = deadline {
196            let remaining = deadline.saturating_duration_since(Instant::now());
197            tokio::select! {
198                _ = termination.requested() => {
199                    termination_requested = true;
200                    let _ = child.start_kill();
201                    break;
202                }
203                read = tokio::time::timeout(remaining, reader.read_line(&mut line)) => {
204                    match read {
205                        Ok(result) => result,
206                        Err(_) => {
207                            match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
208                                Ok(ChildExit::Exited(status)) => {
209                                    exit_status = Some(status);
210                                    break;
211                                }
212                                Ok(ChildExit::TimedOut) => {
213                                    let _ = consume_stderr_capture(stderr_capture).await;
214                                    return Err(OpencodeError::Timeout {
215                                        timeout: timeout.expect("deadline implies timeout"),
216                                    });
217                                }
218                                Err(err) => return Err(err),
219                            }
220                        }
221                    }
222                }
223            }
224        } else {
225            tokio::select! {
226                _ = termination.requested() => {
227                    termination_requested = true;
228                    let _ = child.start_kill();
229                    break;
230                }
231                read = reader.read_line(&mut line) => read,
232            }
233        };
234
235        let bytes = match read_result {
236            Ok(bytes) => bytes,
237            Err(err) => {
238                let _ = child.start_kill();
239                let _ = child.wait().await;
240                let _ = consume_stderr_capture(stderr_capture).await;
241                return Err(OpencodeError::StdoutRead(err));
242            }
243        };
244
245        if bytes == 0 {
246            break;
247        }
248
249        let parsed = parser.parse_line(line.trim_end_matches('\n'));
250        match parsed {
251            Ok(Some(event)) => {
252                if let OpencodeRunJsonEvent::Text { text, .. } = &event {
253                    final_text.push_str(text);
254                } else if matches!(event, OpencodeRunJsonEvent::StepFinish { .. }) {
255                    saw_finish = true;
256                }
257
258                if events_open && events_tx.send(Ok(event)).await.is_err() {
259                    events_open = false;
260                }
261            }
262            Ok(None) => {}
263            Err(error) => {
264                if events_open && events_tx.send(Err(error)).await.is_err() {
265                    events_open = false;
266                }
267            }
268        }
269    }
270
271    let status = match exit_status {
272        Some(status) => status,
273        None => match wait_for_child_exit(&mut child, timeout, deadline).await {
274            Ok(ChildExit::Exited(status)) => status,
275            Ok(ChildExit::TimedOut) => {
276                let _ = consume_stderr_capture(stderr_capture).await;
277                return Err(OpencodeError::Timeout {
278                    timeout: timeout.expect("deadline implies timeout"),
279                });
280            }
281            Err(err) => return Err(err),
282        },
283    };
284    let stderr = consume_stderr_capture(stderr_capture).await?;
285    if !status.success() {
286        if termination_requested {
287            drop(events_tx);
288            return Ok(OpencodeRunCompletion {
289                status,
290                final_text: None,
291            });
292        }
293        if let Some(message) = classify_selection_failure(&stderr, selection_mode) {
294            if events_open {
295                let _ = events_tx
296                    .send(Ok(OpencodeRunJsonEvent::TerminalError {
297                        message: message.clone(),
298                        raw: serde_json::Value::Null,
299                    }))
300                    .await;
301            }
302            drop(events_tx);
303            return Err(OpencodeError::SelectionFailed { message });
304        }
305        if events_open {
306            let _ = events_tx
307                .send(Ok(OpencodeRunJsonEvent::TerminalError {
308                    message: RUN_FAILED_MESSAGE.to_string(),
309                    raw: serde_json::Value::Null,
310                }))
311                .await;
312        }
313        drop(events_tx);
314        return Err(OpencodeError::RunFailed {
315            status,
316            message: RUN_FAILED_MESSAGE.to_string(),
317        });
318    }
319    drop(events_tx);
320
321    let final_text = (saw_finish && !final_text.is_empty()).then_some(final_text);
322
323    Ok(OpencodeRunCompletion { status, final_text })
324}
325
326#[derive(Debug, Clone, Copy)]
327enum ChildExit {
328    Exited(std::process::ExitStatus),
329    TimedOut,
330}
331
332async fn wait_for_child_exit(
333    child: &mut tokio::process::Child,
334    timeout: Option<Duration>,
335    deadline: Option<Instant>,
336) -> Result<ChildExit, OpencodeError> {
337    match deadline {
338        None => child
339            .wait()
340            .await
341            .map(ChildExit::Exited)
342            .map_err(OpencodeError::Wait),
343        Some(deadline) => {
344            let remaining = deadline.saturating_duration_since(Instant::now());
345            if remaining.is_zero() {
346                match child.try_wait().map_err(OpencodeError::Wait)? {
347                    Some(status) => Ok(ChildExit::Exited(status)),
348                    None => {
349                        timeout.expect("deadline implies timeout");
350                        let _ = child.start_kill();
351                        match child.wait().await {
352                            Ok(_status) => Ok(ChildExit::TimedOut),
353                            Err(err) => Err(OpencodeError::Wait(err)),
354                        }
355                    }
356                }
357            } else {
358                match tokio::time::timeout(remaining, child.wait()).await {
359                    Ok(result) => result.map(ChildExit::Exited).map_err(OpencodeError::Wait),
360                    Err(_) => match child.try_wait().map_err(OpencodeError::Wait)? {
361                        Some(status) => Ok(ChildExit::Exited(status)),
362                        None => {
363                            timeout.expect("deadline implies timeout");
364                            let _ = child.start_kill();
365                            match child.wait().await {
366                                Ok(_status) => Ok(ChildExit::TimedOut),
367                                Err(err) => Err(OpencodeError::Wait(err)),
368                            }
369                        }
370                    },
371                }
372            }
373        }
374    }
375}
376
377fn selection_mode(request: &OpencodeRunRequest) -> Option<SelectionMode> {
378    if request.session_id().is_some() {
379        Some(SelectionMode::Id)
380    } else if request.continue_requested() {
381        Some(SelectionMode::Last)
382    } else {
383        None
384    }
385}
386
387async fn capture_stderr(
388    mut stderr: tokio::process::ChildStderr,
389) -> Result<Vec<u8>, std::io::Error> {
390    let mut captured = Vec::new();
391    let mut buffer = [0u8; 1024];
392
393    loop {
394        let read = stderr.read(&mut buffer).await?;
395        if read == 0 {
396            break;
397        }
398
399        if captured.len() < STDERR_CAPTURE_MAX_BYTES {
400            let remaining = STDERR_CAPTURE_MAX_BYTES - captured.len();
401            captured.extend_from_slice(&buffer[..read.min(remaining)]);
402        }
403    }
404
405    Ok(captured)
406}
407
408async fn consume_stderr_capture(
409    stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
410) -> Result<String, OpencodeError> {
411    let Some(stderr_capture) = stderr_capture else {
412        return Ok(String::new());
413    };
414
415    let captured = stderr_capture
416        .await
417        .map_err(|err| OpencodeError::Join(format!("stderr capture task failed: {err}")))?
418        .map_err(OpencodeError::StderrRead)?;
419
420    Ok(String::from_utf8_lossy(&captured).into_owned())
421}
422
423fn classify_selection_failure(
424    stderr: &str,
425    selection_mode: Option<SelectionMode>,
426) -> Option<String> {
427    let selection_mode = selection_mode?;
428    let stderr = stderr.to_ascii_lowercase();
429
430    let saw_not_found = (stderr.contains("not found")
431        && (stderr.contains("session")
432            || stderr.contains("thread")
433            || stderr.contains("conversation")))
434        || stderr.contains("no session")
435        || stderr.contains("no sessions")
436        || stderr.contains("unknown session")
437        || stderr.contains("no thread")
438        || stderr.contains("no threads")
439        || stderr.contains("unknown thread")
440        || stderr.contains("no conversation")
441        || stderr.contains("unknown conversation");
442
443    if !saw_not_found {
444        return None;
445    }
446
447    Some(match selection_mode {
448        SelectionMode::Last => "no session found".to_string(),
449        SelectionMode::Id => "session not found".to_string(),
450    })
451}
452
453#[cfg(test)]
454mod tests {
455    use std::process::Stdio;
456    use std::time::{Duration, Instant};
457
458    use super::{wait_for_child_exit, ChildExit};
459
460    #[cfg(unix)]
461    #[tokio::test]
462    async fn wait_for_child_exit_returns_status_when_deadline_has_elapsed() {
463        let mut child = tokio::process::Command::new("sh")
464            .args(["-c", "exit 0"])
465            .stdout(Stdio::null())
466            .stderr(Stdio::null())
467            .spawn()
468            .expect("spawn child");
469        tokio::time::sleep(Duration::from_millis(50)).await;
470
471        let outcome = wait_for_child_exit(
472            &mut child,
473            Some(Duration::from_millis(1)),
474            Some(Instant::now()),
475        )
476        .await
477        .expect("wait helper succeeds");
478
479        match outcome {
480            ChildExit::Exited(status) => assert!(status.success()),
481            ChildExit::TimedOut => panic!("expected exited status, got timeout"),
482        }
483    }
484}