1use crate::capabilities::{
12 AgentCapabilities, AuthenticateParams, AuthenticateResult, ClientCapabilities, ClientInfo,
13 InitializeParams, InitializeResult, SUPPORTED_VERSIONS,
14};
15use crate::error::{AcpError, AcpResult};
16use crate::jsonrpc::{JSONRPC_VERSION, JsonRpcId, JsonRpcRequest, JsonRpcResponse};
17use crate::session::{
18 AcpSession, ServerRequestNotification, SessionCancelParams, SessionLoadParams,
19 SessionLoadResult, SessionNewParams, SessionNewResult, SessionPromptParams,
20 SessionPromptResult, SessionState, SessionUpdateNotification, ToolExecutionResult,
21};
22
23use hashbrown::HashMap;
24use reqwest::{Client as HttpClient, StatusCode};
25use serde::Serialize;
26use serde::de::DeserializeOwned;
27use serde_json::Value;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::time::Duration;
30use tokio::sync::{RwLock, mpsc};
31use tracing::{debug, trace, warn};
32
33pub struct AcpClientV2 {
45 http_client: HttpClient,
47
48 base_url: String,
50
51 #[allow(dead_code)]
53 client_id: String,
54
55 capabilities: ClientCapabilities,
57
58 #[allow(dead_code)]
60 timeout: Duration,
61
62 request_counter: AtomicU64,
64
65 protocol_version: RwLock<Option<String>>,
67
68 agent_capabilities: RwLock<Option<AgentCapabilities>>,
70
71 sessions: RwLock<HashMap<String, AcpSession>>,
73
74 auth_token: RwLock<Option<String>>,
76}
77
78pub struct AcpClientV2Builder {
80 base_url: String,
81 client_id: String,
82 capabilities: ClientCapabilities,
83 timeout: Duration,
84}
85
86impl AcpClientV2Builder {
87 pub fn new(base_url: impl Into<String>) -> Self {
89 Self {
90 base_url: base_url.into(),
91 client_id: format!("vtcode-{}", uuid::Uuid::new_v4()),
92 capabilities: ClientCapabilities::default(),
93 timeout: Duration::from_secs(30),
94 }
95 }
96
97 pub fn with_client_id(mut self, id: impl Into<String>) -> Self {
99 self.client_id = id.into();
100 self
101 }
102
103 pub fn with_capabilities(mut self, caps: ClientCapabilities) -> Self {
105 self.capabilities = caps;
106 self
107 }
108
109 pub fn with_timeout(mut self, timeout: Duration) -> Self {
111 self.timeout = timeout;
112 self
113 }
114
115 pub fn build(self) -> AcpResult<AcpClientV2> {
117 let http_client = HttpClient::builder()
118 .timeout(self.timeout)
119 .build()
120 .map_err(|e| AcpError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
121
122 Ok(AcpClientV2 {
123 http_client,
124 base_url: self.base_url,
125 client_id: self.client_id,
126 capabilities: self.capabilities,
127 timeout: self.timeout,
128 request_counter: AtomicU64::new(1),
129 protocol_version: RwLock::new(None),
130 agent_capabilities: RwLock::new(None),
131 sessions: RwLock::new(HashMap::new()),
132 auth_token: RwLock::new(None),
133 })
134 }
135}
136
137impl AcpClientV2 {
138 pub fn new(base_url: impl Into<String>) -> AcpResult<Self> {
140 AcpClientV2Builder::new(base_url).build()
141 }
142
143 fn next_request_id(&self) -> JsonRpcId {
145 let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
146 JsonRpcId::Number(id as i64)
147 }
148
149 pub async fn is_initialized(&self) -> bool {
151 self.protocol_version.read().await.is_some()
152 }
153
154 pub async fn protocol_version(&self) -> Option<String> {
156 self.protocol_version.read().await.clone()
157 }
158
159 pub async fn agent_capabilities(&self) -> Option<AgentCapabilities> {
161 self.agent_capabilities.read().await.clone()
162 }
163
164 async fn call<P: Serialize, R: DeserializeOwned>(
170 &self,
171 method: &str,
172 params: Option<P>,
173 ) -> AcpResult<R> {
174 let id = self.next_request_id();
175 let params_value = params
176 .map(|p| serde_json::to_value(p))
177 .transpose()
178 .map_err(|e| AcpError::SerializationError(e.to_string()))?;
179
180 let request = JsonRpcRequest {
181 jsonrpc: JSONRPC_VERSION.to_string(),
182 method: method.to_string(),
183 params: params_value,
184 id: Some(id.clone()),
185 };
186
187 debug!(method = method, id = %id, "Sending JSON-RPC request");
188
189 let url = format!("{}/rpc", self.base_url.trim_end_matches('/'));
190
191 let mut req_builder = self.http_client.post(&url).json(&request);
192
193 if let Some(token) = self.auth_token.read().await.as_ref() {
195 req_builder = req_builder.bearer_auth(token);
196 }
197
198 let response = req_builder
199 .send()
200 .await
201 .map_err(|e| AcpError::NetworkError(format!("Request failed: {}", e)))?;
202
203 let status = response.status();
204
205 match status {
206 StatusCode::OK => {
207 let body = response
208 .text()
209 .await
210 .map_err(|e| AcpError::NetworkError(e.to_string()))?;
211
212 trace!(body_len = body.len(), "Received JSON-RPC response");
213
214 let rpc_response: JsonRpcResponse = serde_json::from_str(&body).map_err(|e| {
215 AcpError::SerializationError(format!("Invalid response: {}", e))
216 })?;
217
218 if let Some(error) = rpc_response.error {
219 return Err(AcpError::RemoteError {
220 agent_id: self.base_url.clone(),
221 message: error.message,
222 code: Some(error.code),
223 });
224 }
225
226 let result = rpc_response.result.unwrap_or(Value::Null);
227 serde_json::from_value(result)
228 .map_err(|e| AcpError::SerializationError(format!("Result parse error: {}", e)))
229 }
230
231 StatusCode::UNAUTHORIZED => Err(AcpError::RemoteError {
232 agent_id: self.base_url.clone(),
233 message: "Authentication required".to_string(),
234 code: Some(401),
235 }),
236
237 StatusCode::REQUEST_TIMEOUT => Err(AcpError::Timeout("Request timed out".to_string())),
238
239 _ => {
240 let body = response.text().await.unwrap_or_default();
241 Err(AcpError::RemoteError {
242 agent_id: self.base_url.clone(),
243 message: format!("HTTP {}: {}", status.as_u16(), body),
244 code: Some(status.as_u16() as i32),
245 })
246 }
247 }
248 }
249
250 async fn notify<P: Serialize>(&self, method: &str, params: Option<P>) -> AcpResult<()> {
252 let params_value = params
253 .map(|p| serde_json::to_value(p))
254 .transpose()
255 .map_err(|e| AcpError::SerializationError(e.to_string()))?;
256
257 let request = JsonRpcRequest::notification(method, params_value);
258
259 debug!(method = method, "Sending JSON-RPC notification");
260
261 let url = format!("{}/rpc", self.base_url.trim_end_matches('/'));
262
263 let mut req_builder = self.http_client.post(&url).json(&request);
264
265 if let Some(token) = self.auth_token.read().await.as_ref() {
266 req_builder = req_builder.bearer_auth(token);
267 }
268
269 if let Err(e) = req_builder.send().await {
271 warn!(method = method, error = %e, "ACP notification send failed");
272 }
273
274 Ok(())
275 }
276
277 pub async fn initialize(&self) -> AcpResult<InitializeResult> {
288 let params = InitializeParams {
289 protocol_versions: SUPPORTED_VERSIONS.iter().map(|s| s.to_string()).collect(),
290 capabilities: self.capabilities.clone(),
291 client_info: ClientInfo::default(),
292 };
293
294 let result: InitializeResult = self.call("initialize", Some(params)).await?;
295
296 if !SUPPORTED_VERSIONS.contains(&result.protocol_version.as_str()) {
298 return Err(AcpError::InvalidRequest(format!(
299 "Agent negotiated unsupported protocol version: {}",
300 result.protocol_version
301 )));
302 }
303
304 *self.protocol_version.write().await = Some(result.protocol_version.clone());
306 *self.agent_capabilities.write().await = Some(result.capabilities.clone());
307
308 debug!(
309 protocol = %result.protocol_version,
310 agent = %result.agent_info.name,
311 "ACP connection initialized"
312 );
313
314 Ok(result)
315 }
316
317 pub async fn authenticate(&self, params: AuthenticateParams) -> AcpResult<AuthenticateResult> {
319 let result: AuthenticateResult = self.call("authenticate", Some(params)).await?;
320
321 if result.authenticated {
322 if let Some(token) = &result.session_token {
323 *self.auth_token.write().await = Some(token.clone());
324 }
325 debug!("Authentication successful");
326 } else {
327 warn!("Authentication failed");
328 }
329
330 Ok(result)
331 }
332
333 pub async fn session_new(&self, params: SessionNewParams) -> AcpResult<SessionNewResult> {
335 if !self.is_initialized().await {
336 return Err(AcpError::InvalidRequest(
337 "Client not initialized. Call initialize() first.".to_string(),
338 ));
339 }
340
341 let result: SessionNewResult = self.call("session/new", Some(params)).await?;
342
343 let session = AcpSession::new(&result.session_id);
345 self.sessions
346 .write()
347 .await
348 .insert(result.session_id.clone(), session);
349
350 debug!(session_id = %result.session_id, "Session created");
351
352 Ok(result)
353 }
354
355 pub async fn session_load(&self, session_id: &str) -> AcpResult<SessionLoadResult> {
357 if !self.is_initialized().await {
358 return Err(AcpError::InvalidRequest(
359 "Client not initialized. Call initialize() first.".to_string(),
360 ));
361 }
362
363 let params = SessionLoadParams {
364 session_id: session_id.to_string(),
365 };
366
367 let result: SessionLoadResult = self.call("session/load", Some(params)).await?;
368
369 self.sessions
371 .write()
372 .await
373 .insert(session_id.to_string(), result.session.clone());
374
375 debug!(
376 session_id = session_id,
377 turns = result.history.len(),
378 "Session loaded"
379 );
380
381 Ok(result)
382 }
383
384 pub async fn session_prompt(
389 &self,
390 params: SessionPromptParams,
391 ) -> AcpResult<SessionPromptResult> {
392 self.session_prompt_with_timeout(params, None).await
393 }
394
395 pub async fn session_prompt_with_timeout(
397 &self,
398 params: SessionPromptParams,
399 timeout: Option<Duration>,
400 ) -> AcpResult<SessionPromptResult> {
401 if !self.is_initialized().await {
402 return Err(AcpError::InvalidRequest(
403 "Client not initialized. Call initialize() first.".to_string(),
404 ));
405 }
406
407 let session_id = params.session_id.clone();
408
409 if let Some(session) = self.sessions.write().await.get_mut(&session_id) {
411 session.set_state(SessionState::Active);
412 session.increment_turn();
413 }
414
415 let result: SessionPromptResult = if let Some(custom_timeout) = timeout {
417 tokio::time::timeout(
418 custom_timeout,
419 self.call::<_, SessionPromptResult>("session/prompt", Some(params)),
420 )
421 .await
422 .map_err(|_| AcpError::Timeout("Prompt request timed out".to_string()))??
423 } else {
424 self.call("session/prompt", Some(params)).await?
425 };
426
427 if let Some(session) = self.sessions.write().await.get_mut(&session_id) {
429 match result.status {
430 crate::session::TurnStatus::Completed => {
431 session.set_state(SessionState::AwaitingInput);
432 }
433 crate::session::TurnStatus::Cancelled => {
434 session.set_state(SessionState::Cancelled);
435 }
436 crate::session::TurnStatus::Failed => {
437 session.set_state(SessionState::Failed);
438 }
439 crate::session::TurnStatus::AwaitingInput => {
440 session.set_state(SessionState::AwaitingInput);
441 }
442 }
443 }
444
445 debug!(
446 session_id = %session_id,
447 turn_id = %result.turn_id,
448 status = ?result.status,
449 "Prompt completed"
450 );
451
452 Ok(result)
453 }
454
455 pub async fn session_cancel(&self, session_id: &str, turn_id: Option<&str>) -> AcpResult<()> {
457 let params = SessionCancelParams {
458 session_id: session_id.to_string(),
459 turn_id: turn_id.map(String::from),
460 };
461
462 self.notify("session/cancel", Some(params)).await?;
463
464 if let Some(session) = self.sessions.write().await.get_mut(session_id) {
466 session.set_state(SessionState::Cancelled);
467 }
468
469 debug!(session_id = session_id, "Session cancelled");
470
471 Ok(())
472 }
473
474 pub async fn get_session(&self, session_id: &str) -> Option<AcpSession> {
476 self.sessions.read().await.get(session_id).cloned()
477 }
478
479 pub async fn list_sessions(&self) -> Vec<AcpSession> {
481 self.sessions.read().await.values().cloned().collect()
482 }
483
484 pub async fn session_tool_response(
493 &self,
494 session_id: &str,
495 result: ToolExecutionResult,
496 ) -> AcpResult<()> {
497 self.notify(
498 "client/response",
499 Some(serde_json::json!({
500 "session_id": session_id,
501 "result": result,
502 })),
503 )
504 .await
505 }
506
507 pub async fn subscribe_updates(
516 &self,
517 session_id: &str,
518 ) -> AcpResult<mpsc::Receiver<SessionUpdateNotification>> {
519 let (tx, rx) = mpsc::channel(100);
520
521 let url = format!(
522 "{}/sse/session/{}",
523 self.base_url.trim_end_matches('/'),
524 session_id
525 );
526
527 let _http_client = self.http_client.clone();
528 let auth_token = self.auth_token.read().await.clone();
529
530 tokio::spawn(async move {
532 if let Err(e) = Self::sse_listener(url, auth_token, tx).await {
533 warn!("SSE listener error: {}", e);
534 }
535 });
536
537 Ok(rx)
538 }
539
540 async fn sse_listener(
542 url: String,
543 auth_token: Option<String>,
544 tx: mpsc::Sender<SessionUpdateNotification>,
545 ) -> AcpResult<()> {
546 let client = HttpClient::new();
547
548 let mut req = client.get(&url);
549 if let Some(token) = auth_token {
550 req = req.bearer_auth(token);
551 }
552
553 let response = req
554 .header("Accept", "text/event-stream")
555 .send()
556 .await
557 .map_err(|e| AcpError::NetworkError(e.to_string()))?;
558
559 if !response.status().is_success() {
560 return Err(AcpError::NetworkError(format!(
561 "SSE connection failed: {}",
562 response.status()
563 )));
564 }
565
566 let mut stream = response.bytes_stream();
567 use futures_util::StreamExt;
568
569 let mut buffer = String::new();
570
571 while let Some(chunk) = stream.next().await {
572 let chunk = chunk.map_err(|e| AcpError::NetworkError(e.to_string()))?;
573 buffer.push_str(&String::from_utf8_lossy(&chunk));
574
575 while let Some(event_end) = buffer.find("\n\n") {
577 let event = buffer[..event_end].to_string();
578 buffer = buffer[event_end + 2..].to_string();
579
580 let mut event_type = None;
582 let mut data_lines = Vec::new();
583
584 for line in event.lines() {
585 if let Some(data) = line.strip_prefix("data:") {
586 data_lines.push(data.trim());
587 } else if let Some(evt) = line.strip_prefix("event:") {
588 event_type = Some(evt.trim());
589 }
590 }
592
593 if (event_type.is_none() || event_type == Some("session/update"))
595 && !data_lines.is_empty()
596 {
597 let data = data_lines.join("\n");
598 if let Ok(notification) =
599 serde_json::from_str::<SessionUpdateNotification>(&data)
600 && tx.send(notification).await.is_err()
601 {
602 return Ok(());
604 }
605 }
606
607 if event_type == Some("server/request") && !data_lines.is_empty() {
610 let data = data_lines.join("\n");
611 match serde_json::from_str::<ServerRequestNotification>(&data) {
612 Ok(server_req) => {
613 let notification = SessionUpdateNotification {
614 session_id: server_req.session_id.clone(),
615 turn_id: String::new(),
616 update: crate::session::SessionUpdate::ServerRequest {
617 request: server_req.request,
618 },
619 };
620 if tx.send(notification).await.is_err() {
621 return Ok(());
622 }
623 }
624 Err(e) => {
625 warn!("Failed to parse server/request SSE event: {e}");
626 }
627 }
628 }
629 }
630 }
631
632 Ok(())
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_client_builder() {
642 let client = AcpClientV2Builder::new("http://localhost:8080")
643 .with_client_id("test-client")
644 .with_timeout(Duration::from_secs(60))
645 .build()
646 .unwrap();
647
648 assert_eq!(client.base_url, "http://localhost:8080");
649 assert_eq!(client.client_id, "test-client");
650 assert_eq!(client.timeout, Duration::from_secs(60));
651 }
652
653 #[tokio::test]
654 async fn test_client_not_initialized() {
655 let client = AcpClientV2::new("http://localhost:8080").unwrap();
656 assert!(!client.is_initialized().await);
657 }
658
659 #[test]
660 fn test_request_id_generation() {
661 let client = AcpClientV2::new("http://localhost:8080").unwrap();
662
663 let id1 = client.next_request_id();
664 let id2 = client.next_request_id();
665
666 assert_ne!(id1, id2);
667 }
668
669 #[tokio::test]
670 async fn test_session_tool_response_sends_notification() {
671 let client = AcpClientV2::new("http://localhost:9999").unwrap();
674 let result = ToolExecutionResult {
675 request_id: "req-1".to_string(),
676 tool_call_id: "tc-1".to_string(),
677 output: serde_json::json!({"result": "ok"}),
678 success: true,
679 error: None,
680 };
681 let _ = client.session_tool_response("sess-1", result).await;
684 }
685}