spark_graphql_client/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::collections::HashMap;
4
5use code_status_macros::{needs, untested};
6use eyre::eyre;
7use graphql_utils::compress_payload_zstd;
8use reqwest::{Client, RequestBuilder, Response, StatusCode};
9use serde::de::DeserializeOwned;
10use serde_json::{Value, from_slice, from_value};
11use url::Url;
12use zstd::decode_all;
13
14mod graphql_utils;
15
16/// Specifies the type of payload compression to use for outgoing GraphQL requests.
17#[derive(Clone, Debug, Default, PartialEq, Eq)]
18pub enum GraphQLPayloadCompression {
19    /// Compress the request payload using the Zstandard (zstd) algorithm.
20    /// This is the default if no compression type is specified.
21    #[default]
22    Zstd,
23
24    /// Send the request payload without any compression.
25    NeverCompress,
26}
27
28impl GraphQLPayloadCompression {
29    /// Returns the corresponding `Accept-Encoding` header value for the compression type.
30    fn apply_encoding_header(&self) -> &str {
31        match self {
32            Self::Zstd => "zstd",
33            Self::NeverCompress => "",
34        }
35    }
36}
37
38/// A client for executing GraphQL operations (queries and mutations).
39///
40/// This client handles constructing the HTTP request, sending it to the specified
41/// GraphQL endpoint, and processing the response.
42#[derive(Clone)]
43pub struct GraphqlClient<'a> {
44    /// The base URL of the GraphQL endpoint.
45    base_url: Url,
46    /// The underlying `reqwest::Client` used for making HTTP requests.
47    http_client: Client,
48    /// The User-Agent string sent with each request.
49    user_agent: &'a str,
50    /// The compression strategy to use for request payloads.
51    payload_compression: GraphQLPayloadCompression,
52}
53
54impl<'a> GraphqlClient<'a> {
55    /// Creates a new `GraphqlClient` instance.
56    ///
57    /// # Arguments
58    ///
59    /// * `base_url` - The base URL of the GraphQL API endpoint (e.g., "https://api.example.com/graphql").
60    /// * `user_agent` - The User-Agent string to identify this client.
61    /// * `payload_compression` - An optional compression strategy. Defaults to `GraphQLPayloadCompression::Zstd` if `None`.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the provided `base_url` is invalid.
66    pub fn new(
67        base_url: &str,
68        user_agent: &'a str,
69        payload_compression: Option<GraphQLPayloadCompression>,
70    ) -> Self {
71        let base_url = Url::parse(base_url)
72            .expect("Provided invalid base URL for the GraphQL client. Please verify your URL.");
73
74        Self {
75            base_url,
76            http_client: Client::new(),
77            user_agent,
78            payload_compression: payload_compression.unwrap_or_default(),
79        }
80    }
81
82    /// Executes a GraphQL query or mutation.
83    ///
84    /// This method builds the request payload, applies compression if configured,
85    /// sends the request, handles potential response compression, and deserializes
86    /// the `data` field of the response into the specified type `T`.
87    ///
88    /// # Arguments
89    ///
90    /// * `query` - The GraphQL query or mutation string.
91    ///           This should be a raw string literal (e.g., `r#"query { ... }"#`)
92    ///           containing the full text of the operation. This client does not
93    ///           use pre-generated query types.
94    /// * `variables` - A map of variables to be included in the request.
95    /// * `identity_public_key` - The public key identifying the client/user.
96    /// * `session_token` - An optional session token for authentication.
97    ///
98    /// # Type Parameters
99    ///
100    /// * `T` - The type to deserialize the `data` field of the GraphQL response into.
101    ///         It must implement `serde::de::DeserializeOwned`.
102    ///
103    /// # Returns
104    ///
105    /// Returns a `Result` containing the deserialized data `T` on success, or an `eyre::Report`
106    /// detailing the error on failure (e.g., network issues, non-200 status code,
107    /// deserialization errors, missing `data` field).
108    #[needs("retry logic on network errors.")]
109    pub async fn execute_graphql_request<T: DeserializeOwned>(
110        &self,
111        query: &str,
112        variables: &HashMap<String, Value>,
113        identity_public_key: &[u8],
114        session_token: Option<&str>,
115    ) -> eyre::Result<T> {
116        let operation_name = graphql_utils::extract_operation_name(query)?;
117
118        let payload = graphql_utils::build_graphql_request(&operation_name, query, variables);
119        let payload_bytes = serde_json::to_string(&payload)?.into_bytes();
120        let (payload_bytes, payload_is_compressed) = self.compress_payload(&payload_bytes)?;
121
122        let request = self.construct_request(
123            &operation_name,
124            identity_public_key,
125            session_token,
126            payload_is_compressed,
127        );
128
129        let response = request.body(payload_bytes).send().await;
130        let response = if let Ok(response) = response {
131            response
132        } else {
133            return Err(eyre!("GraphQL request failed: {:?}", response.err()));
134        };
135        let status = response.status();
136
137        if status != StatusCode::OK {
138            return Err(eyre!("GraphQL request failed with status: {}", status));
139        }
140
141        let response_bytes = self.decompress_response(response).await?;
142        let response_variable_map: HashMap<String, Value> = from_slice(&response_bytes)?;
143        let response_data = response_variable_map
144            .get("data")
145            .ok_or(eyre!("No data found in the response"))?;
146
147        let response = from_value(response_data.clone())?;
148        Ok(response)
149    }
150
151    /// Compresses the request payload based on the client's configuration.
152    ///
153    /// # Arguments
154    ///
155    /// * `payload_bytes` - The raw bytes of the request payload.
156    ///
157    /// # Returns
158    ///
159    /// A tuple containing:
160    ///  - The potentially compressed payload bytes.
161    ///  - A boolean indicating whether compression was applied.
162    /// Returns an error if compression fails.
163    #[untested]
164    fn compress_payload(&self, payload_bytes: &[u8]) -> eyre::Result<(Vec<u8>, bool)> {
165        let (payload_bytes, payload_is_compressed) = match self.payload_compression {
166            GraphQLPayloadCompression::Zstd => compress_payload_zstd(&payload_bytes)?,
167            GraphQLPayloadCompression::NeverCompress => (payload_bytes.to_vec(), false),
168        };
169
170        Ok((payload_bytes, payload_is_compressed))
171    }
172
173    /// Decompresses the response body if it's compressed (currently supports zstd).
174    ///
175    /// Checks the `Content-Encoding` header of the response.
176    ///
177    /// # Arguments
178    ///
179    /// * `graphql_response` - The `reqwest::Response` received from the server.
180    ///
181    /// # Returns
182    ///
183    /// Returns the decompressed response body as bytes. Returns an error if decompression fails
184    /// or an unsupported compression algorithm is used.
185    #[untested]
186    #[needs("a formatted matching for response encoding algorithms used.")]
187    async fn decompress_response(&self, graphql_response: Response) -> eyre::Result<Vec<u8>> {
188        // Clone the header value to avoid borrowing graphql_response
189        let encoding = graphql_response
190            .headers()
191            .get("Content-Encoding")
192            .and_then(|h| h.to_str().ok())
193            .map(|s| s.to_string()); // Clone the relevant part
194
195        let body_bytes = graphql_response.bytes().await?;
196
197        if let Some(encoding) = encoding {
198            if encoding == "zstd" {
199                Ok(decode_all(&*body_bytes)?)
200            } else {
201                Err(eyre!("Unsupported compression algorithm: {}", encoding))
202            }
203        } else {
204            Ok(body_bytes.to_vec())
205        }
206    }
207
208    /// Constructs the basic `reqwest::RequestBuilder` with necessary headers.
209    ///
210    /// Sets headers like `Content-Type`, `User-Agent`, `X-GraphQL-Operation`,
211    /// `Spark-Identity-Public-Key`, `Accept-Encoding` (if compression is used),
212    /// and `Spark-Session-Token` (if provided).
213    ///
214    /// # Arguments
215    ///
216    /// * `operation_name` - The extracted name of the GraphQL operation.
217    /// * `identity_public_key` - The public key identifying the client/user.
218    /// * `session_token` - An optional session token for authentication.
219    /// * `is_compressed` - Indicates if the payload being sent is compressed (affects `Accept-Encoding` header).
220    ///
221    /// # Returns
222    ///
223    /// A `reqwest::RequestBuilder` ready to have the body added and be sent.
224    #[untested]
225    #[needs("refactoring")]
226    fn construct_request(
227        &self,
228        operation_name: &str,
229        identity_public_key: &[u8],
230        session_token: Option<&str>,
231        is_compressed: bool,
232    ) -> RequestBuilder {
233        // Build the HTTP request
234        let mut request = self
235            .http_client
236            .post(self.base_url.clone())
237            .header("Content-Type", "application/json")
238            .header("User-Agent", self.user_agent)
239            .header("X-GraphQL-Operation", operation_name)
240            .header(
241                "Spark-Identity-Public-Key",
242                hex::encode(identity_public_key),
243            );
244
245        if is_compressed {
246            let applied_encoding = self.payload_compression.apply_encoding_header();
247            request = request.header("Accept-Encoding", applied_encoding);
248        }
249
250        if let Some(session_token) = session_token {
251            request = request.header("Spark-Session-Token", session_token);
252        }
253
254        request
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use serde::Deserialize;
261    use tokio;
262
263    use super::*;
264
265    // Define structs mirroring the expected GraphQL response structure
266    #[derive(Deserialize, Debug, PartialEq)]
267    struct Character {
268        name: String,
269    }
270
271    #[derive(Deserialize, Debug, PartialEq)]
272    struct Characters {
273        results: Vec<Character>,
274    }
275
276    #[derive(Deserialize, Debug, PartialEq)]
277    struct CharacterData {
278        characters: Characters,
279    }
280
281    #[tokio::test]
282    async fn test_rick_and_morty_graphql_request() {
283        let client = GraphqlClient::new(
284            "https://rickandmortyapi.com/graphql",
285            "test-agent/0.1.0",
286            Some(GraphQLPayloadCompression::NeverCompress), // API might not support compression
287        );
288
289        // Simple query to get the name of the first character
290        let query = r#"
291            query GetCharacterName {
292                characters(page: 1) {
293                    results {
294                        name
295                    }
296                }
297            }
298        "#;
299
300        let variables = HashMap::new(); // No variables needed for this query
301        let dummy_identity_public_key = b""; // Dummy key for testing
302        let session_token = None;
303
304        let result = client
305            .execute_graphql_request::<CharacterData>(
306                query,
307                &variables,
308                dummy_identity_public_key,
309                session_token,
310            )
311            .await;
312
313        assert!(result.is_ok(), "Request failed: {:?}", result.err());
314
315        let data = result.unwrap();
316        // We expect at least one character result
317        assert!(!data.characters.results.is_empty(), "No characters found");
318        // We can check the first character's name if we know it,
319        // but simply checking for non-empty results is safer for a public API.
320        // Example: assert_eq!(data.characters.results[0].name, "Rick Sanchez");
321        println!("First character name: {}", data.characters.results[0].name);
322    }
323}