spark_graphql_client/lib.rs
1#![doc = include_str!("../README.md")]
2
3use std::collections::HashMap;
4
5use code_status_macros::{needs, untested};
6use eyre::eyre;
7use graphql_utils::compress_payload_zstd;
8use reqwest::{Client, RequestBuilder, Response, StatusCode};
9use serde::de::DeserializeOwned;
10use serde_json::{Value, from_slice, from_value};
11use url::Url;
12use zstd::decode_all;
13
14mod graphql_utils;
15
16/// Specifies the type of payload compression to use for outgoing GraphQL requests.
17#[derive(Clone, Debug, Default, PartialEq, Eq)]
18pub enum GraphQLPayloadCompression {
19 /// Compress the request payload using the Zstandard (zstd) algorithm.
20 /// This is the default if no compression type is specified.
21 #[default]
22 Zstd,
23
24 /// Send the request payload without any compression.
25 NeverCompress,
26}
27
28impl GraphQLPayloadCompression {
29 /// Returns the corresponding `Accept-Encoding` header value for the compression type.
30 fn apply_encoding_header(&self) -> &str {
31 match self {
32 Self::Zstd => "zstd",
33 Self::NeverCompress => "",
34 }
35 }
36}
37
38/// A client for executing GraphQL operations (queries and mutations).
39///
40/// This client handles constructing the HTTP request, sending it to the specified
41/// GraphQL endpoint, and processing the response.
42#[derive(Clone)]
43pub struct GraphqlClient<'a> {
44 /// The base URL of the GraphQL endpoint.
45 base_url: Url,
46 /// The underlying `reqwest::Client` used for making HTTP requests.
47 http_client: Client,
48 /// The User-Agent string sent with each request.
49 user_agent: &'a str,
50 /// The compression strategy to use for request payloads.
51 payload_compression: GraphQLPayloadCompression,
52}
53
54impl<'a> GraphqlClient<'a> {
55 /// Creates a new `GraphqlClient` instance.
56 ///
57 /// # Arguments
58 ///
59 /// * `base_url` - The base URL of the GraphQL API endpoint (e.g., "https://api.example.com/graphql").
60 /// * `user_agent` - The User-Agent string to identify this client.
61 /// * `payload_compression` - An optional compression strategy. Defaults to `GraphQLPayloadCompression::Zstd` if `None`.
62 ///
63 /// # Panics
64 ///
65 /// Panics if the provided `base_url` is invalid.
66 pub fn new(
67 base_url: &str,
68 user_agent: &'a str,
69 payload_compression: Option<GraphQLPayloadCompression>,
70 ) -> Self {
71 let base_url = Url::parse(base_url)
72 .expect("Provided invalid base URL for the GraphQL client. Please verify your URL.");
73
74 Self {
75 base_url,
76 http_client: Client::new(),
77 user_agent,
78 payload_compression: payload_compression.unwrap_or_default(),
79 }
80 }
81
82 /// Executes a GraphQL query or mutation.
83 ///
84 /// This method builds the request payload, applies compression if configured,
85 /// sends the request, handles potential response compression, and deserializes
86 /// the `data` field of the response into the specified type `T`.
87 ///
88 /// # Arguments
89 ///
90 /// * `query` - The GraphQL query or mutation string.
91 /// This should be a raw string literal (e.g., `r#"query { ... }"#`)
92 /// containing the full text of the operation. This client does not
93 /// use pre-generated query types.
94 /// * `variables` - A map of variables to be included in the request.
95 /// * `identity_public_key` - The public key identifying the client/user.
96 /// * `session_token` - An optional session token for authentication.
97 ///
98 /// # Type Parameters
99 ///
100 /// * `T` - The type to deserialize the `data` field of the GraphQL response into.
101 /// It must implement `serde::de::DeserializeOwned`.
102 ///
103 /// # Returns
104 ///
105 /// Returns a `Result` containing the deserialized data `T` on success, or an `eyre::Report`
106 /// detailing the error on failure (e.g., network issues, non-200 status code,
107 /// deserialization errors, missing `data` field).
108 #[needs("retry logic on network errors.")]
109 pub async fn execute_graphql_request<T: DeserializeOwned>(
110 &self,
111 query: &str,
112 variables: &HashMap<String, Value>,
113 identity_public_key: &[u8],
114 session_token: Option<&str>,
115 ) -> eyre::Result<T> {
116 let operation_name = graphql_utils::extract_operation_name(query)?;
117
118 let payload = graphql_utils::build_graphql_request(&operation_name, query, variables);
119 let payload_bytes = serde_json::to_string(&payload)?.into_bytes();
120 let (payload_bytes, payload_is_compressed) = self.compress_payload(&payload_bytes)?;
121
122 let request = self.construct_request(
123 &operation_name,
124 identity_public_key,
125 session_token,
126 payload_is_compressed,
127 );
128
129 let response = request.body(payload_bytes).send().await;
130 let response = if let Ok(response) = response {
131 response
132 } else {
133 return Err(eyre!("GraphQL request failed: {:?}", response.err()));
134 };
135 let status = response.status();
136
137 if status != StatusCode::OK {
138 return Err(eyre!("GraphQL request failed with status: {}", status));
139 }
140
141 let response_bytes = self.decompress_response(response).await?;
142 let response_variable_map: HashMap<String, Value> = from_slice(&response_bytes)?;
143 let response_data = response_variable_map
144 .get("data")
145 .ok_or(eyre!("No data found in the response"))?;
146
147 let response = from_value(response_data.clone())?;
148 Ok(response)
149 }
150
151 /// Compresses the request payload based on the client's configuration.
152 ///
153 /// # Arguments
154 ///
155 /// * `payload_bytes` - The raw bytes of the request payload.
156 ///
157 /// # Returns
158 ///
159 /// A tuple containing:
160 /// - The potentially compressed payload bytes.
161 /// - A boolean indicating whether compression was applied.
162 /// Returns an error if compression fails.
163 #[untested]
164 fn compress_payload(&self, payload_bytes: &[u8]) -> eyre::Result<(Vec<u8>, bool)> {
165 let (payload_bytes, payload_is_compressed) = match self.payload_compression {
166 GraphQLPayloadCompression::Zstd => compress_payload_zstd(&payload_bytes)?,
167 GraphQLPayloadCompression::NeverCompress => (payload_bytes.to_vec(), false),
168 };
169
170 Ok((payload_bytes, payload_is_compressed))
171 }
172
173 /// Decompresses the response body if it's compressed (currently supports zstd).
174 ///
175 /// Checks the `Content-Encoding` header of the response.
176 ///
177 /// # Arguments
178 ///
179 /// * `graphql_response` - The `reqwest::Response` received from the server.
180 ///
181 /// # Returns
182 ///
183 /// Returns the decompressed response body as bytes. Returns an error if decompression fails
184 /// or an unsupported compression algorithm is used.
185 #[untested]
186 #[needs("a formatted matching for response encoding algorithms used.")]
187 async fn decompress_response(&self, graphql_response: Response) -> eyre::Result<Vec<u8>> {
188 // Clone the header value to avoid borrowing graphql_response
189 let encoding = graphql_response
190 .headers()
191 .get("Content-Encoding")
192 .and_then(|h| h.to_str().ok())
193 .map(|s| s.to_string()); // Clone the relevant part
194
195 let body_bytes = graphql_response.bytes().await?;
196
197 if let Some(encoding) = encoding {
198 if encoding == "zstd" {
199 Ok(decode_all(&*body_bytes)?)
200 } else {
201 Err(eyre!("Unsupported compression algorithm: {}", encoding))
202 }
203 } else {
204 Ok(body_bytes.to_vec())
205 }
206 }
207
208 /// Constructs the basic `reqwest::RequestBuilder` with necessary headers.
209 ///
210 /// Sets headers like `Content-Type`, `User-Agent`, `X-GraphQL-Operation`,
211 /// `Spark-Identity-Public-Key`, `Accept-Encoding` (if compression is used),
212 /// and `Spark-Session-Token` (if provided).
213 ///
214 /// # Arguments
215 ///
216 /// * `operation_name` - The extracted name of the GraphQL operation.
217 /// * `identity_public_key` - The public key identifying the client/user.
218 /// * `session_token` - An optional session token for authentication.
219 /// * `is_compressed` - Indicates if the payload being sent is compressed (affects `Accept-Encoding` header).
220 ///
221 /// # Returns
222 ///
223 /// A `reqwest::RequestBuilder` ready to have the body added and be sent.
224 #[untested]
225 #[needs("refactoring")]
226 fn construct_request(
227 &self,
228 operation_name: &str,
229 identity_public_key: &[u8],
230 session_token: Option<&str>,
231 is_compressed: bool,
232 ) -> RequestBuilder {
233 // Build the HTTP request
234 let mut request = self
235 .http_client
236 .post(self.base_url.clone())
237 .header("Content-Type", "application/json")
238 .header("User-Agent", self.user_agent)
239 .header("X-GraphQL-Operation", operation_name)
240 .header(
241 "Spark-Identity-Public-Key",
242 hex::encode(identity_public_key),
243 );
244
245 if is_compressed {
246 let applied_encoding = self.payload_compression.apply_encoding_header();
247 request = request.header("Accept-Encoding", applied_encoding);
248 }
249
250 if let Some(session_token) = session_token {
251 request = request.header("Spark-Session-Token", session_token);
252 }
253
254 request
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use serde::Deserialize;
261 use tokio;
262
263 use super::*;
264
265 // Define structs mirroring the expected GraphQL response structure
266 #[derive(Deserialize, Debug, PartialEq)]
267 struct Character {
268 name: String,
269 }
270
271 #[derive(Deserialize, Debug, PartialEq)]
272 struct Characters {
273 results: Vec<Character>,
274 }
275
276 #[derive(Deserialize, Debug, PartialEq)]
277 struct CharacterData {
278 characters: Characters,
279 }
280
281 #[tokio::test]
282 async fn test_rick_and_morty_graphql_request() {
283 let client = GraphqlClient::new(
284 "https://rickandmortyapi.com/graphql",
285 "test-agent/0.1.0",
286 Some(GraphQLPayloadCompression::NeverCompress), // API might not support compression
287 );
288
289 // Simple query to get the name of the first character
290 let query = r#"
291 query GetCharacterName {
292 characters(page: 1) {
293 results {
294 name
295 }
296 }
297 }
298 "#;
299
300 let variables = HashMap::new(); // No variables needed for this query
301 let dummy_identity_public_key = b""; // Dummy key for testing
302 let session_token = None;
303
304 let result = client
305 .execute_graphql_request::<CharacterData>(
306 query,
307 &variables,
308 dummy_identity_public_key,
309 session_token,
310 )
311 .await;
312
313 assert!(result.is_ok(), "Request failed: {:?}", result.err());
314
315 let data = result.unwrap();
316 // We expect at least one character result
317 assert!(!data.characters.results.is_empty(), "No characters found");
318 // We can check the first character's name if we know it,
319 // but simply checking for non-empty results is safer for a public API.
320 // Example: assert_eq!(data.characters.results[0].name, "Rick Sanchez");
321 println!("First character name: {}", data.characters.results[0].name);
322 }
323}