1use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::time::timeout;
18use urlencoding::encode;
19
20type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
21const GRAPHQL_WS_MESSAGE_TIMEOUT: Duration = Duration::from_secs(2);
22const GRAPHQL_WS_MAX_CONTROL_MESSAGES: usize = 32;
23
24#[derive(Debug, Clone, PartialEq)]
26pub struct GraphQLSubscriptionSnapshot {
27 pub operation_id: String,
29 pub acknowledged: bool,
31 pub event: Option<Value>,
33 pub errors: Vec<Value>,
35 pub complete_received: bool,
37}
38
39pub struct TestClient {
46 server: Arc<TestServer>,
47}
48
49impl TestClient {
50 pub fn from_router(router: axum::Router) -> Result<Self, String> {
52 let server = if tokio::runtime::Handle::try_current().is_ok() {
53 TestServer::builder()
54 .http_transport()
55 .try_build(router)
56 .map_err(|e| format!("Failed to create test server: {}", e))?
57 } else {
58 TestServer::try_new(router).map_err(|e| format!("Failed to create test server: {}", e))?
59 };
60
61 Ok(Self {
62 server: Arc::new(server),
63 })
64 }
65
66 pub fn server(&self) -> &TestServer {
68 &self.server
69 }
70
71 pub async fn get(
73 &self,
74 path: &str,
75 query_params: Option<Vec<(String, String)>>,
76 headers: Option<Vec<(String, String)>>,
77 ) -> Result<ResponseSnapshot, SnapshotError> {
78 let full_path = build_full_path(path, query_params.as_deref());
79 let mut request = self.server.get(&full_path);
80
81 if let Some(headers_vec) = headers {
82 request = self.add_headers(request, headers_vec)?;
83 }
84
85 let response = request.await;
86 snapshot_response(response).await
87 }
88
89 pub async fn post(
91 &self,
92 path: &str,
93 json: Option<Value>,
94 form_data: Option<Vec<(String, String)>>,
95 multipart: MultipartPayload,
96 query_params: Option<Vec<(String, String)>>,
97 headers: Option<Vec<(String, String)>>,
98 ) -> Result<ResponseSnapshot, SnapshotError> {
99 let full_path = build_full_path(path, query_params.as_deref());
100 let mut request = self.server.post(&full_path);
101
102 if let Some(headers_vec) = headers {
103 request = self.add_headers(request, headers_vec)?;
104 }
105
106 if let Some((form_fields, files)) = multipart {
107 let (body, boundary) = super::build_multipart_body(&form_fields, &files);
108 let content_type = format!("multipart/form-data; boundary={}", boundary);
109 request = request.add_header("content-type", &content_type);
110 request = request.bytes(Bytes::from(body));
111 } else if let Some(form_fields) = form_data {
112 let fields_value = serde_json::to_value(&form_fields)
113 .map_err(|e| SnapshotError::Decompression(format!("Failed to serialize form fields: {}", e)))?;
114 let encoded = super::encode_urlencoded_body(&fields_value)
115 .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
116 request = request.add_header("content-type", "application/x-www-form-urlencoded");
117 request = request.bytes(Bytes::from(encoded));
118 } else if let Some(json_value) = json {
119 request = request.json(&json_value);
120 }
121
122 let response = request.await;
123 snapshot_response(response).await
124 }
125
126 pub async fn request_raw(
128 &self,
129 method: Method,
130 path: &str,
131 body: Bytes,
132 query_params: Option<Vec<(String, String)>>,
133 headers: Option<Vec<(String, String)>>,
134 ) -> Result<ResponseSnapshot, SnapshotError> {
135 let full_path = build_full_path(path, query_params.as_deref());
136 let mut request = self.server.method(method, &full_path);
137
138 if let Some(headers_vec) = headers {
139 request = self.add_headers(request, headers_vec)?;
140 }
141
142 request = request.bytes(body);
143 let response = request.await;
144 snapshot_response(response).await
145 }
146
147 pub async fn put(
149 &self,
150 path: &str,
151 json: Option<Value>,
152 query_params: Option<Vec<(String, String)>>,
153 headers: Option<Vec<(String, String)>>,
154 ) -> Result<ResponseSnapshot, SnapshotError> {
155 let full_path = build_full_path(path, query_params.as_deref());
156 let mut request = self.server.put(&full_path);
157
158 if let Some(headers_vec) = headers {
159 request = self.add_headers(request, headers_vec)?;
160 }
161
162 if let Some(json_value) = json {
163 request = request.json(&json_value);
164 }
165
166 let response = request.await;
167 snapshot_response(response).await
168 }
169
170 pub async fn patch(
172 &self,
173 path: &str,
174 json: Option<Value>,
175 query_params: Option<Vec<(String, String)>>,
176 headers: Option<Vec<(String, String)>>,
177 ) -> Result<ResponseSnapshot, SnapshotError> {
178 let full_path = build_full_path(path, query_params.as_deref());
179 let mut request = self.server.patch(&full_path);
180
181 if let Some(headers_vec) = headers {
182 request = self.add_headers(request, headers_vec)?;
183 }
184
185 if let Some(json_value) = json {
186 request = request.json(&json_value);
187 }
188
189 let response = request.await;
190 snapshot_response(response).await
191 }
192
193 pub async fn delete(
195 &self,
196 path: &str,
197 query_params: Option<Vec<(String, String)>>,
198 headers: Option<Vec<(String, String)>>,
199 ) -> Result<ResponseSnapshot, SnapshotError> {
200 let full_path = build_full_path(path, query_params.as_deref());
201 let mut request = self.server.delete(&full_path);
202
203 if let Some(headers_vec) = headers {
204 request = self.add_headers(request, headers_vec)?;
205 }
206
207 let response = request.await;
208 snapshot_response(response).await
209 }
210
211 pub async fn options(
213 &self,
214 path: &str,
215 query_params: Option<Vec<(String, String)>>,
216 headers: Option<Vec<(String, String)>>,
217 ) -> Result<ResponseSnapshot, SnapshotError> {
218 let full_path = build_full_path(path, query_params.as_deref());
219 let mut request = self.server.method(Method::OPTIONS, &full_path);
220
221 if let Some(headers_vec) = headers {
222 request = self.add_headers(request, headers_vec)?;
223 }
224
225 let response = request.await;
226 snapshot_response(response).await
227 }
228
229 pub async fn head(
231 &self,
232 path: &str,
233 query_params: Option<Vec<(String, String)>>,
234 headers: Option<Vec<(String, String)>>,
235 ) -> Result<ResponseSnapshot, SnapshotError> {
236 let full_path = build_full_path(path, query_params.as_deref());
237 let mut request = self.server.method(Method::HEAD, &full_path);
238
239 if let Some(headers_vec) = headers {
240 request = self.add_headers(request, headers_vec)?;
241 }
242
243 let response = request.await;
244 snapshot_response(response).await
245 }
246
247 pub async fn trace(
249 &self,
250 path: &str,
251 query_params: Option<Vec<(String, String)>>,
252 headers: Option<Vec<(String, String)>>,
253 ) -> Result<ResponseSnapshot, SnapshotError> {
254 let full_path = build_full_path(path, query_params.as_deref());
255 let mut request = self.server.method(Method::TRACE, &full_path);
256
257 if let Some(headers_vec) = headers {
258 request = self.add_headers(request, headers_vec)?;
259 }
260
261 let response = request.await;
262 snapshot_response(response).await
263 }
264
265 pub async fn graphql_at(
267 &self,
268 endpoint: &str,
269 query: &str,
270 variables: Option<Value>,
271 operation_name: Option<&str>,
272 ) -> Result<ResponseSnapshot, SnapshotError> {
273 let body = build_graphql_body(query, variables, operation_name);
274 self.post(endpoint, Some(body), None, None, None, None).await
275 }
276
277 pub async fn graphql(
279 &self,
280 query: &str,
281 variables: Option<Value>,
282 operation_name: Option<&str>,
283 ) -> Result<ResponseSnapshot, SnapshotError> {
284 self.graphql_at("/graphql", query, variables, operation_name).await
285 }
286
287 pub async fn graphql_with_status(
303 &self,
304 query: &str,
305 variables: Option<Value>,
306 operation_name: Option<&str>,
307 ) -> Result<(u16, ResponseSnapshot), SnapshotError> {
308 let snapshot = self.graphql(query, variables, operation_name).await?;
309 let status = snapshot.status;
310 Ok((status, snapshot))
311 }
312
313 pub async fn graphql_subscription_at(
318 &self,
319 endpoint: &str,
320 query: &str,
321 variables: Option<Value>,
322 operation_name: Option<&str>,
323 ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
324 let operation_id = "spikard-subscription-1".to_string();
325 let upgrade = self
326 .server
327 .get_websocket(endpoint)
328 .add_header("sec-websocket-protocol", "graphql-transport-ws")
329 .await;
330
331 if upgrade.status_code().as_u16() != 101 {
332 return Err(SnapshotError::Decompression(format!(
333 "GraphQL subscription upgrade failed with status {}",
334 upgrade.status_code()
335 )));
336 }
337
338 let mut websocket = super::WebSocketConnection::new(upgrade.into_websocket().await);
339
340 websocket
341 .send_json(&serde_json::json!({"type": "connection_init"}))
342 .await;
343 wait_for_graphql_ack(&mut websocket).await?;
344
345 websocket
346 .send_json(&serde_json::json!({
347 "id": operation_id,
348 "type": "subscribe",
349 "payload": build_graphql_body(query, variables, operation_name),
350 }))
351 .await;
352
353 let mut event = None;
354 let mut errors = Vec::new();
355 let mut complete_received = false;
356
357 for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
358 let message = timeout(
359 GRAPHQL_WS_MESSAGE_TIMEOUT,
360 receive_graphql_protocol_message(&mut websocket),
361 )
362 .await
363 .map_err(|_| {
364 SnapshotError::Decompression("Timed out waiting for GraphQL subscription message".to_string())
365 })??;
366
367 let message_type = message.get("type").and_then(Value::as_str).unwrap_or_default();
368 match message_type {
369 "next" => {
370 if message
371 .get("id")
372 .and_then(Value::as_str)
373 .is_none_or(|id| id == operation_id)
374 {
375 event = message.get("payload").cloned();
376
377 websocket
378 .send_json(&serde_json::json!({
379 "id": operation_id,
380 "type": "complete",
381 }))
382 .await;
383
384 if let Ok(next_message) = timeout(
385 GRAPHQL_WS_MESSAGE_TIMEOUT,
386 receive_graphql_protocol_message(&mut websocket),
387 )
388 .await
389 && let Ok(next_message) = next_message
390 && next_message.get("type").and_then(Value::as_str) == Some("complete")
391 && next_message
392 .get("id")
393 .and_then(Value::as_str)
394 .is_none_or(|id| id == operation_id)
395 {
396 complete_received = true;
397 }
398 break;
399 }
400 }
401 "error" => {
402 errors.push(message.get("payload").cloned().unwrap_or(message));
403 break;
404 }
405 "complete" => {
406 if message
407 .get("id")
408 .and_then(Value::as_str)
409 .is_none_or(|id| id == operation_id)
410 {
411 complete_received = true;
412 break;
413 }
414 }
415 "ping" => {
416 let mut pong = serde_json::json!({"type": "pong"});
417 if let Some(payload) = message.get("payload") {
418 pong["payload"] = payload.clone();
419 }
420 websocket.send_json(&pong).await;
421 }
422 "pong" => {}
423 _ => {}
424 }
425 }
426
427 websocket.close().await;
428
429 if event.is_none() && errors.is_empty() && !complete_received {
430 return Err(SnapshotError::Decompression(
431 "No GraphQL subscription event received before timeout".to_string(),
432 ));
433 }
434
435 Ok(GraphQLSubscriptionSnapshot {
436 operation_id,
437 acknowledged: true,
438 event,
439 errors,
440 complete_received,
441 })
442 }
443
444 pub async fn graphql_subscription(
448 &self,
449 query: &str,
450 variables: Option<Value>,
451 operation_name: Option<&str>,
452 ) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
453 self.graphql_subscription_at("/graphql", query, variables, operation_name)
454 .await
455 }
456
457 fn add_headers(
459 &self,
460 mut request: axum_test::TestRequest,
461 headers: Vec<(String, String)>,
462 ) -> Result<axum_test::TestRequest, SnapshotError> {
463 for (key, value) in headers {
464 let header_name = HeaderName::from_bytes(key.as_bytes())
465 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
466 let header_value = HeaderValue::from_str(&value)
467 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
468 request = request.add_header(header_name, header_value);
469 }
470 Ok(request)
471 }
472}
473
474async fn wait_for_graphql_ack(websocket: &mut super::WebSocketConnection) -> Result<(), SnapshotError> {
475 for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
476 let message = timeout(GRAPHQL_WS_MESSAGE_TIMEOUT, receive_graphql_protocol_message(websocket))
477 .await
478 .map_err(|_| SnapshotError::Decompression("Timed out waiting for GraphQL connection_ack".to_string()))??;
479
480 match message.get("type").and_then(Value::as_str).unwrap_or_default() {
481 "connection_ack" => return Ok(()),
482 "ping" => {
483 let mut pong = serde_json::json!({"type": "pong"});
484 if let Some(payload) = message.get("payload") {
485 pong["payload"] = payload.clone();
486 }
487 websocket.send_json(&pong).await;
488 }
489 "connection_error" | "error" => {
490 return Err(SnapshotError::Decompression(format!(
491 "GraphQL subscription rejected during init: {}",
492 message
493 )));
494 }
495 _ => {}
496 }
497 }
498
499 Err(SnapshotError::Decompression(
500 "No GraphQL connection_ack received".to_string(),
501 ))
502}
503
504async fn receive_graphql_protocol_message(websocket: &mut super::WebSocketConnection) -> Result<Value, SnapshotError> {
505 loop {
506 match websocket.receive_message().await {
507 super::WebSocketMessage::Text(text) => {
508 return serde_json::from_str::<Value>(&text).map_err(|e| {
509 SnapshotError::Decompression(format!("Failed to parse GraphQL WebSocket message as JSON: {}", e))
510 });
511 }
512 super::WebSocketMessage::Binary(bytes) => {
513 return serde_json::from_slice::<Value>(&bytes).map_err(|e| {
514 SnapshotError::Decompression(format!(
515 "Failed to parse GraphQL binary WebSocket message as JSON: {}",
516 e
517 ))
518 });
519 }
520 super::WebSocketMessage::Ping(_) | super::WebSocketMessage::Pong(_) => continue,
521 super::WebSocketMessage::Close(reason) => {
522 return Err(SnapshotError::Decompression(format!(
523 "GraphQL WebSocket connection closed before response: {:?}",
524 reason
525 )));
526 }
527 }
528 }
529}
530
531pub fn build_graphql_body(query: &str, variables: Option<Value>, operation_name: Option<&str>) -> Value {
533 let mut body = serde_json::json!({ "query": query });
534 if let Some(vars) = variables {
535 body["variables"] = vars;
536 }
537 if let Some(op_name) = operation_name {
538 body["operationName"] = Value::String(op_name.to_string());
539 }
540 body
541}
542
543fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
545 match query_params {
546 None | Some(&[]) => path.to_string(),
547 Some(params) => {
548 let query_string: Vec<String> = params
549 .iter()
550 .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
551 .collect();
552
553 if path.contains('?') {
554 format!("{}&{}", path, query_string.join("&"))
555 } else {
556 format!("{}?{}", path, query_string.join("&"))
557 }
558 }
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use axum::{
566 Router,
567 extract::ws::{Message, WebSocketUpgrade},
568 routing::get,
569 };
570
571 #[test]
572 fn build_full_path_no_params() {
573 let path = "/users";
574 assert_eq!(build_full_path(path, None), "/users");
575 assert_eq!(build_full_path(path, Some(&[])), "/users");
576 }
577
578 #[test]
579 fn build_full_path_with_params() {
580 let path = "/users";
581 let params = vec![
582 ("id".to_string(), "123".to_string()),
583 ("name".to_string(), "test user".to_string()),
584 ];
585 let result = build_full_path(path, Some(¶ms));
586 assert!(result.starts_with("/users?"));
587 assert!(result.contains("id=123"));
588 assert!(result.contains("name=test%20user"));
589 }
590
591 #[test]
592 fn build_full_path_existing_query() {
593 let path = "/users?active=true";
594 let params = vec![("id".to_string(), "123".to_string())];
595 let result = build_full_path(path, Some(¶ms));
596 assert_eq!(result, "/users?active=true&id=123");
597 }
598
599 #[test]
600 fn test_graphql_query_builder() {
601 let query = "{ users { id name } }";
602 let variables = Some(serde_json::json!({ "limit": 10 }));
603 let op_name = Some("GetUsers");
604
605 let mut body = serde_json::json!({ "query": query });
606 if let Some(vars) = variables {
607 body["variables"] = vars;
608 }
609 if let Some(op_name) = op_name {
610 body["operationName"] = Value::String(op_name.to_string());
611 }
612
613 assert_eq!(body["query"], query);
614 assert_eq!(body["variables"]["limit"], 10);
615 assert_eq!(body["operationName"], "GetUsers");
616 }
617
618 #[test]
619 fn test_graphql_with_status_method() {
620 let query = "query { hello }";
621 let body = serde_json::json!({
622 "query": query,
623 "variables": null,
624 "operationName": null
625 });
626
627 let expected_fields = vec!["query", "variables", "operationName"];
630 for field in expected_fields {
631 assert!(body.get(field).is_some(), "Missing field: {}", field);
632 }
633 }
634
635 #[test]
636 fn test_build_graphql_body_basic() {
637 let query = "{ users { id name } }";
638 let body = build_graphql_body(query, None, None);
639
640 assert_eq!(body["query"], query);
641 assert!(body.get("variables").is_none() || body["variables"].is_null());
642 assert!(body.get("operationName").is_none() || body["operationName"].is_null());
643 }
644
645 #[test]
646 fn test_build_graphql_body_with_variables() {
647 let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
648 let variables = Some(serde_json::json!({ "id": "123" }));
649 let body = build_graphql_body(query, variables, None);
650
651 assert_eq!(body["query"], query);
652 assert_eq!(body["variables"]["id"], "123");
653 }
654
655 #[test]
656 fn test_build_graphql_body_with_operation_name() {
657 let query = "query GetUsers { users { id } }";
658 let op_name = Some("GetUsers");
659 let body = build_graphql_body(query, None, op_name);
660
661 assert_eq!(body["query"], query);
662 assert_eq!(body["operationName"], "GetUsers");
663 }
664
665 #[test]
666 fn test_build_graphql_body_all_fields() {
667 let query = "mutation CreateUser($name: String!) { createUser(name: $name) { id } }";
668 let variables = Some(serde_json::json!({ "name": "Alice" }));
669 let op_name = Some("CreateUser");
670 let body = build_graphql_body(query, variables, op_name);
671
672 assert_eq!(body["query"], query);
673 assert_eq!(body["variables"]["name"], "Alice");
674 assert_eq!(body["operationName"], "CreateUser");
675 }
676
677 #[tokio::test]
678 async fn graphql_subscription_returns_first_event_and_completes() {
679 let app = Router::new().route(
680 "/graphql",
681 get(|ws: WebSocketUpgrade| async move {
682 ws.on_upgrade(|mut socket| async move {
683 while let Some(result) = socket.recv().await {
684 let Ok(Message::Text(text)) = result else {
685 continue;
686 };
687 let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
688 continue;
689 };
690
691 match message.get("type").and_then(Value::as_str) {
692 Some("connection_init") => {
693 let _ = socket
694 .send(Message::Text(
695 serde_json::json!({"type":"connection_ack"}).to_string().into(),
696 ))
697 .await;
698 }
699 Some("subscribe") => {
700 let id = message.get("id").and_then(Value::as_str).unwrap_or("1");
701 let _ = socket
702 .send(Message::Text(
703 serde_json::json!({
704 "id": id,
705 "type": "next",
706 "payload": {"data": {"ticker": "AAPL"}},
707 })
708 .to_string()
709 .into(),
710 ))
711 .await;
712
713 if let Some(Ok(Message::Text(complete_text))) = socket.recv().await {
714 let Ok(complete_message): Result<Value, _> = serde_json::from_str(&complete_text)
715 else {
716 break;
717 };
718 if complete_message.get("type").and_then(Value::as_str) == Some("complete") {
719 let _ = socket
720 .send(Message::Text(
721 serde_json::json!({"id": id, "type":"complete"}).to_string().into(),
722 ))
723 .await;
724 }
725 }
726 break;
727 }
728 _ => {}
729 }
730 }
731 })
732 }),
733 );
734
735 let client = TestClient::from_router(app).expect("client");
736 let snapshot = client
737 .graphql_subscription("subscription { ticker }", None, None)
738 .await
739 .expect("subscription snapshot");
740
741 assert!(snapshot.acknowledged);
742 assert_eq!(snapshot.errors, Vec::<Value>::new());
743 assert_eq!(snapshot.event, Some(serde_json::json!({"data": {"ticker": "AAPL"}})));
744 assert!(snapshot.complete_received);
745 }
746
747 #[tokio::test]
748 async fn graphql_subscription_surfaces_connection_error() {
749 let app = Router::new().route(
750 "/graphql",
751 get(|ws: WebSocketUpgrade| async move {
752 ws.on_upgrade(|mut socket| async move {
753 while let Some(result) = socket.recv().await {
754 let Ok(Message::Text(text)) = result else {
755 continue;
756 };
757 let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
758 continue;
759 };
760
761 if message.get("type").and_then(Value::as_str) == Some("connection_init") {
762 let _ = socket
763 .send(Message::Text(
764 serde_json::json!({
765 "type": "connection_error",
766 "payload": {"message": "not authorized"},
767 })
768 .to_string()
769 .into(),
770 ))
771 .await;
772 break;
773 }
774 }
775 })
776 }),
777 );
778
779 let client = TestClient::from_router(app).expect("client");
780 let error = client
781 .graphql_subscription("subscription { privateFeed }", None, None)
782 .await
783 .expect_err("expected connection error");
784
785 assert!(error.to_string().contains("connection_error"));
786 }
787}