1use super::GatewayConfig;
4use super::auth::GatewayAuth;
5use super::connection::ConnectionManager;
6use super::events::{ClientMessage, GatewayEvent, ServerMessage};
7use super::session::SessionManager;
8use axum::{
9 Router,
10 extract::{
11 Path, State,
12 ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
13 },
14 http::StatusCode,
15 response::IntoResponse,
16 routing::{get, post},
17};
18use chrono::Utc;
19use futures::SinkExt;
20use std::sync::Arc;
21use tokio::sync::{Mutex, broadcast};
22use uuid::Uuid;
23
24pub trait StatusProvider: Send + Sync {
29 fn channel_statuses(&self) -> Vec<(String, String)>;
31 fn node_statuses(&self) -> Vec<(String, String)>;
33}
34
35pub type SharedGateway = Arc<Mutex<GatewayServer>>;
37
38pub struct GatewayServer {
40 config: GatewayConfig,
41 auth: GatewayAuth,
42 connections: ConnectionManager,
43 sessions: SessionManager,
44 event_tx: broadcast::Sender<GatewayEvent>,
45 started_at: chrono::DateTime<Utc>,
46 status_provider: Option<Box<dyn StatusProvider>>,
47 total_tool_calls: u64,
49 total_llm_requests: u64,
50 pending_approvals: std::collections::HashMap<Uuid, PendingApproval>,
52 config_json: String,
54 toggle_state: Option<Arc<crate::voice::toggle::ToggleState>>,
56}
57
58#[derive(Debug, Clone)]
60pub struct PendingApproval {
61 pub id: Uuid,
63 pub tool_name: String,
65 pub description: String,
67 pub risk_level: String,
69}
70
71impl std::fmt::Debug for GatewayServer {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("GatewayServer")
74 .field("config", &self.config)
75 .field("connections", &self.connections.active_count())
76 .field("sessions", &self.sessions.total_count())
77 .finish()
78 }
79}
80
81impl GatewayServer {
82 pub fn new(config: GatewayConfig) -> Self {
84 let auth = GatewayAuth::from_config(&config);
85 let connections = ConnectionManager::new(config.max_connections);
86 let sessions = SessionManager::new();
87 let (event_tx, _) = broadcast::channel(config.broadcast_capacity);
88
89 Self {
90 config,
91 auth,
92 connections,
93 sessions,
94 event_tx,
95 started_at: Utc::now(),
96 status_provider: None,
97 total_tool_calls: 0,
98 total_llm_requests: 0,
99 pending_approvals: std::collections::HashMap::new(),
100 config_json: "{}".to_string(),
101 toggle_state: None,
102 }
103 }
104
105 pub fn config(&self) -> &GatewayConfig {
107 &self.config
108 }
109
110 pub fn auth(&self) -> &GatewayAuth {
112 &self.auth
113 }
114
115 pub fn connections_mut(&mut self) -> &mut ConnectionManager {
117 &mut self.connections
118 }
119
120 pub fn connections(&self) -> &ConnectionManager {
122 &self.connections
123 }
124
125 pub fn sessions_mut(&mut self) -> &mut SessionManager {
127 &mut self.sessions
128 }
129
130 pub fn sessions(&self) -> &SessionManager {
132 &self.sessions
133 }
134
135 pub fn subscribe(&self) -> broadcast::Receiver<GatewayEvent> {
137 self.event_tx.subscribe()
138 }
139
140 pub fn broadcast(&self, event: GatewayEvent) -> usize {
142 self.event_tx.send(event).unwrap_or(0)
143 }
144
145 pub fn uptime_secs(&self) -> u64 {
147 let elapsed = Utc::now() - self.started_at;
148 elapsed.num_seconds().max(0) as u64
149 }
150
151 pub fn set_status_provider(&mut self, provider: Box<dyn StatusProvider>) {
153 self.status_provider = Some(provider);
154 }
155
156 pub fn set_toggle_state(&mut self, state: Arc<crate::voice::toggle::ToggleState>) {
158 self.toggle_state = Some(state);
159 }
160
161 pub fn toggle_state(&self) -> Option<&Arc<crate::voice::toggle::ToggleState>> {
163 self.toggle_state.as_ref()
164 }
165
166 pub fn active_connections(&self) -> usize {
168 self.connections.active_count()
169 }
170
171 pub fn active_sessions(&self) -> usize {
173 self.sessions.active_count()
174 }
175
176 pub fn record_tool_call(&mut self) {
178 self.total_tool_calls += 1;
179 }
180
181 pub fn record_llm_request(&mut self) {
183 self.total_llm_requests += 1;
184 }
185
186 pub fn total_tool_calls(&self) -> u64 {
188 self.total_tool_calls
189 }
190
191 pub fn total_llm_requests(&self) -> u64 {
193 self.total_llm_requests
194 }
195
196 pub fn add_approval(&mut self, approval: PendingApproval) {
198 let id = approval.id;
199 let tool_name = approval.tool_name.clone();
200 let description = approval.description.clone();
201 let risk_level = approval.risk_level.clone();
202 self.pending_approvals.insert(id, approval);
203 self.broadcast(GatewayEvent::ApprovalRequest {
204 approval_id: id,
205 tool_name,
206 description,
207 risk_level,
208 });
209 }
210
211 pub fn resolve_approval(&mut self, approval_id: &Uuid, _approved: bool) -> bool {
213 self.pending_approvals.remove(approval_id).is_some()
214 }
215
216 pub fn pending_approvals(&self) -> Vec<&PendingApproval> {
218 self.pending_approvals.values().collect()
219 }
220
221 pub fn set_config_json(&mut self, json: String) {
223 self.config_json = json;
224 }
225
226 pub fn config_json(&self) -> &str {
228 &self.config_json
229 }
230
231 pub fn handle_client_message(&mut self, msg: ClientMessage, conn_id: Uuid) -> ServerMessage {
233 match msg {
234 ClientMessage::Authenticate { token } => {
235 if self.auth.validate(&token) {
236 self.connections.authenticate(&conn_id);
237 self.broadcast(GatewayEvent::Connected {
238 connection_id: conn_id,
239 });
240 ServerMessage::Authenticated {
241 connection_id: conn_id,
242 }
243 } else {
244 ServerMessage::AuthFailed {
245 reason: "Invalid token".to_string(),
246 }
247 }
248 }
249 ClientMessage::SubmitTask { description } => {
250 if !self.connections.is_authenticated(&conn_id) {
251 return ServerMessage::AuthFailed {
252 reason: "Not authenticated".to_string(),
253 };
254 }
255 let task_id = Uuid::new_v4();
256 let _session_id = self.sessions.create_session(conn_id);
257 self.broadcast(GatewayEvent::TaskSubmitted {
258 task_id,
259 description: description.clone(),
260 });
261 ServerMessage::Event {
262 event: GatewayEvent::TaskSubmitted {
263 task_id,
264 description,
265 },
266 }
267 }
268 ClientMessage::CancelTask { task_id } => {
269 if !self.connections.is_authenticated(&conn_id) {
270 return ServerMessage::AuthFailed {
271 reason: "Not authenticated".to_string(),
272 };
273 }
274 self.broadcast(GatewayEvent::TaskCompleted {
275 task_id,
276 success: false,
277 summary: "Cancelled by client".to_string(),
278 });
279 ServerMessage::Event {
280 event: GatewayEvent::TaskCompleted {
281 task_id,
282 success: false,
283 summary: "Cancelled by client".to_string(),
284 },
285 }
286 }
287 ClientMessage::GetStatus => ServerMessage::StatusResponse {
288 connected_clients: self.connections.active_count(),
289 active_tasks: self.sessions.active_count(),
290 uptime_secs: self.uptime_secs(),
291 },
292 ClientMessage::Ping { timestamp } => ServerMessage::Pong { timestamp },
293 ClientMessage::ListChannels => {
294 let channels = self
295 .status_provider
296 .as_ref()
297 .map(|p| p.channel_statuses())
298 .unwrap_or_default();
299 ServerMessage::ChannelStatus { channels }
300 }
301 ClientMessage::ListNodes => {
302 let nodes = self
303 .status_provider
304 .as_ref()
305 .map(|p| p.node_statuses())
306 .unwrap_or_default();
307 ServerMessage::NodeStatus { nodes }
308 }
309 ClientMessage::GetMetrics => ServerMessage::MetricsResponse {
310 active_connections: self.connections.active_count(),
311 active_sessions: self.sessions.active_count(),
312 total_tool_calls: self.total_tool_calls,
313 total_llm_requests: self.total_llm_requests,
314 uptime_secs: self.uptime_secs(),
315 },
316 ClientMessage::GetConfig => ServerMessage::ConfigResponse {
317 config_json: self.config_json.clone(),
318 },
319 ClientMessage::ApprovalDecision {
320 approval_id,
321 approved,
322 reason: _,
323 } => {
324 let found = self.resolve_approval(&approval_id, approved);
325 ServerMessage::ApprovalAck {
326 approval_id,
327 accepted: found,
328 }
329 }
330 }
331 }
332}
333
334pub fn router(shared: SharedGateway) -> Router {
336 Router::new()
337 .route("/ws", get(ws_handler))
338 .route("/health", get(health_handler))
339 .route("/api/status", get(api_status_handler))
340 .route("/api/sessions", get(api_sessions_handler))
341 .route("/api/config", get(api_config_handler))
342 .route("/api/metrics", get(api_metrics_handler))
343 .route("/api/audit", get(api_audit_handler))
344 .route("/api/approvals", get(api_approvals_handler))
345 .route("/api/approval/{id}", post(api_approval_decision_handler))
346 .route("/api/voice/start", post(api_voice_start_handler))
347 .route("/api/voice/stop", post(api_voice_stop_handler))
348 .route("/api/voice/status", get(api_voice_status_handler))
349 .route("/api/meeting/start", post(api_meeting_start_handler))
350 .route("/api/meeting/stop", post(api_meeting_stop_handler))
351 .route("/api/meeting/status", get(api_meeting_status_handler))
352 .with_state(shared)
353}
354
355async fn ws_handler(ws: WebSocketUpgrade, State(gw): State<SharedGateway>) -> impl IntoResponse {
357 ws.on_upgrade(move |socket| handle_socket(socket, gw))
358}
359
360async fn health_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
362 let gw = gw.lock().await;
363 let body = serde_json::json!({
364 "status": "ok",
365 "connections": gw.active_connections(),
366 "sessions": gw.active_sessions(),
367 "uptime_secs": gw.uptime_secs(),
368 });
369 axum::Json(body)
370}
371
372async fn api_status_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
374 let gw = gw.lock().await;
375 let channels = gw
376 .status_provider
377 .as_ref()
378 .map(|p| p.channel_statuses())
379 .unwrap_or_default();
380 let nodes = gw
381 .status_provider
382 .as_ref()
383 .map(|p| p.node_statuses())
384 .unwrap_or_default();
385
386 let body = serde_json::json!({
387 "version": env!("CARGO_PKG_VERSION"),
388 "uptime_secs": gw.uptime_secs(),
389 "active_connections": gw.active_connections(),
390 "active_sessions": gw.active_sessions(),
391 "total_tool_calls": gw.total_tool_calls(),
392 "total_llm_requests": gw.total_llm_requests(),
393 "channels": channels.iter().map(|(n, s)| serde_json::json!({"name": n, "status": s})).collect::<Vec<_>>(),
394 "nodes": nodes.iter().map(|(n, s)| serde_json::json!({"name": n, "status": s})).collect::<Vec<_>>(),
395 "pending_approvals": gw.pending_approvals().len(),
396 });
397 axum::Json(body)
398}
399
400async fn api_sessions_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
402 let gw = gw.lock().await;
403 let body = serde_json::json!({
404 "total": gw.active_sessions(),
405 "sessions": gw.sessions().list_active().iter().map(|s| {
406 serde_json::json!({
407 "id": s.session_id.to_string(),
408 "connection_id": s.connection_id.to_string(),
409 "state": format!("{:?}", s.state),
410 "created_at": s.created_at.to_rfc3339(),
411 })
412 }).collect::<Vec<_>>(),
413 });
414 axum::Json(body)
415}
416
417async fn api_config_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
419 let gw = gw.lock().await;
420 let config_json = gw.config_json();
421 match serde_json::from_str::<serde_json::Value>(config_json) {
422 Ok(val) => axum::Json(val),
423 Err(_) => axum::Json(serde_json::json!({"error": "Invalid config JSON"})),
424 }
425}
426
427async fn api_metrics_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
429 let gw = gw.lock().await;
430 let body = serde_json::json!({
431 "active_connections": gw.active_connections(),
432 "active_sessions": gw.active_sessions(),
433 "total_tool_calls": gw.total_tool_calls(),
434 "total_llm_requests": gw.total_llm_requests(),
435 "uptime_secs": gw.uptime_secs(),
436 });
437 axum::Json(body)
438}
439
440async fn api_audit_handler(State(_gw): State<SharedGateway>) -> impl IntoResponse {
442 let body = serde_json::json!({
445 "entries": [],
446 "total": 0,
447 });
448 axum::Json(body)
449}
450
451async fn api_approvals_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
453 let gw = gw.lock().await;
454 let approvals: Vec<serde_json::Value> = gw
455 .pending_approvals()
456 .iter()
457 .map(|a| {
458 serde_json::json!({
459 "id": a.id.to_string(),
460 "tool_name": a.tool_name,
461 "description": a.description,
462 "risk_level": a.risk_level,
463 })
464 })
465 .collect();
466 axum::Json(serde_json::json!({ "approvals": approvals }))
467}
468
469async fn api_approval_decision_handler(
471 Path(id): Path<String>,
472 State(gw): State<SharedGateway>,
473 axum::Json(body): axum::Json<serde_json::Value>,
474) -> impl IntoResponse {
475 let approval_id = match Uuid::parse_str(&id) {
476 Ok(uuid) => uuid,
477 Err(_) => {
478 return (
479 StatusCode::BAD_REQUEST,
480 axum::Json(serde_json::json!({"error": "Invalid UUID"})),
481 );
482 }
483 };
484
485 let approved = body
486 .get("approved")
487 .and_then(|v| v.as_bool())
488 .unwrap_or(false);
489 let mut gw = gw.lock().await;
490 let found = gw.resolve_approval(&approval_id, approved);
491
492 if found {
493 (
494 StatusCode::OK,
495 axum::Json(serde_json::json!({"status": "resolved", "approved": approved})),
496 )
497 } else {
498 (
499 StatusCode::NOT_FOUND,
500 axum::Json(serde_json::json!({"error": "Approval not found"})),
501 )
502 }
503}
504
505async fn handle_socket(mut socket: WebSocket, gw: SharedGateway) {
507 let conn_id = {
509 let mut gw = gw.lock().await;
510 match gw.connections_mut().add_connection() {
511 Some(id) => id,
512 None => {
513 let err = ServerMessage::Event {
515 event: GatewayEvent::Error {
516 code: "CAPACITY_FULL".to_string(),
517 message: "Server at maximum connections".to_string(),
518 },
519 };
520 if let Ok(json) = serde_json::to_string(&err) {
521 let _ = socket.send(WsMessage::Text(json.into())).await;
522 }
523 let _ = socket.close().await;
524 return;
525 }
526 }
527 };
528
529 while let Some(Ok(ws_msg)) = socket.recv().await {
531 let text = match ws_msg {
532 WsMessage::Text(t) => t.to_string(),
533 WsMessage::Close(_) => break,
534 _ => continue,
535 };
536
537 let client_msg: ClientMessage = match serde_json::from_str(&text) {
538 Ok(m) => m,
539 Err(e) => {
540 let err = ServerMessage::Event {
541 event: GatewayEvent::Error {
542 code: "PARSE_ERROR".to_string(),
543 message: format!("Invalid message: {}", e),
544 },
545 };
546 if let Ok(json) = serde_json::to_string(&err) {
547 let _ = socket.send(WsMessage::Text(json.into())).await;
548 }
549 continue;
550 }
551 };
552
553 let response = {
554 let mut gw = gw.lock().await;
555 gw.connections_mut().touch(&conn_id);
556 gw.handle_client_message(client_msg, conn_id)
557 };
558
559 if let Ok(json) = serde_json::to_string(&response)
560 && socket.send(WsMessage::Text(json.into())).await.is_err()
561 {
562 break;
563 }
564 }
565
566 {
568 let mut gw = gw.lock().await;
569 gw.connections_mut().remove_connection(&conn_id);
570 gw.broadcast(GatewayEvent::Disconnected {
571 connection_id: conn_id,
572 });
573 }
574}
575
576async fn api_voice_start_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
580 let gw = gw.lock().await;
581 let ts = match gw.toggle_state() {
582 Some(ts) => ts.clone(),
583 None => {
584 return (
585 StatusCode::SERVICE_UNAVAILABLE,
586 axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
587 );
588 }
589 };
590 drop(gw); if ts.voice_active().await {
593 return (
594 StatusCode::CONFLICT,
595 axum::Json(serde_json::json!({"error": "Voice session already active"})),
596 );
597 }
598
599 (
601 StatusCode::OK,
602 axum::Json(serde_json::json!({
603 "status": "voice_start_requested",
604 "message": "Voice session start requires agent config. Use /voicecmd on in the REPL or Ctrl+V in TUI."
605 })),
606 )
607}
608
609async fn api_voice_stop_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
611 let gw = gw.lock().await;
612 let ts = match gw.toggle_state() {
613 Some(ts) => ts.clone(),
614 None => {
615 return (
616 StatusCode::SERVICE_UNAVAILABLE,
617 axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
618 );
619 }
620 };
621 drop(gw);
622
623 match ts.voice_stop().await {
624 Ok(()) => (
625 StatusCode::OK,
626 axum::Json(serde_json::json!({"status": "stopped"})),
627 ),
628 Err(e) => (
629 StatusCode::BAD_REQUEST,
630 axum::Json(serde_json::json!({"error": e.to_string()})),
631 ),
632 }
633}
634
635async fn api_voice_status_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
637 let gw = gw.lock().await;
638 let ts = match gw.toggle_state() {
639 Some(ts) => ts.clone(),
640 None => {
641 return axum::Json(serde_json::json!({"active": false, "available": false}));
642 }
643 };
644 drop(gw);
645
646 axum::Json(serde_json::json!({
647 "active": ts.voice_active().await,
648 "available": true,
649 }))
650}
651
652async fn api_meeting_start_handler(
654 State(gw): State<SharedGateway>,
655 axum::Json(body): axum::Json<serde_json::Value>,
656) -> impl IntoResponse {
657 let gw_guard = gw.lock().await;
658 let ts = match gw_guard.toggle_state() {
659 Some(ts) => ts.clone(),
660 None => {
661 return (
662 StatusCode::SERVICE_UNAVAILABLE,
663 axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
664 );
665 }
666 };
667 drop(gw_guard);
668
669 if ts.meeting_active().await {
670 return (
671 StatusCode::CONFLICT,
672 axum::Json(serde_json::json!({"error": "Meeting recording already active"})),
673 );
674 }
675
676 let title = body.get("title").and_then(|v| v.as_str()).map(String::from);
677 let config = crate::config::MeetingConfig::default();
678
679 match ts.meeting_start(config, title).await {
680 Ok(()) => (
681 StatusCode::OK,
682 axum::Json(serde_json::json!({"status": "recording"})),
683 ),
684 Err(e) => (
685 StatusCode::INTERNAL_SERVER_ERROR,
686 axum::Json(serde_json::json!({"error": e})),
687 ),
688 }
689}
690
691async fn api_meeting_stop_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
693 let gw_guard = gw.lock().await;
694 let ts = match gw_guard.toggle_state() {
695 Some(ts) => ts.clone(),
696 None => {
697 return (
698 StatusCode::SERVICE_UNAVAILABLE,
699 axum::Json(serde_json::json!({"error": "Toggle state not configured"})),
700 );
701 }
702 };
703 drop(gw_guard);
704
705 match ts.meeting_stop().await {
706 Ok(result) => (
707 StatusCode::OK,
708 axum::Json(serde_json::json!({
709 "status": "stopped",
710 "duration_secs": result.duration_secs,
711 "transcript_length": result.transcript.len(),
712 "notes_saved": result.notes_saved,
713 })),
714 ),
715 Err(e) => (
716 StatusCode::BAD_REQUEST,
717 axum::Json(serde_json::json!({"error": e})),
718 ),
719 }
720}
721
722async fn api_meeting_status_handler(State(gw): State<SharedGateway>) -> impl IntoResponse {
724 let gw_guard = gw.lock().await;
725 let ts = match gw_guard.toggle_state() {
726 Some(ts) => ts.clone(),
727 None => {
728 return axum::Json(serde_json::json!({
729 "active": false,
730 "available": false,
731 }));
732 }
733 };
734 drop(gw_guard);
735
736 match ts.meeting_status().await {
737 Some(status) => axum::Json(serde_json::json!({
738 "active": true,
739 "available": true,
740 "title": status.title,
741 "started_at": status.started_at,
742 "elapsed_secs": status.elapsed_secs,
743 })),
744 None => axum::Json(serde_json::json!({
745 "active": false,
746 "available": true,
747 })),
748 }
749}
750
751pub async fn run(gw: SharedGateway) -> Result<(), std::io::Error> {
755 let (host, port) = {
756 let gw = gw.lock().await;
757 (gw.config().host.clone(), gw.config().port)
758 };
759 let app = router(gw);
760 let addr = format!("{}:{}", host, port);
761 let listener = tokio::net::TcpListener::bind(&addr).await?;
762 axum::serve(listener, app).await?;
763 Ok(())
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769 use axum::body::Body;
770 use tower::ServiceExt;
771
772 #[test]
773 fn test_server_construction() {
774 let config = GatewayConfig::default();
775 let server = GatewayServer::new(config);
776 assert_eq!(server.active_connections(), 0);
777 assert_eq!(server.active_sessions(), 0);
778 }
779
780 #[test]
781 fn test_server_with_auth_tokens() {
782 let config = GatewayConfig {
783 auth_tokens: vec!["tok1".into(), "tok2".into()],
784 ..GatewayConfig::default()
785 };
786 let server = GatewayServer::new(config);
787 assert!(server.auth().validate("tok1"));
788 assert!(!server.auth().validate("wrong"));
789 }
790
791 #[test]
792 fn test_server_broadcast_no_subscribers() {
793 let server = GatewayServer::new(GatewayConfig::default());
794 let sent = server.broadcast(GatewayEvent::Connected {
795 connection_id: Uuid::new_v4(),
796 });
797 assert_eq!(sent, 0);
798 }
799
800 #[test]
801 fn test_server_broadcast_with_subscriber() {
802 let server = GatewayServer::new(GatewayConfig::default());
803 let mut rx = server.subscribe();
804
805 let sent = server.broadcast(GatewayEvent::AssistantMessage {
806 content: "hello".into(),
807 });
808 assert_eq!(sent, 1);
809
810 let event = rx.try_recv().unwrap();
811 match event {
812 GatewayEvent::AssistantMessage { content } => {
813 assert_eq!(content, "hello");
814 }
815 _ => panic!("Wrong event type"),
816 }
817 }
818
819 #[test]
820 fn test_server_uptime() {
821 let server = GatewayServer::new(GatewayConfig::default());
822 assert!(server.uptime_secs() < 2);
823 }
824
825 #[test]
826 fn test_server_connection_lifecycle() {
827 let config = GatewayConfig {
828 max_connections: 5,
829 ..GatewayConfig::default()
830 };
831 let mut server = GatewayServer::new(config);
832
833 let conn_id = server.connections_mut().add_connection().unwrap();
834 assert_eq!(server.active_connections(), 1);
835
836 let session_id = server.sessions_mut().create_session(conn_id);
837 assert_eq!(server.active_sessions(), 1);
838
839 server.sessions_mut().end_session(&session_id);
840 assert_eq!(server.active_sessions(), 0);
841
842 server.connections_mut().remove_connection(&conn_id);
843 assert_eq!(server.active_connections(), 0);
844 }
845
846 fn make_shared_gateway(config: GatewayConfig) -> SharedGateway {
849 Arc::new(Mutex::new(GatewayServer::new(config)))
850 }
851
852 #[test]
853 fn test_router_builds() {
854 let gw = make_shared_gateway(GatewayConfig::default());
855 let _app = router(gw);
856 }
857
858 #[tokio::test]
859 async fn test_health_endpoint() {
860 let gw = make_shared_gateway(GatewayConfig::default());
861 let app = router(gw);
862
863 let req = axum::http::Request::builder()
864 .uri("/health")
865 .body(Body::empty())
866 .unwrap();
867
868 let resp = ServiceExt::<axum::http::Request<Body>>::oneshot(app, req)
869 .await
870 .unwrap();
871 assert_eq!(resp.status(), 200);
872
873 let body = axum::body::to_bytes(resp.into_body(), 10_000)
874 .await
875 .unwrap();
876 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
877 assert_eq!(json["status"], "ok");
878 assert_eq!(json["connections"], 0);
879 assert_eq!(json["sessions"], 0);
880 }
881
882 #[test]
883 fn test_handle_authenticate_valid() {
884 let config = GatewayConfig {
885 auth_tokens: vec!["secret".into()],
886 ..GatewayConfig::default()
887 };
888 let mut server = GatewayServer::new(config);
889 let conn_id = server.connections_mut().add_connection().unwrap();
890
891 let resp = server.handle_client_message(
892 ClientMessage::Authenticate {
893 token: "secret".into(),
894 },
895 conn_id,
896 );
897 match resp {
898 ServerMessage::Authenticated { connection_id } => {
899 assert_eq!(connection_id, conn_id);
900 }
901 _ => panic!("Expected Authenticated, got {:?}", resp),
902 }
903 assert!(server.connections().is_authenticated(&conn_id));
904 }
905
906 #[test]
907 fn test_handle_authenticate_invalid() {
908 let config = GatewayConfig {
909 auth_tokens: vec!["secret".into()],
910 ..GatewayConfig::default()
911 };
912 let mut server = GatewayServer::new(config);
913 let conn_id = server.connections_mut().add_connection().unwrap();
914
915 let resp = server.handle_client_message(
916 ClientMessage::Authenticate {
917 token: "wrong".into(),
918 },
919 conn_id,
920 );
921 match resp {
922 ServerMessage::AuthFailed { reason } => {
923 assert!(reason.contains("Invalid"));
924 }
925 _ => panic!("Expected AuthFailed, got {:?}", resp),
926 }
927 assert!(!server.connections().is_authenticated(&conn_id));
928 }
929
930 #[test]
931 fn test_handle_get_status() {
932 let mut server = GatewayServer::new(GatewayConfig::default());
933 let conn_id = server.connections_mut().add_connection().unwrap();
934
935 let resp = server.handle_client_message(ClientMessage::GetStatus, conn_id);
936 match resp {
937 ServerMessage::StatusResponse {
938 connected_clients,
939 active_tasks,
940 ..
941 } => {
942 assert_eq!(connected_clients, 1);
943 assert_eq!(active_tasks, 0);
944 }
945 _ => panic!("Expected StatusResponse"),
946 }
947 }
948
949 #[test]
950 fn test_handle_ping_pong() {
951 let mut server = GatewayServer::new(GatewayConfig::default());
952 let conn_id = server.connections_mut().add_connection().unwrap();
953 let now = Utc::now();
954
955 let resp = server.handle_client_message(ClientMessage::Ping { timestamp: now }, conn_id);
956 match resp {
957 ServerMessage::Pong { timestamp } => {
958 assert_eq!(timestamp, now);
959 }
960 _ => panic!("Expected Pong"),
961 }
962 }
963
964 #[test]
965 fn test_handle_submit_task_unauthenticated() {
966 let config = GatewayConfig {
967 auth_tokens: vec!["secret".into()],
968 ..GatewayConfig::default()
969 };
970 let mut server = GatewayServer::new(config);
971 let conn_id = server.connections_mut().add_connection().unwrap();
972
973 let resp = server.handle_client_message(
974 ClientMessage::SubmitTask {
975 description: "test task".into(),
976 },
977 conn_id,
978 );
979 match resp {
980 ServerMessage::AuthFailed { reason } => {
981 assert!(reason.contains("Not authenticated"));
982 }
983 _ => panic!("Expected AuthFailed for unauthenticated submit"),
984 }
985 }
986
987 #[test]
988 fn test_handle_submit_task_authenticated() {
989 let mut server = GatewayServer::new(GatewayConfig::default());
990 let conn_id = server.connections_mut().add_connection().unwrap();
991 server.connections_mut().authenticate(&conn_id);
993
994 let resp = server.handle_client_message(
995 ClientMessage::SubmitTask {
996 description: "build feature X".into(),
997 },
998 conn_id,
999 );
1000 match resp {
1001 ServerMessage::Event {
1002 event: GatewayEvent::TaskSubmitted { description, .. },
1003 } => {
1004 assert_eq!(description, "build feature X");
1005 }
1006 _ => panic!("Expected TaskSubmitted event"),
1007 }
1008 assert_eq!(server.active_sessions(), 1);
1010 }
1011
1012 #[test]
1013 fn test_handle_cancel_task() {
1014 let mut server = GatewayServer::new(GatewayConfig::default());
1015 let conn_id = server.connections_mut().add_connection().unwrap();
1016 server.connections_mut().authenticate(&conn_id);
1017 let task_id = Uuid::new_v4();
1018
1019 let resp = server.handle_client_message(ClientMessage::CancelTask { task_id }, conn_id);
1020 match resp {
1021 ServerMessage::Event {
1022 event:
1023 GatewayEvent::TaskCompleted {
1024 task_id: tid,
1025 success,
1026 summary,
1027 },
1028 } => {
1029 assert_eq!(tid, task_id);
1030 assert!(!success);
1031 assert!(summary.contains("Cancelled"));
1032 }
1033 _ => panic!("Expected TaskCompleted with cancel"),
1034 }
1035 }
1036
1037 struct MockStatusProvider {
1040 channels: Vec<(String, String)>,
1041 nodes: Vec<(String, String)>,
1042 }
1043
1044 impl StatusProvider for MockStatusProvider {
1045 fn channel_statuses(&self) -> Vec<(String, String)> {
1046 self.channels.clone()
1047 }
1048 fn node_statuses(&self) -> Vec<(String, String)> {
1049 self.nodes.clone()
1050 }
1051 }
1052
1053 #[test]
1054 fn test_list_channels_without_provider() {
1055 let mut server = GatewayServer::new(GatewayConfig::default());
1056 let conn_id = server.connections_mut().add_connection().unwrap();
1057
1058 let resp = server.handle_client_message(ClientMessage::ListChannels, conn_id);
1059 match resp {
1060 ServerMessage::ChannelStatus { channels } => {
1061 assert!(channels.is_empty());
1062 }
1063 _ => panic!("Expected ChannelStatus"),
1064 }
1065 }
1066
1067 #[test]
1068 fn test_list_nodes_without_provider() {
1069 let mut server = GatewayServer::new(GatewayConfig::default());
1070 let conn_id = server.connections_mut().add_connection().unwrap();
1071
1072 let resp = server.handle_client_message(ClientMessage::ListNodes, conn_id);
1073 match resp {
1074 ServerMessage::NodeStatus { nodes } => {
1075 assert!(nodes.is_empty());
1076 }
1077 _ => panic!("Expected NodeStatus"),
1078 }
1079 }
1080
1081 #[test]
1082 fn test_list_channels_with_provider() {
1083 let mut server = GatewayServer::new(GatewayConfig::default());
1084 server.set_status_provider(Box::new(MockStatusProvider {
1085 channels: vec![
1086 ("slack".into(), "Connected".into()),
1087 ("telegram".into(), "Disconnected".into()),
1088 ],
1089 nodes: vec![],
1090 }));
1091 let conn_id = server.connections_mut().add_connection().unwrap();
1092
1093 let resp = server.handle_client_message(ClientMessage::ListChannels, conn_id);
1094 match resp {
1095 ServerMessage::ChannelStatus { channels } => {
1096 assert_eq!(channels.len(), 2);
1097 assert_eq!(channels[0].0, "slack");
1098 assert_eq!(channels[0].1, "Connected");
1099 assert_eq!(channels[1].0, "telegram");
1100 assert_eq!(channels[1].1, "Disconnected");
1101 }
1102 _ => panic!("Expected ChannelStatus"),
1103 }
1104 }
1105
1106 #[test]
1107 fn test_list_nodes_with_provider() {
1108 let mut server = GatewayServer::new(GatewayConfig::default());
1109 server.set_status_provider(Box::new(MockStatusProvider {
1110 channels: vec![],
1111 nodes: vec![
1112 ("macos-local".into(), "Healthy".into()),
1113 ("linux-remote".into(), "Degraded".into()),
1114 ],
1115 }));
1116 let conn_id = server.connections_mut().add_connection().unwrap();
1117
1118 let resp = server.handle_client_message(ClientMessage::ListNodes, conn_id);
1119 match resp {
1120 ServerMessage::NodeStatus { nodes } => {
1121 assert_eq!(nodes.len(), 2);
1122 assert_eq!(nodes[0].0, "macos-local");
1123 assert_eq!(nodes[0].1, "Healthy");
1124 assert_eq!(nodes[1].0, "linux-remote");
1125 assert_eq!(nodes[1].1, "Degraded");
1126 }
1127 _ => panic!("Expected NodeStatus"),
1128 }
1129 }
1130
1131 #[test]
1132 fn test_status_provider_can_be_replaced() {
1133 let mut server = GatewayServer::new(GatewayConfig::default());
1134 server.set_status_provider(Box::new(MockStatusProvider {
1135 channels: vec![("a".into(), "x".into())],
1136 nodes: vec![],
1137 }));
1138 server.set_status_provider(Box::new(MockStatusProvider {
1140 channels: vec![("b".into(), "y".into()), ("c".into(), "z".into())],
1141 nodes: vec![],
1142 }));
1143 let conn_id = server.connections_mut().add_connection().unwrap();
1144
1145 let resp = server.handle_client_message(ClientMessage::ListChannels, conn_id);
1146 match resp {
1147 ServerMessage::ChannelStatus { channels } => {
1148 assert_eq!(channels.len(), 2);
1149 assert_eq!(channels[0].0, "b");
1150 }
1151 _ => panic!("Expected ChannelStatus"),
1152 }
1153 }
1154}