1use hmac::{Hmac, Mac};
2use reqwest::Method;
3use serde_json::Value;
4use sha2::Sha256;
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10use subtle::ConstantTimeEq;
11
12use http::StatusCode;
13
14use crate::errors::StripeErrorCode;
15
16mod paginated_list;
17
18pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
19pub type StripeTransportFuture<'a> = BoxFuture<'a, Result<StripeResponse, StripeApiError>>;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct StripeRequest {
23 pub method: String,
24 pub path: String,
25 pub headers: BTreeMap<String, String>,
26 pub body: String,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct StripeResponse {
31 pub status: u16,
32 pub body: Value,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum StripeApiError {
37 #[error("{message}")]
38 Stripe {
39 status: u16,
40 code: Option<String>,
41 message: String,
42 },
43 #[error("transport error: {0}")]
44 Transport(String),
45 #[error("{0}")]
46 Webhook(StripeErrorCode),
47}
48
49impl StripeApiError {
50 pub fn code(&self) -> &str {
51 match self {
52 Self::Stripe {
53 code: Some(code), ..
54 } => code,
55 Self::Stripe { .. } => "STRIPE_API_ERROR",
56 Self::Transport(_) => "STRIPE_TRANSPORT_ERROR",
57 Self::Webhook(code) => code.code(),
58 }
59 }
60
61 pub fn is_already_scheduled_cancel(&self) -> bool {
62 match self {
63 Self::Stripe { code, message, .. } => {
64 matches!(
65 code.as_deref(),
66 Some(
67 "subscription_already_canceled"
68 | "resource_already_exists"
69 | "invalid_request_error"
70 )
71 ) || message.contains("already set to be canceled")
72 }
73 _ => false,
74 }
75 }
76
77 pub fn plugin_response(&self, default: StripeErrorCode) -> (StatusCode, StripeErrorCode) {
78 match self {
79 Self::Webhook(code) => (StatusCode::BAD_REQUEST, *code),
80 Self::Transport(_) => (StatusCode::BAD_GATEWAY, StripeErrorCode::FailedToFetchPlans),
81 Self::Stripe { status, code, .. } if *status >= 500 => {
82 (StatusCode::BAD_GATEWAY, StripeErrorCode::FailedToFetchPlans)
83 }
84 Self::Stripe { code, .. } => (
85 StatusCode::BAD_REQUEST,
86 map_stripe_code_to_plugin(default, code.as_deref()),
87 ),
88 }
89 }
90}
91
92fn map_stripe_code_to_plugin(
93 default: StripeErrorCode,
94 stripe_code: Option<&str>,
95) -> StripeErrorCode {
96 match (default, stripe_code) {
97 (StripeErrorCode::UnableToCreateCustomer, Some("resource_missing")) => {
98 StripeErrorCode::CustomerNotFound
99 }
100 (StripeErrorCode::UnableToCreateBillingPortal, Some("resource_missing")) => {
101 StripeErrorCode::SubscriptionNotFound
102 }
103 (StripeErrorCode::SubscriptionNotFound, Some("resource_missing")) => {
104 StripeErrorCode::SubscriptionNotFound
105 }
106 _ => default,
107 }
108}
109
110pub trait StripeTransport: Send + Sync {
111 fn send<'a>(&'a self, request: StripeRequest) -> StripeTransportFuture<'a>;
112}
113
114#[derive(Clone)]
115pub struct StripeClient {
116 secret_key: String,
117 api_base: String,
118 api_version: Option<String>,
119 transport: Arc<dyn StripeTransport>,
120}
121
122impl StripeClient {
123 pub fn new(secret_key: impl Into<String>) -> Self {
124 Self {
125 secret_key: secret_key.into(),
126 api_base: "https://api.stripe.com".to_owned(),
127 api_version: None,
128 transport: Arc::new(ReqwestStripeTransport::new("https://api.stripe.com")),
129 }
130 }
131
132 pub fn with_transport(
133 secret_key: impl Into<String>,
134 transport: Arc<dyn StripeTransport>,
135 ) -> Self {
136 Self {
137 secret_key: secret_key.into(),
138 api_base: "https://api.stripe.com".to_owned(),
139 api_version: None,
140 transport,
141 }
142 }
143
144 pub fn with_api_base(mut self, api_base: impl Into<String>) -> Self {
145 self.api_base = api_base.into();
146 self.transport = Arc::new(ReqwestStripeTransport::new(self.api_base.clone()));
147 self
148 }
149
150 pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
151 self.api_version = Some(api_version.into());
152 self
153 }
154
155 pub async fn create_customer(&self, params: Value) -> Result<Value, StripeApiError> {
156 self.post("/v1/customers", params).await
157 }
158
159 pub async fn update_customer(
160 &self,
161 customer_id: &str,
162 params: Value,
163 ) -> Result<Value, StripeApiError> {
164 self.post(&format!("/v1/customers/{customer_id}"), params)
165 .await
166 }
167
168 pub async fn retrieve_customer(&self, customer_id: &str) -> Result<Value, StripeApiError> {
169 self.get(&format!("/v1/customers/{customer_id}"), Value::Null)
170 .await
171 }
172
173 pub async fn search_customers(&self, query: &str) -> Result<Value, StripeApiError> {
174 self.search_customers_page(query, None).await
175 }
176
177 pub(crate) async fn search_customers_page(
178 &self,
179 query: &str,
180 page: Option<&str>,
181 ) -> Result<Value, StripeApiError> {
182 let mut params = serde_json::json!({
183 "query": query,
184 "limit": paginated_list::STRIPE_LIST_PAGE_LIMIT,
185 });
186 if let Some(page) = page {
187 if let Some(object) = params.as_object_mut() {
188 object.insert("page".to_owned(), serde_json::json!(page));
189 }
190 }
191 self.get("/v1/customers/search", params).await
192 }
193
194 pub async fn list_customers(&self, params: Value) -> Result<Value, StripeApiError> {
195 self.get("/v1/customers", params).await
196 }
197
198 pub async fn retrieve_price(&self, price_id: &str) -> Result<Value, StripeApiError> {
199 self.get(&format!("/v1/prices/{price_id}"), Value::Null)
200 .await
201 }
202
203 pub async fn list_prices(&self, params: Value) -> Result<Value, StripeApiError> {
204 self.get("/v1/prices", params).await
205 }
206
207 pub async fn price_by_lookup_key(&self, lookup_key: &str) -> Result<Value, StripeApiError> {
208 self.list_prices(serde_json::json!({
209 "lookup_keys": [lookup_key],
210 "active": true,
211 "limit": 1
212 }))
213 .await
214 }
215
216 pub async fn create_checkout_session(&self, params: Value) -> Result<Value, StripeApiError> {
217 self.post("/v1/checkout/sessions", params).await
218 }
219
220 pub async fn retrieve_checkout_session(
221 &self,
222 session_id: &str,
223 ) -> Result<Value, StripeApiError> {
224 self.get(&format!("/v1/checkout/sessions/{session_id}"), Value::Null)
225 .await
226 }
227
228 pub async fn create_billing_portal_session(
229 &self,
230 params: Value,
231 ) -> Result<Value, StripeApiError> {
232 self.post("/v1/billing_portal/sessions", params).await
233 }
234
235 pub async fn list_subscriptions(&self, params: Value) -> Result<Value, StripeApiError> {
236 self.get("/v1/subscriptions", params).await
237 }
238
239 pub async fn retrieve_subscription(
240 &self,
241 subscription_id: &str,
242 ) -> Result<Value, StripeApiError> {
243 self.get(&format!("/v1/subscriptions/{subscription_id}"), Value::Null)
244 .await
245 }
246
247 pub async fn update_subscription(
248 &self,
249 subscription_id: &str,
250 params: Value,
251 ) -> Result<Value, StripeApiError> {
252 self.post(&format!("/v1/subscriptions/{subscription_id}"), params)
253 .await
254 }
255
256 pub async fn create_subscription_schedule(
257 &self,
258 params: Value,
259 ) -> Result<Value, StripeApiError> {
260 self.post("/v1/subscription_schedules", params).await
261 }
262
263 pub async fn list_subscription_schedules(
264 &self,
265 params: Value,
266 ) -> Result<Value, StripeApiError> {
267 self.get("/v1/subscription_schedules", params).await
268 }
269
270 pub async fn retrieve_subscription_schedule(
271 &self,
272 schedule_id: &str,
273 ) -> Result<Value, StripeApiError> {
274 self.get(
275 &format!("/v1/subscription_schedules/{schedule_id}"),
276 Value::Null,
277 )
278 .await
279 }
280
281 pub async fn update_subscription_schedule(
282 &self,
283 schedule_id: &str,
284 params: Value,
285 ) -> Result<Value, StripeApiError> {
286 self.post(&format!("/v1/subscription_schedules/{schedule_id}"), params)
287 .await
288 }
289
290 pub async fn release_subscription_schedule(
291 &self,
292 schedule_id: &str,
293 ) -> Result<Value, StripeApiError> {
294 self.post(
295 &format!("/v1/subscription_schedules/{schedule_id}/release"),
296 Value::Object(Default::default()),
297 )
298 .await
299 }
300
301 async fn post(&self, path: &str, params: Value) -> Result<Value, StripeApiError> {
302 self.send("POST", path, params).await
303 }
304
305 async fn get(&self, path: &str, params: Value) -> Result<Value, StripeApiError> {
306 self.send("GET", path, params).await
307 }
308
309 async fn send(&self, method: &str, path: &str, params: Value) -> Result<Value, StripeApiError> {
310 let body = if params.is_null() {
311 String::new()
312 } else {
313 encode_form(¶ms)
314 };
315 let mut headers = BTreeMap::new();
316 headers.insert(
317 "Authorization".to_owned(),
318 format!("Bearer {}", self.secret_key),
319 );
320 headers.insert(
321 "Content-Type".to_owned(),
322 "application/x-www-form-urlencoded".to_owned(),
323 );
324 if let Some(api_version) = &self.api_version {
325 headers.insert("Stripe-Version".to_owned(), api_version.clone());
326 }
327 let request = StripeRequest {
328 method: method.to_owned(),
329 path: path.to_owned(),
330 headers,
331 body,
332 };
333 let response = self.transport.send(request).await?;
334 if (200..300).contains(&response.status) {
335 Ok(response.body)
336 } else {
337 Err(stripe_error_from_response(response))
338 }
339 }
340}
341
342pub struct ReqwestStripeTransport {
343 client: reqwest::Client,
344 api_base: String,
345}
346
347const DEFAULT_STRIPE_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
348
349impl ReqwestStripeTransport {
350 pub fn new(api_base: impl Into<String>) -> Self {
351 Self::with_timeout(api_base, DEFAULT_STRIPE_HTTP_TIMEOUT)
352 }
353
354 pub fn with_timeout(api_base: impl Into<String>, timeout: Duration) -> Self {
355 let client = reqwest::Client::builder()
356 .timeout(timeout)
357 .build()
358 .unwrap_or_else(|_| reqwest::Client::new());
359 Self {
360 client,
361 api_base: api_base.into(),
362 }
363 }
364}
365
366impl StripeTransport for ReqwestStripeTransport {
367 fn send<'a>(&'a self, request: StripeRequest) -> StripeTransportFuture<'a> {
368 Box::pin(async move {
369 let method = request
370 .method
371 .parse::<Method>()
372 .map_err(|error| StripeApiError::Transport(error.to_string()))?;
373 let url = if request.method == "GET" && !request.body.is_empty() {
374 format!("{}{}?{}", self.api_base, request.path, request.body)
375 } else {
376 format!("{}{}", self.api_base, request.path)
377 };
378 let mut builder = self.client.request(method, url);
379 for (name, value) in request.headers {
380 builder = builder.header(name, value);
381 }
382 if request.method != "GET" {
383 builder = builder.body(request.body);
384 }
385 let response = builder
386 .send()
387 .await
388 .map_err(|error| StripeApiError::Transport(error.to_string()))?;
389 let status = response.status().as_u16();
390 let body = response
391 .json::<Value>()
392 .await
393 .map_err(|error| StripeApiError::Transport(error.to_string()))?;
394 Ok(StripeResponse { status, body })
395 })
396 }
397}
398
399pub fn encode_form(value: &Value) -> String {
400 let mut pairs = Vec::new();
401 collect_form_pairs(None, value, &mut pairs);
402 pairs
403 .into_iter()
404 .map(|(key, value)| format!("{}={}", form_encode(&key), form_encode(&value)))
405 .collect::<Vec<_>>()
406 .join("&")
407}
408
409fn collect_form_pairs(prefix: Option<String>, value: &Value, pairs: &mut Vec<(String, String)>) {
410 match value {
411 Value::Object(map) => {
412 for (key, value) in map {
413 let key = match &prefix {
414 Some(prefix) => format!("{prefix}[{key}]"),
415 None => key.clone(),
416 };
417 collect_form_pairs(Some(key), value, pairs);
418 }
419 }
420 Value::Array(values) => {
421 for (index, value) in values.iter().enumerate() {
422 if let Some(prefix) = &prefix {
423 collect_form_pairs(Some(format!("{prefix}[{index}]")), value, pairs);
424 }
425 }
426 }
427 Value::String(value) => {
428 if let Some(prefix) = prefix {
429 pairs.push((prefix, value.clone()));
430 }
431 }
432 Value::Number(value) => {
433 if let Some(prefix) = prefix {
434 pairs.push((prefix, value.to_string()));
435 }
436 }
437 Value::Bool(value) => {
438 if let Some(prefix) = prefix {
439 pairs.push((prefix, value.to_string()));
440 }
441 }
442 Value::Null => {}
443 }
444}
445
446fn form_encode(value: &str) -> String {
447 url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
448}
449
450pub fn verify_webhook_signature(
451 payload: &[u8],
452 signature_header: &str,
453 secret: &str,
454 tolerance_seconds: i64,
455 now_unix: i64,
456) -> Result<(), StripeApiError> {
457 let timestamp = signature_header
458 .split(',')
459 .find_map(|part| part.strip_prefix("t="))
460 .and_then(|value| value.parse::<i64>().ok())
461 .ok_or(StripeApiError::Webhook(
462 StripeErrorCode::FailedToConstructStripeEvent,
463 ))?;
464 if (now_unix - timestamp).abs() > tolerance_seconds {
465 return Err(StripeApiError::Webhook(
466 StripeErrorCode::FailedToConstructStripeEvent,
467 ));
468 }
469 let expected = webhook_signature(payload, secret, timestamp)?;
470 let verified = signature_header
471 .split(',')
472 .filter_map(|part| part.strip_prefix("v1="))
473 .filter_map(|signature| hex::decode(signature).ok())
474 .any(|candidate| candidate.ct_eq(expected.as_slice()).into());
475 if verified {
476 Ok(())
477 } else {
478 Err(StripeApiError::Webhook(
479 StripeErrorCode::FailedToConstructStripeEvent,
480 ))
481 }
482}
483
484fn webhook_signature(
485 payload: &[u8],
486 secret: &str,
487 timestamp: i64,
488) -> Result<Vec<u8>, StripeApiError> {
489 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).map_err(|error| {
492 StripeApiError::Transport(format!("failed to initialize webhook verifier: {error}"))
493 })?;
494 mac.update(timestamp.to_string().as_bytes());
495 mac.update(b".");
496 mac.update(payload);
497 Ok(mac.finalize().into_bytes().to_vec())
498}
499
500fn stripe_error_from_response(response: StripeResponse) -> StripeApiError {
501 let error = response.body.get("error").unwrap_or(&response.body);
502 let code = error.get("code").and_then(Value::as_str).map(str::to_owned);
503 let message = error
504 .get("message")
505 .and_then(Value::as_str)
506 .unwrap_or("Stripe API request failed")
507 .to_owned();
508 StripeApiError::Stripe {
509 status: response.status,
510 code,
511 message,
512 }
513}