spark_rust/wallet/
graphql.rs

1// Copyright ©, 2023-present, Lightspark Group, Inc. - All Rights Reserved
2
3use std::collections::HashMap;
4
5use bitcoin::secp256k1::PublicKey;
6use once_cell::sync::Lazy;
7use regex::Regex;
8use reqwest::Client;
9use serde_json::{from_slice, to_vec, Value};
10use url::Url;
11use zstd::{decode_all, encode_all};
12
13use crate::{
14    constants::spark::LIGHTSPARK_SSP_ENDPOINT,
15    error::{IoError, NetworkError, SparkSdkError},
16};
17
18/// Requester struct for making GraphQL requests to Lightspark API.
19#[derive(Clone)]
20pub struct GraphqlClient {
21    base_url: Url,
22    identity_public_key: PublicKey,
23    http_client: Client,
24}
25
26impl GraphqlClient {
27    /// Creates a new Requester with a custom base URL.
28    pub fn with_base_url(
29        identity_public_key: PublicKey,
30        base_url: Option<String>,
31    ) -> Result<Self, SparkSdkError> {
32        let url = validate_base_url(&base_url.unwrap_or(LIGHTSPARK_SSP_ENDPOINT.to_string()))?;
33
34        Ok(GraphqlClient {
35            base_url: url,
36            identity_public_key,
37            http_client: Client::new(),
38        })
39    }
40
41    /// Executes a GraphQL query or mutation with the given context (timeout handled via Client).
42    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
43    pub async fn execute_graphql(
44        &self,
45        query: &str,
46        variables: HashMap<String, Value>,
47    ) -> Result<HashMap<String, Value>, SparkSdkError> {
48        let operation_name = validate_operation(query)?;
49
50        // Prepare payload
51        let payload = serde_json::json!({
52            "operationName": operation_name,
53            "query": query,
54            "variables": variables,
55        });
56        let encoded_payload = to_vec(&payload).map_err(|_| {
57            SparkSdkError::from(NetworkError::InvalidGraphQLOperation(query.to_string()))
58        })?;
59
60        // Compress if payload > 1024 bytes
61        let (body, compressed) = if encoded_payload.len() > 1024 {
62            let compressed = encode_all(&encoded_payload[..], 0)
63                .map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
64            (compressed, true)
65        } else {
66            (encoded_payload, false)
67        };
68
69        // Build HTTP request
70        let mut request = self
71            .http_client
72            .post(self.base_url.clone())
73            .header(
74                "Spark-Identity-Public-Key",
75                hex::encode(self.identity_public_key.serialize()),
76            )
77            .header("Content-Type", "application/json")
78            .header("Accept-Encoding", "zstd")
79            .header("X-GraphQL-Operation", operation_name)
80            .header("User-Agent", self.get_user_agent())
81            .header("X-Polarity-SDK", self.get_user_agent());
82
83        if compressed {
84            request = request.header("Content-Encoding", "zstd");
85        }
86
87        let response = request
88            .body(body)
89            .send()
90            .await
91            .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?;
92
93        // Handle response
94        let status = response.status();
95
96        if !status.is_success() {
97            return Err(SparkSdkError::from(NetworkError::GraphQLRequestFailed {
98                status_code: status.as_u16(),
99            }));
100        }
101
102        // Check headers before consuming response
103        let is_zstd_encoded = response
104            .headers()
105            .get("Content-Encoding")
106            .is_some_and(|v| v == "zstd");
107
108        // Now consume response to get the body
109        let mut data = response
110            .bytes()
111            .await
112            .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?
113            .to_vec();
114
115        // Decode if zstd encoded
116        if is_zstd_encoded {
117            data = decode_all(&data[..]).map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
118        }
119
120        // Parse JSON response
121        let result: HashMap<String, Value> =
122            from_slice(&data).map_err(|err| SparkSdkError::from(IoError::SerdeJson(err)))?;
123        if let Some(errors) = result.get("errors") {
124            let err = errors
125                .as_array()
126                .and_then(|arr| arr.first())
127                .and_then(|v| v.as_object())
128                .ok_or("invalid error format")?;
129            let error_message = err
130                .get("message")
131                .and_then(|v| v.as_str())
132                .ok_or("missing error message")?
133                .to_string();
134
135            if let Some(extensions) = err.get("extensions").and_then(|v| v.as_object()) {
136                if extensions
137                    .get("error_name")
138                    .and_then(|v| v.as_str())
139                    .is_some()
140                {
141                    return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
142                }
143            }
144            return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
145        }
146
147        result
148            .get("data")
149            .and_then(|v| {
150                v.as_object()
151                    .map(|o| o.into_iter().map(|(k, v)| (k.clone(), v.clone())).collect())
152            })
153            .ok_or_else(|| "missing data field".into())
154    }
155
156    fn get_user_agent(&self) -> &str {
157        "spark"
158    }
159}
160
161/// Validates the given query is a valid GraphQL operation, and extracts the operation name.
162fn validate_operation(query: &str) -> Result<String, SparkSdkError> {
163    static RE: Lazy<Regex> =
164        Lazy::new(|| Regex::new(r"(?i)\s*(?:query|mutation)\s+(?P<OperationName>\w+)").unwrap());
165
166    let captures =
167        RE.captures(query)
168            .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
169                query.to_string(),
170            )))?;
171
172    captures
173        .name("OperationName")
174        .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
175            query.to_string(),
176        )))
177        .map(|m| m.as_str().to_string())
178}
179
180/// Returns true if the given hostname is localhost or a local network address. It is a local
181/// network address if the TLD is "local" or "internal" (e.g. mycomputer.local).
182fn is_localhost(hostname: &str) -> bool {
183    hostname == "localhost"
184        || hostname == "127.0.0.1"
185        || hostname
186            .split('.')
187            .last()
188            .is_some_and(|tld| tld == "internal" || tld == "local")
189}
190
191fn validate_base_url(base_url: &str) -> Result<Url, SparkSdkError> {
192    let url = Url::parse(base_url).map_err(|_| {
193        SparkSdkError::from(NetworkError::InvalidUrl {
194            url: base_url.to_string(),
195            details: None,
196        })
197    })?;
198
199    if url.scheme() != "https" && !url.host_str().is_some_and(is_localhost) {
200        return Err(SparkSdkError::from(NetworkError::InvalidUrl {
201            url: base_url.to_string(),
202            details: Some("Only HTTPS is supported for non-localhost".to_string()),
203        }));
204    }
205    Ok(url)
206}