Skip to main content

vtcode_core/a2a/
server.rs

1//! A2A HTTP Server using axum
2//!
3//! Provides HTTP endpoints for the A2A Protocol, enabling VT Code to operate as an A2A agent.
4//! The server exposes:
5//! - Agent discovery via `/.well-known/agent-card.json`
6//! - RPC endpoints at `/a2a` for message sending and task management
7//! - Streaming endpoint at `/a2a/stream` for real-time updates via Server-Sent Events
8
9use crate::a2a::WebhookNotifier;
10use crate::a2a::agent_card::AgentCard;
11use crate::a2a::errors::{A2aError, A2aErrorCode, A2aResult};
12use crate::a2a::rpc::{
13    JSONRPC_VERSION, JsonRpcError, JsonRpcRequest, JsonRpcResponse, ListTasksParams,
14    METHOD_MESSAGE_SEND, METHOD_MESSAGE_STREAM, METHOD_TASKS_CANCEL, METHOD_TASKS_GET,
15    METHOD_TASKS_LIST, METHOD_TASKS_PUSH_CONFIG_GET, METHOD_TASKS_PUSH_CONFIG_SET,
16    MessageSendParams, SendStreamingMessageResponse, StreamingEvent, TaskIdParams, TaskQueryParams,
17};
18use crate::a2a::task_manager::TaskManager;
19use crate::a2a::types::TaskState;
20use axum::{
21    Json, Router,
22    extract::State,
23    http::StatusCode,
24    response::{
25        IntoResponse, Response,
26        sse::{Event, Sse},
27    },
28    routing::post,
29};
30use serde_json::{Value, json};
31use std::convert::Infallible;
32use std::future::Future;
33use std::net::SocketAddr;
34use std::sync::Arc;
35use std::time::Duration;
36use tower_http::cors::CorsLayer;
37
38// ============================================================================
39// Server State
40// ============================================================================
41
42/// A2A Server State containing shared resources
43#[derive(Debug, Clone)]
44pub struct A2aServerState {
45    /// Task manager for handling task lifecycle
46    pub task_manager: Arc<TaskManager>,
47    /// Agent card for discovery
48    pub agent_card: Arc<AgentCard>,
49    /// Broadcast channel for streaming events
50    pub event_tx: Arc<tokio::sync::broadcast::Sender<StreamingEvent>>,
51    /// Webhook notifier for push notifications
52    pub webhook_notifier: Arc<WebhookNotifier>,
53}
54
55impl A2aServerState {
56    /// Create a new server state
57    pub fn new(task_manager: TaskManager, agent_card: AgentCard) -> Self {
58        let (event_tx, _) = tokio::sync::broadcast::channel(100);
59        Self {
60            task_manager: Arc::new(task_manager),
61            agent_card: Arc::new(agent_card),
62            event_tx: Arc::new(event_tx),
63            webhook_notifier: Arc::new(WebhookNotifier::new()),
64        }
65    }
66
67    /// Create a server state with default settings for VT Code
68    pub fn vtcode_default(base_url: impl Into<String>) -> Self {
69        Self::new(TaskManager::new(), AgentCard::vtcode_default(base_url))
70    }
71}
72
73// ============================================================================
74// Router Creation
75// ============================================================================
76
77/// Create the A2A HTTP router
78pub fn create_router(state: A2aServerState) -> Router {
79    Router::new()
80        .route(
81            "/.well-known/agent-card.json",
82            axum::routing::get(get_agent_card),
83        )
84        .route("/a2a", post(handle_rpc))
85        .route("/a2a/stream", post(handle_stream))
86        .with_state(state)
87        .layer(CorsLayer::permissive())
88}
89
90// ============================================================================
91// Handlers
92// ============================================================================
93
94/// Get agent card for discovery
95async fn get_agent_card(State(state): State<A2aServerState>) -> Json<AgentCard> {
96    Json(state.agent_card.as_ref().clone())
97}
98
99/// Handle JSON-RPC requests
100async fn handle_rpc(
101    State(state): State<A2aServerState>,
102    Json(request): Json<JsonRpcRequest>,
103) -> Result<Json<JsonRpcResponse>, A2aErrorResponse> {
104    // Validate request
105    if request.jsonrpc != JSONRPC_VERSION {
106        return Err(A2aErrorResponse::invalid_request(
107            "Invalid JSON-RPC version",
108            request.id,
109        ));
110    }
111
112    // Dispatch to method handler
113    let result = match request.method.as_str() {
114        METHOD_MESSAGE_SEND => {
115            handle_message_send(&state, request.params, request.id.clone()).await
116        }
117        METHOD_MESSAGE_STREAM => {
118            handle_message_stream(&state, request.params, request.id.clone()).await
119        }
120        METHOD_TASKS_GET => handle_tasks_get(&state, request.params, request.id.clone()).await,
121        METHOD_TASKS_LIST => handle_tasks_list(&state, request.params, request.id.clone()).await,
122        METHOD_TASKS_CANCEL => {
123            handle_tasks_cancel(&state, request.params, request.id.clone()).await
124        }
125        METHOD_TASKS_PUSH_CONFIG_SET => {
126            handle_push_config_set(&state, request.params, request.id.clone()).await
127        }
128        METHOD_TASKS_PUSH_CONFIG_GET => {
129            handle_push_config_get(&state, request.params, request.id.clone()).await
130        }
131        _ => {
132            return Err(A2aErrorResponse::method_not_found(
133                &request.method,
134                request.id,
135            ));
136        }
137    };
138
139    match result {
140        Ok(result_value) => Ok(Json(JsonRpcResponse::success(result_value, request.id))),
141        Err(err) => Err(A2aErrorResponse::from_error(err, request.id)),
142    }
143}
144
145/// Handle Server-Sent Events streaming
146async fn handle_stream(
147    State(state): State<A2aServerState>,
148    Json(request): Json<JsonRpcRequest>,
149) -> impl IntoResponse {
150    if request.jsonrpc != JSONRPC_VERSION {
151        return Err(A2aErrorResponse::invalid_request(
152            "Invalid JSON-RPC version",
153            request.id.clone(),
154        ));
155    }
156
157    if request.method != METHOD_MESSAGE_STREAM {
158        return Err(A2aErrorResponse::method_not_found(
159            &request.method,
160            request.id.clone(),
161        ));
162    }
163
164    // Parse params
165    let params: MessageSendParams = serde_json::from_value(request.params.unwrap_or_default())
166        .map_err(|_| {
167            A2aErrorResponse::invalid_request("Invalid message/stream params", request.id.clone())
168        })?;
169
170    // Create or get task
171    let task_id = if let Some(task_id) = params.task_id.clone() {
172        task_id
173    } else {
174        let task = state
175            .task_manager
176            .create_task(params.context_id.clone())
177            .await;
178        task.id.clone()
179    };
180
181    // Add initial message
182    state
183        .task_manager
184        .add_message(&task_id, params.message.clone())
185        .await
186        .map_err(|e| A2aErrorResponse::from_error(e, request.id.clone()))?;
187
188    // Subscribe to broadcast channel
189    let mut rx = state.event_tx.subscribe();
190    let task_id_clone = task_id.clone();
191    let context_id = params.context_id.clone();
192    let notifier = state.webhook_notifier.clone();
193    let task_manager = state.task_manager.clone();
194
195    // Create stream from broadcast receiver using async_stream
196    let stream = async_stream::stream! {
197        while let Ok(event) = rx.recv().await {
198            // Filter events for this task/context
199            let matches = match &event {
200                StreamingEvent::Message { context_id: ctx, .. } => {
201                    context_id.as_ref() == ctx.as_ref()
202                }
203                StreamingEvent::TaskStatus { task_id: tid, .. } => tid == &task_id_clone,
204                StreamingEvent::TaskArtifact { task_id: tid, .. } => tid == &task_id_clone,
205            };
206
207            if matches {
208                // Fire webhook asynchronously (best-effort)
209                let notifier = notifier.clone();
210                let task_manager = task_manager.clone();
211                let task_id_for_hook = task_id_clone.clone();
212                let event_for_hook = event.clone();
213                tokio::spawn(async move {
214                    if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
215                        let _ = notifier.send_event(&cfg, event_for_hook).await;
216                    }
217                });
218
219                let is_final = event.is_final();
220                let json = serde_json::to_string(&SendStreamingMessageResponse { event })
221                    .unwrap_or_default();
222                yield Ok::<_, Infallible>(Event::default().data(json));
223
224                if is_final {
225                    break;
226                }
227            }
228        }
229    };
230
231    // Start background task to process and emit events
232    let state_clone = state.clone();
233    let task_id_clone = task_id.clone();
234    tokio::spawn(async move {
235        // Simulate agent processing
236        tokio::time::sleep(Duration::from_millis(100)).await;
237
238        // Update task to working
239        let _ = state_clone
240            .task_manager
241            .update_status(&task_id_clone, TaskState::Working, None)
242            .await;
243
244        // Send status update event
245        let status_event = StreamingEvent::TaskStatus {
246            task_id: task_id_clone.clone(),
247            context_id: params.context_id.clone(),
248            status: crate::a2a::types::TaskStatus::new(TaskState::Working),
249            kind: "status-update".to_string(),
250            r#final: false,
251        };
252        let _ = state_clone.event_tx.send(status_event.clone());
253
254        // Fire webhook if configured
255        let notifier = state_clone.webhook_notifier.clone();
256        let task_manager = state_clone.task_manager.clone();
257        let task_id_for_hook = task_id_clone.clone();
258        tokio::spawn(async move {
259            if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
260                let _ = notifier.send_event(&cfg, status_event).await;
261            }
262        });
263
264        // Simulate generating a response message
265        tokio::time::sleep(Duration::from_millis(200)).await;
266        let response_msg = crate::a2a::types::Message::agent_text("Processing your request...");
267        let message_event = StreamingEvent::Message {
268            message: response_msg,
269            context_id: params.context_id.clone(),
270            kind: "streaming-response".to_string(),
271            r#final: false,
272        };
273        let _ = state_clone.event_tx.send(message_event.clone());
274
275        // Fire webhook if configured
276        let notifier = state_clone.webhook_notifier.clone();
277        let task_manager = state_clone.task_manager.clone();
278        let task_id_for_hook = task_id_clone.clone();
279        tokio::spawn(async move {
280            if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
281                let _ = notifier.send_event(&cfg, message_event).await;
282            }
283        });
284
285        // Complete the task
286        tokio::time::sleep(Duration::from_millis(300)).await;
287        let _ = state_clone
288            .task_manager
289            .update_status(&task_id_clone, TaskState::Completed, None)
290            .await;
291
292        // Send final status event
293        let final_status_event = StreamingEvent::TaskStatus {
294            task_id: task_id_clone,
295            context_id: params.context_id,
296            status: crate::a2a::types::TaskStatus::new(TaskState::Completed),
297            kind: "status-update".to_string(),
298            r#final: true,
299        };
300        let _ = state_clone.event_tx.send(final_status_event.clone());
301
302        // Fire webhook if configured
303        let notifier = state_clone.webhook_notifier.clone();
304        let task_manager = state_clone.task_manager.clone();
305        let task_id_for_hook = final_status_event.task_id().unwrap_or_default().to_string();
306        tokio::spawn(async move {
307            if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
308                let _ = notifier.send_event(&cfg, final_status_event).await;
309            }
310        });
311    });
312
313    Ok(Sse::new(Box::pin(stream)).keep_alive(
314        axum::response::sse::KeepAlive::new()
315            .interval(Duration::from_secs(15))
316            .text("keep-alive"),
317    ))
318}
319
320// ============================================================================
321// RPC Method Handlers
322// ============================================================================
323
324/// Handle message/send RPC method
325async fn handle_message_send(
326    state: &A2aServerState,
327    params: Option<Value>,
328    _id: Value,
329) -> A2aResult<Value> {
330    let params: MessageSendParams = serde_json::from_value(params.unwrap_or_default())
331        .map_err(|_| A2aError::rpc(A2aErrorCode::InvalidParams, "Invalid message/send params"))?;
332
333    // Create or get task
334    let task_id = if let Some(task_id) = params.task_id {
335        task_id
336    } else {
337        let task = state.task_manager.create_task(params.context_id).await;
338        task.id.clone()
339    };
340
341    // Add message to history
342    state
343        .task_manager
344        .add_message(&task_id, params.message)
345        .await?;
346
347    // Update status to working
348    let task = state
349        .task_manager
350        .update_status(&task_id, TaskState::Working, None)
351        .await?;
352
353    // Return task as response
354    Ok(serde_json::to_value(task)?)
355}
356
357/// Handle tasks/pushNotificationConfig/set RPC method
358async fn handle_push_config_set(
359    state: &A2aServerState,
360    params: Option<Value>,
361    _id: Value,
362) -> A2aResult<Value> {
363    let config: crate::a2a::rpc::TaskPushNotificationConfig =
364        serde_json::from_value(params.unwrap_or_default()).map_err(|_| {
365            A2aError::rpc(
366                A2aErrorCode::InvalidParams,
367                "Invalid pushNotificationConfig/set params",
368            )
369        })?;
370
371    state.task_manager.set_webhook_config(config).await?;
372
373    Ok(json!({ "success": true }))
374}
375
376/// Handle tasks/pushNotificationConfig/get RPC method
377async fn handle_push_config_get(
378    state: &A2aServerState,
379    params: Option<Value>,
380    _id: Value,
381) -> A2aResult<Value> {
382    let params: TaskIdParams =
383        serde_json::from_value(params.unwrap_or_default()).map_err(|_| {
384            A2aError::rpc(
385                A2aErrorCode::InvalidParams,
386                "Invalid pushNotificationConfig/get params",
387            )
388        })?;
389
390    let config = state.task_manager.get_webhook_config(&params.id).await;
391
392    Ok(serde_json::to_value(config)?)
393}
394
395/// Handle message/stream RPC method
396fn handle_message_stream<'a>(
397    state: &'a A2aServerState,
398    params: Option<Value>,
399    id: Value,
400) -> impl Future<Output = A2aResult<Value>> + 'a {
401    // Same as message_send for now, but would support streaming
402    handle_message_send(state, params, id)
403}
404
405/// Handle tasks/get RPC method
406async fn handle_tasks_get(
407    state: &A2aServerState,
408    params: Option<Value>,
409    _id: Value,
410) -> A2aResult<Value> {
411    let params: TaskQueryParams = serde_json::from_value(params.unwrap_or_default())
412        .map_err(|_| A2aError::rpc(A2aErrorCode::InvalidParams, "Invalid tasks/get params"))?;
413
414    let task = state.task_manager.get_task_or_error(&params.id).await?;
415
416    Ok(serde_json::to_value(task)?)
417}
418
419/// Handle tasks/list RPC method
420async fn handle_tasks_list(
421    state: &A2aServerState,
422    params: Option<Value>,
423    _id: Value,
424) -> A2aResult<Value> {
425    let params: ListTasksParams =
426        serde_json::from_value(params.unwrap_or_default()).unwrap_or_default();
427
428    let result = state.task_manager.list_tasks(params).await;
429
430    Ok(serde_json::to_value(result)?)
431}
432
433/// Handle tasks/cancel RPC method
434async fn handle_tasks_cancel(
435    state: &A2aServerState,
436    params: Option<Value>,
437    _id: Value,
438) -> A2aResult<Value> {
439    let params: TaskIdParams = serde_json::from_value(params.unwrap_or_default())
440        .map_err(|_| A2aError::rpc(A2aErrorCode::InvalidParams, "Invalid tasks/cancel params"))?;
441
442    let task = state.task_manager.cancel_task(&params.id).await?;
443
444    Ok(serde_json::to_value(task)?)
445}
446
447// ============================================================================
448// Error Response Handler
449// ============================================================================
450
451/// A2A error response for Axum
452pub struct A2aErrorResponse {
453    response: JsonRpcResponse,
454    status_code: StatusCode,
455}
456
457impl A2aErrorResponse {
458    /// Create a new error response
459    pub fn new(error: JsonRpcError, id: Value, status_code: StatusCode) -> Self {
460        Self {
461            response: JsonRpcResponse::error(error, id),
462            status_code,
463        }
464    }
465
466    /// Create an invalid request error response
467    pub fn invalid_request(message: &str, id: Value) -> Self {
468        Self::new(
469            JsonRpcError::invalid_request(message),
470            id,
471            StatusCode::BAD_REQUEST,
472        )
473    }
474
475    /// Create a method not found error response
476    pub fn method_not_found(method: &str, id: Value) -> Self {
477        Self::new(
478            JsonRpcError::method_not_found(method),
479            id,
480            StatusCode::NOT_FOUND,
481        )
482    }
483
484    /// Create an error response from an A2aError
485    pub fn from_error(error: A2aError, id: Value) -> Self {
486        let code: i32 = error.code().into();
487        let message = error.to_string();
488        let status_code = match error {
489            A2aError::TaskNotFound(_) => StatusCode::NOT_FOUND,
490            A2aError::TaskNotCancelable(_) => StatusCode::UNPROCESSABLE_ENTITY,
491            A2aError::InvalidStateTransition { .. } => StatusCode::UNPROCESSABLE_ENTITY,
492            _ => StatusCode::INTERNAL_SERVER_ERROR,
493        };
494
495        Self::new(JsonRpcError::new(code, message), id, status_code)
496    }
497}
498
499impl IntoResponse for A2aErrorResponse {
500    fn into_response(self) -> Response {
501        (self.status_code, Json(self.response)).into_response()
502    }
503}
504
505// ============================================================================
506// Server Startup
507// ============================================================================
508
509/// Run the A2A server
510pub async fn run(
511    state: A2aServerState,
512    addr: SocketAddr,
513) -> Result<(), Box<dyn std::error::Error>> {
514    let listener = tokio::net::TcpListener::bind(addr).await?;
515    tracing::info!("A2A server listening on {}", addr);
516    axum::serve(listener, create_router(state))
517        .with_graceful_shutdown(crate::shutdown::shutdown_signal_logged("A2A"))
518        .await?;
519    Ok(())
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    #[test]
527    fn test_server_state_creation() {
528        let state = A2aServerState::vtcode_default("http://localhost:8080");
529        assert_eq!(state.agent_card.name, "vtcode-agent");
530    }
531
532    #[test]
533    fn test_error_response_task_not_found() {
534        use serde_json::json;
535        let err_response =
536            A2aErrorResponse::from_error(A2aError::TaskNotFound("test-id".to_string()), json!(1));
537        assert_eq!(err_response.status_code, StatusCode::NOT_FOUND);
538    }
539
540    #[test]
541    fn test_error_response_task_not_cancelable() {
542        use serde_json::json;
543        let err = A2aError::TaskNotCancelable("Cannot cancel completed task".to_string());
544        let err_response = A2aErrorResponse::from_error(err, json!(1));
545        assert_eq!(err_response.status_code, StatusCode::UNPROCESSABLE_ENTITY);
546    }
547
548    #[test]
549    fn test_error_response_invalid_request() {
550        use serde_json::json;
551        let err_response = A2aErrorResponse::invalid_request("Invalid JSON", json!(1));
552        assert_eq!(err_response.status_code, StatusCode::BAD_REQUEST);
553    }
554
555    #[tokio::test]
556    async fn test_server_state_with_broadcast() {
557        let state = A2aServerState::vtcode_default("http://localhost:8080");
558
559        // Verify broadcast channel works
560        let mut rx = state.event_tx.subscribe();
561
562        // Send a test event
563        let test_event = StreamingEvent::Message {
564            message: super::super::types::Message::agent_text("Test"),
565            context_id: Some("test".to_string()),
566            kind: "streaming-response".to_string(),
567            r#final: false,
568        };
569
570        state.event_tx.send(test_event.clone()).expect("send event");
571
572        // Receive the event
573        let received = rx.recv().await.expect("receive event");
574        assert!(!received.is_final());
575    }
576}