Skip to main content

spikard_http/testing/
test_client.rs

1//! Core test client for Spikard applications
2//!
3//! This module provides a language-agnostic TestClient that can be wrapped by
4//! language bindings (PyO3, napi-rs, magnus) to provide Pythonic, JavaScripty, and
5//! Ruby-like APIs respectively.
6//!
7//! The core client handles all HTTP method dispatch, query params, header management,
8//! body encoding (JSON, form-data, multipart), and response snapshot capture.
9
10use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::time::timeout;
18use urlencoding::encode;
19
20type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
21const GRAPHQL_WS_MESSAGE_TIMEOUT: Duration = Duration::from_secs(2);
22const GRAPHQL_WS_MAX_CONTROL_MESSAGES: usize = 32;
23
24/// Snapshot of a GraphQL subscription exchange over WebSocket.
25#[derive(Debug, Clone, PartialEq)]
26pub struct GraphQLSubscriptionSnapshot {
27    /// Operation id used for the subscription request.
28    pub operation_id: String,
29    /// Whether the server acknowledged the GraphQL WebSocket connection.
30    pub acknowledged: bool,
31    /// First `next.payload` received for this subscription, if any.
32    pub event: Option<Value>,
33    /// GraphQL protocol errors emitted by the server.
34    pub errors: Vec<Value>,
35    /// Whether a `complete` frame was observed for this operation.
36    pub complete_received: bool,
37}
38
39/// Core test client for making HTTP requests to a Spikard application.
40///
41/// This struct wraps axum-test's TestServer and provides a language-agnostic
42/// interface for making HTTP requests, sending WebSocket connections, and
43/// handling Server-Sent Events. Language bindings wrap this to provide
44/// native API surfaces.
45pub struct TestClient {
46    server: Arc<TestServer>,
47}
48
49impl TestClient {
50    /// Create a new test client from an Axum router
51    pub fn from_router(router: axum::Router) -> Result<Self, String> {
52        let server = if tokio::runtime::Handle::try_current().is_ok() {
53            TestServer::builder()
54                .http_transport()
55                .try_build(router)
56                .map_err(|e| format!("Failed to create test server: {}", e))?
57        } else {
58            TestServer::try_new(router).map_err(|e| format!("Failed to create test server: {}", e))?
59        };
60
61        Ok(Self {
62            server: Arc::new(server),
63        })
64    }
65
66    /// Get the underlying test server (for WebSocket and SSE connections)
67    pub fn server(&self) -> &TestServer {
68        &self.server
69    }
70
71    /// Make a GET request
72    pub async fn get(
73        &self,
74        path: &str,
75        query_params: Option<Vec<(String, String)>>,
76        headers: Option<Vec<(String, String)>>,
77    ) -> Result<ResponseSnapshot, SnapshotError> {
78        let full_path = build_full_path(path, query_params.as_deref());
79        let mut request = self.server.get(&full_path);
80
81        if let Some(headers_vec) = headers {
82            request = self.add_headers(request, headers_vec)?;
83        }
84
85        let response = request.await;
86        snapshot_response(response).await
87    }
88
89    /// Make a POST request
90    pub async fn post(
91        &self,
92        path: &str,
93        json: Option<Value>,
94        form_data: Option<Vec<(String, String)>>,
95        multipart: MultipartPayload,
96        query_params: Option<Vec<(String, String)>>,
97        headers: Option<Vec<(String, String)>>,
98    ) -> Result<ResponseSnapshot, SnapshotError> {
99        let full_path = build_full_path(path, query_params.as_deref());
100        let mut request = self.server.post(&full_path);
101
102        if let Some(headers_vec) = headers {
103            request = self.add_headers(request, headers_vec)?;
104        }
105
106        if let Some((form_fields, files)) = multipart {
107            let (body, boundary) = super::build_multipart_body(&form_fields, &files);
108            let content_type = format!("multipart/form-data; boundary={}", boundary);
109            request = request.add_header("content-type", &content_type);
110            request = request.bytes(Bytes::from(body));
111        } else if let Some(form_fields) = form_data {
112            let fields_value = serde_json::to_value(&form_fields)
113                .map_err(|e| SnapshotError::Decompression(format!("Failed to serialize form fields: {}", e)))?;
114            let encoded = super::encode_urlencoded_body(&fields_value)
115                .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
116            request = request.add_header("content-type", "application/x-www-form-urlencoded");
117            request = request.bytes(Bytes::from(encoded));
118        } else if let Some(json_value) = json {
119            request = request.json(&json_value);
120        }
121
122        let response = request.await;
123        snapshot_response(response).await
124    }
125
126    /// Make a request with a raw body payload.
127    pub async fn request_raw(
128        &self,
129        method: Method,
130        path: &str,
131        body: Bytes,
132        query_params: Option<Vec<(String, String)>>,
133        headers: Option<Vec<(String, String)>>,
134    ) -> Result<ResponseSnapshot, SnapshotError> {
135        let full_path = build_full_path(path, query_params.as_deref());
136        let mut request = self.server.method(method, &full_path);
137
138        if let Some(headers_vec) = headers {
139            request = self.add_headers(request, headers_vec)?;
140        }
141
142        request = request.bytes(body);
143        let response = request.await;
144        snapshot_response(response).await
145    }
146
147    /// Make a PUT request
148    pub async fn put(
149        &self,
150        path: &str,
151        json: Option<Value>,
152        query_params: Option<Vec<(String, String)>>,
153        headers: Option<Vec<(String, String)>>,
154    ) -> Result<ResponseSnapshot, SnapshotError> {
155        let full_path = build_full_path(path, query_params.as_deref());
156        let mut request = self.server.put(&full_path);
157
158        if let Some(headers_vec) = headers {
159            request = self.add_headers(request, headers_vec)?;
160        }
161
162        if let Some(json_value) = json {
163            request = request.json(&json_value);
164        }
165
166        let response = request.await;
167        snapshot_response(response).await
168    }
169
170    /// Make a PATCH request
171    pub async fn patch(
172        &self,
173        path: &str,
174        json: Option<Value>,
175        query_params: Option<Vec<(String, String)>>,
176        headers: Option<Vec<(String, String)>>,
177    ) -> Result<ResponseSnapshot, SnapshotError> {
178        let full_path = build_full_path(path, query_params.as_deref());
179        let mut request = self.server.patch(&full_path);
180
181        if let Some(headers_vec) = headers {
182            request = self.add_headers(request, headers_vec)?;
183        }
184
185        if let Some(json_value) = json {
186            request = request.json(&json_value);
187        }
188
189        let response = request.await;
190        snapshot_response(response).await
191    }
192
193    /// Make a DELETE request
194    pub async fn delete(
195        &self,
196        path: &str,
197        query_params: Option<Vec<(String, String)>>,
198        headers: Option<Vec<(String, String)>>,
199    ) -> Result<ResponseSnapshot, SnapshotError> {
200        let full_path = build_full_path(path, query_params.as_deref());
201        let mut request = self.server.delete(&full_path);
202
203        if let Some(headers_vec) = headers {
204            request = self.add_headers(request, headers_vec)?;
205        }
206
207        let response = request.await;
208        snapshot_response(response).await
209    }
210
211    /// Make an OPTIONS request
212    pub async fn options(
213        &self,
214        path: &str,
215        query_params: Option<Vec<(String, String)>>,
216        headers: Option<Vec<(String, String)>>,
217    ) -> Result<ResponseSnapshot, SnapshotError> {
218        let full_path = build_full_path(path, query_params.as_deref());
219        let mut request = self.server.method(Method::OPTIONS, &full_path);
220
221        if let Some(headers_vec) = headers {
222            request = self.add_headers(request, headers_vec)?;
223        }
224
225        let response = request.await;
226        snapshot_response(response).await
227    }
228
229    /// Make a HEAD request
230    pub async fn head(
231        &self,
232        path: &str,
233        query_params: Option<Vec<(String, String)>>,
234        headers: Option<Vec<(String, String)>>,
235    ) -> Result<ResponseSnapshot, SnapshotError> {
236        let full_path = build_full_path(path, query_params.as_deref());
237        let mut request = self.server.method(Method::HEAD, &full_path);
238
239        if let Some(headers_vec) = headers {
240            request = self.add_headers(request, headers_vec)?;
241        }
242
243        let response = request.await;
244        snapshot_response(response).await
245    }
246
247    /// Make a TRACE request
248    pub async fn trace(
249        &self,
250        path: &str,
251        query_params: Option<Vec<(String, String)>>,
252        headers: Option<Vec<(String, String)>>,
253    ) -> Result<ResponseSnapshot, SnapshotError> {
254        let full_path = build_full_path(path, query_params.as_deref());
255        let mut request = self.server.method(Method::TRACE, &full_path);
256
257        if let Some(headers_vec) = headers {
258            request = self.add_headers(request, headers_vec)?;
259        }
260
261        let response = request.await;
262        snapshot_response(response).await
263    }
264
265    /// Send a GraphQL query/mutation to a custom endpoint
266    pub async fn graphql_at(
267        &self,
268        endpoint: &str,
269        query: &str,
270        variables: Option<Value>,
271        operation_name: Option<&str>,
272    ) -> Result<ResponseSnapshot, SnapshotError> {
273        let body = build_graphql_body(query, variables, operation_name);
274        self.post(endpoint, Some(body), None, None, None, None).await
275    }
276
277    /// Send a GraphQL query/mutation
278    pub async fn graphql(
279        &self,
280        query: &str,
281        variables: Option<Value>,
282        operation_name: Option<&str>,
283    ) -> Result<ResponseSnapshot, SnapshotError> {
284        self.graphql_at("/graphql", query, variables, operation_name).await
285    }
286
287    /// Send a GraphQL query and return HTTP status code separately
288    ///
289    /// This method allows tests to distinguish between:
290    /// - HTTP-level errors (400/422 for invalid requests)
291    /// - GraphQL-level errors (200 with errors in response body)
292    ///
293    /// # Example
294    /// ```ignore
295    /// let (status, snapshot) = client.graphql_with_status(
296    ///     "query { invalid syntax",
297    ///     None,
298    ///     None
299    /// ).await?;
300    /// assert_eq!(status, 400); // HTTP parse error
301    /// ```
302    pub async fn graphql_with_status(
303        &self,
304        query: &str,
305        variables: Option<Value>,
306        operation_name: Option<&str>,
307    ) -> Result<(u16, ResponseSnapshot), SnapshotError> {
308        let snapshot = self.graphql(query, variables, operation_name).await?;
309        let status = snapshot.status;
310        Ok((status, snapshot))
311    }
312
313    /// Send a GraphQL subscription (WebSocket) to a custom endpoint.
314    ///
315    /// Uses the `graphql-transport-ws` protocol and captures the first `next` payload.
316    /// After the first payload is received, this client sends `complete` to unsubscribe.
317    pub async fn graphql_subscription_at(
318        &self,
319        endpoint: &str,
320        query: &str,
321        variables: Option<Value>,
322        operation_name: Option<&str>,
323    ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
324        let operation_id = "spikard-subscription-1".to_string();
325        let upgrade = self
326            .server
327            .get_websocket(endpoint)
328            .add_header("sec-websocket-protocol", "graphql-transport-ws")
329            .await;
330
331        if upgrade.status_code().as_u16() != 101 {
332            return Err(SnapshotError::Decompression(format!(
333                "GraphQL subscription upgrade failed with status {}",
334                upgrade.status_code()
335            )));
336        }
337
338        let mut websocket = super::WebSocketConnection::new(upgrade.into_websocket().await);
339
340        websocket
341            .send_json(&serde_json::json!({"type": "connection_init"}))
342            .await;
343        wait_for_graphql_ack(&mut websocket).await?;
344
345        websocket
346            .send_json(&serde_json::json!({
347                "id": operation_id,
348                "type": "subscribe",
349                "payload": build_graphql_body(query, variables, operation_name),
350            }))
351            .await;
352
353        let mut event = None;
354        let mut errors = Vec::new();
355        let mut complete_received = false;
356
357        for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
358            let message = timeout(
359                GRAPHQL_WS_MESSAGE_TIMEOUT,
360                receive_graphql_protocol_message(&mut websocket),
361            )
362            .await
363            .map_err(|_| {
364                SnapshotError::Decompression("Timed out waiting for GraphQL subscription message".to_string())
365            })??;
366
367            let message_type = message.get("type").and_then(Value::as_str).unwrap_or_default();
368            match message_type {
369                "next" => {
370                    if message
371                        .get("id")
372                        .and_then(Value::as_str)
373                        .is_none_or(|id| id == operation_id)
374                    {
375                        event = message.get("payload").cloned();
376
377                        websocket
378                            .send_json(&serde_json::json!({
379                                "id": operation_id,
380                                "type": "complete",
381                            }))
382                            .await;
383
384                        if let Ok(next_message) = timeout(
385                            GRAPHQL_WS_MESSAGE_TIMEOUT,
386                            receive_graphql_protocol_message(&mut websocket),
387                        )
388                        .await
389                            && let Ok(next_message) = next_message
390                            && next_message.get("type").and_then(Value::as_str) == Some("complete")
391                            && next_message
392                                .get("id")
393                                .and_then(Value::as_str)
394                                .is_none_or(|id| id == operation_id)
395                        {
396                            complete_received = true;
397                        }
398                        break;
399                    }
400                }
401                "error" => {
402                    errors.push(message.get("payload").cloned().unwrap_or(message));
403                    break;
404                }
405                "complete" => {
406                    if message
407                        .get("id")
408                        .and_then(Value::as_str)
409                        .is_none_or(|id| id == operation_id)
410                    {
411                        complete_received = true;
412                        break;
413                    }
414                }
415                "ping" => {
416                    let mut pong = serde_json::json!({"type": "pong"});
417                    if let Some(payload) = message.get("payload") {
418                        pong["payload"] = payload.clone();
419                    }
420                    websocket.send_json(&pong).await;
421                }
422                "pong" => {}
423                _ => {}
424            }
425        }
426
427        websocket.close().await;
428
429        if event.is_none() && errors.is_empty() && !complete_received {
430            return Err(SnapshotError::Decompression(
431                "No GraphQL subscription event received before timeout".to_string(),
432            ));
433        }
434
435        Ok(GraphQLSubscriptionSnapshot {
436            operation_id,
437            acknowledged: true,
438            event,
439            errors,
440            complete_received,
441        })
442    }
443
444    /// Send a GraphQL subscription (WebSocket).
445    ///
446    /// Uses `/graphql` as the default subscription endpoint.
447    pub async fn graphql_subscription(
448        &self,
449        query: &str,
450        variables: Option<Value>,
451        operation_name: Option<&str>,
452    ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
453        self.graphql_subscription_at("/graphql", query, variables, operation_name)
454            .await
455    }
456
457    /// Add headers to a test request builder
458    fn add_headers(
459        &self,
460        mut request: axum_test::TestRequest,
461        headers: Vec<(String, String)>,
462    ) -> Result<axum_test::TestRequest, SnapshotError> {
463        for (key, value) in headers {
464            let header_name = HeaderName::from_bytes(key.as_bytes())
465                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
466            let header_value = HeaderValue::from_str(&value)
467                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
468            request = request.add_header(header_name, header_value);
469        }
470        Ok(request)
471    }
472}
473
474async fn wait_for_graphql_ack(websocket: &mut super::WebSocketConnection) -> Result<(), SnapshotError> {
475    for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
476        let message = timeout(GRAPHQL_WS_MESSAGE_TIMEOUT, receive_graphql_protocol_message(websocket))
477            .await
478            .map_err(|_| SnapshotError::Decompression("Timed out waiting for GraphQL connection_ack".to_string()))??;
479
480        match message.get("type").and_then(Value::as_str).unwrap_or_default() {
481            "connection_ack" => return Ok(()),
482            "ping" => {
483                let mut pong = serde_json::json!({"type": "pong"});
484                if let Some(payload) = message.get("payload") {
485                    pong["payload"] = payload.clone();
486                }
487                websocket.send_json(&pong).await;
488            }
489            "connection_error" | "error" => {
490                return Err(SnapshotError::Decompression(format!(
491                    "GraphQL subscription rejected during init: {}",
492                    message
493                )));
494            }
495            _ => {}
496        }
497    }
498
499    Err(SnapshotError::Decompression(
500        "No GraphQL connection_ack received".to_string(),
501    ))
502}
503
504async fn receive_graphql_protocol_message(websocket: &mut super::WebSocketConnection) -> Result<Value, SnapshotError> {
505    loop {
506        match websocket.receive_message().await {
507            super::WebSocketMessage::Text(text) => {
508                return serde_json::from_str::<Value>(&text).map_err(|e| {
509                    SnapshotError::Decompression(format!("Failed to parse GraphQL WebSocket message as JSON: {}", e))
510                });
511            }
512            super::WebSocketMessage::Binary(bytes) => {
513                return serde_json::from_slice::<Value>(&bytes).map_err(|e| {
514                    SnapshotError::Decompression(format!(
515                        "Failed to parse GraphQL binary WebSocket message as JSON: {}",
516                        e
517                    ))
518                });
519            }
520            super::WebSocketMessage::Ping(_) | super::WebSocketMessage::Pong(_) => continue,
521            super::WebSocketMessage::Close(reason) => {
522                return Err(SnapshotError::Decompression(format!(
523                    "GraphQL WebSocket connection closed before response: {:?}",
524                    reason
525                )));
526            }
527        }
528    }
529}
530
531/// Build a GraphQL request body from query, variables, and operation name
532pub fn build_graphql_body(query: &str, variables: Option<Value>, operation_name: Option<&str>) -> Value {
533    let mut body = serde_json::json!({ "query": query });
534    if let Some(vars) = variables {
535        body["variables"] = vars;
536    }
537    if let Some(op_name) = operation_name {
538        body["operationName"] = Value::String(op_name.to_string());
539    }
540    body
541}
542
543/// Build a full path with query parameters
544fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
545    match query_params {
546        None | Some(&[]) => path.to_string(),
547        Some(params) => {
548            let query_string: Vec<String> = params
549                .iter()
550                .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
551                .collect();
552
553            if path.contains('?') {
554                format!("{}&{}", path, query_string.join("&"))
555            } else {
556                format!("{}?{}", path, query_string.join("&"))
557            }
558        }
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use axum::{
566        Router,
567        extract::ws::{Message, WebSocketUpgrade},
568        routing::get,
569    };
570
571    #[test]
572    fn build_full_path_no_params() {
573        let path = "/users";
574        assert_eq!(build_full_path(path, None), "/users");
575        assert_eq!(build_full_path(path, Some(&[])), "/users");
576    }
577
578    #[test]
579    fn build_full_path_with_params() {
580        let path = "/users";
581        let params = vec![
582            ("id".to_string(), "123".to_string()),
583            ("name".to_string(), "test user".to_string()),
584        ];
585        let result = build_full_path(path, Some(&params));
586        assert!(result.starts_with("/users?"));
587        assert!(result.contains("id=123"));
588        assert!(result.contains("name=test%20user"));
589    }
590
591    #[test]
592    fn build_full_path_existing_query() {
593        let path = "/users?active=true";
594        let params = vec![("id".to_string(), "123".to_string())];
595        let result = build_full_path(path, Some(&params));
596        assert_eq!(result, "/users?active=true&id=123");
597    }
598
599    #[test]
600    fn test_graphql_query_builder() {
601        let query = "{ users { id name } }";
602        let variables = Some(serde_json::json!({ "limit": 10 }));
603        let op_name = Some("GetUsers");
604
605        let mut body = serde_json::json!({ "query": query });
606        if let Some(vars) = variables {
607            body["variables"] = vars;
608        }
609        if let Some(op_name) = op_name {
610            body["operationName"] = Value::String(op_name.to_string());
611        }
612
613        assert_eq!(body["query"], query);
614        assert_eq!(body["variables"]["limit"], 10);
615        assert_eq!(body["operationName"], "GetUsers");
616    }
617
618    #[test]
619    fn test_graphql_with_status_method() {
620        let query = "query { hello }";
621        let body = serde_json::json!({
622            "query": query,
623            "variables": null,
624            "operationName": null
625        });
626
627        // This test validates the method signature and return type
628        // Actual HTTP status testing will happen in integration tests
629        let expected_fields = vec!["query", "variables", "operationName"];
630        for field in expected_fields {
631            assert!(body.get(field).is_some(), "Missing field: {}", field);
632        }
633    }
634
635    #[test]
636    fn test_build_graphql_body_basic() {
637        let query = "{ users { id name } }";
638        let body = build_graphql_body(query, None, None);
639
640        assert_eq!(body["query"], query);
641        assert!(body.get("variables").is_none() || body["variables"].is_null());
642        assert!(body.get("operationName").is_none() || body["operationName"].is_null());
643    }
644
645    #[test]
646    fn test_build_graphql_body_with_variables() {
647        let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
648        let variables = Some(serde_json::json!({ "id": "123" }));
649        let body = build_graphql_body(query, variables, None);
650
651        assert_eq!(body["query"], query);
652        assert_eq!(body["variables"]["id"], "123");
653    }
654
655    #[test]
656    fn test_build_graphql_body_with_operation_name() {
657        let query = "query GetUsers { users { id } }";
658        let op_name = Some("GetUsers");
659        let body = build_graphql_body(query, None, op_name);
660
661        assert_eq!(body["query"], query);
662        assert_eq!(body["operationName"], "GetUsers");
663    }
664
665    #[test]
666    fn test_build_graphql_body_all_fields() {
667        let query = "mutation CreateUser($name: String!) { createUser(name: $name) { id } }";
668        let variables = Some(serde_json::json!({ "name": "Alice" }));
669        let op_name = Some("CreateUser");
670        let body = build_graphql_body(query, variables, op_name);
671
672        assert_eq!(body["query"], query);
673        assert_eq!(body["variables"]["name"], "Alice");
674        assert_eq!(body["operationName"], "CreateUser");
675    }
676
677    #[tokio::test]
678    async fn graphql_subscription_returns_first_event_and_completes() {
679        let app = Router::new().route(
680            "/graphql",
681            get(|ws: WebSocketUpgrade| async move {
682                ws.on_upgrade(|mut socket| async move {
683                    while let Some(result) = socket.recv().await {
684                        let Ok(Message::Text(text)) = result else {
685                            continue;
686                        };
687                        let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
688                            continue;
689                        };
690
691                        match message.get("type").and_then(Value::as_str) {
692                            Some("connection_init") => {
693                                let _ = socket
694                                    .send(Message::Text(
695                                        serde_json::json!({"type":"connection_ack"}).to_string().into(),
696                                    ))
697                                    .await;
698                            }
699                            Some("subscribe") => {
700                                let id = message.get("id").and_then(Value::as_str).unwrap_or("1");
701                                let _ = socket
702                                    .send(Message::Text(
703                                        serde_json::json!({
704                                            "id": id,
705                                            "type": "next",
706                                            "payload": {"data": {"ticker": "AAPL"}},
707                                        })
708                                        .to_string()
709                                        .into(),
710                                    ))
711                                    .await;
712
713                                if let Some(Ok(Message::Text(complete_text))) = socket.recv().await {
714                                    let Ok(complete_message): Result<Value, _> = serde_json::from_str(&complete_text)
715                                    else {
716                                        break;
717                                    };
718                                    if complete_message.get("type").and_then(Value::as_str) == Some("complete") {
719                                        let _ = socket
720                                            .send(Message::Text(
721                                                serde_json::json!({"id": id, "type":"complete"}).to_string().into(),
722                                            ))
723                                            .await;
724                                    }
725                                }
726                                break;
727                            }
728                            _ => {}
729                        }
730                    }
731                })
732            }),
733        );
734
735        let client = TestClient::from_router(app).expect("client");
736        let snapshot = client
737            .graphql_subscription("subscription { ticker }", None, None)
738            .await
739            .expect("subscription snapshot");
740
741        assert!(snapshot.acknowledged);
742        assert_eq!(snapshot.errors, Vec::<Value>::new());
743        assert_eq!(snapshot.event, Some(serde_json::json!({"data": {"ticker": "AAPL"}})));
744        assert!(snapshot.complete_received);
745    }
746
747    #[tokio::test]
748    async fn graphql_subscription_surfaces_connection_error() {
749        let app = Router::new().route(
750            "/graphql",
751            get(|ws: WebSocketUpgrade| async move {
752                ws.on_upgrade(|mut socket| async move {
753                    while let Some(result) = socket.recv().await {
754                        let Ok(Message::Text(text)) = result else {
755                            continue;
756                        };
757                        let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
758                            continue;
759                        };
760
761                        if message.get("type").and_then(Value::as_str) == Some("connection_init") {
762                            let _ = socket
763                                .send(Message::Text(
764                                    serde_json::json!({
765                                        "type": "connection_error",
766                                        "payload": {"message": "not authorized"},
767                                    })
768                                    .to_string()
769                                    .into(),
770                                ))
771                                .await;
772                            break;
773                        }
774                    }
775                })
776            }),
777        );
778
779        let client = TestClient::from_router(app).expect("client");
780        let error = client
781            .graphql_subscription("subscription { privateFeed }", None, None)
782            .await
783            .expect_err("expected connection error");
784
785        assert!(error.to_string().contains("connection_error"));
786    }
787}