1use std::{
2 collections::BTreeMap,
3 path::PathBuf,
4 pin::Pin,
5 process::Stdio,
6 task::{Context, Poll},
7 time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use tokio::{
12 io::{AsyncBufReadExt, AsyncReadExt, BufReader},
13 sync::{mpsc, oneshot},
14};
15
16use crate::{
17 DynOpencodeRunJsonCompletion, DynOpencodeRunJsonEventStream, OpencodeError,
18 OpencodeRunCompletion, OpencodeRunJsonControlHandle, OpencodeRunJsonEvent,
19 OpencodeRunJsonHandle, OpencodeRunJsonParseError, OpencodeRunJsonParser, OpencodeRunRequest,
20 OpencodeTerminationHandle,
21};
22
23const STDERR_CAPTURE_MAX_BYTES: usize = 4096;
24const RUN_FAILED_MESSAGE: &str = "opencode run failed";
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27enum SelectionMode {
28 Last,
29 Id,
30}
31
32#[derive(Clone, Debug)]
33pub struct OpencodeClient {
34 pub(crate) binary: PathBuf,
35 pub(crate) env: BTreeMap<String, String>,
36 pub(crate) timeout: Option<Duration>,
37}
38
39impl OpencodeClient {
40 pub fn builder() -> crate::OpencodeClientBuilder {
41 crate::OpencodeClientBuilder::default()
42 }
43
44 pub async fn run_json(
45 &self,
46 request: OpencodeRunRequest,
47 ) -> Result<OpencodeRunJsonHandle, OpencodeError> {
48 let (events, completion, _termination) = self.spawn_run_json(request).await?;
49 Ok(OpencodeRunJsonHandle { events, completion })
50 }
51
52 pub async fn run_json_control(
53 &self,
54 request: OpencodeRunRequest,
55 ) -> Result<OpencodeRunJsonControlHandle, OpencodeError> {
56 let (events, completion, termination) = self.spawn_run_json(request).await?;
57 Ok(OpencodeRunJsonControlHandle {
58 events,
59 completion,
60 termination,
61 })
62 }
63
64 async fn spawn_run_json(
65 &self,
66 request: OpencodeRunRequest,
67 ) -> Result<
68 (
69 DynOpencodeRunJsonEventStream,
70 DynOpencodeRunJsonCompletion,
71 OpencodeTerminationHandle,
72 ),
73 OpencodeError,
74 > {
75 let selection_mode = selection_mode(&request);
76 let argv = request.argv()?;
77 let mut command = tokio::process::Command::new(&self.binary);
78 command
79 .args(argv)
80 .stdin(Stdio::null())
81 .stdout(Stdio::piped())
82 .stderr(Stdio::piped());
83
84 for (key, value) in &self.env {
85 command.env(key, value);
86 }
87
88 let mut child = command.spawn().map_err(|source| {
89 if source.kind() == std::io::ErrorKind::NotFound {
90 OpencodeError::MissingBinary
91 } else {
92 OpencodeError::Spawn {
93 binary: self.binary.clone(),
94 source,
95 }
96 }
97 })?;
98
99 let stdout = child.stdout.take().ok_or(OpencodeError::MissingStdout)?;
100 let stderr_capture = child
101 .stderr
102 .take()
103 .map(|stderr| tokio::spawn(async move { capture_stderr(stderr).await }));
104 let timeout = self.timeout;
105 let termination = OpencodeTerminationHandle::new();
106 let termination_for_runner = termination.clone();
107
108 let (events_tx, events_rx) = mpsc::channel(32);
109 let (completion_tx, completion_rx) = oneshot::channel();
110
111 tokio::spawn(async move {
112 let result = run_opencode_child(
113 child,
114 stdout,
115 stderr_capture,
116 events_tx,
117 timeout,
118 termination_for_runner,
119 selection_mode,
120 )
121 .await;
122 let _ = completion_tx.send(result);
123 });
124
125 let events: DynOpencodeRunJsonEventStream =
126 Box::pin(OpencodeRunJsonEventChannelStream::new(events_rx));
127
128 let completion: DynOpencodeRunJsonCompletion = Box::pin(async move {
129 completion_rx
130 .await
131 .map_err(|_| OpencodeError::Join("run-json task dropped".to_string()))?
132 });
133
134 Ok((events, completion, termination))
135 }
136}
137
138struct OpencodeRunJsonEventChannelStream {
139 rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
140}
141
142impl OpencodeRunJsonEventChannelStream {
143 fn new(rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>) -> Self {
144 Self { rx }
145 }
146}
147
148impl Stream for OpencodeRunJsonEventChannelStream {
149 type Item = Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>;
150
151 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152 self.get_mut().rx.poll_recv(cx)
153 }
154}
155
156async fn run_opencode_child(
157 mut child: tokio::process::Child,
158 stdout: tokio::process::ChildStdout,
159 stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
160 events_tx: mpsc::Sender<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
161 timeout: Option<Duration>,
162 termination: OpencodeTerminationHandle,
163 selection_mode: Option<SelectionMode>,
164) -> Result<OpencodeRunCompletion, OpencodeError> {
165 let mut reader = BufReader::new(stdout);
166 let mut parser = OpencodeRunJsonParser::new();
167 let mut line = String::new();
168 let mut events_open = true;
169 let mut final_text = String::new();
170 let mut saw_finish = false;
171 let mut termination_requested = false;
172 let deadline = timeout.map(|value| Instant::now() + value);
173 let mut exit_status = None;
174
175 loop {
176 if let Some(deadline) = deadline {
177 if Instant::now() >= deadline {
178 match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
179 Ok(ChildExit::Exited(status)) => {
180 exit_status = Some(status);
181 break;
182 }
183 Ok(ChildExit::TimedOut) => {
184 let _ = consume_stderr_capture(stderr_capture).await;
185 return Err(OpencodeError::Timeout {
186 timeout: timeout.expect("deadline implies timeout"),
187 });
188 }
189 Err(err) => return Err(err),
190 }
191 }
192 }
193
194 line.clear();
195 let read_result = if let Some(deadline) = deadline {
196 let remaining = deadline.saturating_duration_since(Instant::now());
197 tokio::select! {
198 _ = termination.requested() => {
199 termination_requested = true;
200 let _ = child.start_kill();
201 break;
202 }
203 read = tokio::time::timeout(remaining, reader.read_line(&mut line)) => {
204 match read {
205 Ok(result) => result,
206 Err(_) => {
207 match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
208 Ok(ChildExit::Exited(status)) => {
209 exit_status = Some(status);
210 break;
211 }
212 Ok(ChildExit::TimedOut) => {
213 let _ = consume_stderr_capture(stderr_capture).await;
214 return Err(OpencodeError::Timeout {
215 timeout: timeout.expect("deadline implies timeout"),
216 });
217 }
218 Err(err) => return Err(err),
219 }
220 }
221 }
222 }
223 }
224 } else {
225 tokio::select! {
226 _ = termination.requested() => {
227 termination_requested = true;
228 let _ = child.start_kill();
229 break;
230 }
231 read = reader.read_line(&mut line) => read,
232 }
233 };
234
235 let bytes = match read_result {
236 Ok(bytes) => bytes,
237 Err(err) => {
238 let _ = child.start_kill();
239 let _ = child.wait().await;
240 let _ = consume_stderr_capture(stderr_capture).await;
241 return Err(OpencodeError::StdoutRead(err));
242 }
243 };
244
245 if bytes == 0 {
246 break;
247 }
248
249 let parsed = parser.parse_line(line.trim_end_matches('\n'));
250 match parsed {
251 Ok(Some(event)) => {
252 if let OpencodeRunJsonEvent::Text { text, .. } = &event {
253 final_text.push_str(text);
254 } else if matches!(event, OpencodeRunJsonEvent::StepFinish { .. }) {
255 saw_finish = true;
256 }
257
258 if events_open && events_tx.send(Ok(event)).await.is_err() {
259 events_open = false;
260 }
261 }
262 Ok(None) => {}
263 Err(error) => {
264 if events_open && events_tx.send(Err(error)).await.is_err() {
265 events_open = false;
266 }
267 }
268 }
269 }
270
271 let status = match exit_status {
272 Some(status) => status,
273 None => match wait_for_child_exit(&mut child, timeout, deadline).await {
274 Ok(ChildExit::Exited(status)) => status,
275 Ok(ChildExit::TimedOut) => {
276 let _ = consume_stderr_capture(stderr_capture).await;
277 return Err(OpencodeError::Timeout {
278 timeout: timeout.expect("deadline implies timeout"),
279 });
280 }
281 Err(err) => return Err(err),
282 },
283 };
284 let stderr = consume_stderr_capture(stderr_capture).await?;
285 if !status.success() {
286 if termination_requested {
287 drop(events_tx);
288 return Ok(OpencodeRunCompletion {
289 status,
290 final_text: None,
291 });
292 }
293 if let Some(message) = classify_selection_failure(&stderr, selection_mode) {
294 if events_open {
295 let _ = events_tx
296 .send(Ok(OpencodeRunJsonEvent::TerminalError {
297 message: message.clone(),
298 raw: serde_json::Value::Null,
299 }))
300 .await;
301 }
302 drop(events_tx);
303 return Err(OpencodeError::SelectionFailed { message });
304 }
305 if events_open {
306 let _ = events_tx
307 .send(Ok(OpencodeRunJsonEvent::TerminalError {
308 message: RUN_FAILED_MESSAGE.to_string(),
309 raw: serde_json::Value::Null,
310 }))
311 .await;
312 }
313 drop(events_tx);
314 return Err(OpencodeError::RunFailed {
315 status,
316 message: RUN_FAILED_MESSAGE.to_string(),
317 });
318 }
319 drop(events_tx);
320
321 let final_text = (saw_finish && !final_text.is_empty()).then_some(final_text);
322
323 Ok(OpencodeRunCompletion { status, final_text })
324}
325
326#[derive(Debug, Clone, Copy)]
327enum ChildExit {
328 Exited(std::process::ExitStatus),
329 TimedOut,
330}
331
332async fn wait_for_child_exit(
333 child: &mut tokio::process::Child,
334 timeout: Option<Duration>,
335 deadline: Option<Instant>,
336) -> Result<ChildExit, OpencodeError> {
337 match deadline {
338 None => child
339 .wait()
340 .await
341 .map(ChildExit::Exited)
342 .map_err(OpencodeError::Wait),
343 Some(deadline) => {
344 let remaining = deadline.saturating_duration_since(Instant::now());
345 if remaining.is_zero() {
346 match child.try_wait().map_err(OpencodeError::Wait)? {
347 Some(status) => Ok(ChildExit::Exited(status)),
348 None => {
349 timeout.expect("deadline implies timeout");
350 let _ = child.start_kill();
351 match child.wait().await {
352 Ok(_status) => Ok(ChildExit::TimedOut),
353 Err(err) => Err(OpencodeError::Wait(err)),
354 }
355 }
356 }
357 } else {
358 match tokio::time::timeout(remaining, child.wait()).await {
359 Ok(result) => result.map(ChildExit::Exited).map_err(OpencodeError::Wait),
360 Err(_) => match child.try_wait().map_err(OpencodeError::Wait)? {
361 Some(status) => Ok(ChildExit::Exited(status)),
362 None => {
363 timeout.expect("deadline implies timeout");
364 let _ = child.start_kill();
365 match child.wait().await {
366 Ok(_status) => Ok(ChildExit::TimedOut),
367 Err(err) => Err(OpencodeError::Wait(err)),
368 }
369 }
370 },
371 }
372 }
373 }
374 }
375}
376
377fn selection_mode(request: &OpencodeRunRequest) -> Option<SelectionMode> {
378 if request.session_id().is_some() {
379 Some(SelectionMode::Id)
380 } else if request.continue_requested() {
381 Some(SelectionMode::Last)
382 } else {
383 None
384 }
385}
386
387async fn capture_stderr(
388 mut stderr: tokio::process::ChildStderr,
389) -> Result<Vec<u8>, std::io::Error> {
390 let mut captured = Vec::new();
391 let mut buffer = [0u8; 1024];
392
393 loop {
394 let read = stderr.read(&mut buffer).await?;
395 if read == 0 {
396 break;
397 }
398
399 if captured.len() < STDERR_CAPTURE_MAX_BYTES {
400 let remaining = STDERR_CAPTURE_MAX_BYTES - captured.len();
401 captured.extend_from_slice(&buffer[..read.min(remaining)]);
402 }
403 }
404
405 Ok(captured)
406}
407
408async fn consume_stderr_capture(
409 stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
410) -> Result<String, OpencodeError> {
411 let Some(stderr_capture) = stderr_capture else {
412 return Ok(String::new());
413 };
414
415 let captured = stderr_capture
416 .await
417 .map_err(|err| OpencodeError::Join(format!("stderr capture task failed: {err}")))?
418 .map_err(OpencodeError::StderrRead)?;
419
420 Ok(String::from_utf8_lossy(&captured).into_owned())
421}
422
423fn classify_selection_failure(
424 stderr: &str,
425 selection_mode: Option<SelectionMode>,
426) -> Option<String> {
427 let selection_mode = selection_mode?;
428 let stderr = stderr.to_ascii_lowercase();
429
430 let saw_not_found = (stderr.contains("not found")
431 && (stderr.contains("session")
432 || stderr.contains("thread")
433 || stderr.contains("conversation")))
434 || stderr.contains("no session")
435 || stderr.contains("no sessions")
436 || stderr.contains("unknown session")
437 || stderr.contains("no thread")
438 || stderr.contains("no threads")
439 || stderr.contains("unknown thread")
440 || stderr.contains("no conversation")
441 || stderr.contains("unknown conversation");
442
443 if !saw_not_found {
444 return None;
445 }
446
447 Some(match selection_mode {
448 SelectionMode::Last => "no session found".to_string(),
449 SelectionMode::Id => "session not found".to_string(),
450 })
451}
452
453#[cfg(test)]
454mod tests {
455 use std::process::Stdio;
456 use std::time::{Duration, Instant};
457
458 use super::{wait_for_child_exit, ChildExit};
459
460 #[cfg(unix)]
461 #[tokio::test]
462 async fn wait_for_child_exit_returns_status_when_deadline_has_elapsed() {
463 let mut child = tokio::process::Command::new("sh")
464 .args(["-c", "exit 0"])
465 .stdout(Stdio::null())
466 .stderr(Stdio::null())
467 .spawn()
468 .expect("spawn child");
469 tokio::time::sleep(Duration::from_millis(50)).await;
470
471 let outcome = wait_for_child_exit(
472 &mut child,
473 Some(Duration::from_millis(1)),
474 Some(Instant::now()),
475 )
476 .await
477 .expect("wait helper succeeds");
478
479 match outcome {
480 ChildExit::Exited(status) => assert!(status.success()),
481 ChildExit::TimedOut => panic!("expected exited status, got timeout"),
482 }
483 }
484}