1use std::convert::Infallible;
9
10use axum::{
11 Json,
12 extract::{
13 State,
14 ws::{Message, WebSocket, WebSocketUpgrade},
15 },
16 response::{
17 IntoResponse, Response,
18 sse::{Event as SseEvent, KeepAlive, Sse},
19 },
20};
21use futures_util::{SinkExt, Stream, StreamExt};
22use serde::Deserialize;
23use serde_json::json;
24use tokio_stream::wrappers::BroadcastStream;
25use tracing::{debug, warn};
26
27use super::{AgentMessage, RunAgentInput, ServerState};
28
29pub async fn health() -> Json<serde_json::Value> {
31 Json(json!({
32 "status": "ok",
33 "service": "syncable-cli-agent",
34 "protocol": "ag-ui"
35 }))
36}
37
38pub async fn info() -> Json<serde_json::Value> {
44 Json(json!({
45 "version": "1.0.0",
46 "agents": {
47 "syncable": {
48 "name": "syncable",
49 "className": "HttpAgent",
50 "description": "Syncable CLI Agent - Kubernetes and DevOps assistant"
51 }
52 },
53 "actions": {},
54 "audioFileTranscriptionEnabled": false
55 }))
56}
57
58#[derive(Debug, Clone, Deserialize)]
61pub struct CopilotKitRequest {
62 pub method: Option<String>,
64 pub params: Option<CopilotKitParams>,
66 pub body: Option<CopilotKitBody>,
68 #[serde(rename = "threadId")]
70 pub thread_id: Option<String>,
71 #[serde(rename = "runId")]
72 pub run_id: Option<String>,
73 pub messages: Option<Vec<serde_json::Value>>,
74 pub tools: Option<Vec<serde_json::Value>>,
75 pub context: Option<Vec<serde_json::Value>>,
76 pub state: Option<serde_json::Value>,
77 #[serde(rename = "forwardedProps")]
78 pub forwarded_props: Option<serde_json::Value>,
79}
80
81#[derive(Debug, Clone, Deserialize)]
82pub struct CopilotKitParams {
83 #[serde(rename = "agentId")]
84 pub agent_id: Option<String>,
85 #[serde(rename = "threadId")]
86 pub thread_id: Option<String>,
87}
88
89#[derive(Debug, Clone, Deserialize)]
90pub struct CopilotKitBody {
91 pub messages: Option<Vec<serde_json::Value>>,
92 #[serde(rename = "threadId")]
93 pub thread_id: Option<String>,
94 #[serde(rename = "runId")]
95 pub run_id: Option<String>,
96 pub tools: Option<Vec<serde_json::Value>>,
97 pub context: Option<Vec<serde_json::Value>>,
98 pub state: Option<serde_json::Value>,
99 #[serde(rename = "forwardedProps")]
100 pub forwarded_props: Option<serde_json::Value>,
101}
102
103pub async fn post_message(
109 State(state): State<ServerState>,
110 Json(raw): Json<serde_json::Value>,
111) -> Response {
112 debug!(
113 "Received POST request body: {}",
114 serde_json::to_string_pretty(&raw).unwrap_or_default()
115 );
116
117 let copilot_req: Result<CopilotKitRequest, _> = serde_json::from_value(raw.clone());
119
120 let (input, original_thread_id, original_run_id) = match copilot_req {
122 Ok(req) => {
123 if let Some(ref method) = req.method {
125 debug!("Detected CopilotKit envelope format, method: {:?}", method);
126
127 if method == "info" {
129 return Json(json!({
130 "version": "1.0.0",
131 "agents": {
132 "syncable": {
133 "name": "syncable",
134 "className": "HttpAgent",
135 "description": "Syncable CLI Agent - Kubernetes and DevOps assistant"
136 }
137 },
138 "actions": {},
139 "audioFileTranscriptionEnabled": false
140 })).into_response();
141 }
142
143 let body = req.body.unwrap_or(CopilotKitBody {
145 messages: None,
146 thread_id: None,
147 run_id: None,
148 tools: None,
149 context: None,
150 state: None,
151 forwarded_props: None,
152 });
153
154 let thread_id_str = body
155 .thread_id
156 .or(req.params.as_ref().and_then(|p| p.thread_id.clone()))
157 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
158 let run_id_str = body
159 .run_id
160 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
161
162 let thread_id: syncable_ag_ui_core::ThreadId = thread_id_str
164 .parse()
165 .unwrap_or_else(|_| syncable_ag_ui_core::ThreadId::random());
166 let run_id: syncable_ag_ui_core::RunId = run_id_str
167 .parse()
168 .unwrap_or_else(|_| syncable_ag_ui_core::RunId::random());
169
170 let messages = convert_messages(body.messages.unwrap_or_default());
172 let tools = convert_tools(body.tools.unwrap_or_default());
173 let context = convert_context(body.context.unwrap_or_default());
174
175 let input = RunAgentInput::new(thread_id, run_id)
176 .with_messages(messages)
177 .with_tools(tools)
178 .with_context(context)
179 .with_state(body.state.unwrap_or(serde_json::Value::Null))
180 .with_forwarded_props(body.forwarded_props.unwrap_or(serde_json::Value::Null));
181
182 (input, thread_id_str, run_id_str)
183 } else if req.thread_id.is_some() || req.messages.is_some() {
184 debug!("Detected direct RunAgentInput format");
186
187 let thread_id_str = req
188 .thread_id
189 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
190 let run_id_str = req
191 .run_id
192 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
193
194 let thread_id: syncable_ag_ui_core::ThreadId = thread_id_str
196 .parse()
197 .unwrap_or_else(|_| syncable_ag_ui_core::ThreadId::random());
198 let run_id: syncable_ag_ui_core::RunId = run_id_str
199 .parse()
200 .unwrap_or_else(|_| syncable_ag_ui_core::RunId::random());
201
202 let messages = convert_messages(req.messages.unwrap_or_default());
203 let tools = convert_tools(req.tools.unwrap_or_default());
204 let context = convert_context(req.context.unwrap_or_default());
205
206 let input = RunAgentInput::new(thread_id, run_id)
207 .with_messages(messages)
208 .with_tools(tools)
209 .with_context(context)
210 .with_state(req.state.unwrap_or(serde_json::Value::Null))
211 .with_forwarded_props(req.forwarded_props.unwrap_or(serde_json::Value::Null));
212
213 (input, thread_id_str, run_id_str)
214 } else {
215 warn!("Could not parse request format: {:?}", raw);
216 return Json(json!({
217 "status": "error",
218 "message": "Invalid request format"
219 }))
220 .into_response();
221 }
222 }
223 Err(e) => {
224 warn!("Failed to parse request: {}", e);
225 return Json(json!({
226 "status": "error",
227 "message": format!("Failed to parse request: {}", e)
228 }))
229 .into_response();
230 }
231 };
232
233 let thread_id = original_thread_id;
235 let run_id = original_run_id;
236
237 debug!(
238 thread_id = %thread_id,
239 run_id = %run_id,
240 message_count = input.messages.len(),
241 "Processed RunAgentInput via POST"
242 );
243
244 let mut event_rx = state.subscribe();
246
247 let message_tx = state.message_sender();
248 let agent_msg = AgentMessage::new(input);
249
250 if let Err(e) = message_tx.send(agent_msg).await {
251 warn!("Failed to route message to agent processor: {}", e);
252 return Json(json!({
253 "status": "error",
254 "message": "Failed to route message to agent processor"
255 }))
256 .into_response();
257 }
258
259 let stream = async_stream::stream! {
261 use syncable_ag_ui_core::Event;
262
263 loop {
264 match event_rx.recv().await {
265 Ok(event) => {
266 let is_terminal = matches!(&event, Event::RunFinished(_) | Event::RunError(_));
267
268 if let Ok(json) = serde_json::to_string(&event) {
270 let event_type = event.event_type().as_str().to_string();
271 yield Ok::<_, Infallible>(SseEvent::default()
272 .event(event_type)
273 .data(json));
274 }
275
276 if is_terminal {
278 break;
279 }
280 }
281 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
282 continue;
284 }
285 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
286 break;
288 }
289 }
290 }
291 };
292
293 Sse::new(stream)
294 .keep_alive(KeepAlive::default())
295 .into_response()
296}
297
298fn convert_messages(
300 raw_messages: Vec<serde_json::Value>,
301) -> Vec<syncable_ag_ui_core::types::Message> {
302 use syncable_ag_ui_core::MessageId;
303
304 raw_messages
305 .into_iter()
306 .filter_map(|msg| {
307 let role = msg.get("role")?.as_str()?;
308 let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or("");
309 let id_str = msg
310 .get("id")
311 .and_then(|i| i.as_str())
312 .map(|s| s.to_string())
313 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
314
315 let id: MessageId = id_str.parse().unwrap_or_else(|_| MessageId::random());
317
318 match role {
319 "user" => Some(syncable_ag_ui_core::types::Message::User {
320 id,
321 content: content.to_string(),
322 name: msg.get("name").and_then(|n| n.as_str()).map(String::from),
323 }),
324 "assistant" => Some(syncable_ag_ui_core::types::Message::Assistant {
325 id,
326 content: Some(content.to_string()),
327 name: msg.get("name").and_then(|n| n.as_str()).map(String::from),
328 tool_calls: None,
329 }),
330 "system" => Some(syncable_ag_ui_core::types::Message::System {
331 id,
332 content: content.to_string(),
333 name: msg.get("name").and_then(|n| n.as_str()).map(String::from),
334 }),
335 _ => {
336 debug!("Unknown message role: {}", role);
337 None
338 }
339 }
340 })
341 .collect()
342}
343
344fn convert_tools(raw_tools: Vec<serde_json::Value>) -> Vec<syncable_ag_ui_core::types::Tool> {
346 raw_tools
347 .into_iter()
348 .filter_map(|tool| {
349 let name = tool.get("name")?.as_str()?.to_string();
350 let description = tool
351 .get("description")
352 .and_then(|d| d.as_str())
353 .unwrap_or("")
354 .to_string();
355 let parameters = tool
356 .get("parameters")
357 .cloned()
358 .unwrap_or(serde_json::json!({}));
359
360 Some(syncable_ag_ui_core::types::Tool::new(
361 name,
362 description,
363 parameters,
364 ))
365 })
366 .collect()
367}
368
369fn convert_context(
371 raw_context: Vec<serde_json::Value>,
372) -> Vec<syncable_ag_ui_core::types::Context> {
373 raw_context
374 .into_iter()
375 .filter_map(|ctx| {
376 let description = ctx.get("description")?.as_str()?.to_string();
377 let value = ctx.get("value")?.as_str()?.to_string();
378 Some(syncable_ag_ui_core::types::Context::new(description, value))
379 })
380 .collect()
381}
382
383pub async fn sse_handler(
385 State(state): State<ServerState>,
386) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
387 let rx = state.subscribe();
388 let stream = BroadcastStream::new(rx);
389
390 let event_stream = stream.filter_map(|result| async move {
391 match result {
392 Ok(event) => {
393 let json = serde_json::to_string(&event).ok()?;
395 let event_type = event.event_type().as_str().to_string();
396
397 Some(Ok(SseEvent::default().event(event_type).data(json)))
398 }
399 Err(_) => None, }
401 });
402
403 Sse::new(event_stream).keep_alive(KeepAlive::default())
404}
405
406pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<ServerState>) -> Response {
408 ws.on_upgrade(move |socket| handle_websocket(socket, state))
409}
410
411async fn handle_websocket(socket: WebSocket, state: ServerState) {
413 let (mut sender, mut receiver) = socket.split();
414 let mut event_rx = state.subscribe();
415 let message_tx = state.message_sender();
416
417 let send_task = tokio::spawn(async move {
419 while let Ok(event) = event_rx.recv().await {
420 if let Ok(json) = serde_json::to_string(&event) {
421 if sender.send(Message::Text(json.into())).await.is_err() {
422 break; }
424 }
425 }
426 });
427
428 let recv_task = tokio::spawn(async move {
430 while let Some(msg) = receiver.next().await {
431 match msg {
432 Ok(Message::Close(_)) => break,
433 Ok(Message::Ping(_)) => {
434 }
436 Ok(Message::Text(text)) => {
437 match serde_json::from_str::<RunAgentInput>(&text) {
439 Ok(input) => {
440 debug!(
441 thread_id = %input.thread_id,
442 run_id = %input.run_id,
443 message_count = input.messages.len(),
444 "Received RunAgentInput via WebSocket"
445 );
446 let agent_msg = AgentMessage::new(input);
447 if let Err(e) = message_tx.send(agent_msg).await {
448 warn!("Failed to send message to agent processor: {}", e);
449 }
450 }
451 Err(e) => {
452 warn!("Failed to parse WebSocket message as RunAgentInput: {}", e);
453 }
455 }
456 }
457 Ok(Message::Binary(_)) => {
458 debug!("Received binary WebSocket message, ignoring");
460 }
461 Ok(Message::Pong(_)) => {
462 }
464 Err(e) => {
465 warn!("WebSocket error: {}", e);
466 break;
467 }
468 }
469 }
470 });
471
472 tokio::select! {
474 _ = send_task => {}
475 _ = recv_task => {}
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use axum::extract::State;
483 use syncable_ag_ui_core::types::Message as AgUiProtocolMessage;
484 use syncable_ag_ui_core::{RunId, ThreadId};
485
486 #[tokio::test]
487 async fn test_health_endpoint() {
488 let response = health().await;
489 assert_eq!(response.0["status"], "ok");
490 assert_eq!(response.0["protocol"], "ag-ui");
491 }
492
493 #[tokio::test]
494 async fn test_post_message_accepted() {
495 use crate::server::ServerState;
496 use http::StatusCode;
497
498 let state = ServerState::new();
499 let mut msg_rx = state
500 .take_message_receiver()
501 .await
502 .expect("Should get receiver");
503
504 let thread_id = ThreadId::random();
506 let run_id = RunId::random();
507 let input_json = json!({
508 "threadId": thread_id.to_string(),
509 "runId": run_id.to_string(),
510 "messages": [{
511 "id": "msg-1",
512 "role": "user",
513 "content": "Hello from POST"
514 }],
515 "tools": [],
516 "context": [],
517 "state": null,
518 "forwardedProps": null
519 });
520
521 let response = post_message(State(state), Json(input_json)).await;
523
524 assert_eq!(response.status(), StatusCode::OK);
526
527 let received = msg_rx.recv().await.expect("Should receive message");
529 assert_eq!(received.input.messages.len(), 1);
530 }
531
532 #[tokio::test]
533 async fn test_post_message_copilotkit_envelope() {
534 use crate::server::ServerState;
535 use http::StatusCode;
536
537 let state = ServerState::new();
538 let mut msg_rx = state
539 .take_message_receiver()
540 .await
541 .expect("Should get receiver");
542
543 let input_json = json!({
545 "method": "agent/run",
546 "params": { "agentId": "syncable" },
547 "body": {
548 "threadId": "thread-123",
549 "messages": [{
550 "id": "msg-1",
551 "role": "user",
552 "content": "Hello from CopilotKit"
553 }]
554 }
555 });
556
557 let response = post_message(State(state), Json(input_json)).await;
559
560 assert_eq!(response.status(), StatusCode::OK);
562
563 let received = msg_rx.recv().await.expect("Should receive message");
565 assert_eq!(received.input.messages.len(), 1);
566 }
567}