Skip to main content

spikard_http/testing/
test_client.rs

1//! Core test client for Spikard applications
2//!
3//! This module provides a language-agnostic TestClient that can be wrapped by
4//! language bindings (PyO3, napi-rs, magnus) to provide Pythonic, JavaScripty, and
5//! Ruby-like APIs respectively.
6//!
7//! The core client handles all HTTP method dispatch, query params, header management,
8//! body encoding (JSON, form-data, multipart), and response snapshot capture.
9
10use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::{Arc, Mutex};
16use std::time::Duration;
17use tokio::time::timeout;
18use urlencoding::encode;
19
20type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
21const GRAPHQL_WS_MESSAGE_TIMEOUT: Duration = Duration::from_secs(2);
22const GRAPHQL_WS_MAX_CONTROL_MESSAGES: usize = 32;
23
24/// Snapshot of a GraphQL subscription exchange over WebSocket.
25#[derive(Debug, Clone, PartialEq)]
26pub struct GraphQLSubscriptionSnapshot {
27    /// Operation id used for the subscription request.
28    pub operation_id: String,
29    /// Whether the server acknowledged the GraphQL WebSocket connection.
30    pub acknowledged: bool,
31    /// First `next.payload` received for this subscription, if any.
32    pub event: Option<Value>,
33    /// GraphQL protocol errors emitted by the server.
34    pub errors: Vec<Value>,
35    /// Whether a `complete` frame was observed for this operation.
36    pub complete_received: bool,
37}
38
39/// Core test client for making HTTP requests to a Spikard application.
40///
41/// This struct wraps axum-test's TestServer and provides a language-agnostic
42/// interface for making HTTP requests, sending WebSocket connections, and
43/// handling Server-Sent Events. Language bindings wrap this to provide
44/// native API surfaces.
45pub struct TestClient {
46    mock_server: Arc<TestServer>,
47    router: axum::Router,
48    http_server: Mutex<Option<Arc<TestServer>>>,
49}
50
51impl TestClient {
52    /// Create a new test client from an Axum router
53    pub fn from_router(router: axum::Router) -> Result<Self, String> {
54        let mock_server =
55            TestServer::try_new(router.clone()).map_err(|e| format!("Failed to create test server: {}", e))?;
56
57        Ok(Self {
58            mock_server: Arc::new(mock_server),
59            router,
60            http_server: Mutex::new(None),
61        })
62    }
63
64    /// Get or initialize the underlying socket-backed test server for WebSocket traffic.
65    pub fn http_server(&self) -> Result<Arc<TestServer>, SnapshotError> {
66        let mut guard = self
67            .http_server
68            .lock()
69            .map_err(|_| SnapshotError::Decompression("Failed to lock HTTP test server state".to_string()))?;
70
71        if let Some(server) = guard.as_ref() {
72            return Ok(Arc::clone(server));
73        }
74
75        if tokio::runtime::Handle::try_current().is_err() {
76            return Err(SnapshotError::Decompression(
77                "WebSocket test transport requires an active Tokio runtime".to_string(),
78            ));
79        }
80
81        let server = Arc::new(
82            TestServer::builder()
83                .http_transport()
84                .try_build(self.router.clone())
85                .map_err(|e| SnapshotError::Decompression(format!("Failed to create test server: {}", e)))?,
86        );
87        *guard = Some(Arc::clone(&server));
88        Ok(server)
89    }
90
91    /// Make a GET request
92    pub async fn get(
93        &self,
94        path: &str,
95        query_params: Option<Vec<(String, String)>>,
96        headers: Option<Vec<(String, String)>>,
97    ) -> Result<ResponseSnapshot, SnapshotError> {
98        let full_path = build_full_path(path, query_params.as_deref());
99        let mut request = self.mock_server.get(&full_path);
100
101        if let Some(headers_vec) = headers {
102            request = self.add_headers(request, headers_vec)?;
103        }
104
105        let response = request.await;
106        snapshot_response(response).await
107    }
108
109    /// Make a POST request
110    pub async fn post(
111        &self,
112        path: &str,
113        json: Option<Value>,
114        form_data: Option<Vec<(String, String)>>,
115        multipart: MultipartPayload,
116        query_params: Option<Vec<(String, String)>>,
117        headers: Option<Vec<(String, String)>>,
118    ) -> Result<ResponseSnapshot, SnapshotError> {
119        let full_path = build_full_path(path, query_params.as_deref());
120        let mut request = self.mock_server.post(&full_path);
121
122        if let Some(headers_vec) = headers {
123            request = self.add_headers(request, headers_vec)?;
124        }
125
126        if let Some((form_fields, files)) = multipart {
127            let (body, boundary) = super::build_multipart_body(&form_fields, &files);
128            let content_type = format!("multipart/form-data; boundary={}", boundary);
129            request = request.add_header("content-type", &content_type);
130            request = request.bytes(Bytes::from(body));
131        } else if let Some(form_fields) = form_data {
132            let fields_value = serde_json::to_value(&form_fields)
133                .map_err(|e| SnapshotError::Decompression(format!("Failed to serialize form fields: {}", e)))?;
134            let encoded = super::encode_urlencoded_body(&fields_value)
135                .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
136            request = request.add_header("content-type", "application/x-www-form-urlencoded");
137            request = request.bytes(Bytes::from(encoded));
138        } else if let Some(json_value) = json {
139            request = request.json(&json_value);
140        }
141
142        let response = request.await;
143        snapshot_response(response).await
144    }
145
146    /// Make a request with a raw body payload.
147    pub async fn request_raw(
148        &self,
149        method: Method,
150        path: &str,
151        body: Bytes,
152        query_params: Option<Vec<(String, String)>>,
153        headers: Option<Vec<(String, String)>>,
154    ) -> Result<ResponseSnapshot, SnapshotError> {
155        let full_path = build_full_path(path, query_params.as_deref());
156        let mut request = self.mock_server.method(method, &full_path);
157
158        if let Some(headers_vec) = headers {
159            request = self.add_headers(request, headers_vec)?;
160        }
161
162        request = request.bytes(body);
163        let response = request.await;
164        snapshot_response(response).await
165    }
166
167    /// Make a PUT request
168    pub async fn put(
169        &self,
170        path: &str,
171        json: Option<Value>,
172        query_params: Option<Vec<(String, String)>>,
173        headers: Option<Vec<(String, String)>>,
174    ) -> Result<ResponseSnapshot, SnapshotError> {
175        let full_path = build_full_path(path, query_params.as_deref());
176        let mut request = self.mock_server.put(&full_path);
177
178        if let Some(headers_vec) = headers {
179            request = self.add_headers(request, headers_vec)?;
180        }
181
182        if let Some(json_value) = json {
183            request = request.json(&json_value);
184        }
185
186        let response = request.await;
187        snapshot_response(response).await
188    }
189
190    /// Make a PATCH request
191    pub async fn patch(
192        &self,
193        path: &str,
194        json: Option<Value>,
195        query_params: Option<Vec<(String, String)>>,
196        headers: Option<Vec<(String, String)>>,
197    ) -> Result<ResponseSnapshot, SnapshotError> {
198        let full_path = build_full_path(path, query_params.as_deref());
199        let mut request = self.mock_server.patch(&full_path);
200
201        if let Some(headers_vec) = headers {
202            request = self.add_headers(request, headers_vec)?;
203        }
204
205        if let Some(json_value) = json {
206            request = request.json(&json_value);
207        }
208
209        let response = request.await;
210        snapshot_response(response).await
211    }
212
213    /// Make a DELETE request
214    pub async fn delete(
215        &self,
216        path: &str,
217        query_params: Option<Vec<(String, String)>>,
218        headers: Option<Vec<(String, String)>>,
219    ) -> Result<ResponseSnapshot, SnapshotError> {
220        let full_path = build_full_path(path, query_params.as_deref());
221        let mut request = self.mock_server.delete(&full_path);
222
223        if let Some(headers_vec) = headers {
224            request = self.add_headers(request, headers_vec)?;
225        }
226
227        let response = request.await;
228        snapshot_response(response).await
229    }
230
231    /// Make an OPTIONS request
232    pub async fn options(
233        &self,
234        path: &str,
235        query_params: Option<Vec<(String, String)>>,
236        headers: Option<Vec<(String, String)>>,
237    ) -> Result<ResponseSnapshot, SnapshotError> {
238        let full_path = build_full_path(path, query_params.as_deref());
239        let mut request = self.mock_server.method(Method::OPTIONS, &full_path);
240
241        if let Some(headers_vec) = headers {
242            request = self.add_headers(request, headers_vec)?;
243        }
244
245        let response = request.await;
246        snapshot_response(response).await
247    }
248
249    /// Make a HEAD request
250    pub async fn head(
251        &self,
252        path: &str,
253        query_params: Option<Vec<(String, String)>>,
254        headers: Option<Vec<(String, String)>>,
255    ) -> Result<ResponseSnapshot, SnapshotError> {
256        let full_path = build_full_path(path, query_params.as_deref());
257        let mut request = self.mock_server.method(Method::HEAD, &full_path);
258
259        if let Some(headers_vec) = headers {
260            request = self.add_headers(request, headers_vec)?;
261        }
262
263        let response = request.await;
264        snapshot_response(response).await
265    }
266
267    /// Make a TRACE request
268    pub async fn trace(
269        &self,
270        path: &str,
271        query_params: Option<Vec<(String, String)>>,
272        headers: Option<Vec<(String, String)>>,
273    ) -> Result<ResponseSnapshot, SnapshotError> {
274        let full_path = build_full_path(path, query_params.as_deref());
275        let mut request = self.mock_server.method(Method::TRACE, &full_path);
276
277        if let Some(headers_vec) = headers {
278            request = self.add_headers(request, headers_vec)?;
279        }
280
281        let response = request.await;
282        snapshot_response(response).await
283    }
284
285    /// Send a GraphQL query/mutation to a custom endpoint
286    pub async fn graphql_at(
287        &self,
288        endpoint: &str,
289        query: &str,
290        variables: Option<Value>,
291        operation_name: Option<&str>,
292    ) -> Result<ResponseSnapshot, SnapshotError> {
293        let body = build_graphql_body(query, variables, operation_name);
294        self.post(endpoint, Some(body), None, None, None, None).await
295    }
296
297    /// Send a GraphQL query/mutation
298    pub async fn graphql(
299        &self,
300        query: &str,
301        variables: Option<Value>,
302        operation_name: Option<&str>,
303    ) -> Result<ResponseSnapshot, SnapshotError> {
304        self.graphql_at("/graphql", query, variables, operation_name).await
305    }
306
307    /// Send a GraphQL query and return HTTP status code separately
308    ///
309    /// This method allows tests to distinguish between:
310    /// - HTTP-level errors (400/422 for invalid requests)
311    /// - GraphQL-level errors (200 with errors in response body)
312    ///
313    /// # Example
314    /// ```ignore
315    /// let (status, snapshot) = client.graphql_with_status(
316    ///     "query { invalid syntax",
317    ///     None,
318    ///     None
319    /// ).await?;
320    /// assert_eq!(status, 400); // HTTP parse error
321    /// ```
322    pub async fn graphql_with_status(
323        &self,
324        query: &str,
325        variables: Option<Value>,
326        operation_name: Option<&str>,
327    ) -> Result<(u16, ResponseSnapshot), SnapshotError> {
328        let snapshot = self.graphql(query, variables, operation_name).await?;
329        let status = snapshot.status;
330        Ok((status, snapshot))
331    }
332
333    /// Send a GraphQL subscription (WebSocket) to a custom endpoint.
334    ///
335    /// Uses the `graphql-transport-ws` protocol and captures the first `next` payload.
336    /// After the first payload is received, this client sends `complete` to unsubscribe.
337    pub async fn graphql_subscription_at(
338        &self,
339        endpoint: &str,
340        query: &str,
341        variables: Option<Value>,
342        operation_name: Option<&str>,
343    ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
344        let operation_id = "spikard-subscription-1".to_string();
345        let http_server = self.http_server()?;
346        let upgrade = http_server
347            .get_websocket(endpoint)
348            .add_header("sec-websocket-protocol", "graphql-transport-ws")
349            .await;
350
351        if upgrade.status_code().as_u16() != 101 {
352            return Err(SnapshotError::Decompression(format!(
353                "GraphQL subscription upgrade failed with status {}",
354                upgrade.status_code()
355            )));
356        }
357
358        let mut websocket = super::WebSocketConnection::new(upgrade.into_websocket().await);
359
360        websocket
361            .send_json(&serde_json::json!({"type": "connection_init"}))
362            .await;
363        wait_for_graphql_ack(&mut websocket).await?;
364
365        websocket
366            .send_json(&serde_json::json!({
367                "id": operation_id,
368                "type": "subscribe",
369                "payload": build_graphql_body(query, variables, operation_name),
370            }))
371            .await;
372
373        let mut event = None;
374        let mut errors = Vec::new();
375        let mut complete_received = false;
376
377        for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
378            let message = timeout(
379                GRAPHQL_WS_MESSAGE_TIMEOUT,
380                receive_graphql_protocol_message(&mut websocket),
381            )
382            .await
383            .map_err(|_| {
384                SnapshotError::Decompression("Timed out waiting for GraphQL subscription message".to_string())
385            })??;
386
387            let message_type = message.get("type").and_then(Value::as_str).unwrap_or_default();
388            match message_type {
389                "next"
390                    if message
391                        .get("id")
392                        .and_then(Value::as_str)
393                        .is_none_or(|id| id == operation_id) =>
394                {
395                    event = message.get("payload").cloned();
396
397                    websocket
398                        .send_json(&serde_json::json!({
399                            "id": operation_id,
400                            "type": "complete",
401                        }))
402                        .await;
403
404                    if let Ok(next_message) = timeout(
405                        GRAPHQL_WS_MESSAGE_TIMEOUT,
406                        receive_graphql_protocol_message(&mut websocket),
407                    )
408                    .await
409                        && let Ok(next_message) = next_message
410                        && next_message.get("type").and_then(Value::as_str) == Some("complete")
411                        && next_message
412                            .get("id")
413                            .and_then(Value::as_str)
414                            .is_none_or(|id| id == operation_id)
415                    {
416                        complete_received = true;
417                    }
418                    break;
419                }
420                "error" => {
421                    errors.push(message.get("payload").cloned().unwrap_or(message));
422                    break;
423                }
424                "complete"
425                    if message
426                        .get("id")
427                        .and_then(Value::as_str)
428                        .is_none_or(|id| id == operation_id) =>
429                {
430                    complete_received = true;
431                    break;
432                }
433                "ping" => {
434                    let mut pong = serde_json::json!({"type": "pong"});
435                    if let Some(payload) = message.get("payload") {
436                        pong["payload"] = payload.clone();
437                    }
438                    websocket.send_json(&pong).await;
439                }
440                "pong" => {}
441                _ => {}
442            }
443        }
444
445        websocket.close().await;
446
447        if event.is_none() && errors.is_empty() && !complete_received {
448            return Err(SnapshotError::Decompression(
449                "No GraphQL subscription event received before timeout".to_string(),
450            ));
451        }
452
453        Ok(GraphQLSubscriptionSnapshot {
454            operation_id,
455            acknowledged: true,
456            event,
457            errors,
458            complete_received,
459        })
460    }
461
462    /// Send a GraphQL subscription (WebSocket).
463    ///
464    /// Uses `/graphql` as the default subscription endpoint.
465    pub async fn graphql_subscription(
466        &self,
467        query: &str,
468        variables: Option<Value>,
469        operation_name: Option<&str>,
470    ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
471        self.graphql_subscription_at("/graphql", query, variables, operation_name)
472            .await
473    }
474
475    /// Add headers to a test request builder
476    fn add_headers(
477        &self,
478        mut request: axum_test::TestRequest,
479        headers: Vec<(String, String)>,
480    ) -> Result<axum_test::TestRequest, SnapshotError> {
481        for (key, value) in headers {
482            let header_name = HeaderName::from_bytes(key.as_bytes())
483                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
484            let header_value = HeaderValue::from_str(&value)
485                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
486            request = request.add_header(header_name, header_value);
487        }
488        Ok(request)
489    }
490}
491
492async fn wait_for_graphql_ack(websocket: &mut super::WebSocketConnection) -> Result<(), SnapshotError> {
493    for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
494        let message = timeout(GRAPHQL_WS_MESSAGE_TIMEOUT, receive_graphql_protocol_message(websocket))
495            .await
496            .map_err(|_| SnapshotError::Decompression("Timed out waiting for GraphQL connection_ack".to_string()))??;
497
498        match message.get("type").and_then(Value::as_str).unwrap_or_default() {
499            "connection_ack" => return Ok(()),
500            "ping" => {
501                let mut pong = serde_json::json!({"type": "pong"});
502                if let Some(payload) = message.get("payload") {
503                    pong["payload"] = payload.clone();
504                }
505                websocket.send_json(&pong).await;
506            }
507            "connection_error" | "error" => {
508                return Err(SnapshotError::Decompression(format!(
509                    "GraphQL subscription rejected during init: {}",
510                    message
511                )));
512            }
513            _ => {}
514        }
515    }
516
517    Err(SnapshotError::Decompression(
518        "No GraphQL connection_ack received".to_string(),
519    ))
520}
521
522async fn receive_graphql_protocol_message(websocket: &mut super::WebSocketConnection) -> Result<Value, SnapshotError> {
523    loop {
524        match websocket.receive_message().await {
525            super::WebSocketMessage::Text(text) => {
526                return serde_json::from_str::<Value>(&text).map_err(|e| {
527                    SnapshotError::Decompression(format!("Failed to parse GraphQL WebSocket message as JSON: {}", e))
528                });
529            }
530            super::WebSocketMessage::Binary(bytes) => {
531                return serde_json::from_slice::<Value>(&bytes).map_err(|e| {
532                    SnapshotError::Decompression(format!(
533                        "Failed to parse GraphQL binary WebSocket message as JSON: {}",
534                        e
535                    ))
536                });
537            }
538            super::WebSocketMessage::Ping(_) | super::WebSocketMessage::Pong(_) => continue,
539            super::WebSocketMessage::Close(reason) => {
540                return Err(SnapshotError::Decompression(format!(
541                    "GraphQL WebSocket connection closed before response: {:?}",
542                    reason
543                )));
544            }
545        }
546    }
547}
548
549/// Build a GraphQL request body from query, variables, and operation name
550pub fn build_graphql_body(query: &str, variables: Option<Value>, operation_name: Option<&str>) -> Value {
551    let mut body = serde_json::json!({ "query": query });
552    if let Some(vars) = variables {
553        body["variables"] = vars;
554    }
555    if let Some(op_name) = operation_name {
556        body["operationName"] = Value::String(op_name.to_string());
557    }
558    body
559}
560
561/// Build a full path with query parameters
562fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
563    match query_params {
564        None | Some(&[]) => path.to_string(),
565        Some(params) => {
566            let query_string: Vec<String> = params
567                .iter()
568                .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
569                .collect();
570
571            if path.contains('?') {
572                format!("{}&{}", path, query_string.join("&"))
573            } else {
574                format!("{}?{}", path, query_string.join("&"))
575            }
576        }
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use axum::{
584        Router,
585        extract::ws::{Message, WebSocketUpgrade},
586        routing::get,
587    };
588
589    #[test]
590    fn build_full_path_no_params() {
591        let path = "/users";
592        assert_eq!(build_full_path(path, None), "/users");
593        assert_eq!(build_full_path(path, Some(&[])), "/users");
594    }
595
596    #[test]
597    fn build_full_path_with_params() {
598        let path = "/users";
599        let params = vec![
600            ("id".to_string(), "123".to_string()),
601            ("name".to_string(), "test user".to_string()),
602        ];
603        let result = build_full_path(path, Some(&params));
604        assert!(result.starts_with("/users?"));
605        assert!(result.contains("id=123"));
606        assert!(result.contains("name=test%20user"));
607    }
608
609    #[test]
610    fn build_full_path_existing_query() {
611        let path = "/users?active=true";
612        let params = vec![("id".to_string(), "123".to_string())];
613        let result = build_full_path(path, Some(&params));
614        assert_eq!(result, "/users?active=true&id=123");
615    }
616
617    #[test]
618    fn test_graphql_query_builder() {
619        let query = "{ users { id name } }";
620        let variables = Some(serde_json::json!({ "limit": 10 }));
621        let op_name = Some("GetUsers");
622
623        let mut body = serde_json::json!({ "query": query });
624        if let Some(vars) = variables {
625            body["variables"] = vars;
626        }
627        if let Some(op_name) = op_name {
628            body["operationName"] = Value::String(op_name.to_string());
629        }
630
631        assert_eq!(body["query"], query);
632        assert_eq!(body["variables"]["limit"], 10);
633        assert_eq!(body["operationName"], "GetUsers");
634    }
635
636    #[test]
637    fn test_graphql_with_status_method() {
638        let query = "query { hello }";
639        let body = serde_json::json!({
640            "query": query,
641            "variables": null,
642            "operationName": null
643        });
644
645        // This test validates the method signature and return type
646        // Actual HTTP status testing will happen in integration tests
647        let expected_fields = vec!["query", "variables", "operationName"];
648        for field in expected_fields {
649            assert!(body.get(field).is_some(), "Missing field: {}", field);
650        }
651    }
652
653    #[test]
654    fn test_build_graphql_body_basic() {
655        let query = "{ users { id name } }";
656        let body = build_graphql_body(query, None, None);
657
658        assert_eq!(body["query"], query);
659        assert!(body.get("variables").is_none() || body["variables"].is_null());
660        assert!(body.get("operationName").is_none() || body["operationName"].is_null());
661    }
662
663    #[test]
664    fn test_build_graphql_body_with_variables() {
665        let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
666        let variables = Some(serde_json::json!({ "id": "123" }));
667        let body = build_graphql_body(query, variables, None);
668
669        assert_eq!(body["query"], query);
670        assert_eq!(body["variables"]["id"], "123");
671    }
672
673    #[test]
674    fn test_build_graphql_body_with_operation_name() {
675        let query = "query GetUsers { users { id } }";
676        let op_name = Some("GetUsers");
677        let body = build_graphql_body(query, None, op_name);
678
679        assert_eq!(body["query"], query);
680        assert_eq!(body["operationName"], "GetUsers");
681    }
682
683    #[test]
684    fn test_build_graphql_body_all_fields() {
685        let query = "mutation CreateUser($name: String!) { createUser(name: $name) { id } }";
686        let variables = Some(serde_json::json!({ "name": "Alice" }));
687        let op_name = Some("CreateUser");
688        let body = build_graphql_body(query, variables, op_name);
689
690        assert_eq!(body["query"], query);
691        assert_eq!(body["variables"]["name"], "Alice");
692        assert_eq!(body["operationName"], "CreateUser");
693    }
694
695    #[tokio::test]
696    async fn graphql_subscription_returns_first_event_and_completes() {
697        let app = Router::new().route(
698            "/graphql",
699            get(|ws: WebSocketUpgrade| async move {
700                ws.on_upgrade(|mut socket| async move {
701                    while let Some(result) = socket.recv().await {
702                        let Ok(Message::Text(text)) = result else {
703                            continue;
704                        };
705                        let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
706                            continue;
707                        };
708
709                        match message.get("type").and_then(Value::as_str) {
710                            Some("connection_init") => {
711                                let _ = socket
712                                    .send(Message::Text(
713                                        serde_json::json!({"type":"connection_ack"}).to_string().into(),
714                                    ))
715                                    .await;
716                            }
717                            Some("subscribe") => {
718                                let id = message.get("id").and_then(Value::as_str).unwrap_or("1");
719                                let _ = socket
720                                    .send(Message::Text(
721                                        serde_json::json!({
722                                            "id": id,
723                                            "type": "next",
724                                            "payload": {"data": {"ticker": "AAPL"}},
725                                        })
726                                        .to_string()
727                                        .into(),
728                                    ))
729                                    .await;
730
731                                if let Some(Ok(Message::Text(complete_text))) = socket.recv().await {
732                                    let Ok(complete_message): Result<Value, _> = serde_json::from_str(&complete_text)
733                                    else {
734                                        break;
735                                    };
736                                    if complete_message.get("type").and_then(Value::as_str) == Some("complete") {
737                                        let _ = socket
738                                            .send(Message::Text(
739                                                serde_json::json!({"id": id, "type":"complete"}).to_string().into(),
740                                            ))
741                                            .await;
742                                    }
743                                }
744                                break;
745                            }
746                            _ => {}
747                        }
748                    }
749                })
750            }),
751        );
752
753        let client = TestClient::from_router(app).expect("client");
754        assert!(client.http_server.lock().expect("lock").is_none());
755
756        let snapshot = client
757            .graphql_subscription("subscription { ticker }", None, None)
758            .await
759            .expect("subscription snapshot");
760
761        assert!(snapshot.acknowledged);
762        assert_eq!(snapshot.errors, Vec::<Value>::new());
763        assert_eq!(snapshot.event, Some(serde_json::json!({"data": {"ticker": "AAPL"}})));
764        assert!(snapshot.complete_received);
765        assert!(client.http_server.lock().expect("lock").is_some());
766    }
767
768    #[tokio::test]
769    async fn graphql_subscription_surfaces_connection_error() {
770        let app = Router::new().route(
771            "/graphql",
772            get(|ws: WebSocketUpgrade| async move {
773                ws.on_upgrade(|mut socket| async move {
774                    while let Some(result) = socket.recv().await {
775                        let Ok(Message::Text(text)) = result else {
776                            continue;
777                        };
778                        let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
779                            continue;
780                        };
781
782                        if message.get("type").and_then(Value::as_str) == Some("connection_init") {
783                            let _ = socket
784                                .send(Message::Text(
785                                    serde_json::json!({
786                                        "type": "connection_error",
787                                        "payload": {"message": "not authorized"},
788                                    })
789                                    .to_string()
790                                    .into(),
791                                ))
792                                .await;
793                            break;
794                        }
795                    }
796                })
797            }),
798        );
799
800        let client = TestClient::from_router(app).expect("client");
801        assert!(client.http_server.lock().expect("lock").is_none());
802
803        let error = client
804            .graphql_subscription("subscription { privateFeed }", None, None)
805            .await
806            .expect_err("expected connection error");
807
808        assert!(error.to_string().contains("connection_error"));
809        assert!(client.http_server.lock().expect("lock").is_some());
810    }
811
812    #[tokio::test]
813    async fn http_requests_do_not_initialize_socket_transport() {
814        let app = Router::new().route("/health", get(|| async { "ok" }));
815
816        let client = TestClient::from_router(app).expect("client");
817        assert!(client.http_server.lock().expect("lock").is_none());
818
819        let snapshot = client.get("/health", None, None).await.expect("response snapshot");
820
821        assert_eq!(snapshot.status, 200);
822        assert_eq!(snapshot.text().expect("body"), "ok");
823        assert!(client.http_server.lock().expect("lock").is_none());
824    }
825}