Skip to main content

spikard_http/testing/
test_client.rs

1//! Core test client for Spikard applications
2//!
3//! This module provides a language-agnostic TestClient that can be wrapped by
4//! language bindings (PyO3, napi-rs, magnus) to provide Pythonic, JavaScripty, and
5//! Ruby-like APIs respectively.
6//!
7//! The core client handles all HTTP method dispatch, query params, header management,
8//! body encoding (JSON, form-data, multipart), and response snapshot capture.
9
10use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::{Arc, Mutex};
16use std::time::Duration;
17use tokio::time::timeout;
18use urlencoding::encode;
19
20type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
21const GRAPHQL_WS_MESSAGE_TIMEOUT: Duration = Duration::from_secs(2);
22const GRAPHQL_WS_MAX_CONTROL_MESSAGES: usize = 32;
23
24/// Snapshot of a GraphQL subscription exchange over WebSocket.
25#[derive(Debug, Clone, PartialEq)]
26pub struct GraphQLSubscriptionSnapshot {
27    /// Operation id used for the subscription request.
28    pub operation_id: String,
29    /// Whether the server acknowledged the GraphQL WebSocket connection.
30    pub acknowledged: bool,
31    /// First `next.payload` received for this subscription, if any.
32    pub event: Option<serde_json::Value>,
33    /// GraphQL protocol errors emitted by the server.
34    pub errors: Vec<serde_json::Value>,
35    /// Whether a `complete` frame was observed for this operation.
36    pub complete_received: bool,
37}
38
39/// Core test client for making HTTP requests to a Spikard application.
40///
41/// This struct wraps axum-test's TestServer and provides a language-agnostic
42/// interface for making HTTP requests, sending WebSocket connections, and
43/// handling Server-Sent Events. Language bindings wrap this to provide
44/// native API surfaces.
45pub struct TestClient {
46    mock_server: Arc<TestServer>,
47    router: axum::Router,
48    http_server: Mutex<Option<Arc<TestServer>>>,
49}
50
51impl TestClient {
52    /// Create a new test client from an Axum router
53    pub fn from_router(router: axum::Router) -> Result<Self, String> {
54        let mock_server =
55            TestServer::try_new(router.clone()).map_err(|e| format!("Failed to create test server: {}", e))?;
56
57        Ok(Self {
58            mock_server: Arc::new(mock_server),
59            router,
60            http_server: Mutex::new(None),
61        })
62    }
63
64    /// Get or initialize the underlying socket-backed test server for WebSocket traffic.
65    pub fn http_server(&self) -> Result<Arc<TestServer>, SnapshotError> {
66        let mut guard = self
67            .http_server
68            .lock()
69            .map_err(|_| SnapshotError::Decompression("Failed to lock HTTP test server state".to_string()))?;
70
71        if let Some(server) = guard.as_ref() {
72            return Ok(Arc::clone(server));
73        }
74
75        if tokio::runtime::Handle::try_current().is_err() {
76            return Err(SnapshotError::Decompression(
77                "WebSocket test transport requires an active Tokio runtime".to_string(),
78            ));
79        }
80
81        let server = Arc::new(
82            TestServer::builder()
83                .http_transport()
84                .try_build(self.router.clone())
85                .map_err(|e| SnapshotError::Decompression(format!("Failed to create test server: {}", e)))?,
86        );
87        *guard = Some(Arc::clone(&server));
88        Ok(server)
89    }
90
91    /// Make a GET request
92    #[doc(hidden)]
93    pub async fn get(
94        &self,
95        path: &str,
96        query_params: Option<Vec<(String, String)>>,
97        headers: Option<Vec<(String, String)>>,
98    ) -> Result<ResponseSnapshot, SnapshotError> {
99        let full_path = build_full_path(path, query_params.as_deref());
100        let mut request = self.mock_server.get(&full_path);
101
102        if let Some(headers_vec) = headers {
103            request = self.add_headers(request, headers_vec)?;
104        }
105
106        let response = request.await;
107        snapshot_response(response).await
108    }
109
110    /// Make a POST request
111    #[doc(hidden)]
112    pub async fn post(
113        &self,
114        path: &str,
115        json: Option<Value>,
116        form_data: Option<Vec<(String, String)>>,
117        multipart: MultipartPayload,
118        query_params: Option<Vec<(String, String)>>,
119        headers: Option<Vec<(String, String)>>,
120    ) -> Result<ResponseSnapshot, SnapshotError> {
121        let full_path = build_full_path(path, query_params.as_deref());
122        let mut request = self.mock_server.post(&full_path);
123
124        if let Some(headers_vec) = headers {
125            request = self.add_headers(request, headers_vec)?;
126        }
127
128        if let Some((form_fields, files)) = multipart {
129            let (body, boundary) = super::build_multipart_body(&form_fields, &files);
130            let content_type = format!("multipart/form-data; boundary={}", boundary);
131            request = request.add_header("content-type", &content_type);
132            request = request.bytes(Bytes::from(body));
133        } else if let Some(form_fields) = form_data {
134            let fields_value = serde_json::to_value(&form_fields)
135                .map_err(|e| SnapshotError::Decompression(format!("Failed to serialize form fields: {}", e)))?;
136            let encoded = super::encode_urlencoded_body(&fields_value)
137                .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
138            request = request.add_header("content-type", "application/x-www-form-urlencoded");
139            request = request.bytes(Bytes::from(encoded));
140        } else if let Some(json_value) = json {
141            request = request.json(&json_value);
142        }
143
144        let response = request.await;
145        snapshot_response(response).await
146    }
147
148    /// Make a request with a raw body payload.
149    #[doc(hidden)]
150    pub async fn request_raw(
151        &self,
152        method: Method,
153        path: &str,
154        body: Bytes,
155        query_params: Option<Vec<(String, String)>>,
156        headers: Option<Vec<(String, String)>>,
157    ) -> Result<ResponseSnapshot, SnapshotError> {
158        let full_path = build_full_path(path, query_params.as_deref());
159        let mut request = self.mock_server.method(method, &full_path);
160
161        if let Some(headers_vec) = headers {
162            request = self.add_headers(request, headers_vec)?;
163        }
164
165        request = request.bytes(body);
166        let response = request.await;
167        snapshot_response(response).await
168    }
169
170    /// Make a PUT request
171    #[doc(hidden)]
172    pub async fn put(
173        &self,
174        path: &str,
175        json: Option<Value>,
176        query_params: Option<Vec<(String, String)>>,
177        headers: Option<Vec<(String, String)>>,
178    ) -> Result<ResponseSnapshot, SnapshotError> {
179        let full_path = build_full_path(path, query_params.as_deref());
180        let mut request = self.mock_server.put(&full_path);
181
182        if let Some(headers_vec) = headers {
183            request = self.add_headers(request, headers_vec)?;
184        }
185
186        if let Some(json_value) = json {
187            request = request.json(&json_value);
188        }
189
190        let response = request.await;
191        snapshot_response(response).await
192    }
193
194    /// Make a PATCH request
195    #[doc(hidden)]
196    pub async fn patch(
197        &self,
198        path: &str,
199        json: Option<Value>,
200        query_params: Option<Vec<(String, String)>>,
201        headers: Option<Vec<(String, String)>>,
202    ) -> Result<ResponseSnapshot, SnapshotError> {
203        let full_path = build_full_path(path, query_params.as_deref());
204        let mut request = self.mock_server.patch(&full_path);
205
206        if let Some(headers_vec) = headers {
207            request = self.add_headers(request, headers_vec)?;
208        }
209
210        if let Some(json_value) = json {
211            request = request.json(&json_value);
212        }
213
214        let response = request.await;
215        snapshot_response(response).await
216    }
217
218    /// Make a DELETE request
219    #[doc(hidden)]
220    pub async fn delete(
221        &self,
222        path: &str,
223        query_params: Option<Vec<(String, String)>>,
224        headers: Option<Vec<(String, String)>>,
225    ) -> Result<ResponseSnapshot, SnapshotError> {
226        let full_path = build_full_path(path, query_params.as_deref());
227        let mut request = self.mock_server.delete(&full_path);
228
229        if let Some(headers_vec) = headers {
230            request = self.add_headers(request, headers_vec)?;
231        }
232
233        let response = request.await;
234        snapshot_response(response).await
235    }
236
237    /// Make an OPTIONS request
238    #[doc(hidden)]
239    pub async fn options(
240        &self,
241        path: &str,
242        query_params: Option<Vec<(String, String)>>,
243        headers: Option<Vec<(String, String)>>,
244    ) -> Result<ResponseSnapshot, SnapshotError> {
245        let full_path = build_full_path(path, query_params.as_deref());
246        let mut request = self.mock_server.method(Method::OPTIONS, &full_path);
247
248        if let Some(headers_vec) = headers {
249            request = self.add_headers(request, headers_vec)?;
250        }
251
252        let response = request.await;
253        snapshot_response(response).await
254    }
255
256    /// Make a HEAD request
257    #[doc(hidden)]
258    pub async fn head(
259        &self,
260        path: &str,
261        query_params: Option<Vec<(String, String)>>,
262        headers: Option<Vec<(String, String)>>,
263    ) -> Result<ResponseSnapshot, SnapshotError> {
264        let full_path = build_full_path(path, query_params.as_deref());
265        let mut request = self.mock_server.method(Method::HEAD, &full_path);
266
267        if let Some(headers_vec) = headers {
268            request = self.add_headers(request, headers_vec)?;
269        }
270
271        let response = request.await;
272        snapshot_response(response).await
273    }
274
275    /// Make a TRACE request
276    #[doc(hidden)]
277    pub async fn trace(
278        &self,
279        path: &str,
280        query_params: Option<Vec<(String, String)>>,
281        headers: Option<Vec<(String, String)>>,
282    ) -> Result<ResponseSnapshot, SnapshotError> {
283        let full_path = build_full_path(path, query_params.as_deref());
284        let mut request = self.mock_server.method(Method::TRACE, &full_path);
285
286        if let Some(headers_vec) = headers {
287            request = self.add_headers(request, headers_vec)?;
288        }
289
290        let response = request.await;
291        snapshot_response(response).await
292    }
293
294    /// Send a GraphQL query/mutation to a custom endpoint
295    pub async fn graphql_at(
296        &self,
297        endpoint: &str,
298        query: &str,
299        variables: Option<serde_json::Value>,
300        operation_name: Option<&str>,
301    ) -> Result<ResponseSnapshot, SnapshotError> {
302        let body = build_graphql_body(query, variables, operation_name);
303        self.post(endpoint, Some(body), None, None, None, None).await
304    }
305
306    /// Send a GraphQL query/mutation
307    pub async fn graphql(
308        &self,
309        query: &str,
310        variables: Option<serde_json::Value>,
311        operation_name: Option<&str>,
312    ) -> Result<ResponseSnapshot, SnapshotError> {
313        self.graphql_at("/graphql", query, variables, operation_name).await
314    }
315
316    /// Send a GraphQL query and return HTTP status code separately
317    ///
318    /// This method allows tests to distinguish between:
319    /// - HTTP-level errors (400/422 for invalid requests)
320    /// - GraphQL-level errors (200 with errors in response body)
321    ///
322    /// # Example
323    /// ```ignore
324    /// let (status, snapshot) = client.graphql_with_status(
325    ///     "query { invalid syntax",
326    ///     None,
327    ///     None
328    /// ).await?;
329    /// assert_eq!(status, 400); // HTTP parse error
330    /// ```
331    #[doc(hidden)]
332    pub async fn graphql_with_status(
333        &self,
334        query: &str,
335        variables: Option<serde_json::Value>,
336        operation_name: Option<&str>,
337    ) -> Result<(u16, ResponseSnapshot), SnapshotError> {
338        let snapshot = self.graphql(query, variables, operation_name).await?;
339        let status = snapshot.status;
340        Ok((status, snapshot))
341    }
342
343    /// Send a GraphQL subscription (WebSocket) to a custom endpoint.
344    ///
345    /// Uses the `graphql-transport-ws` protocol and captures the first `next` payload.
346    /// After the first payload is received, this client sends `complete` to unsubscribe.
347    pub async fn graphql_subscription_at(
348        &self,
349        endpoint: &str,
350        query: &str,
351        variables: Option<serde_json::Value>,
352        operation_name: Option<&str>,
353    ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
354        let operation_id = "spikard-subscription-1".to_string();
355        let http_server = self.http_server()?;
356        let upgrade = http_server
357            .get_websocket(endpoint)
358            .add_header("sec-websocket-protocol", "graphql-transport-ws")
359            .await;
360
361        if upgrade.status_code().as_u16() != 101 {
362            return Err(SnapshotError::Decompression(format!(
363                "GraphQL subscription upgrade failed with status {}",
364                upgrade.status_code()
365            )));
366        }
367
368        let mut websocket = super::WebSocketConnection::new(upgrade.into_websocket().await);
369
370        websocket
371            .send_json(&serde_json::json!({"type": "connection_init"}))
372            .await;
373        wait_for_graphql_ack(&mut websocket).await?;
374
375        websocket
376            .send_json(&serde_json::json!({
377                "id": operation_id,
378                "type": "subscribe",
379                "payload": build_graphql_body(query, variables, operation_name),
380            }))
381            .await;
382
383        let mut event = None;
384        let mut errors = Vec::new();
385        let mut complete_received = false;
386
387        for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
388            let message = timeout(
389                GRAPHQL_WS_MESSAGE_TIMEOUT,
390                receive_graphql_protocol_message(&mut websocket),
391            )
392            .await
393            .map_err(|_| {
394                SnapshotError::Decompression("Timed out waiting for GraphQL subscription message".to_string())
395            })??;
396
397            let message_type = message.get("type").and_then(Value::as_str).unwrap_or_default();
398            match message_type {
399                "next"
400                    if message
401                        .get("id")
402                        .and_then(Value::as_str)
403                        .is_none_or(|id| id == operation_id) =>
404                {
405                    event = message.get("payload").cloned();
406
407                    websocket
408                        .send_json(&serde_json::json!({
409                            "id": operation_id,
410                            "type": "complete",
411                        }))
412                        .await;
413
414                    if let Ok(next_message) = timeout(
415                        GRAPHQL_WS_MESSAGE_TIMEOUT,
416                        receive_graphql_protocol_message(&mut websocket),
417                    )
418                    .await
419                        && let Ok(next_message) = next_message
420                        && next_message.get("type").and_then(Value::as_str) == Some("complete")
421                        && next_message
422                            .get("id")
423                            .and_then(Value::as_str)
424                            .is_none_or(|id| id == operation_id)
425                    {
426                        complete_received = true;
427                    }
428                    break;
429                }
430                "error" => {
431                    errors.push(message.get("payload").cloned().unwrap_or(message));
432                    break;
433                }
434                "complete"
435                    if message
436                        .get("id")
437                        .and_then(Value::as_str)
438                        .is_none_or(|id| id == operation_id) =>
439                {
440                    complete_received = true;
441                    break;
442                }
443                "ping" => {
444                    let mut pong = serde_json::json!({"type": "pong"});
445                    if let Some(payload) = message.get("payload") {
446                        pong["payload"] = payload.clone();
447                    }
448                    websocket.send_json(&pong).await;
449                }
450                "pong" => {}
451                _ => {}
452            }
453        }
454
455        let _ = websocket.close().await;
456
457        if event.is_none() && errors.is_empty() && !complete_received {
458            return Err(SnapshotError::Decompression(
459                "No GraphQL subscription event received before timeout".to_string(),
460            ));
461        }
462
463        Ok(GraphQLSubscriptionSnapshot {
464            operation_id,
465            acknowledged: true,
466            event,
467            errors,
468            complete_received,
469        })
470    }
471
472    /// Send a GraphQL subscription (WebSocket).
473    ///
474    /// Uses `/graphql` as the default subscription endpoint.
475    pub async fn graphql_subscription(
476        &self,
477        query: &str,
478        variables: Option<serde_json::Value>,
479        operation_name: Option<&str>,
480    ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
481        self.graphql_subscription_at("/graphql", query, variables, operation_name)
482            .await
483    }
484
485    /// Add headers to a test request builder
486    fn add_headers(
487        &self,
488        mut request: axum_test::TestRequest,
489        headers: Vec<(String, String)>,
490    ) -> Result<axum_test::TestRequest, SnapshotError> {
491        for (key, value) in headers {
492            let header_name = HeaderName::from_bytes(key.as_bytes())
493                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
494            let header_value = HeaderValue::from_str(&value)
495                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
496            request = request.add_header(header_name, header_value);
497        }
498        Ok(request)
499    }
500}
501
502async fn wait_for_graphql_ack(websocket: &mut super::WebSocketConnection) -> Result<(), SnapshotError> {
503    for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
504        let message = timeout(GRAPHQL_WS_MESSAGE_TIMEOUT, receive_graphql_protocol_message(websocket))
505            .await
506            .map_err(|_| SnapshotError::Decompression("Timed out waiting for GraphQL connection_ack".to_string()))??;
507
508        match message.get("type").and_then(Value::as_str).unwrap_or_default() {
509            "connection_ack" => return Ok(()),
510            "ping" => {
511                let mut pong = serde_json::json!({"type": "pong"});
512                if let Some(payload) = message.get("payload") {
513                    pong["payload"] = payload.clone();
514                }
515                websocket.send_json(&pong).await;
516            }
517            "connection_error" | "error" => {
518                return Err(SnapshotError::Decompression(format!(
519                    "GraphQL subscription rejected during init: {}",
520                    message
521                )));
522            }
523            _ => {}
524        }
525    }
526
527    Err(SnapshotError::Decompression(
528        "No GraphQL connection_ack received".to_string(),
529    ))
530}
531
532async fn receive_graphql_protocol_message(websocket: &mut super::WebSocketConnection) -> Result<Value, SnapshotError> {
533    loop {
534        match websocket.receive_message().await {
535            super::WebSocketMessage::Text(text) => {
536                return serde_json::from_str::<Value>(&text).map_err(|e| {
537                    SnapshotError::Decompression(format!("Failed to parse GraphQL WebSocket message as JSON: {}", e))
538                });
539            }
540            super::WebSocketMessage::Binary(bytes) => {
541                return serde_json::from_slice::<Value>(&bytes).map_err(|e| {
542                    SnapshotError::Decompression(format!(
543                        "Failed to parse GraphQL binary WebSocket message as JSON: {}",
544                        e
545                    ))
546                });
547            }
548            super::WebSocketMessage::Ping(_) | super::WebSocketMessage::Pong(_) => continue,
549            super::WebSocketMessage::Close { code, reason } => {
550                return Err(SnapshotError::Decompression(format!(
551                    "GraphQL WebSocket connection closed before response: code={} reason={:?}",
552                    code, reason
553                )));
554            }
555        }
556    }
557}
558
559/// Build a GraphQL request body from query, variables, and operation name
560pub fn build_graphql_body(query: &str, variables: Option<serde_json::Value>, operation_name: Option<&str>) -> Value {
561    let mut body = serde_json::json!({ "query": query });
562    if let Some(vars) = variables {
563        body["variables"] = vars;
564    }
565    if let Some(op_name) = operation_name {
566        body["operationName"] = Value::String(op_name.to_string());
567    }
568    body
569}
570
571/// Build a full path with query parameters
572fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
573    match query_params {
574        None | Some(&[]) => path.to_string(),
575        Some(params) => {
576            let query_string: Vec<String> = params
577                .iter()
578                .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
579                .collect();
580
581            if path.contains('?') {
582                format!("{}&{}", path, query_string.join("&"))
583            } else {
584                format!("{}?{}", path, query_string.join("&"))
585            }
586        }
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use axum::{
594        Router,
595        extract::ws::{Message, WebSocketUpgrade},
596        routing::get,
597    };
598
599    #[test]
600    fn build_full_path_no_params() {
601        let path = "/users";
602        assert_eq!(build_full_path(path, None), "/users");
603        assert_eq!(build_full_path(path, Some(&[])), "/users");
604    }
605
606    #[test]
607    fn build_full_path_with_params() {
608        let path = "/users";
609        let params = vec![
610            ("id".to_string(), "123".to_string()),
611            ("name".to_string(), "test user".to_string()),
612        ];
613        let result = build_full_path(path, Some(&params));
614        assert!(result.starts_with("/users?"));
615        assert!(result.contains("id=123"));
616        assert!(result.contains("name=test%20user"));
617    }
618
619    #[test]
620    fn build_full_path_existing_query() {
621        let path = "/users?active=true";
622        let params = vec![("id".to_string(), "123".to_string())];
623        let result = build_full_path(path, Some(&params));
624        assert_eq!(result, "/users?active=true&id=123");
625    }
626
627    #[test]
628    fn test_graphql_query_builder() {
629        let query = "{ users { id name } }";
630        let variables = Some(serde_json::json!({ "limit": 10 }));
631        let op_name = Some("GetUsers");
632
633        let mut body = serde_json::json!({ "query": query });
634        if let Some(vars) = variables {
635            body["variables"] = vars;
636        }
637        if let Some(op_name) = op_name {
638            body["operationName"] = Value::String(op_name.to_string());
639        }
640
641        assert_eq!(body["query"], query);
642        assert_eq!(body["variables"]["limit"], 10);
643        assert_eq!(body["operationName"], "GetUsers");
644    }
645
646    #[test]
647    fn test_graphql_with_status_method() {
648        let query = "query { hello }";
649        let body = serde_json::json!({
650            "query": query,
651            "variables": null,
652            "operationName": null
653        });
654
655        // This test validates the method signature and return type
656        // Actual HTTP status testing will happen in integration tests
657        let expected_fields = vec!["query", "variables", "operationName"];
658        for field in expected_fields {
659            assert!(body.get(field).is_some(), "Missing field: {}", field);
660        }
661    }
662
663    #[test]
664    fn test_build_graphql_body_basic() {
665        let query = "{ users { id name } }";
666        let body = build_graphql_body(query, None, None);
667
668        assert_eq!(body["query"], query);
669        assert!(body.get("variables").is_none() || body["variables"].is_null());
670        assert!(body.get("operationName").is_none() || body["operationName"].is_null());
671    }
672
673    #[test]
674    fn test_build_graphql_body_with_variables() {
675        let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
676        let variables = Some(serde_json::json!({ "id": "123" }));
677        let body = build_graphql_body(query, variables, None);
678
679        assert_eq!(body["query"], query);
680        assert_eq!(body["variables"]["id"], "123");
681    }
682
683    #[test]
684    fn test_build_graphql_body_with_operation_name() {
685        let query = "query GetUsers { users { id } }";
686        let op_name = Some("GetUsers");
687        let body = build_graphql_body(query, None, op_name);
688
689        assert_eq!(body["query"], query);
690        assert_eq!(body["operationName"], "GetUsers");
691    }
692
693    #[test]
694    fn test_build_graphql_body_all_fields() {
695        let query = "mutation CreateUser($name: String!) { createUser(name: $name) { id } }";
696        let variables = Some(serde_json::json!({ "name": "Alice" }));
697        let op_name = Some("CreateUser");
698        let body = build_graphql_body(query, variables, op_name);
699
700        assert_eq!(body["query"], query);
701        assert_eq!(body["variables"]["name"], "Alice");
702        assert_eq!(body["operationName"], "CreateUser");
703    }
704
705    #[tokio::test]
706    async fn graphql_subscription_returns_first_event_and_completes() {
707        let app = Router::new().route(
708            "/graphql",
709            get(|ws: WebSocketUpgrade| async move {
710                ws.on_upgrade(|mut socket| async move {
711                    while let Some(result) = socket.recv().await {
712                        let Ok(Message::Text(text)) = result else {
713                            continue;
714                        };
715                        let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
716                            continue;
717                        };
718
719                        match message.get("type").and_then(Value::as_str) {
720                            Some("connection_init") => {
721                                let _ = socket
722                                    .send(Message::Text(
723                                        serde_json::json!({"type":"connection_ack"}).to_string().into(),
724                                    ))
725                                    .await;
726                            }
727                            Some("subscribe") => {
728                                let id = message.get("id").and_then(Value::as_str).unwrap_or("1");
729                                let _ = socket
730                                    .send(Message::Text(
731                                        serde_json::json!({
732                                            "id": id,
733                                            "type": "next",
734                                            "payload": {"data": {"ticker": "AAPL"}},
735                                        })
736                                        .to_string()
737                                        .into(),
738                                    ))
739                                    .await;
740
741                                if let Some(Ok(Message::Text(complete_text))) = socket.recv().await {
742                                    let Ok(complete_message): Result<Value, _> = serde_json::from_str(&complete_text)
743                                    else {
744                                        break;
745                                    };
746                                    if complete_message.get("type").and_then(Value::as_str) == Some("complete") {
747                                        let _ = socket
748                                            .send(Message::Text(
749                                                serde_json::json!({"id": id, "type":"complete"}).to_string().into(),
750                                            ))
751                                            .await;
752                                    }
753                                }
754                                break;
755                            }
756                            _ => {}
757                        }
758                    }
759                })
760            }),
761        );
762
763        let client = TestClient::from_router(app).expect("client");
764        assert!(client.http_server.lock().expect("lock").is_none());
765
766        let snapshot = client
767            .graphql_subscription("subscription { ticker }", None, None)
768            .await
769            .expect("subscription snapshot");
770
771        assert!(snapshot.acknowledged);
772        assert_eq!(snapshot.errors, Vec::<Value>::new());
773        assert_eq!(snapshot.event, Some(serde_json::json!({"data": {"ticker": "AAPL"}})));
774        assert!(snapshot.complete_received);
775        assert!(client.http_server.lock().expect("lock").is_some());
776    }
777
778    #[tokio::test]
779    async fn graphql_subscription_surfaces_connection_error() {
780        let app = Router::new().route(
781            "/graphql",
782            get(|ws: WebSocketUpgrade| async move {
783                ws.on_upgrade(|mut socket| async move {
784                    while let Some(result) = socket.recv().await {
785                        let Ok(Message::Text(text)) = result else {
786                            continue;
787                        };
788                        let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
789                            continue;
790                        };
791
792                        if message.get("type").and_then(Value::as_str) == Some("connection_init") {
793                            let _ = socket
794                                .send(Message::Text(
795                                    serde_json::json!({
796                                        "type": "connection_error",
797                                        "payload": {"message": "not authorized"},
798                                    })
799                                    .to_string()
800                                    .into(),
801                                ))
802                                .await;
803                            break;
804                        }
805                    }
806                })
807            }),
808        );
809
810        let client = TestClient::from_router(app).expect("client");
811        assert!(client.http_server.lock().expect("lock").is_none());
812
813        let error = client
814            .graphql_subscription("subscription { privateFeed }", None, None)
815            .await
816            .expect_err("expected connection error");
817
818        assert!(error.to_string().contains("connection_error"));
819        assert!(client.http_server.lock().expect("lock").is_some());
820    }
821
822    #[tokio::test]
823    async fn http_requests_do_not_initialize_socket_transport() {
824        let app = Router::new().route("/health", get(|| async { "ok" }));
825
826        let client = TestClient::from_router(app).expect("client");
827        assert!(client.http_server.lock().expect("lock").is_none());
828
829        let snapshot = client.get("/health", None, None).await.expect("response snapshot");
830
831        assert_eq!(snapshot.status, 200);
832        assert_eq!(snapshot.text().expect("body"), "ok");
833        assert!(client.http_server.lock().expect("lock").is_none());
834    }
835}