1use crate::a2a::WebhookNotifier;
10use crate::a2a::agent_card::AgentCard;
11use crate::a2a::errors::{A2aError, A2aErrorCode, A2aResult};
12use crate::a2a::rpc::{
13 JSONRPC_VERSION, JsonRpcError, JsonRpcRequest, JsonRpcResponse, ListTasksParams,
14 METHOD_MESSAGE_SEND, METHOD_MESSAGE_STREAM, METHOD_TASKS_CANCEL, METHOD_TASKS_GET,
15 METHOD_TASKS_LIST, METHOD_TASKS_PUSH_CONFIG_GET, METHOD_TASKS_PUSH_CONFIG_SET,
16 MessageSendParams, SendStreamingMessageResponse, StreamingEvent, TaskIdParams, TaskQueryParams,
17};
18use crate::a2a::task_manager::TaskManager;
19use crate::a2a::types::TaskState;
20use axum::{
21 Json, Router,
22 extract::State,
23 http::StatusCode,
24 response::{
25 IntoResponse, Response,
26 sse::{Event, Sse},
27 },
28 routing::post,
29};
30use serde_json::{Value, json};
31use std::convert::Infallible;
32use std::future::Future;
33use std::net::SocketAddr;
34use std::sync::Arc;
35use std::time::Duration;
36use tower_http::cors::CorsLayer;
37
38#[derive(Debug, Clone)]
44pub struct A2aServerState {
45 pub task_manager: Arc<TaskManager>,
47 pub agent_card: Arc<AgentCard>,
49 pub event_tx: Arc<tokio::sync::broadcast::Sender<StreamingEvent>>,
51 pub webhook_notifier: Arc<WebhookNotifier>,
53}
54
55impl A2aServerState {
56 pub fn new(task_manager: TaskManager, agent_card: AgentCard) -> Self {
58 let (event_tx, _) = tokio::sync::broadcast::channel(100);
59 Self {
60 task_manager: Arc::new(task_manager),
61 agent_card: Arc::new(agent_card),
62 event_tx: Arc::new(event_tx),
63 webhook_notifier: Arc::new(WebhookNotifier::new()),
64 }
65 }
66
67 pub fn vtcode_default(base_url: impl Into<String>) -> Self {
69 Self::new(TaskManager::new(), AgentCard::vtcode_default(base_url))
70 }
71}
72
73pub fn create_router(state: A2aServerState) -> Router {
79 Router::new()
80 .route(
81 "/.well-known/agent-card.json",
82 axum::routing::get(get_agent_card),
83 )
84 .route("/a2a", post(handle_rpc))
85 .route("/a2a/stream", post(handle_stream))
86 .with_state(state)
87 .layer(CorsLayer::permissive())
88}
89
90async fn get_agent_card(State(state): State<A2aServerState>) -> Json<AgentCard> {
96 Json(state.agent_card.as_ref().clone())
97}
98
99async fn handle_rpc(
101 State(state): State<A2aServerState>,
102 Json(request): Json<JsonRpcRequest>,
103) -> Result<Json<JsonRpcResponse>, A2aErrorResponse> {
104 if request.jsonrpc != JSONRPC_VERSION {
106 return Err(A2aErrorResponse::invalid_request(
107 "Invalid JSON-RPC version",
108 request.id,
109 ));
110 }
111
112 let result = match request.method.as_str() {
114 METHOD_MESSAGE_SEND => {
115 handle_message_send(&state, request.params, request.id.clone()).await
116 }
117 METHOD_MESSAGE_STREAM => {
118 handle_message_stream(&state, request.params, request.id.clone()).await
119 }
120 METHOD_TASKS_GET => handle_tasks_get(&state, request.params, request.id.clone()).await,
121 METHOD_TASKS_LIST => handle_tasks_list(&state, request.params, request.id.clone()).await,
122 METHOD_TASKS_CANCEL => {
123 handle_tasks_cancel(&state, request.params, request.id.clone()).await
124 }
125 METHOD_TASKS_PUSH_CONFIG_SET => {
126 handle_push_config_set(&state, request.params, request.id.clone()).await
127 }
128 METHOD_TASKS_PUSH_CONFIG_GET => {
129 handle_push_config_get(&state, request.params, request.id.clone()).await
130 }
131 _ => {
132 return Err(A2aErrorResponse::method_not_found(
133 &request.method,
134 request.id,
135 ));
136 }
137 };
138
139 match result {
140 Ok(result_value) => Ok(Json(JsonRpcResponse::success(result_value, request.id))),
141 Err(err) => Err(A2aErrorResponse::from_error(err, request.id)),
142 }
143}
144
145async fn handle_stream(
147 State(state): State<A2aServerState>,
148 Json(request): Json<JsonRpcRequest>,
149) -> impl IntoResponse {
150 if request.jsonrpc != JSONRPC_VERSION {
151 return Err(A2aErrorResponse::invalid_request(
152 "Invalid JSON-RPC version",
153 request.id.clone(),
154 ));
155 }
156
157 if request.method != METHOD_MESSAGE_STREAM {
158 return Err(A2aErrorResponse::method_not_found(
159 &request.method,
160 request.id.clone(),
161 ));
162 }
163
164 let params: MessageSendParams = serde_json::from_value(request.params.unwrap_or_default())
166 .map_err(|_| {
167 A2aErrorResponse::invalid_request("Invalid message/stream params", request.id.clone())
168 })?;
169
170 let task_id = if let Some(task_id) = params.task_id.clone() {
172 task_id
173 } else {
174 let task = state
175 .task_manager
176 .create_task(params.context_id.clone())
177 .await;
178 task.id.clone()
179 };
180
181 state
183 .task_manager
184 .add_message(&task_id, params.message.clone())
185 .await
186 .map_err(|e| A2aErrorResponse::from_error(e, request.id.clone()))?;
187
188 let mut rx = state.event_tx.subscribe();
190 let task_id_clone = task_id.clone();
191 let context_id = params.context_id.clone();
192 let notifier = state.webhook_notifier.clone();
193 let task_manager = state.task_manager.clone();
194
195 let stream = async_stream::stream! {
197 while let Ok(event) = rx.recv().await {
198 let matches = match &event {
200 StreamingEvent::Message { context_id: ctx, .. } => {
201 context_id.as_ref() == ctx.as_ref()
202 }
203 StreamingEvent::TaskStatus { task_id: tid, .. } => tid == &task_id_clone,
204 StreamingEvent::TaskArtifact { task_id: tid, .. } => tid == &task_id_clone,
205 };
206
207 if matches {
208 let notifier = notifier.clone();
210 let task_manager = task_manager.clone();
211 let task_id_for_hook = task_id_clone.clone();
212 let event_for_hook = event.clone();
213 tokio::spawn(async move {
214 if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
215 let _ = notifier.send_event(&cfg, event_for_hook).await;
216 }
217 });
218
219 let is_final = event.is_final();
220 let json = serde_json::to_string(&SendStreamingMessageResponse { event })
221 .unwrap_or_default();
222 yield Ok::<_, Infallible>(Event::default().data(json));
223
224 if is_final {
225 break;
226 }
227 }
228 }
229 };
230
231 let state_clone = state.clone();
233 let task_id_clone = task_id.clone();
234 tokio::spawn(async move {
235 tokio::time::sleep(Duration::from_millis(100)).await;
237
238 let _ = state_clone
240 .task_manager
241 .update_status(&task_id_clone, TaskState::Working, None)
242 .await;
243
244 let status_event = StreamingEvent::TaskStatus {
246 task_id: task_id_clone.clone(),
247 context_id: params.context_id.clone(),
248 status: crate::a2a::types::TaskStatus::new(TaskState::Working),
249 kind: "status-update".to_string(),
250 r#final: false,
251 };
252 let _ = state_clone.event_tx.send(status_event.clone());
253
254 let notifier = state_clone.webhook_notifier.clone();
256 let task_manager = state_clone.task_manager.clone();
257 let task_id_for_hook = task_id_clone.clone();
258 tokio::spawn(async move {
259 if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
260 let _ = notifier.send_event(&cfg, status_event).await;
261 }
262 });
263
264 tokio::time::sleep(Duration::from_millis(200)).await;
266 let response_msg = crate::a2a::types::Message::agent_text("Processing your request...");
267 let message_event = StreamingEvent::Message {
268 message: response_msg,
269 context_id: params.context_id.clone(),
270 kind: "streaming-response".to_string(),
271 r#final: false,
272 };
273 let _ = state_clone.event_tx.send(message_event.clone());
274
275 let notifier = state_clone.webhook_notifier.clone();
277 let task_manager = state_clone.task_manager.clone();
278 let task_id_for_hook = task_id_clone.clone();
279 tokio::spawn(async move {
280 if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
281 let _ = notifier.send_event(&cfg, message_event).await;
282 }
283 });
284
285 tokio::time::sleep(Duration::from_millis(300)).await;
287 let _ = state_clone
288 .task_manager
289 .update_status(&task_id_clone, TaskState::Completed, None)
290 .await;
291
292 let final_status_event = StreamingEvent::TaskStatus {
294 task_id: task_id_clone,
295 context_id: params.context_id,
296 status: crate::a2a::types::TaskStatus::new(TaskState::Completed),
297 kind: "status-update".to_string(),
298 r#final: true,
299 };
300 let _ = state_clone.event_tx.send(final_status_event.clone());
301
302 let notifier = state_clone.webhook_notifier.clone();
304 let task_manager = state_clone.task_manager.clone();
305 let task_id_for_hook = final_status_event.task_id().unwrap_or_default().to_string();
306 tokio::spawn(async move {
307 if let Some(cfg) = task_manager.get_webhook_config(&task_id_for_hook).await {
308 let _ = notifier.send_event(&cfg, final_status_event).await;
309 }
310 });
311 });
312
313 Ok(Sse::new(Box::pin(stream)).keep_alive(
314 axum::response::sse::KeepAlive::new()
315 .interval(Duration::from_secs(15))
316 .text("keep-alive"),
317 ))
318}
319
320async fn handle_message_send(
326 state: &A2aServerState,
327 params: Option<Value>,
328 _id: Value,
329) -> A2aResult<Value> {
330 let params: MessageSendParams = serde_json::from_value(params.unwrap_or_default())
331 .map_err(|_| A2aError::rpc(A2aErrorCode::InvalidParams, "Invalid message/send params"))?;
332
333 let task_id = if let Some(task_id) = params.task_id {
335 task_id
336 } else {
337 let task = state.task_manager.create_task(params.context_id).await;
338 task.id.clone()
339 };
340
341 state
343 .task_manager
344 .add_message(&task_id, params.message)
345 .await?;
346
347 let task = state
349 .task_manager
350 .update_status(&task_id, TaskState::Working, None)
351 .await?;
352
353 Ok(serde_json::to_value(task)?)
355}
356
357async fn handle_push_config_set(
359 state: &A2aServerState,
360 params: Option<Value>,
361 _id: Value,
362) -> A2aResult<Value> {
363 let config: crate::a2a::rpc::TaskPushNotificationConfig =
364 serde_json::from_value(params.unwrap_or_default()).map_err(|_| {
365 A2aError::rpc(
366 A2aErrorCode::InvalidParams,
367 "Invalid pushNotificationConfig/set params",
368 )
369 })?;
370
371 state.task_manager.set_webhook_config(config).await?;
372
373 Ok(json!({ "success": true }))
374}
375
376async fn handle_push_config_get(
378 state: &A2aServerState,
379 params: Option<Value>,
380 _id: Value,
381) -> A2aResult<Value> {
382 let params: TaskIdParams =
383 serde_json::from_value(params.unwrap_or_default()).map_err(|_| {
384 A2aError::rpc(
385 A2aErrorCode::InvalidParams,
386 "Invalid pushNotificationConfig/get params",
387 )
388 })?;
389
390 let config = state.task_manager.get_webhook_config(¶ms.id).await;
391
392 Ok(serde_json::to_value(config)?)
393}
394
395fn handle_message_stream<'a>(
397 state: &'a A2aServerState,
398 params: Option<Value>,
399 id: Value,
400) -> impl Future<Output = A2aResult<Value>> + 'a {
401 handle_message_send(state, params, id)
403}
404
405async fn handle_tasks_get(
407 state: &A2aServerState,
408 params: Option<Value>,
409 _id: Value,
410) -> A2aResult<Value> {
411 let params: TaskQueryParams = serde_json::from_value(params.unwrap_or_default())
412 .map_err(|_| A2aError::rpc(A2aErrorCode::InvalidParams, "Invalid tasks/get params"))?;
413
414 let task = state.task_manager.get_task_or_error(¶ms.id).await?;
415
416 Ok(serde_json::to_value(task)?)
417}
418
419async fn handle_tasks_list(
421 state: &A2aServerState,
422 params: Option<Value>,
423 _id: Value,
424) -> A2aResult<Value> {
425 let params: ListTasksParams =
426 serde_json::from_value(params.unwrap_or_default()).unwrap_or_default();
427
428 let result = state.task_manager.list_tasks(params).await;
429
430 Ok(serde_json::to_value(result)?)
431}
432
433async fn handle_tasks_cancel(
435 state: &A2aServerState,
436 params: Option<Value>,
437 _id: Value,
438) -> A2aResult<Value> {
439 let params: TaskIdParams = serde_json::from_value(params.unwrap_or_default())
440 .map_err(|_| A2aError::rpc(A2aErrorCode::InvalidParams, "Invalid tasks/cancel params"))?;
441
442 let task = state.task_manager.cancel_task(¶ms.id).await?;
443
444 Ok(serde_json::to_value(task)?)
445}
446
447pub struct A2aErrorResponse {
453 response: JsonRpcResponse,
454 status_code: StatusCode,
455}
456
457impl A2aErrorResponse {
458 pub fn new(error: JsonRpcError, id: Value, status_code: StatusCode) -> Self {
460 Self {
461 response: JsonRpcResponse::error(error, id),
462 status_code,
463 }
464 }
465
466 pub fn invalid_request(message: &str, id: Value) -> Self {
468 Self::new(
469 JsonRpcError::invalid_request(message),
470 id,
471 StatusCode::BAD_REQUEST,
472 )
473 }
474
475 pub fn method_not_found(method: &str, id: Value) -> Self {
477 Self::new(
478 JsonRpcError::method_not_found(method),
479 id,
480 StatusCode::NOT_FOUND,
481 )
482 }
483
484 pub fn from_error(error: A2aError, id: Value) -> Self {
486 let code: i32 = error.code().into();
487 let message = error.to_string();
488 let status_code = match error {
489 A2aError::TaskNotFound(_) => StatusCode::NOT_FOUND,
490 A2aError::TaskNotCancelable(_) => StatusCode::UNPROCESSABLE_ENTITY,
491 A2aError::InvalidStateTransition { .. } => StatusCode::UNPROCESSABLE_ENTITY,
492 _ => StatusCode::INTERNAL_SERVER_ERROR,
493 };
494
495 Self::new(JsonRpcError::new(code, message), id, status_code)
496 }
497}
498
499impl IntoResponse for A2aErrorResponse {
500 fn into_response(self) -> Response {
501 (self.status_code, Json(self.response)).into_response()
502 }
503}
504
505pub async fn run(
511 state: A2aServerState,
512 addr: SocketAddr,
513) -> Result<(), Box<dyn std::error::Error>> {
514 let listener = tokio::net::TcpListener::bind(addr).await?;
515 tracing::info!("A2A server listening on {}", addr);
516 axum::serve(listener, create_router(state))
517 .with_graceful_shutdown(crate::shutdown::shutdown_signal_logged("A2A"))
518 .await?;
519 Ok(())
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[test]
527 fn test_server_state_creation() {
528 let state = A2aServerState::vtcode_default("http://localhost:8080");
529 assert_eq!(state.agent_card.name, "vtcode-agent");
530 }
531
532 #[test]
533 fn test_error_response_task_not_found() {
534 use serde_json::json;
535 let err_response =
536 A2aErrorResponse::from_error(A2aError::TaskNotFound("test-id".to_string()), json!(1));
537 assert_eq!(err_response.status_code, StatusCode::NOT_FOUND);
538 }
539
540 #[test]
541 fn test_error_response_task_not_cancelable() {
542 use serde_json::json;
543 let err = A2aError::TaskNotCancelable("Cannot cancel completed task".to_string());
544 let err_response = A2aErrorResponse::from_error(err, json!(1));
545 assert_eq!(err_response.status_code, StatusCode::UNPROCESSABLE_ENTITY);
546 }
547
548 #[test]
549 fn test_error_response_invalid_request() {
550 use serde_json::json;
551 let err_response = A2aErrorResponse::invalid_request("Invalid JSON", json!(1));
552 assert_eq!(err_response.status_code, StatusCode::BAD_REQUEST);
553 }
554
555 #[tokio::test]
556 async fn test_server_state_with_broadcast() {
557 let state = A2aServerState::vtcode_default("http://localhost:8080");
558
559 let mut rx = state.event_tx.subscribe();
561
562 let test_event = StreamingEvent::Message {
564 message: super::super::types::Message::agent_text("Test"),
565 context_id: Some("test".to_string()),
566 kind: "streaming-response".to_string(),
567 r#final: false,
568 };
569
570 state.event_tx.send(test_event.clone()).expect("send event");
571
572 let received = rx.recv().await.expect("receive event");
574 assert!(!received.is_final());
575 }
576}