spark_rust/wallet/
graphql.rs

1// Copyright ©, 2023-present, Lightspark Group, Inc. - All Rights Reserved
2
3//! GraphQL client for Lightspark API communication
4//!
5//! Provides a client for executing GraphQL operations against the Lightspark API,
6//! with features for request compression, authentication, and error handling.
7
8use std::collections::HashMap;
9
10use bitcoin::secp256k1::PublicKey;
11use once_cell::sync::Lazy;
12use regex::Regex;
13use reqwest::Client;
14use serde_json::{from_slice, to_vec, Value};
15use url::Url;
16use zstd::{decode_all, encode_all};
17
18use crate::{
19    constants::spark::LIGHTSPARK_SSP_ENDPOINT,
20    error::{IoError, NetworkError, SparkSdkError},
21};
22
23/// Client for making GraphQL requests to the Lightspark API
24#[derive(Clone)]
25pub struct GraphqlClient {
26    base_url: Url,
27    http_client: Client,
28}
29
30impl GraphqlClient {
31    /// Creates a new GraphQL client
32    ///
33    /// Uses LIGHTSPARK_SSP_ENDPOINT environment variable if available,
34    /// otherwise falls back to the default endpoint
35    pub fn new() -> Result<Self, SparkSdkError> {
36        let base_url = validate_base_url(
37            std::env::var("LIGHTSPARK_SSP_ENDPOINT")
38                .unwrap_or(LIGHTSPARK_SSP_ENDPOINT.to_string())
39                .as_str(),
40        )?;
41
42        Ok(GraphqlClient {
43            base_url,
44            http_client: Client::new(),
45        })
46    }
47
48    /// Executes a GraphQL query or mutation
49    ///
50    /// Handles payload compression, authentication, and response processing.
51    /// Large payloads (>1KB) are automatically compressed using zstd.
52    ///
53    /// # Arguments
54    /// * `identity_public_key` - Public key used for authentication
55    /// * `query` - GraphQL operation (must include operation name)
56    /// * `variables` - Variables to include with the GraphQL operation
57    ///
58    /// # Returns
59    /// * `Ok(HashMap<String, Value>)` - Data returned from the GraphQL operation
60    /// * `Err(SparkSdkError)` - If the request or response handling fails
61    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
62    pub async fn execute_graphql(
63        &self,
64        identity_public_key: &PublicKey,
65        query: &str,
66        variables: HashMap<String, Value>,
67    ) -> Result<HashMap<String, Value>, SparkSdkError> {
68        let operation_name = validate_operation(query)?;
69
70        // Prepare payload
71        let payload = serde_json::json!({
72            "operationName": operation_name,
73            "query": query,
74            "variables": variables,
75        });
76        let encoded_payload = to_vec(&payload).map_err(|_| {
77            SparkSdkError::from(NetworkError::InvalidGraphQLOperation(query.to_string()))
78        })?;
79
80        // Compress if payload > 1024 bytes
81        let (body, compressed) = if encoded_payload.len() > 1024 {
82            let compressed = encode_all(&encoded_payload[..], 0)
83                .map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
84            (compressed, true)
85        } else {
86            (encoded_payload, false)
87        };
88
89        // Build HTTP request
90        let mut request = self
91            .http_client
92            .post(self.base_url.clone())
93            .header(
94                "Spark-Identity-Public-Key",
95                hex::encode(identity_public_key.serialize()),
96            )
97            .header("Content-Type", "application/json")
98            .header("Accept-Encoding", "zstd")
99            .header("X-GraphQL-Operation", operation_name)
100            .header("User-Agent", self.get_user_agent())
101            .header("X-Polarity-SDK", self.get_user_agent());
102
103        if compressed {
104            request = request.header("Content-Encoding", "zstd");
105        }
106
107        let response = request
108            .body(body)
109            .send()
110            .await
111            .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?;
112
113        // Handle response
114        let status = response.status();
115
116        if !status.is_success() {
117            return Err(SparkSdkError::from(NetworkError::GraphQLRequestFailed {
118                status_code: status.as_u16(),
119            }));
120        }
121
122        // Check headers before consuming response
123        let is_zstd_encoded = response
124            .headers()
125            .get("Content-Encoding")
126            .is_some_and(|v| v == "zstd");
127
128        // Now consume response to get the body
129        let mut data = response
130            .bytes()
131            .await
132            .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?
133            .to_vec();
134
135        // Decode if zstd encoded
136        if is_zstd_encoded {
137            data = decode_all(&data[..]).map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
138        }
139
140        // Parse JSON response
141        let result: HashMap<String, Value> =
142            from_slice(&data).map_err(|err| SparkSdkError::from(IoError::SerdeJson(err)))?;
143        if let Some(errors) = result.get("errors") {
144            let err = errors
145                .as_array()
146                .and_then(|arr| arr.first())
147                .and_then(|v| v.as_object())
148                .ok_or("invalid error format")?;
149            let error_message = err
150                .get("message")
151                .and_then(|v| v.as_str())
152                .ok_or("missing error message")?
153                .to_string();
154
155            if let Some(extensions) = err.get("extensions").and_then(|v| v.as_object()) {
156                if extensions
157                    .get("error_name")
158                    .and_then(|v| v.as_str())
159                    .is_some()
160                {
161                    return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
162                }
163            }
164            return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
165        }
166
167        result
168            .get("data")
169            .and_then(|v| {
170                v.as_object()
171                    .map(|o| o.into_iter().map(|(k, v)| (k.clone(), v.clone())).collect())
172            })
173            .ok_or_else(|| "missing data field".into())
174    }
175
176    /// Returns the user agent string for requests
177    fn get_user_agent(&self) -> &str {
178        "spark"
179    }
180}
181
182/// Extracts the operation name from a GraphQL query or mutation
183///
184/// # Arguments
185/// * `query` - GraphQL operation string
186///
187/// # Returns
188/// * `Ok(String)` - The extracted operation name
189/// * `Err(SparkSdkError)` - If the operation name cannot be extracted
190fn validate_operation(query: &str) -> Result<String, SparkSdkError> {
191    static RE: Lazy<Regex> =
192        Lazy::new(|| Regex::new(r"(?i)\s*(?:query|mutation)\s+(?P<OperationName>\w+)").unwrap());
193
194    let captures =
195        RE.captures(query)
196            .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
197                query.to_string(),
198            )))?;
199
200    captures
201        .name("OperationName")
202        .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
203            query.to_string(),
204        )))
205        .map(|m| m.as_str().to_string())
206}
207
208/// Checks if a hostname is a localhost or local network address
209///
210/// # Arguments
211/// * `hostname` - Hostname to check
212///
213/// # Returns
214/// * `true` if the hostname is localhost or a local network address
215/// * `false` otherwise
216fn is_localhost(hostname: &str) -> bool {
217    hostname == "localhost"
218        || hostname == "127.0.0.1"
219        || hostname
220            .split('.')
221            .next_back()
222            .is_some_and(|tld| tld == "internal" || tld == "local")
223}
224
225/// Validates and parses a base URL for the GraphQL API
226///
227/// Ensures the URL uses HTTPS unless it's a localhost address
228///
229/// # Arguments
230/// * `base_url` - URL string to validate
231///
232/// # Returns
233/// * `Ok(Url)` - Parsed and validated URL
234/// * `Err(SparkSdkError)` - If the URL is invalid or uses an insecure scheme
235fn validate_base_url(base_url: &str) -> Result<Url, SparkSdkError> {
236    let url = Url::parse(base_url).map_err(|_| {
237        SparkSdkError::from(NetworkError::InvalidUrl {
238            url: base_url.to_string(),
239            details: None,
240        })
241    })?;
242
243    if url.scheme() != "https" && !url.host_str().is_some_and(is_localhost) {
244        return Err(SparkSdkError::from(NetworkError::InvalidUrl {
245            url: base_url.to_string(),
246            details: Some("Only HTTPS is supported for non-localhost".to_string()),
247        }));
248    }
249    Ok(url)
250}