1use std::{
2 collections::BTreeMap,
3 path::PathBuf,
4 pin::Pin,
5 process::Stdio,
6 task::{Context, Poll},
7 time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use tokio::{
12 io::{AsyncBufReadExt, AsyncReadExt, BufReader},
13 sync::{mpsc, oneshot},
14};
15
16use crate::{
17 DynOpencodeRunJsonCompletion, DynOpencodeRunJsonEventStream, OpencodeError,
18 OpencodeRunCompletion, OpencodeRunJsonControlHandle, OpencodeRunJsonEvent,
19 OpencodeRunJsonHandle, OpencodeRunJsonParseError, OpencodeRunJsonParser, OpencodeRunRequest,
20 OpencodeTerminationHandle,
21};
22
23const STDERR_CAPTURE_MAX_BYTES: usize = 4096;
24const RUN_FAILED_MESSAGE: &str = "opencode run failed";
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27enum SelectionMode {
28 Last,
29 Id,
30}
31
32#[derive(Clone, Debug)]
33pub struct OpencodeClient {
34 pub(crate) binary: PathBuf,
35 pub(crate) env: BTreeMap<String, String>,
36 pub(crate) timeout: Option<Duration>,
37}
38
39impl OpencodeClient {
40 pub fn builder() -> crate::OpencodeClientBuilder {
41 crate::OpencodeClientBuilder::default()
42 }
43
44 pub async fn run_json(
45 &self,
46 request: OpencodeRunRequest,
47 ) -> Result<OpencodeRunJsonHandle, OpencodeError> {
48 let (events, completion, _termination) = self.spawn_run_json(request).await?;
49 Ok(OpencodeRunJsonHandle { events, completion })
50 }
51
52 pub async fn run_json_control(
53 &self,
54 request: OpencodeRunRequest,
55 ) -> Result<OpencodeRunJsonControlHandle, OpencodeError> {
56 let (events, completion, termination) = self.spawn_run_json(request).await?;
57 Ok(OpencodeRunJsonControlHandle {
58 events,
59 completion,
60 termination,
61 })
62 }
63
64 async fn spawn_run_json(
65 &self,
66 request: OpencodeRunRequest,
67 ) -> Result<
68 (
69 DynOpencodeRunJsonEventStream,
70 DynOpencodeRunJsonCompletion,
71 OpencodeTerminationHandle,
72 ),
73 OpencodeError,
74 > {
75 let selection_mode = selection_mode(&request);
76 let argv = request.argv()?;
77 let mut command = tokio::process::Command::new(&self.binary);
78 command
79 .args(argv)
80 .stdin(Stdio::null())
81 .stdout(Stdio::piped())
82 .stderr(Stdio::piped());
83
84 for (key, value) in &self.env {
85 command.env(key, value);
86 }
87
88 let mut child = command.spawn().map_err(|source| {
89 if source.kind() == std::io::ErrorKind::NotFound {
90 OpencodeError::MissingBinary
91 } else {
92 OpencodeError::Spawn {
93 binary: self.binary.clone(),
94 source,
95 }
96 }
97 })?;
98
99 let stdout = child.stdout.take().ok_or(OpencodeError::MissingStdout)?;
100 let stderr_capture = child
101 .stderr
102 .take()
103 .map(|stderr| tokio::spawn(async move { capture_stderr(stderr).await }));
104 let timeout = self.timeout;
105 let termination = OpencodeTerminationHandle::new();
106 let termination_for_runner = termination.clone();
107
108 let (events_tx, events_rx) = mpsc::channel(32);
109 let (completion_tx, completion_rx) = oneshot::channel();
110
111 tokio::spawn(async move {
112 let result = run_opencode_child(
113 child,
114 stdout,
115 stderr_capture,
116 events_tx,
117 timeout,
118 termination_for_runner,
119 selection_mode,
120 )
121 .await;
122 let _ = completion_tx.send(result);
123 });
124
125 let events: DynOpencodeRunJsonEventStream =
126 Box::pin(OpencodeRunJsonEventChannelStream::new(events_rx));
127
128 let completion: DynOpencodeRunJsonCompletion = Box::pin(async move {
129 completion_rx
130 .await
131 .map_err(|_| OpencodeError::Join("run-json task dropped".to_string()))?
132 });
133
134 Ok((events, completion, termination))
135 }
136}
137
138struct OpencodeRunJsonEventChannelStream {
139 rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
140}
141
142impl OpencodeRunJsonEventChannelStream {
143 fn new(rx: mpsc::Receiver<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>) -> Self {
144 Self { rx }
145 }
146}
147
148impl Stream for OpencodeRunJsonEventChannelStream {
149 type Item = Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>;
150
151 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152 self.get_mut().rx.poll_recv(cx)
153 }
154}
155
156async fn run_opencode_child(
157 mut child: tokio::process::Child,
158 stdout: tokio::process::ChildStdout,
159 stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
160 events_tx: mpsc::Sender<Result<OpencodeRunJsonEvent, OpencodeRunJsonParseError>>,
161 timeout: Option<Duration>,
162 termination: OpencodeTerminationHandle,
163 selection_mode: Option<SelectionMode>,
164) -> Result<OpencodeRunCompletion, OpencodeError> {
165 let mut reader = BufReader::new(stdout);
166 let mut parser = OpencodeRunJsonParser::new();
167 let mut line = String::new();
168 let mut events_open = true;
169 let mut final_text = String::new();
170 let mut saw_finish = false;
171 let mut termination_requested = false;
172 let deadline = timeout.map(|value| Instant::now() + value);
173
174 loop {
175 if let Some(deadline) = deadline {
176 if Instant::now() >= deadline {
177 match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
178 Ok(_) => {
179 let _ = consume_stderr_capture(stderr_capture).await;
180 return Err(OpencodeError::Timeout {
181 timeout: timeout.expect("deadline implies timeout"),
182 });
183 }
184 Err(err) => return Err(err),
185 }
186 }
187 }
188
189 line.clear();
190 let read_result = if let Some(deadline) = deadline {
191 let remaining = deadline.saturating_duration_since(Instant::now());
192 tokio::select! {
193 _ = termination.requested() => {
194 termination_requested = true;
195 let _ = child.start_kill();
196 break;
197 }
198 read = tokio::time::timeout(remaining, reader.read_line(&mut line)) => {
199 match read {
200 Ok(result) => result,
201 Err(_) => {
202 match wait_for_child_exit(&mut child, timeout, Some(deadline)).await {
203 Ok(_) => {
204 let _ = consume_stderr_capture(stderr_capture).await;
205 return Err(OpencodeError::Timeout {
206 timeout: timeout.expect("deadline implies timeout"),
207 });
208 }
209 Err(err) => return Err(err),
210 }
211 }
212 }
213 }
214 }
215 } else {
216 tokio::select! {
217 _ = termination.requested() => {
218 termination_requested = true;
219 let _ = child.start_kill();
220 break;
221 }
222 read = reader.read_line(&mut line) => read,
223 }
224 };
225
226 let bytes = match read_result {
227 Ok(bytes) => bytes,
228 Err(err) => {
229 let _ = child.start_kill();
230 let _ = child.wait().await;
231 let _ = consume_stderr_capture(stderr_capture).await;
232 return Err(OpencodeError::StdoutRead(err));
233 }
234 };
235
236 if bytes == 0 {
237 break;
238 }
239
240 let parsed = parser.parse_line(line.trim_end_matches('\n'));
241 match parsed {
242 Ok(Some(event)) => {
243 if let OpencodeRunJsonEvent::Text { text, .. } = &event {
244 final_text.push_str(text);
245 } else if matches!(event, OpencodeRunJsonEvent::StepFinish { .. }) {
246 saw_finish = true;
247 }
248
249 if events_open && events_tx.send(Ok(event)).await.is_err() {
250 events_open = false;
251 }
252 }
253 Ok(None) => {}
254 Err(error) => {
255 if events_open && events_tx.send(Err(error)).await.is_err() {
256 events_open = false;
257 }
258 }
259 }
260 }
261
262 let status = match wait_for_child_exit(&mut child, timeout, deadline).await {
263 Ok(status) => status,
264 Err(err @ OpencodeError::Timeout { .. }) => {
265 let _ = consume_stderr_capture(stderr_capture).await;
266 return Err(err);
267 }
268 Err(err) => return Err(err),
269 };
270 let stderr = consume_stderr_capture(stderr_capture).await?;
271 if !status.success() {
272 if termination_requested {
273 drop(events_tx);
274 return Ok(OpencodeRunCompletion {
275 status,
276 final_text: None,
277 });
278 }
279 if let Some(message) = classify_selection_failure(&stderr, selection_mode) {
280 if events_open {
281 let _ = events_tx
282 .send(Ok(OpencodeRunJsonEvent::TerminalError {
283 message: message.clone(),
284 raw: serde_json::Value::Null,
285 }))
286 .await;
287 }
288 drop(events_tx);
289 return Err(OpencodeError::SelectionFailed { message });
290 }
291 if events_open {
292 let _ = events_tx
293 .send(Ok(OpencodeRunJsonEvent::TerminalError {
294 message: RUN_FAILED_MESSAGE.to_string(),
295 raw: serde_json::Value::Null,
296 }))
297 .await;
298 }
299 drop(events_tx);
300 return Err(OpencodeError::RunFailed {
301 status,
302 message: RUN_FAILED_MESSAGE.to_string(),
303 });
304 }
305 drop(events_tx);
306
307 let final_text = (saw_finish && !final_text.is_empty()).then_some(final_text);
308
309 Ok(OpencodeRunCompletion { status, final_text })
310}
311
312async fn wait_for_child_exit(
313 child: &mut tokio::process::Child,
314 timeout: Option<Duration>,
315 deadline: Option<Instant>,
316) -> Result<std::process::ExitStatus, OpencodeError> {
317 match deadline {
318 None => child.wait().await.map_err(OpencodeError::Wait),
319 Some(deadline) => {
320 let remaining = deadline.saturating_duration_since(Instant::now());
321 if remaining.is_zero() {
322 let timeout = timeout.expect("deadline implies timeout");
323 let _ = child.start_kill();
324 match child.wait().await {
325 Ok(_status) => Err(OpencodeError::Timeout { timeout }),
326 Err(err) => Err(OpencodeError::Wait(err)),
327 }
328 } else {
329 match tokio::time::timeout(remaining, child.wait()).await {
330 Ok(result) => result.map_err(OpencodeError::Wait),
331 Err(_) => {
332 let timeout = timeout.expect("deadline implies timeout");
333 let _ = child.start_kill();
334 match child.wait().await {
335 Ok(_status) => Err(OpencodeError::Timeout { timeout }),
336 Err(err) => Err(OpencodeError::Wait(err)),
337 }
338 }
339 }
340 }
341 }
342 }
343}
344
345fn selection_mode(request: &OpencodeRunRequest) -> Option<SelectionMode> {
346 if request.session_id().is_some() {
347 Some(SelectionMode::Id)
348 } else if request.continue_requested() {
349 Some(SelectionMode::Last)
350 } else {
351 None
352 }
353}
354
355async fn capture_stderr(
356 mut stderr: tokio::process::ChildStderr,
357) -> Result<Vec<u8>, std::io::Error> {
358 let mut captured = Vec::new();
359 let mut buffer = [0u8; 1024];
360
361 loop {
362 let read = stderr.read(&mut buffer).await?;
363 if read == 0 {
364 break;
365 }
366
367 if captured.len() < STDERR_CAPTURE_MAX_BYTES {
368 let remaining = STDERR_CAPTURE_MAX_BYTES - captured.len();
369 captured.extend_from_slice(&buffer[..read.min(remaining)]);
370 }
371 }
372
373 Ok(captured)
374}
375
376async fn consume_stderr_capture(
377 stderr_capture: Option<tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>>,
378) -> Result<String, OpencodeError> {
379 let Some(stderr_capture) = stderr_capture else {
380 return Ok(String::new());
381 };
382
383 let captured = stderr_capture
384 .await
385 .map_err(|err| OpencodeError::Join(format!("stderr capture task failed: {err}")))?
386 .map_err(OpencodeError::StderrRead)?;
387
388 Ok(String::from_utf8_lossy(&captured).into_owned())
389}
390
391fn classify_selection_failure(
392 stderr: &str,
393 selection_mode: Option<SelectionMode>,
394) -> Option<String> {
395 let selection_mode = selection_mode?;
396 let stderr = stderr.to_ascii_lowercase();
397
398 let saw_not_found = (stderr.contains("not found")
399 && (stderr.contains("session")
400 || stderr.contains("thread")
401 || stderr.contains("conversation")))
402 || stderr.contains("no session")
403 || stderr.contains("no sessions")
404 || stderr.contains("unknown session")
405 || stderr.contains("no thread")
406 || stderr.contains("no threads")
407 || stderr.contains("unknown thread")
408 || stderr.contains("no conversation")
409 || stderr.contains("unknown conversation");
410
411 if !saw_not_found {
412 return None;
413 }
414
415 Some(match selection_mode {
416 SelectionMode::Last => "no session found".to_string(),
417 SelectionMode::Id => "session not found".to_string(),
418 })
419}