spark_rust/wallet/
graphql.rs1use std::collections::HashMap;
4
5use bitcoin::secp256k1::PublicKey;
6use once_cell::sync::Lazy;
7use regex::Regex;
8use reqwest::Client;
9use serde_json::{from_slice, to_vec, Value};
10use url::Url;
11use zstd::{decode_all, encode_all};
12
13use crate::{
14 constants::spark::LIGHTSPARK_SSP_ENDPOINT,
15 error::{IoError, NetworkError, SparkSdkError},
16};
17
18#[derive(Clone)]
20pub struct GraphqlClient {
21 base_url: Url,
22 identity_public_key: PublicKey,
23 http_client: Client,
24}
25
26impl GraphqlClient {
27 pub fn with_base_url(
29 identity_public_key: PublicKey,
30 base_url: Option<String>,
31 ) -> Result<Self, SparkSdkError> {
32 let url = validate_base_url(&base_url.unwrap_or(LIGHTSPARK_SSP_ENDPOINT.to_string()))?;
33
34 Ok(GraphqlClient {
35 base_url: url,
36 identity_public_key,
37 http_client: Client::new(),
38 })
39 }
40
41 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
43 pub async fn execute_graphql(
44 &self,
45 query: &str,
46 variables: HashMap<String, Value>,
47 ) -> Result<HashMap<String, Value>, SparkSdkError> {
48 let operation_name = validate_operation(query)?;
49
50 let payload = serde_json::json!({
52 "operationName": operation_name,
53 "query": query,
54 "variables": variables,
55 });
56 let encoded_payload = to_vec(&payload).map_err(|_| {
57 SparkSdkError::from(NetworkError::InvalidGraphQLOperation(query.to_string()))
58 })?;
59
60 let (body, compressed) = if encoded_payload.len() > 1024 {
62 let compressed = encode_all(&encoded_payload[..], 0)
63 .map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
64 (compressed, true)
65 } else {
66 (encoded_payload, false)
67 };
68
69 let mut request = self
71 .http_client
72 .post(self.base_url.clone())
73 .header(
74 "Spark-Identity-Public-Key",
75 hex::encode(self.identity_public_key.serialize()),
76 )
77 .header("Content-Type", "application/json")
78 .header("Accept-Encoding", "zstd")
79 .header("X-GraphQL-Operation", operation_name)
80 .header("User-Agent", self.get_user_agent())
81 .header("X-Polarity-SDK", self.get_user_agent());
82
83 if compressed {
84 request = request.header("Content-Encoding", "zstd");
85 }
86
87 let response = request
88 .body(body)
89 .send()
90 .await
91 .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?;
92
93 let status = response.status();
95
96 if !status.is_success() {
97 return Err(SparkSdkError::from(NetworkError::GraphQLRequestFailed {
98 status_code: status.as_u16(),
99 }));
100 }
101
102 let is_zstd_encoded = response
104 .headers()
105 .get("Content-Encoding")
106 .is_some_and(|v| v == "zstd");
107
108 let mut data = response
110 .bytes()
111 .await
112 .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?
113 .to_vec();
114
115 if is_zstd_encoded {
117 data = decode_all(&data[..]).map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
118 }
119
120 let result: HashMap<String, Value> =
122 from_slice(&data).map_err(|err| SparkSdkError::from(IoError::SerdeJson(err)))?;
123 if let Some(errors) = result.get("errors") {
124 let err = errors
125 .as_array()
126 .and_then(|arr| arr.first())
127 .and_then(|v| v.as_object())
128 .ok_or("invalid error format")?;
129 let error_message = err
130 .get("message")
131 .and_then(|v| v.as_str())
132 .ok_or("missing error message")?
133 .to_string();
134
135 if let Some(extensions) = err.get("extensions").and_then(|v| v.as_object()) {
136 if extensions
137 .get("error_name")
138 .and_then(|v| v.as_str())
139 .is_some()
140 {
141 return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
142 }
143 }
144 return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
145 }
146
147 result
148 .get("data")
149 .and_then(|v| {
150 v.as_object()
151 .map(|o| o.into_iter().map(|(k, v)| (k.clone(), v.clone())).collect())
152 })
153 .ok_or_else(|| "missing data field".into())
154 }
155
156 fn get_user_agent(&self) -> &str {
157 "spark"
158 }
159}
160
161fn validate_operation(query: &str) -> Result<String, SparkSdkError> {
163 static RE: Lazy<Regex> =
164 Lazy::new(|| Regex::new(r"(?i)\s*(?:query|mutation)\s+(?P<OperationName>\w+)").unwrap());
165
166 let captures =
167 RE.captures(query)
168 .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
169 query.to_string(),
170 )))?;
171
172 captures
173 .name("OperationName")
174 .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
175 query.to_string(),
176 )))
177 .map(|m| m.as_str().to_string())
178}
179
180fn is_localhost(hostname: &str) -> bool {
183 hostname == "localhost"
184 || hostname == "127.0.0.1"
185 || hostname
186 .split('.')
187 .last()
188 .is_some_and(|tld| tld == "internal" || tld == "local")
189}
190
191fn validate_base_url(base_url: &str) -> Result<Url, SparkSdkError> {
192 let url = Url::parse(base_url).map_err(|_| {
193 SparkSdkError::from(NetworkError::InvalidUrl {
194 url: base_url.to_string(),
195 details: None,
196 })
197 })?;
198
199 if url.scheme() != "https" && !url.host_str().is_some_and(is_localhost) {
200 return Err(SparkSdkError::from(NetworkError::InvalidUrl {
201 url: base_url.to_string(),
202 details: Some("Only HTTPS is supported for non-localhost".to_string()),
203 }));
204 }
205 Ok(url)
206}