spark_rust/wallet/
graphql.rs1use std::collections::HashMap;
9
10use bitcoin::secp256k1::PublicKey;
11use once_cell::sync::Lazy;
12use regex::Regex;
13use reqwest::Client;
14use serde_json::{from_slice, to_vec, Value};
15use url::Url;
16use zstd::{decode_all, encode_all};
17
18use crate::{
19 constants::spark::LIGHTSPARK_SSP_ENDPOINT,
20 error::{IoError, NetworkError, SparkSdkError},
21};
22
23#[derive(Clone)]
25pub struct GraphqlClient {
26 base_url: Url,
27 http_client: Client,
28}
29
30impl GraphqlClient {
31 pub fn new() -> Result<Self, SparkSdkError> {
36 let base_url = validate_base_url(
37 std::env::var("LIGHTSPARK_SSP_ENDPOINT")
38 .unwrap_or(LIGHTSPARK_SSP_ENDPOINT.to_string())
39 .as_str(),
40 )?;
41
42 Ok(GraphqlClient {
43 base_url,
44 http_client: Client::new(),
45 })
46 }
47
48 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
62 pub async fn execute_graphql(
63 &self,
64 identity_public_key: &PublicKey,
65 query: &str,
66 variables: HashMap<String, Value>,
67 ) -> Result<HashMap<String, Value>, SparkSdkError> {
68 let operation_name = validate_operation(query)?;
69
70 let payload = serde_json::json!({
72 "operationName": operation_name,
73 "query": query,
74 "variables": variables,
75 });
76 let encoded_payload = to_vec(&payload).map_err(|_| {
77 SparkSdkError::from(NetworkError::InvalidGraphQLOperation(query.to_string()))
78 })?;
79
80 let (body, compressed) = if encoded_payload.len() > 1024 {
82 let compressed = encode_all(&encoded_payload[..], 0)
83 .map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
84 (compressed, true)
85 } else {
86 (encoded_payload, false)
87 };
88
89 let mut request = self
91 .http_client
92 .post(self.base_url.clone())
93 .header(
94 "Spark-Identity-Public-Key",
95 hex::encode(identity_public_key.serialize()),
96 )
97 .header("Content-Type", "application/json")
98 .header("Accept-Encoding", "zstd")
99 .header("X-GraphQL-Operation", operation_name)
100 .header("User-Agent", self.get_user_agent())
101 .header("X-Polarity-SDK", self.get_user_agent());
102
103 if compressed {
104 request = request.header("Content-Encoding", "zstd");
105 }
106
107 let response = request
108 .body(body)
109 .send()
110 .await
111 .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?;
112
113 let status = response.status();
115
116 if !status.is_success() {
117 return Err(SparkSdkError::from(NetworkError::GraphQLRequestFailed {
118 status_code: status.as_u16(),
119 }));
120 }
121
122 let is_zstd_encoded = response
124 .headers()
125 .get("Content-Encoding")
126 .is_some_and(|v| v == "zstd");
127
128 let mut data = response
130 .bytes()
131 .await
132 .map_err(|err| SparkSdkError::from(NetworkError::Http(err)))?
133 .to_vec();
134
135 if is_zstd_encoded {
137 data = decode_all(&data[..]).map_err(|err| SparkSdkError::from(IoError::Io(err)))?;
138 }
139
140 let result: HashMap<String, Value> =
142 from_slice(&data).map_err(|err| SparkSdkError::from(IoError::SerdeJson(err)))?;
143 if let Some(errors) = result.get("errors") {
144 let err = errors
145 .as_array()
146 .and_then(|arr| arr.first())
147 .and_then(|v| v.as_object())
148 .ok_or("invalid error format")?;
149 let error_message = err
150 .get("message")
151 .and_then(|v| v.as_str())
152 .ok_or("missing error message")?
153 .to_string();
154
155 if let Some(extensions) = err.get("extensions").and_then(|v| v.as_object()) {
156 if extensions
157 .get("error_name")
158 .and_then(|v| v.as_str())
159 .is_some()
160 {
161 return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
162 }
163 }
164 return Err(SparkSdkError::from(NetworkError::GraphQL(error_message)));
165 }
166
167 result
168 .get("data")
169 .and_then(|v| {
170 v.as_object()
171 .map(|o| o.into_iter().map(|(k, v)| (k.clone(), v.clone())).collect())
172 })
173 .ok_or_else(|| "missing data field".into())
174 }
175
176 fn get_user_agent(&self) -> &str {
178 "spark"
179 }
180}
181
182fn validate_operation(query: &str) -> Result<String, SparkSdkError> {
191 static RE: Lazy<Regex> =
192 Lazy::new(|| Regex::new(r"(?i)\s*(?:query|mutation)\s+(?P<OperationName>\w+)").unwrap());
193
194 let captures =
195 RE.captures(query)
196 .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
197 query.to_string(),
198 )))?;
199
200 captures
201 .name("OperationName")
202 .ok_or(SparkSdkError::from(NetworkError::InvalidGraphQLOperation(
203 query.to_string(),
204 )))
205 .map(|m| m.as_str().to_string())
206}
207
208fn is_localhost(hostname: &str) -> bool {
217 hostname == "localhost"
218 || hostname == "127.0.0.1"
219 || hostname
220 .split('.')
221 .next_back()
222 .is_some_and(|tld| tld == "internal" || tld == "local")
223}
224
225fn validate_base_url(base_url: &str) -> Result<Url, SparkSdkError> {
236 let url = Url::parse(base_url).map_err(|_| {
237 SparkSdkError::from(NetworkError::InvalidUrl {
238 url: base_url.to_string(),
239 details: None,
240 })
241 })?;
242
243 if url.scheme() != "https" && !url.host_str().is_some_and(is_localhost) {
244 return Err(SparkSdkError::from(NetworkError::InvalidUrl {
245 url: base_url.to_string(),
246 details: Some("Only HTTPS is supported for non-localhost".to_string()),
247 }));
248 }
249 Ok(url)
250}