Skip to main content

syncable_cli/server/
routes.rs

1//! HTTP Routes for AG-UI Server
2//!
3//! This module provides the HTTP endpoints for AG-UI protocol:
4//! - `/sse` - Server-Sent Events endpoint
5//! - `/ws` - WebSocket endpoint
6//! - `/health` - Health check endpoint
7
8use std::convert::Infallible;
9
10use axum::{
11    Json,
12    extract::{
13        State,
14        ws::{Message, WebSocket, WebSocketUpgrade},
15    },
16    response::{
17        IntoResponse, Response,
18        sse::{Event as SseEvent, KeepAlive, Sse},
19    },
20};
21use futures_util::{SinkExt, Stream, StreamExt};
22use serde::Deserialize;
23use serde_json::json;
24use tokio_stream::wrappers::BroadcastStream;
25use tracing::{debug, warn};
26
27use super::{AgentMessage, RunAgentInput, ServerState};
28
29/// Health check endpoint.
30pub async fn health() -> Json<serde_json::Value> {
31    Json(json!({
32        "status": "ok",
33        "service": "syncable-cli-agent",
34        "protocol": "ag-ui"
35    }))
36}
37
38/// Runtime info endpoint for CopilotKit.
39///
40/// CopilotKit expects this endpoint to return information about
41/// available agents and actions. Agents must be an object/map with
42/// agent ID as key, not an array.
43pub async fn info() -> Json<serde_json::Value> {
44    Json(json!({
45        "version": "1.0.0",
46        "agents": {
47            "syncable": {
48                "name": "syncable",
49                "className": "HttpAgent",
50                "description": "Syncable CLI Agent - Kubernetes and DevOps assistant"
51            }
52        },
53        "actions": {},
54        "audioFileTranscriptionEnabled": false
55    }))
56}
57
58/// CopilotKit request body structure.
59/// CopilotKit wraps requests in an envelope with method, params, and body.
60#[derive(Debug, Clone, Deserialize)]
61pub struct CopilotKitRequest {
62    /// The method being called (e.g., "agent/run")
63    pub method: Option<String>,
64    /// Method parameters
65    pub params: Option<CopilotKitParams>,
66    /// The actual request body
67    pub body: Option<CopilotKitBody>,
68    /// Direct fields for RunAgentInput format (non-envelope)
69    #[serde(rename = "threadId")]
70    pub thread_id: Option<String>,
71    #[serde(rename = "runId")]
72    pub run_id: Option<String>,
73    pub messages: Option<Vec<serde_json::Value>>,
74    pub tools: Option<Vec<serde_json::Value>>,
75    pub context: Option<Vec<serde_json::Value>>,
76    pub state: Option<serde_json::Value>,
77    #[serde(rename = "forwardedProps")]
78    pub forwarded_props: Option<serde_json::Value>,
79}
80
81#[derive(Debug, Clone, Deserialize)]
82pub struct CopilotKitParams {
83    #[serde(rename = "agentId")]
84    pub agent_id: Option<String>,
85    #[serde(rename = "threadId")]
86    pub thread_id: Option<String>,
87}
88
89#[derive(Debug, Clone, Deserialize)]
90pub struct CopilotKitBody {
91    pub messages: Option<Vec<serde_json::Value>>,
92    #[serde(rename = "threadId")]
93    pub thread_id: Option<String>,
94    #[serde(rename = "runId")]
95    pub run_id: Option<String>,
96    pub tools: Option<Vec<serde_json::Value>>,
97    pub context: Option<Vec<serde_json::Value>>,
98    pub state: Option<serde_json::Value>,
99    #[serde(rename = "forwardedProps")]
100    pub forwarded_props: Option<serde_json::Value>,
101}
102
103/// POST endpoint for receiving messages via HTTP.
104///
105/// Accepts both CopilotKit envelope format and direct RunAgentInput format.
106/// Routes messages to the agent processor and returns an SSE stream of events.
107/// Also handles CopilotKit's "info" method requests.
108pub async fn post_message(
109    State(state): State<ServerState>,
110    Json(raw): Json<serde_json::Value>,
111) -> Response {
112    debug!(
113        "Received POST request body: {}",
114        serde_json::to_string_pretty(&raw).unwrap_or_default()
115    );
116
117    // Try to parse as CopilotKit request
118    let copilot_req: Result<CopilotKitRequest, _> = serde_json::from_value(raw.clone());
119
120    // Track original thread/run IDs for response (may not be valid UUIDs)
121    let (input, original_thread_id, original_run_id) = match copilot_req {
122        Ok(req) => {
123            // Check if this is an envelope format (has method field)
124            if let Some(ref method) = req.method {
125                debug!("Detected CopilotKit envelope format, method: {:?}", method);
126
127                // Handle "info" method - return runtime info
128                if method == "info" {
129                    return Json(json!({
130                        "version": "1.0.0",
131                        "agents": {
132                            "syncable": {
133                                "name": "syncable",
134                                "className": "HttpAgent",
135                                "description": "Syncable CLI Agent - Kubernetes and DevOps assistant"
136                            }
137                        },
138                        "actions": {},
139                        "audioFileTranscriptionEnabled": false
140                    })).into_response();
141                }
142
143                // Extract from envelope body
144                let body = req.body.unwrap_or(CopilotKitBody {
145                    messages: None,
146                    thread_id: None,
147                    run_id: None,
148                    tools: None,
149                    context: None,
150                    state: None,
151                    forwarded_props: None,
152                });
153
154                let thread_id_str = body
155                    .thread_id
156                    .or(req.params.as_ref().and_then(|p| p.thread_id.clone()))
157                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
158                let run_id_str = body
159                    .run_id
160                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
161
162                // Parse IDs, falling back to random if invalid UUID
163                let thread_id: syncable_ag_ui_core::ThreadId = thread_id_str
164                    .parse()
165                    .unwrap_or_else(|_| syncable_ag_ui_core::ThreadId::random());
166                let run_id: syncable_ag_ui_core::RunId = run_id_str
167                    .parse()
168                    .unwrap_or_else(|_| syncable_ag_ui_core::RunId::random());
169
170                // Convert messages from JSON to Message type
171                let messages = convert_messages(body.messages.unwrap_or_default());
172                let tools = convert_tools(body.tools.unwrap_or_default());
173                let context = convert_context(body.context.unwrap_or_default());
174
175                let input = RunAgentInput::new(thread_id, run_id)
176                    .with_messages(messages)
177                    .with_tools(tools)
178                    .with_context(context)
179                    .with_state(body.state.unwrap_or(serde_json::Value::Null))
180                    .with_forwarded_props(body.forwarded_props.unwrap_or(serde_json::Value::Null));
181
182                (input, thread_id_str, run_id_str)
183            } else if req.thread_id.is_some() || req.messages.is_some() {
184                // Direct RunAgentInput format
185                debug!("Detected direct RunAgentInput format");
186
187                let thread_id_str = req
188                    .thread_id
189                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
190                let run_id_str = req
191                    .run_id
192                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
193
194                // Parse IDs, falling back to random if invalid UUID
195                let thread_id: syncable_ag_ui_core::ThreadId = thread_id_str
196                    .parse()
197                    .unwrap_or_else(|_| syncable_ag_ui_core::ThreadId::random());
198                let run_id: syncable_ag_ui_core::RunId = run_id_str
199                    .parse()
200                    .unwrap_or_else(|_| syncable_ag_ui_core::RunId::random());
201
202                let messages = convert_messages(req.messages.unwrap_or_default());
203                let tools = convert_tools(req.tools.unwrap_or_default());
204                let context = convert_context(req.context.unwrap_or_default());
205
206                let input = RunAgentInput::new(thread_id, run_id)
207                    .with_messages(messages)
208                    .with_tools(tools)
209                    .with_context(context)
210                    .with_state(req.state.unwrap_or(serde_json::Value::Null))
211                    .with_forwarded_props(req.forwarded_props.unwrap_or(serde_json::Value::Null));
212
213                (input, thread_id_str, run_id_str)
214            } else {
215                warn!("Could not parse request format: {:?}", raw);
216                return Json(json!({
217                    "status": "error",
218                    "message": "Invalid request format"
219                }))
220                .into_response();
221            }
222        }
223        Err(e) => {
224            warn!("Failed to parse request: {}", e);
225            return Json(json!({
226                "status": "error",
227                "message": format!("Failed to parse request: {}", e)
228            }))
229            .into_response();
230        }
231    };
232
233    // Use original string IDs for response (preserves non-UUID IDs like "thread-123")
234    let thread_id = original_thread_id;
235    let run_id = original_run_id;
236
237    debug!(
238        thread_id = %thread_id,
239        run_id = %run_id,
240        message_count = input.messages.len(),
241        "Processed RunAgentInput via POST"
242    );
243
244    // Subscribe to events BEFORE sending message to avoid race condition
245    let mut event_rx = state.subscribe();
246
247    let message_tx = state.message_sender();
248    let agent_msg = AgentMessage::new(input);
249
250    if let Err(e) = message_tx.send(agent_msg).await {
251        warn!("Failed to route message to agent processor: {}", e);
252        return Json(json!({
253            "status": "error",
254            "message": "Failed to route message to agent processor"
255        }))
256        .into_response();
257    }
258
259    // Create SSE stream that filters events and ends on RunFinished/RunError
260    let stream = async_stream::stream! {
261        use syncable_ag_ui_core::Event;
262
263        loop {
264            match event_rx.recv().await {
265                Ok(event) => {
266                    let is_terminal = matches!(&event, Event::RunFinished(_) | Event::RunError(_));
267
268                    // Serialize event to JSON
269                    if let Ok(json) = serde_json::to_string(&event) {
270                        let event_type = event.event_type().as_str().to_string();
271                        yield Ok::<_, Infallible>(SseEvent::default()
272                            .event(event_type)
273                            .data(json));
274                    }
275
276                    // Stop streaming after terminal event
277                    if is_terminal {
278                        break;
279                    }
280                }
281                Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
282                    // Missed some events, continue
283                    continue;
284                }
285                Err(tokio::sync::broadcast::error::RecvError::Closed) => {
286                    // Channel closed, stop streaming
287                    break;
288                }
289            }
290        }
291    };
292
293    Sse::new(stream)
294        .keep_alive(KeepAlive::default())
295        .into_response()
296}
297
298/// Convert JSON messages to AG-UI Message type
299fn convert_messages(
300    raw_messages: Vec<serde_json::Value>,
301) -> Vec<syncable_ag_ui_core::types::Message> {
302    use syncable_ag_ui_core::MessageId;
303
304    raw_messages
305        .into_iter()
306        .filter_map(|msg| {
307            let role = msg.get("role")?.as_str()?;
308            let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or("");
309            let id_str = msg
310                .get("id")
311                .and_then(|i| i.as_str())
312                .map(|s| s.to_string())
313                .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
314
315            // Parse ID, falling back to random if invalid UUID format
316            let id: MessageId = id_str.parse().unwrap_or_else(|_| MessageId::random());
317
318            match role {
319                "user" => Some(syncable_ag_ui_core::types::Message::User {
320                    id,
321                    content: content.to_string(),
322                    name: msg.get("name").and_then(|n| n.as_str()).map(String::from),
323                }),
324                "assistant" => Some(syncable_ag_ui_core::types::Message::Assistant {
325                    id,
326                    content: Some(content.to_string()),
327                    name: msg.get("name").and_then(|n| n.as_str()).map(String::from),
328                    tool_calls: None,
329                }),
330                "system" => Some(syncable_ag_ui_core::types::Message::System {
331                    id,
332                    content: content.to_string(),
333                    name: msg.get("name").and_then(|n| n.as_str()).map(String::from),
334                }),
335                _ => {
336                    debug!("Unknown message role: {}", role);
337                    None
338                }
339            }
340        })
341        .collect()
342}
343
344/// Convert JSON tools to AG-UI Tool type
345fn convert_tools(raw_tools: Vec<serde_json::Value>) -> Vec<syncable_ag_ui_core::types::Tool> {
346    raw_tools
347        .into_iter()
348        .filter_map(|tool| {
349            let name = tool.get("name")?.as_str()?.to_string();
350            let description = tool
351                .get("description")
352                .and_then(|d| d.as_str())
353                .unwrap_or("")
354                .to_string();
355            let parameters = tool
356                .get("parameters")
357                .cloned()
358                .unwrap_or(serde_json::json!({}));
359
360            Some(syncable_ag_ui_core::types::Tool::new(
361                name,
362                description,
363                parameters,
364            ))
365        })
366        .collect()
367}
368
369/// Convert JSON context to AG-UI Context type
370fn convert_context(
371    raw_context: Vec<serde_json::Value>,
372) -> Vec<syncable_ag_ui_core::types::Context> {
373    raw_context
374        .into_iter()
375        .filter_map(|ctx| {
376            let description = ctx.get("description")?.as_str()?.to_string();
377            let value = ctx.get("value")?.as_str()?.to_string();
378            Some(syncable_ag_ui_core::types::Context::new(description, value))
379        })
380        .collect()
381}
382
383/// SSE endpoint for streaming AG-UI events.
384pub async fn sse_handler(
385    State(state): State<ServerState>,
386) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
387    let rx = state.subscribe();
388    let stream = BroadcastStream::new(rx);
389
390    let event_stream = stream.filter_map(|result| async move {
391        match result {
392            Ok(event) => {
393                // Serialize event to JSON
394                let json = serde_json::to_string(&event).ok()?;
395                let event_type = event.event_type().as_str().to_string();
396
397                Some(Ok(SseEvent::default().event(event_type).data(json)))
398            }
399            Err(_) => None, // Lagged, skip this event
400        }
401    });
402
403    Sse::new(event_stream).keep_alive(KeepAlive::default())
404}
405
406/// WebSocket endpoint for streaming AG-UI events.
407pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<ServerState>) -> Response {
408    ws.on_upgrade(move |socket| handle_websocket(socket, state))
409}
410
411/// Handles a WebSocket connection.
412async fn handle_websocket(socket: WebSocket, state: ServerState) {
413    let (mut sender, mut receiver) = socket.split();
414    let mut event_rx = state.subscribe();
415    let message_tx = state.message_sender();
416
417    // Spawn task to send events to client
418    let send_task = tokio::spawn(async move {
419        while let Ok(event) = event_rx.recv().await {
420            if let Ok(json) = serde_json::to_string(&event) {
421                if sender.send(Message::Text(json.into())).await.is_err() {
422                    break; // Client disconnected
423                }
424            }
425        }
426    });
427
428    // Handle incoming messages from client
429    let recv_task = tokio::spawn(async move {
430        while let Some(msg) = receiver.next().await {
431            match msg {
432                Ok(Message::Close(_)) => break,
433                Ok(Message::Ping(_)) => {
434                    // Pong is handled automatically by axum
435                }
436                Ok(Message::Text(text)) => {
437                    // Parse as RunAgentInput and route to agent processor
438                    match serde_json::from_str::<RunAgentInput>(&text) {
439                        Ok(input) => {
440                            debug!(
441                                thread_id = %input.thread_id,
442                                run_id = %input.run_id,
443                                message_count = input.messages.len(),
444                                "Received RunAgentInput via WebSocket"
445                            );
446                            let agent_msg = AgentMessage::new(input);
447                            if let Err(e) = message_tx.send(agent_msg).await {
448                                warn!("Failed to send message to agent processor: {}", e);
449                            }
450                        }
451                        Err(e) => {
452                            warn!("Failed to parse WebSocket message as RunAgentInput: {}", e);
453                            // Log but continue - don't crash the connection
454                        }
455                    }
456                }
457                Ok(Message::Binary(_)) => {
458                    // Binary messages not supported yet
459                    debug!("Received binary WebSocket message, ignoring");
460                }
461                Ok(Message::Pong(_)) => {
462                    // Pong response, ignore
463                }
464                Err(e) => {
465                    warn!("WebSocket error: {}", e);
466                    break;
467                }
468            }
469        }
470    });
471
472    // Wait for either task to complete
473    tokio::select! {
474        _ = send_task => {}
475        _ = recv_task => {}
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use axum::extract::State;
483    use syncable_ag_ui_core::types::Message as AgUiProtocolMessage;
484    use syncable_ag_ui_core::{RunId, ThreadId};
485
486    #[tokio::test]
487    async fn test_health_endpoint() {
488        let response = health().await;
489        assert_eq!(response.0["status"], "ok");
490        assert_eq!(response.0["protocol"], "ag-ui");
491    }
492
493    #[tokio::test]
494    async fn test_post_message_accepted() {
495        use crate::server::ServerState;
496        use http::StatusCode;
497
498        let state = ServerState::new();
499        let mut msg_rx = state
500            .take_message_receiver()
501            .await
502            .expect("Should get receiver");
503
504        // Create RunAgentInput as JSON value
505        let thread_id = ThreadId::random();
506        let run_id = RunId::random();
507        let input_json = json!({
508            "threadId": thread_id.to_string(),
509            "runId": run_id.to_string(),
510            "messages": [{
511                "id": "msg-1",
512                "role": "user",
513                "content": "Hello from POST"
514            }],
515            "tools": [],
516            "context": [],
517            "state": null,
518            "forwardedProps": null
519        });
520
521        // Call post_message handler with raw JSON
522        let response = post_message(State(state), Json(input_json)).await;
523
524        // Verify response is SSE stream (HTTP 200)
525        assert_eq!(response.status(), StatusCode::OK);
526
527        // Verify message was routed
528        let received = msg_rx.recv().await.expect("Should receive message");
529        assert_eq!(received.input.messages.len(), 1);
530    }
531
532    #[tokio::test]
533    async fn test_post_message_copilotkit_envelope() {
534        use crate::server::ServerState;
535        use http::StatusCode;
536
537        let state = ServerState::new();
538        let mut msg_rx = state
539            .take_message_receiver()
540            .await
541            .expect("Should get receiver");
542
543        // Create CopilotKit envelope format
544        let input_json = json!({
545            "method": "agent/run",
546            "params": { "agentId": "syncable" },
547            "body": {
548                "threadId": "thread-123",
549                "messages": [{
550                    "id": "msg-1",
551                    "role": "user",
552                    "content": "Hello from CopilotKit"
553                }]
554            }
555        });
556
557        // Call post_message handler
558        let response = post_message(State(state), Json(input_json)).await;
559
560        // Verify response is SSE stream (HTTP 200)
561        assert_eq!(response.status(), StatusCode::OK);
562
563        // Verify message was routed
564        let received = msg_rx.recv().await.expect("Should receive message");
565        assert_eq!(received.input.messages.len(), 1);
566    }
567}