1use std::{collections::HashSet, io, path::PathBuf, sync::Arc, time::Duration};
2
3use serde_json::{json, Value};
4use thiserror::Error;
5
6use super::{
7 AppCallHandle, ApprovalDecision, ClientInfo, CodexCallHandle, CodexCallParams, CodexCallResult,
8 CodexReplyParams, InitializeParams, RequestId, StdioServerConfig, METHOD_CODEX,
9 METHOD_CODEX_APPROVAL, METHOD_THREAD_FORK, METHOD_THREAD_LIST, METHOD_THREAD_RESUME,
10 METHOD_THREAD_START, METHOD_TURN_INTERRUPT, METHOD_TURN_START,
11};
12
13use super::jsonrpc::{map_response, JsonRpcTransport};
14
15#[derive(Debug, Error)]
17pub enum McpError {
18 #[error("failed to spawn `{command}`: {source}")]
19 Spawn {
20 command: String,
21 #[source]
22 source: io::Error,
23 },
24 #[error("server did not respond to initialize: {0}")]
25 Handshake(String),
26 #[error("transport task failed: {0}")]
27 Transport(String),
28 #[error("server returned JSON-RPC error {code}: {message}")]
29 Rpc {
30 code: i64,
31 message: String,
32 data: Option<Value>,
33 },
34 #[error("server reported an error: {0}")]
35 Server(String),
36 #[error("request was cancelled")]
37 Cancelled,
38 #[error("timed out after {0:?}")]
39 Timeout(Duration),
40 #[error("serialization failed: {0}")]
41 Serialization(#[from] serde_json::Error),
42 #[error("transport channel closed unexpectedly")]
43 ChannelClosed,
44}
45
46pub struct CodexMcpServer {
48 transport: Arc<JsonRpcTransport>,
49}
50
51impl CodexMcpServer {
52 pub async fn start(config: StdioServerConfig, client: ClientInfo) -> Result<Self, McpError> {
54 Self::with_capabilities(config, client, Value::Object(Default::default())).await
55 }
56
57 pub async fn with_capabilities(
59 config: StdioServerConfig,
60 client: ClientInfo,
61 capabilities: Value,
62 ) -> Result<Self, McpError> {
63 let capabilities = match capabilities {
64 Value::Null => Value::Object(Default::default()),
65 other => other,
66 };
67 let transport = JsonRpcTransport::spawn_mcp(config).await?;
68 let params = InitializeParams {
69 client,
70 protocol_version: "2024-11-05".to_string(),
71 capabilities,
72 };
73
74 transport
75 .initialize(params, transport.startup_timeout())
76 .await
77 .map_err(|err| McpError::Handshake(err.to_string()))?;
78
79 Ok(Self {
80 transport: Arc::new(transport),
81 })
82 }
83
84 pub async fn codex(&self, params: CodexCallParams) -> Result<CodexCallHandle, McpError> {
86 self.invoke_tool_call("codex", serde_json::to_value(params)?)
87 .await
88 }
89
90 pub async fn codex_reply(&self, params: CodexReplyParams) -> Result<CodexCallHandle, McpError> {
92 self.invoke_tool_call("codex-reply", serde_json::to_value(params)?)
93 .await
94 }
95
96 pub async fn send_approval(&self, decision: ApprovalDecision) -> Result<(), McpError> {
98 let (_, rx) = self
99 .transport
100 .request(METHOD_CODEX_APPROVAL, serde_json::to_value(decision)?)
101 .await?;
102
103 match rx.await {
104 Ok(Ok(_)) => Ok(()),
105 Ok(Err(err)) => Err(err),
106 Err(_) => Err(McpError::ChannelClosed),
107 }
108 }
109
110 pub fn cancel(&self, request_id: RequestId) -> Result<(), McpError> {
112 self.transport.cancel(request_id)
113 }
114
115 pub async fn shutdown(&self) -> Result<(), McpError> {
117 self.transport.shutdown().await
118 }
119
120 async fn invoke_tool_call(
121 &self,
122 tool_name: &str,
123 arguments: Value,
124 ) -> Result<CodexCallHandle, McpError> {
125 let events = self.transport.register_codex_listener().await;
126 let request = json!({
127 "name": tool_name,
128 "arguments": arguments,
129 });
130 let (request_id, raw_response) = self.transport.request(METHOD_CODEX, request).await?;
131 let response = map_response::<CodexCallResult>(raw_response);
132
133 Ok(CodexCallHandle {
134 request_id,
135 events,
136 response,
137 })
138 }
139}
140
141pub struct CodexAppServer {
143 transport: Arc<JsonRpcTransport>,
144}
145
146impl CodexAppServer {
147 pub async fn start(config: StdioServerConfig, client: ClientInfo) -> Result<Self, McpError> {
149 Self::with_capabilities(config, client, Value::Object(Default::default())).await
150 }
151
152 pub async fn start_experimental(
154 config: StdioServerConfig,
155 client: ClientInfo,
156 ) -> Result<Self, McpError> {
157 Self::with_capabilities(config, client, json!({ "experimentalApi": true })).await
158 }
159
160 pub async fn with_capabilities(
162 config: StdioServerConfig,
163 client: ClientInfo,
164 capabilities: Value,
165 ) -> Result<Self, McpError> {
166 let capabilities = match capabilities {
167 Value::Null => Value::Object(Default::default()),
168 other => other,
169 };
170 let transport = JsonRpcTransport::spawn_app(config).await?;
171 let params = InitializeParams {
172 client,
173 protocol_version: "2024-11-05".to_string(),
174 capabilities,
175 };
176
177 transport
178 .initialize(params, transport.startup_timeout())
179 .await
180 .map_err(|err| McpError::Handshake(err.to_string()))?;
181
182 Ok(Self {
183 transport: Arc::new(transport),
184 })
185 }
186
187 pub async fn thread_start(
189 &self,
190 params: super::ThreadStartParams,
191 ) -> Result<AppCallHandle, McpError> {
192 self.invoke_app_call(METHOD_THREAD_START, serde_json::to_value(params)?)
193 .await
194 }
195
196 pub async fn thread_resume(
198 &self,
199 params: super::ThreadResumeParams,
200 ) -> Result<AppCallHandle, McpError> {
201 self.invoke_app_call(METHOD_THREAD_RESUME, serde_json::to_value(params)?)
202 .await
203 }
204
205 pub async fn thread_list(
207 &self,
208 params: super::ThreadListParams,
209 ) -> Result<super::ThreadListResponse, McpError> {
210 let (_, rx) = self
211 .transport
212 .request(METHOD_THREAD_LIST, serde_json::to_value(params)?)
213 .await?;
214 let mapped = map_response::<super::ThreadListResponse>(rx);
215 match mapped.await {
216 Ok(result) => result,
217 Err(_) => Err(McpError::ChannelClosed),
218 }
219 }
220
221 pub async fn thread_fork(
223 &self,
224 params: super::ThreadForkParams,
225 ) -> Result<super::ThreadForkResponse, McpError> {
226 let (_, rx) = self
227 .transport
228 .request(METHOD_THREAD_FORK, serde_json::to_value(params)?)
229 .await?;
230 let mapped = map_response::<super::ThreadForkResponse>(rx);
231 match mapped.await {
232 Ok(result) => result,
233 Err(_) => Err(McpError::ChannelClosed),
234 }
235 }
236
237 pub async fn turn_start(
239 &self,
240 params: super::TurnStartParams,
241 ) -> Result<AppCallHandle, McpError> {
242 self.invoke_app_call(METHOD_TURN_START, serde_json::to_value(params)?)
243 .await
244 }
245
246 pub async fn turn_start_v2(
248 &self,
249 params: super::TurnStartParamsV2,
250 ) -> Result<AppCallHandle, McpError> {
251 self.invoke_app_call(METHOD_TURN_START, serde_json::to_value(params)?)
252 .await
253 }
254
255 pub async fn select_last_thread_id(&self, cwd: PathBuf) -> Result<Option<String>, McpError> {
257 let mut cursor: Option<String> = None;
258 let mut seen_cursors: HashSet<String> = HashSet::new();
259 let mut best: Option<(i64, i64, String)> = None;
260
261 loop {
262 let page = self
263 .thread_list(super::ThreadListParams {
264 cwd: Some(cwd.clone()),
265 cursor: cursor.clone(),
266 limit: Some(100),
267 sort_key: Some(super::ThreadListSortKey::UpdatedAt),
268 archived: None,
269 model_providers: None,
270 source_kinds: None,
271 })
272 .await?;
273
274 for thread in page.data {
275 let candidate = (thread.updated_at, thread.created_at, thread.id);
276 let should_replace = match best.as_ref() {
277 None => true,
278 Some(current) => {
279 (candidate.0, candidate.1, &candidate.2)
280 > (current.0, current.1, ¤t.2)
281 }
282 };
283
284 if should_replace {
285 best = Some(candidate);
286 }
287 }
288
289 let Some(next_cursor) = page.next_cursor else {
290 break;
291 };
292
293 if !seen_cursors.insert(next_cursor.clone()) {
294 return Err(McpError::Transport(format!(
295 "thread/list pagination cursor repeated: {next_cursor}"
296 )));
297 }
298 cursor = Some(next_cursor);
299 }
300
301 Ok(best.map(|(_, _, id)| id))
302 }
303
304 pub async fn turn_interrupt(
306 &self,
307 params: super::TurnInterruptParams,
308 ) -> Result<AppCallHandle, McpError> {
309 self.invoke_app_call(METHOD_TURN_INTERRUPT, serde_json::to_value(params)?)
310 .await
311 }
312
313 pub fn cancel(&self, request_id: RequestId) -> Result<(), McpError> {
315 self.transport.cancel(request_id)
316 }
317
318 pub async fn shutdown(&self) -> Result<(), McpError> {
320 self.transport.shutdown().await
321 }
322
323 async fn invoke_app_call(
324 &self,
325 method: &str,
326 params: Value,
327 ) -> Result<AppCallHandle, McpError> {
328 let events = self.transport.register_app_listener().await;
329 let (request_id, raw_response) = self.transport.request(method, params).await?;
330 let response = map_response::<Value>(raw_response);
331
332 Ok(AppCallHandle {
333 request_id,
334 events,
335 response,
336 })
337 }
338}