Skip to main content

rustant_core/gateway/
server.rs

1//! WebSocket gateway server built on axum.
2
3use super::GatewayConfig;
4use super::auth::GatewayAuth;
5use super::connection::ConnectionManager;
6use super::events::{ClientMessage, GatewayEvent, ServerMessage};
7use super::session::SessionManager;
8use axum::{
9    Router,
10    extract::{
11        Path, State,
12        ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
13    },
14    http::StatusCode,
15    response::IntoResponse,
16    routing::{get, post},
17};
18use chrono::Utc;
19use futures::SinkExt;
20use std::sync::Arc;
21use tokio::sync::{Mutex, broadcast};
22use uuid::Uuid;
23
24/// Provides channel and node status snapshots for the gateway.
25///
26/// Implement this trait to wire real `ChannelManager` / `NodeManager` data
27/// into the gateway's `ListChannels` and `ListNodes` handlers.
28pub trait StatusProvider: Send + Sync {
29    /// Return `(name, status_string)` pairs for every registered channel.
30    fn channel_statuses(&self) -> Vec<(String, String)>;
31    /// Return `(name, health_string)` pairs for every registered node.
32    fn node_statuses(&self) -> Vec<(String, String)>;
33}
34
35/// Thread-safe shared gateway reference for axum handlers.
36pub type SharedGateway = Arc<Mutex<GatewayServer>>;
37
38/// The WebSocket gateway server.
39pub struct GatewayServer {
40    config: GatewayConfig,
41    auth: GatewayAuth,
42    connections: ConnectionManager,
43    sessions: SessionManager,
44    event_tx: broadcast::Sender<GatewayEvent>,
45    started_at: chrono::DateTime<Utc>,
46    status_provider: Option<Box<dyn StatusProvider>>,
47    /// Counters for metrics dashboard.
48    total_tool_calls: u64,
49    total_llm_requests: u64,
50    /// Pending approvals for security queue (HashMap for O(1) lookup/removal).
51    pending_approvals: std::collections::HashMap<Uuid, PendingApproval>,
52    /// Snapshot of configuration JSON for the UI.
53    config_json: String,
54    /// Shared toggle state for voice/meeting sessions.
55    toggle_state: Option<Arc<crate::voice::toggle::ToggleState>>,
56}
57
58/// A pending approval request awaiting user decision.
59#[derive(Debug, Clone)]
60pub struct PendingApproval {
61    /// Unique approval ID.
62    pub id: Uuid,
63    /// Tool requesting approval.
64    pub tool_name: String,
65    /// Description of the action.
66    pub description: String,
67    /// Risk level string.
68    pub risk_level: String,
69}
70
71impl std::fmt::Debug for GatewayServer {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("GatewayServer")
74            .field("config", &self.config)
75            .field("connections", &self.connections.active_count())
76            .field("sessions", &self.sessions.total_count())
77            .finish()
78    }
79}
80
81impl GatewayServer {
82    /// Create a new gateway server from configuration.
83    pub fn new(config: GatewayConfig) -> Self {
84        let auth = GatewayAuth::from_config(&config);
85        let connections = ConnectionManager::new(config.max_connections);
86        let sessions = SessionManager::new();
87        let (event_tx, _) = broadcast::channel(config.broadcast_capacity);
88
89        Self {
90            config,
91            auth,
92            connections,
93            sessions,
94            event_tx,
95            started_at: Utc::now(),
96            status_provider: None,
97            total_tool_calls: 0,
98            total_llm_requests: 0,
99            pending_approvals: std::collections::HashMap::new(),
100            config_json: "{}".to_string(),
101            toggle_state: None,
102        }
103    }
104
105    /// Get a reference to the gateway configuration.
106    pub fn config(&self) -> &GatewayConfig {
107        &self.config
108    }
109
110    /// Get a reference to the auth validator.
111    pub fn auth(&self) -> &GatewayAuth {
112        &self.auth
113    }
114
115    /// Get a mutable reference to the connection manager.
116    pub fn connections_mut(&mut self) -> &mut ConnectionManager {
117        &mut self.connections
118    }
119
120    /// Get a reference to the connection manager.
121    pub fn connections(&self) -> &ConnectionManager {
122        &self.connections
123    }
124
125    /// Get a mutable reference to the session manager.
126    pub fn sessions_mut(&mut self) -> &mut SessionManager {
127        &mut self.sessions
128    }
129
130    /// Get a reference to the session manager.
131    pub fn sessions(&self) -> &SessionManager {
132        &self.sessions
133    }
134
135    /// Subscribe to gateway events.
136    pub fn subscribe(&self) -> broadcast::Receiver<GatewayEvent> {
137        self.event_tx.subscribe()
138    }
139
140    /// Broadcast an event to all subscribers.
141    pub fn broadcast(&self, event: GatewayEvent) -> usize {
142        self.event_tx.send(event).unwrap_or(0)
143    }
144
145    /// Uptime in seconds since the server was created.
146    pub fn uptime_secs(&self) -> u64 {
147        let elapsed = Utc::now() - self.started_at;
148        elapsed.num_seconds().max(0) as u64
149    }
150
151    /// Set a status provider for channel/node listings.
152    pub fn set_status_provider(&mut self, provider: Box<dyn StatusProvider>) {
153        self.status_provider = Some(provider);
154    }
155
156    /// Set the shared toggle state for voice/meeting controls.
157    pub fn set_toggle_state(&mut self, state: Arc<crate::voice::toggle::ToggleState>) {
158        self.toggle_state = Some(state);
159    }
160
161    /// Get a reference to the toggle state (if set).
162    pub fn toggle_state(&self) -> Option<&Arc<crate::voice::toggle::ToggleState>> {
163        self.toggle_state.as_ref()
164    }
165
166    /// Number of active connections.
167    pub fn active_connections(&self) -> usize {
168        self.connections.active_count()
169    }
170
171    /// Number of active sessions.
172    pub fn active_sessions(&self) -> usize {
173        self.sessions.active_count()
174    }
175
176    /// Increment the tool call counter.
177    pub fn record_tool_call(&mut self) {
178        self.total_tool_calls += 1;
179    }
180
181    /// Increment the LLM request counter.
182    pub fn record_llm_request(&mut self) {
183        self.total_llm_requests += 1;
184    }
185
186    /// Total tool calls since startup.
187    pub fn total_tool_calls(&self) -> u64 {
188        self.total_tool_calls
189    }
190
191    /// Total LLM requests since startup.
192    pub fn total_llm_requests(&self) -> u64 {
193        self.total_llm_requests
194    }
195
196    /// Add a pending approval request.
197    pub fn add_approval(&mut self, approval: PendingApproval) {
198        let id = approval.id;
199        let tool_name = approval.tool_name.clone();
200        let description = approval.description.clone();
201        let risk_level = approval.risk_level.clone();
202        self.pending_approvals.insert(id, approval);
203        self.broadcast(GatewayEvent::ApprovalRequest {
204            approval_id: id,
205            tool_name,
206            description,
207            risk_level,
208        });
209    }
210
211    /// Resolve a pending approval (returns true if found). O(1) via HashMap.
212    pub fn resolve_approval(&mut self, approval_id: &Uuid, _approved: bool) -> bool {
213        self.pending_approvals.remove(approval_id).is_some()
214    }
215
216    /// Get all pending approvals.
217    pub fn pending_approvals(&self) -> Vec<&PendingApproval> {
218        self.pending_approvals.values().collect()
219    }
220
221    /// Set the configuration JSON snapshot for the UI.
222    pub fn set_config_json(&mut self, json: String) {
223        self.config_json = json;
224    }
225
226    /// Get the current configuration JSON snapshot.
227    pub fn config_json(&self) -> &str {
228        &self.config_json
229    }
230
231    /// Handle a client message and produce a server response.
232    pub fn handle_client_message(&mut self, msg: ClientMessage, conn_id: Uuid) -> ServerMessage {
233        match msg {
234            ClientMessage::Authenticate { token } => {
235                if self.auth.validate(&token) {
236                    self.connections.authenticate(&conn_id);
237                    self.broadcast(GatewayEvent::Connected {
238                        connection_id: conn_id,
239                    });
240                    ServerMessage::Authenticated {
241                        connection_id: conn_id,
242                    }
243                } else {
244                    ServerMessage::AuthFailed {
245                        reason: "Invalid token".to_string(),
246                    }
247                }
248            }
249            ClientMessage::SubmitTask { description } => {
250                if !self.connections.is_authenticated(&conn_id) {
251                    return ServerMessage::AuthFailed {
252                        reason: "Not authenticated".to_string(),
253                    };
254                }
255                let task_id = Uuid::new_v4();
256                let _session_id = self.sessions.create_session(conn_id);
257                self.broadcast(GatewayEvent::TaskSubmitted {
258                    task_id,
259                    description: description.clone(),
260                });
261                ServerMessage::Event {
262                    event: GatewayEvent::TaskSubmitted {
263                        task_id,
264                        description,
265                    },
266                }
267            }
268            ClientMessage::CancelTask { task_id } => {
269                if !self.connections.is_authenticated(&conn_id) {
270                    return ServerMessage::AuthFailed {
271                        reason: "Not authenticated".to_string(),
272                    };
273                }
274                self.broadcast(GatewayEvent::TaskCompleted {
275                    task_id,
276                    success: false,
277                    summary: "Cancelled by client".to_string(),
278                });
279                ServerMessage::Event {
280                    event: GatewayEvent::TaskCompleted {
281                        task_id,
282                        success: false,
283                        summary: "Cancelled by client".to_string(),
284                    },
285                }
286            }
287            ClientMessage::GetStatus => ServerMessage::StatusResponse {
288                connected_clients: self.connections.active_count(),
289                active_tasks: self.sessions.active_count(),
290                uptime_secs: self.uptime_secs(),
291            },
292            ClientMessage::Ping { timestamp } => ServerMessage::Pong { timestamp },
293            ClientMessage::ListChannels => {
294                let channels = self
295                    .status_provider
296                    .as_ref()
297                    .map(|p| p.channel_statuses())
298                    .unwrap_or_default();
299                ServerMessage::ChannelStatus { channels }
300            }
301            ClientMessage::ListNodes => {
302                let nodes = self
303                    .status_provider
304                    .as_ref()
305                    .map(|p| p.node_statuses())
306                    .unwrap_or_default();
307                ServerMessage::NodeStatus { nodes }
308            }
309            ClientMessage::GetMetrics => ServerMessage::MetricsResponse {
310                active_connections: self.connections.active_count(),
311                active_sessions: self.sessions.active_count(),
312                total_tool_calls: self.total_tool_calls,
313                total_llm_requests: self.total_llm_requests,
314                uptime_secs: self.uptime_secs(),
315            },
316            ClientMessage::GetConfig => ServerMessage::ConfigResponse {
317                config_json: self.config_json.clone(),
318            },
319            ClientMessage::ApprovalDecision {
320                approval_id,
321                approved,
322                reason: _,
323            } => {
324                let found = self.resolve_approval(&approval_id, approved);
325                ServerMessage::ApprovalAck {
326                    approval_id,
327                    accepted: found,
328                }
329            }
330        }
331    }
332}
333
334/// Build an axum Router with `/ws`, `/health`, and REST API routes.
335pub fn router(shared: SharedGateway) -> Router {
336    Router::new()
337        .route("/ws", get(ws_handler))
338        .route("/health", get(health_handler))
339        .route("/api/status", get(api_status_handler))
340        .route("/api/sessions", get(api_sessions_handler))
341        .route("/api/config", get(api_config_handler))
342        .route("/api/metrics", get(api_metrics_handler))
343        .route("/api/audit", get(api_audit_handler))
344        .route("/api/approvals", get(api_approvals_handler))
345        .route("/api/approval/{id}", post(api_approval_decision_handler))
346        .route("/api/voice/start", post(api_voice_start_handler))
347        .route("/api/voice/stop", post(api_voice_stop_handler))
348        .route("/api/voice/status", get(api_voice_status_handler))
349        .route("/api/meeting/start", post(api_meeting_start_handler))
350        .route("/api/meeting/stop", post(api_meeting_stop_handler))
351        .route("/api/meeting/status", get(api_meeting_status_handler))
352        .with_state(shared)
353}
354
355/// WebSocket upgrade handler.
356async fn ws_handler(ws: WebSocketUpgrade, State(gw): State<SharedGateway>) -> impl IntoResponse {
357    ws.on_upgrade(move |socket| handle_socket(socket, gw))
358}
359
360/// Health check endpoint.
361async fn health_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
362    let gw = gw.lock().await;
363    let body = serde_json::json!({
364        "status": "ok",
365        "connections": gw.active_connections(),
366        "sessions": gw.active_sessions(),
367        "uptime_secs": gw.uptime_secs(),
368    });
369    axum::Json(body)
370}
371
372/// REST API: Get server status overview.
373async fn api_status_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
374    let gw = gw.lock().await;
375    let channels = gw
376        .status_provider
377        .as_ref()
378        .map(|p| p.channel_statuses())
379        .unwrap_or_default();
380    let nodes = gw
381        .status_provider
382        .as_ref()
383        .map(|p| p.node_statuses())
384        .unwrap_or_default();
385
386    let body = serde_json::json!({
387        "version": env!("CARGO_PKG_VERSION"),
388        "uptime_secs": gw.uptime_secs(),
389        "active_connections": gw.active_connections(),
390        "active_sessions": gw.active_sessions(),
391        "total_tool_calls": gw.total_tool_calls(),
392        "total_llm_requests": gw.total_llm_requests(),
393        "channels": channels.iter().map(|(n, s)| serde_json::json!({"name": n, "status": s})).collect::<Vec<_>>(),
394        "nodes": nodes.iter().map(|(n, s)| serde_json::json!({"name": n, "status": s})).collect::<Vec<_>>(),
395        "pending_approvals": gw.pending_approvals().len(),
396    });
397    axum::Json(body)
398}
399
400/// REST API: Get active sessions.
401async fn api_sessions_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
402    let gw = gw.lock().await;
403    let body = serde_json::json!({
404        "total": gw.active_sessions(),
405        "sessions": gw.sessions().list_active().iter().map(|s| {
406            serde_json::json!({
407                "id": s.session_id.to_string(),
408                "connection_id": s.connection_id.to_string(),
409                "state": format!("{:?}", s.state),
410                "created_at": s.created_at.to_rfc3339(),
411            })
412        }).collect::<Vec<_>>(),
413    });
414    axum::Json(body)
415}
416
417/// REST API: Get current configuration snapshot.
418async fn api_config_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
419    let gw = gw.lock().await;
420    let config_json = gw.config_json();
421    match serde_json::from_str::<serde_json::Value>(config_json) {
422        Ok(val) => axum::Json(val),
423        Err(_) => axum::Json(serde_json::json!({"error": "Invalid config JSON"})),
424    }
425}
426
427/// REST API: Get metrics snapshot.
428async fn api_metrics_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
429    let gw = gw.lock().await;
430    let body = serde_json::json!({
431        "active_connections": gw.active_connections(),
432        "active_sessions": gw.active_sessions(),
433        "total_tool_calls": gw.total_tool_calls(),
434        "total_llm_requests": gw.total_llm_requests(),
435        "uptime_secs": gw.uptime_secs(),
436    });
437    axum::Json(body)
438}
439
440/// REST API: Get audit trail (placeholder — returns recent events).
441async fn api_audit_handler(State(_gw): State<SharedGateway>) -> impl IntoResponse {
442    // In a full implementation, this would query the AuditTrail from rustant-core.
443    // For now, return an empty list to indicate the endpoint is functional.
444    let body = serde_json::json!({
445        "entries": [],
446        "total": 0,
447    });
448    axum::Json(body)
449}
450
451/// REST API: Get pending approval requests.
452async fn api_approvals_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
453    let gw = gw.lock().await;
454    let approvals: Vec<serde_json::Value> = gw
455        .pending_approvals()
456        .iter()
457        .map(|a| {
458            serde_json::json!({
459                "id": a.id.to_string(),
460                "tool_name": a.tool_name,
461                "description": a.description,
462                "risk_level": a.risk_level,
463            })
464        })
465        .collect();
466    axum::Json(serde_json::json!({ "approvals": approvals }))
467}
468
469/// REST API: Submit an approval decision.
470async fn api_approval_decision_handler(
471    Path(id): Path<String>,
472    State(gw): State<SharedGateway>,
473    axum::Json(body): axum::Json<serde_json::Value>,
474) -> impl IntoResponse {
475    let approval_id = match Uuid::parse_str(&id) {
476        Ok(uuid) => uuid,
477        Err(_) => {
478            return (
479                StatusCode::BAD_REQUEST,
480                axum::Json(serde_json::json!({"error": "Invalid UUID"})),
481            );
482        }
483    };
484
485    let approved = body
486        .get("approved")
487        .and_then(|v| v.as_bool())
488        .unwrap_or(false);
489    let mut gw = gw.lock().await;
490    let found = gw.resolve_approval(&approval_id, approved);
491
492    if found {
493        (
494            StatusCode::OK,
495            axum::Json(serde_json::json!({"status": "resolved", "approved": approved})),
496        )
497    } else {
498        (
499            StatusCode::NOT_FOUND,
500            axum::Json(serde_json::json!({"error": "Approval not found"})),
501        )
502    }
503}
504
505/// Handle an individual WebSocket connection.
506async fn handle_socket(mut socket: WebSocket, gw: SharedGateway) {
507    // Try to register the connection
508    let conn_id = {
509        let mut gw = gw.lock().await;
510        match gw.connections_mut().add_connection() {
511            Some(id) => id,
512            None => {
513                // At capacity — send error and close
514                let err = ServerMessage::Event {
515                    event: GatewayEvent::Error {
516                        code: "CAPACITY_FULL".to_string(),
517                        message: "Server at maximum connections".to_string(),
518                    },
519                };
520                if let Ok(json) = serde_json::to_string(&err) {
521                    let _ = socket.send(WsMessage::Text(json.into())).await;
522                }
523                let _ = socket.close().await;
524                return;
525            }
526        }
527    };
528
529    // Message loop
530    while let Some(Ok(ws_msg)) = socket.recv().await {
531        let text = match ws_msg {
532            WsMessage::Text(t) => t.to_string(),
533            WsMessage::Close(_) => break,
534            _ => continue,
535        };
536
537        let client_msg: ClientMessage = match serde_json::from_str(&text) {
538            Ok(m) => m,
539            Err(e) => {
540                let err = ServerMessage::Event {
541                    event: GatewayEvent::Error {
542                        code: "PARSE_ERROR".to_string(),
543                        message: format!("Invalid message: {}", e),
544                    },
545                };
546                if let Ok(json) = serde_json::to_string(&err) {
547                    let _ = socket.send(WsMessage::Text(json.into())).await;
548                }
549                continue;
550            }
551        };
552
553        let response = {
554            let mut gw = gw.lock().await;
555            gw.connections_mut().touch(&conn_id);
556            gw.handle_client_message(client_msg, conn_id)
557        };
558
559        if let Ok(json) = serde_json::to_string(&response)
560            && socket.send(WsMessage::Text(json.into())).await.is_err()
561        {
562            break;
563        }
564    }
565
566    // Cleanup
567    {
568        let mut gw = gw.lock().await;
569        gw.connections_mut().remove_connection(&conn_id);
570        gw.broadcast(GatewayEvent::Disconnected {
571            connection_id: conn_id,
572        });
573    }
574}
575
576// ── Voice & Meeting Toggle Endpoints ────────────────────────────────
577
578/// REST API: Start voice command session.
579async fn api_voice_start_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
580    let gw = gw.lock().await;
581    let ts = match gw.toggle_state() {
582        Some(ts) => ts.clone(),
583        None => {
584            return (
585                StatusCode::SERVICE_UNAVAILABLE,
586                axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
587            );
588        }
589    };
590    drop(gw); // Release lock before async operation
591
592    if ts.voice_active().await {
593        return (
594            StatusCode::CONFLICT,
595            axum::Json(serde_json::json!({"error": "Voice session already active"})),
596        );
597    }
598
599    // Voice start requires config and workspace — return instruction to use CLI
600    (
601        StatusCode::OK,
602        axum::Json(serde_json::json!({
603            "status": "voice_start_requested",
604            "message": "Voice session start requires agent config. Use /voicecmd on in the REPL or Ctrl+V in TUI."
605        })),
606    )
607}
608
609/// REST API: Stop voice command session.
610async fn api_voice_stop_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
611    let gw = gw.lock().await;
612    let ts = match gw.toggle_state() {
613        Some(ts) => ts.clone(),
614        None => {
615            return (
616                StatusCode::SERVICE_UNAVAILABLE,
617                axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
618            );
619        }
620    };
621    drop(gw);
622
623    match ts.voice_stop().await {
624        Ok(()) => (
625            StatusCode::OK,
626            axum::Json(serde_json::json!({"status": "stopped"})),
627        ),
628        Err(e) => (
629            StatusCode::BAD_REQUEST,
630            axum::Json(serde_json::json!({"error": e.to_string()})),
631        ),
632    }
633}
634
635/// REST API: Get voice session status.
636async fn api_voice_status_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
637    let gw = gw.lock().await;
638    let ts = match gw.toggle_state() {
639        Some(ts) => ts.clone(),
640        None => {
641            return axum::Json(serde_json::json!({"active": false, "available": false}));
642        }
643    };
644    drop(gw);
645
646    axum::Json(serde_json::json!({
647        "active": ts.voice_active().await,
648        "available": true,
649    }))
650}
651
652/// REST API: Start meeting recording.
653async fn api_meeting_start_handler(
654    State(gw): State<SharedGateway>,
655    axum::Json(body): axum::Json<serde_json::Value>,
656) -> impl IntoResponse {
657    let gw_guard = gw.lock().await;
658    let ts = match gw_guard.toggle_state() {
659        Some(ts) => ts.clone(),
660        None => {
661            return (
662                StatusCode::SERVICE_UNAVAILABLE,
663                axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
664            );
665        }
666    };
667    drop(gw_guard);
668
669    if ts.meeting_active().await {
670        return (
671            StatusCode::CONFLICT,
672            axum::Json(serde_json::json!({"error": "Meeting recording already active"})),
673        );
674    }
675
676    let title = body.get("title").and_then(|v| v.as_str()).map(String::from);
677    let config = crate::config::MeetingConfig::default();
678
679    match ts.meeting_start(config, title).await {
680        Ok(()) => (
681            StatusCode::OK,
682            axum::Json(serde_json::json!({"status": "recording"})),
683        ),
684        Err(e) => (
685            StatusCode::INTERNAL_SERVER_ERROR,
686            axum::Json(serde_json::json!({"error": e})),
687        ),
688    }
689}
690
691/// REST API: Stop meeting recording.
692async fn api_meeting_stop_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
693    let gw_guard = gw.lock().await;
694    let ts = match gw_guard.toggle_state() {
695        Some(ts) => ts.clone(),
696        None => {
697            return (
698                StatusCode::SERVICE_UNAVAILABLE,
699                axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
700            );
701        }
702    };
703    drop(gw_guard);
704
705    match ts.meeting_stop().await {
706        Ok(result) => (
707            StatusCode::OK,
708            axum::Json(serde_json::json!({
709                "status": "stopped",
710                "duration_secs": result.duration_secs,
711                "transcript_length": result.transcript.len(),
712                "notes_saved": result.notes_saved,
713            })),
714        ),
715        Err(e) => (
716            StatusCode::BAD_REQUEST,
717            axum::Json(serde_json::json!({"error": e})),
718        ),
719    }
720}
721
722/// REST API: Get meeting recording status.
723async fn api_meeting_status_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
724    let gw_guard = gw.lock().await;
725    let ts = match gw_guard.toggle_state() {
726        Some(ts) => ts.clone(),
727        None => {
728            return axum::Json(serde_json::json!({
729                "active": false,
730                "available": false,
731            }));
732        }
733    };
734    drop(gw_guard);
735
736    match ts.meeting_status().await {
737        Some(status) => axum::Json(serde_json::json!({
738            "active": true,
739            "available": true,
740            "title": status.title,
741            "started_at": status.started_at,
742            "elapsed_secs": status.elapsed_secs,
743        })),
744        None => axum::Json(serde_json::json!({
745            "active": false,
746            "available": true,
747        })),
748    }
749}
750
751/// Start the gateway server on the configured address.
752///
753/// This is an async function that runs until cancelled.
754pub async fn run(gw: SharedGateway) -> Result<(), std::io::Error> {
755    let (host, port) = {
756        let gw = gw.lock().await;
757        (gw.config().host.clone(), gw.config().port)
758    };
759    let app = router(gw);
760    let addr = format!("{}:{}", host, port);
761    let listener = tokio::net::TcpListener::bind(&addr).await?;
762    axum::serve(listener, app).await?;
763    Ok(())
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use axum::body::Body;
770    use tower::ServiceExt;
771
772    #[test]
773    fn test_server_construction() {
774        let config = GatewayConfig::default();
775        let server = GatewayServer::new(config);
776        assert_eq!(server.active_connections(), 0);
777        assert_eq!(server.active_sessions(), 0);
778    }
779
780    #[test]
781    fn test_server_with_auth_tokens() {
782        let config = GatewayConfig {
783            auth_tokens: vec!["tok1".into(), "tok2".into()],
784            ..GatewayConfig::default()
785        };
786        let server = GatewayServer::new(config);
787        assert!(server.auth().validate("tok1"));
788        assert!(!server.auth().validate("wrong"));
789    }
790
791    #[test]
792    fn test_server_broadcast_no_subscribers() {
793        let server = GatewayServer::new(GatewayConfig::default());
794        let sent = server.broadcast(GatewayEvent::Connected {
795            connection_id: Uuid::new_v4(),
796        });
797        assert_eq!(sent, 0);
798    }
799
800    #[test]
801    fn test_server_broadcast_with_subscriber() {
802        let server = GatewayServer::new(GatewayConfig::default());
803        let mut rx = server.subscribe();
804
805        let sent = server.broadcast(GatewayEvent::AssistantMessage {
806            content: "hello".into(),
807        });
808        assert_eq!(sent, 1);
809
810        let event = rx.try_recv().unwrap();
811        match event {
812            GatewayEvent::AssistantMessage { content } => {
813                assert_eq!(content, "hello");
814            }
815            _ => panic!("Wrong event type"),
816        }
817    }
818
819    #[test]
820    fn test_server_uptime() {
821        let server = GatewayServer::new(GatewayConfig::default());
822        assert!(server.uptime_secs() < 2);
823    }
824
825    #[test]
826    fn test_server_connection_lifecycle() {
827        let config = GatewayConfig {
828            max_connections: 5,
829            ..GatewayConfig::default()
830        };
831        let mut server = GatewayServer::new(config);
832
833        let conn_id = server.connections_mut().add_connection().unwrap();
834        assert_eq!(server.active_connections(), 1);
835
836        let session_id = server.sessions_mut().create_session(conn_id);
837        assert_eq!(server.active_sessions(), 1);
838
839        server.sessions_mut().end_session(&session_id);
840        assert_eq!(server.active_sessions(), 0);
841
842        server.connections_mut().remove_connection(&conn_id);
843        assert_eq!(server.active_connections(), 0);
844    }
845
846    // --- A6: WebSocket handler tests ---
847
848    fn make_shared_gateway(config: GatewayConfig) -> SharedGateway {
849        Arc::new(Mutex::new(GatewayServer::new(config)))
850    }
851
852    #[test]
853    fn test_router_builds() {
854        let gw = make_shared_gateway(GatewayConfig::default());
855        let _app = router(gw);
856    }
857
858    #[tokio::test]
859    async fn test_health_endpoint() {
860        let gw = make_shared_gateway(GatewayConfig::default());
861        let app = router(gw);
862
863        let req = axum::http::Request::builder()
864            .uri("/health")
865            .body(Body::empty())
866            .unwrap();
867
868        let resp = ServiceExt::<axum::http::Request<Body>>::oneshot(app, req)
869            .await
870            .unwrap();
871        assert_eq!(resp.status(), 200);
872
873        let body = axum::body::to_bytes(resp.into_body(), 10_000)
874            .await
875            .unwrap();
876        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
877        assert_eq!(json["status"], "ok");
878        assert_eq!(json["connections"], 0);
879        assert_eq!(json["sessions"], 0);
880    }
881
882    #[test]
883    fn test_handle_authenticate_valid() {
884        let config = GatewayConfig {
885            auth_tokens: vec!["secret".into()],
886            ..GatewayConfig::default()
887        };
888        let mut server = GatewayServer::new(config);
889        let conn_id = server.connections_mut().add_connection().unwrap();
890
891        let resp = server.handle_client_message(
892            ClientMessage::Authenticate {
893                token: "secret".into(),
894            },
895            conn_id,
896        );
897        match resp {
898            ServerMessage::Authenticated { connection_id } => {
899                assert_eq!(connection_id, conn_id);
900            }
901            _ => panic!("Expected Authenticated, got {:?}", resp),
902        }
903        assert!(server.connections().is_authenticated(&conn_id));
904    }
905
906    #[test]
907    fn test_handle_authenticate_invalid() {
908        let config = GatewayConfig {
909            auth_tokens: vec!["secret".into()],
910            ..GatewayConfig::default()
911        };
912        let mut server = GatewayServer::new(config);
913        let conn_id = server.connections_mut().add_connection().unwrap();
914
915        let resp = server.handle_client_message(
916            ClientMessage::Authenticate {
917                token: "wrong".into(),
918            },
919            conn_id,
920        );
921        match resp {
922            ServerMessage::AuthFailed { reason } => {
923                assert!(reason.contains("Invalid"));
924            }
925            _ => panic!("Expected AuthFailed, got {:?}", resp),
926        }
927        assert!(!server.connections().is_authenticated(&conn_id));
928    }
929
930    #[test]
931    fn test_handle_get_status() {
932        let mut server = GatewayServer::new(GatewayConfig::default());
933        let conn_id = server.connections_mut().add_connection().unwrap();
934
935        let resp = server.handle_client_message(ClientMessage::GetStatus, conn_id);
936        match resp {
937            ServerMessage::StatusResponse {
938                connected_clients,
939                active_tasks,
940                ..
941            } => {
942                assert_eq!(connected_clients, 1);
943                assert_eq!(active_tasks, 0);
944            }
945            _ => panic!("Expected StatusResponse"),
946        }
947    }
948
949    #[test]
950    fn test_handle_ping_pong() {
951        let mut server = GatewayServer::new(GatewayConfig::default());
952        let conn_id = server.connections_mut().add_connection().unwrap();
953        let now = Utc::now();
954
955        let resp = server.handle_client_message(ClientMessage::Ping { timestamp: now }, conn_id);
956        match resp {
957            ServerMessage::Pong { timestamp } => {
958                assert_eq!(timestamp, now);
959            }
960            _ => panic!("Expected Pong"),
961        }
962    }
963
964    #[test]
965    fn test_handle_submit_task_unauthenticated() {
966        let config = GatewayConfig {
967            auth_tokens: vec!["secret".into()],
968            ..GatewayConfig::default()
969        };
970        let mut server = GatewayServer::new(config);
971        let conn_id = server.connections_mut().add_connection().unwrap();
972
973        let resp = server.handle_client_message(
974            ClientMessage::SubmitTask {
975                description: "test task".into(),
976            },
977            conn_id,
978        );
979        match resp {
980            ServerMessage::AuthFailed { reason } => {
981                assert!(reason.contains("Not authenticated"));
982            }
983            _ => panic!("Expected AuthFailed for unauthenticated submit"),
984        }
985    }
986
987    #[test]
988    fn test_handle_submit_task_authenticated() {
989        let mut server = GatewayServer::new(GatewayConfig::default());
990        let conn_id = server.connections_mut().add_connection().unwrap();
991        // Open mode — auto-authenticated by validate("")
992        server.connections_mut().authenticate(&conn_id);
993
994        let resp = server.handle_client_message(
995            ClientMessage::SubmitTask {
996                description: "build feature X".into(),
997            },
998            conn_id,
999        );
1000        match resp {
1001            ServerMessage::Event {
1002                event: GatewayEvent::TaskSubmitted { description, .. },
1003            } => {
1004                assert_eq!(description, "build feature X");
1005            }
1006            _ => panic!("Expected TaskSubmitted event"),
1007        }
1008        // Session should have been created
1009        assert_eq!(server.active_sessions(), 1);
1010    }
1011
1012    #[test]
1013    fn test_handle_cancel_task() {
1014        let mut server = GatewayServer::new(GatewayConfig::default());
1015        let conn_id = server.connections_mut().add_connection().unwrap();
1016        server.connections_mut().authenticate(&conn_id);
1017        let task_id = Uuid::new_v4();
1018
1019        let resp = server.handle_client_message(ClientMessage::CancelTask { task_id }, conn_id);
1020        match resp {
1021            ServerMessage::Event {
1022                event:
1023                    GatewayEvent::TaskCompleted {
1024                        task_id: tid,
1025                        success,
1026                        summary,
1027                    },
1028            } => {
1029                assert_eq!(tid, task_id);
1030                assert!(!success);
1031                assert!(summary.contains("Cancelled"));
1032            }
1033            _ => panic!("Expected TaskCompleted with cancel"),
1034        }
1035    }
1036
1037    // --- StatusProvider wiring tests ---
1038
1039    struct MockStatusProvider {
1040        channels: Vec<(String, String)>,
1041        nodes: Vec<(String, String)>,
1042    }
1043
1044    impl StatusProvider for MockStatusProvider {
1045        fn channel_statuses(&self) -> Vec<(String, String)> {
1046            self.channels.clone()
1047        }
1048        fn node_statuses(&self) -> Vec<(String, String)> {
1049            self.nodes.clone()
1050        }
1051    }
1052
1053    #[test]
1054    fn test_list_channels_without_provider() {
1055        let mut server = GatewayServer::new(GatewayConfig::default());
1056        let conn_id = server.connections_mut().add_connection().unwrap();
1057
1058        let resp = server.handle_client_message(ClientMessage::ListChannels, conn_id);
1059        match resp {
1060            ServerMessage::ChannelStatus { channels } => {
1061                assert!(channels.is_empty());
1062            }
1063            _ => panic!("Expected ChannelStatus"),
1064        }
1065    }
1066
1067    #[test]
1068    fn test_list_nodes_without_provider() {
1069        let mut server = GatewayServer::new(GatewayConfig::default());
1070        let conn_id = server.connections_mut().add_connection().unwrap();
1071
1072        let resp = server.handle_client_message(ClientMessage::ListNodes, conn_id);
1073        match resp {
1074            ServerMessage::NodeStatus { nodes } => {
1075                assert!(nodes.is_empty());
1076            }
1077            _ => panic!("Expected NodeStatus"),
1078        }
1079    }
1080
1081    #[test]
1082    fn test_list_channels_with_provider() {
1083        let mut server = GatewayServer::new(GatewayConfig::default());
1084        server.set_status_provider(Box::new(MockStatusProvider {
1085            channels: vec![
1086                ("slack".into(), "Connected".into()),
1087                ("telegram".into(), "Disconnected".into()),
1088            ],
1089            nodes: vec![],
1090        }));
1091        let conn_id = server.connections_mut().add_connection().unwrap();
1092
1093        let resp = server.handle_client_message(ClientMessage::ListChannels, conn_id);
1094        match resp {
1095            ServerMessage::ChannelStatus { channels } => {
1096                assert_eq!(channels.len(), 2);
1097                assert_eq!(channels[0].0, "slack");
1098                assert_eq!(channels[0].1, "Connected");
1099                assert_eq!(channels[1].0, "telegram");
1100                assert_eq!(channels[1].1, "Disconnected");
1101            }
1102            _ => panic!("Expected ChannelStatus"),
1103        }
1104    }
1105
1106    #[test]
1107    fn test_list_nodes_with_provider() {
1108        let mut server = GatewayServer::new(GatewayConfig::default());
1109        server.set_status_provider(Box::new(MockStatusProvider {
1110            channels: vec![],
1111            nodes: vec![
1112                ("macos-local".into(), "Healthy".into()),
1113                ("linux-remote".into(), "Degraded".into()),
1114            ],
1115        }));
1116        let conn_id = server.connections_mut().add_connection().unwrap();
1117
1118        let resp = server.handle_client_message(ClientMessage::ListNodes, conn_id);
1119        match resp {
1120            ServerMessage::NodeStatus { nodes } => {
1121                assert_eq!(nodes.len(), 2);
1122                assert_eq!(nodes[0].0, "macos-local");
1123                assert_eq!(nodes[0].1, "Healthy");
1124                assert_eq!(nodes[1].0, "linux-remote");
1125                assert_eq!(nodes[1].1, "Degraded");
1126            }
1127            _ => panic!("Expected NodeStatus"),
1128        }
1129    }
1130
1131    #[test]
1132    fn test_status_provider_can_be_replaced() {
1133        let mut server = GatewayServer::new(GatewayConfig::default());
1134        server.set_status_provider(Box::new(MockStatusProvider {
1135            channels: vec![("a".into(), "x".into())],
1136            nodes: vec![],
1137        }));
1138        // Replace the provider
1139        server.set_status_provider(Box::new(MockStatusProvider {
1140            channels: vec![("b".into(), "y".into()), ("c".into(), "z".into())],
1141            nodes: vec![],
1142        }));
1143        let conn_id = server.connections_mut().add_connection().unwrap();
1144
1145        let resp = server.handle_client_message(ClientMessage::ListChannels, conn_id);
1146        match resp {
1147            ServerMessage::ChannelStatus { channels } => {
1148                assert_eq!(channels.len(), 2);
1149                assert_eq!(channels[0].0, "b");
1150            }
1151            _ => panic!("Expected ChannelStatus"),
1152        }
1153    }
1154}