Skip to main content

rustauth_stripe/stripe_api/
mod.rs

1use hmac::{Hmac, Mac};
2use reqwest::Method;
3use serde_json::Value;
4use sha2::Sha256;
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10use subtle::ConstantTimeEq;
11
12use http::StatusCode;
13
14use crate::errors::StripeErrorCode;
15
16mod paginated_list;
17
18pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
19pub type StripeTransportFuture<'a> = BoxFuture<'a, Result<StripeResponse, StripeApiError>>;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct StripeRequest {
23    pub method: String,
24    pub path: String,
25    pub headers: BTreeMap<String, String>,
26    pub body: String,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct StripeResponse {
31    pub status: u16,
32    pub body: Value,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum StripeApiError {
37    #[error("{message}")]
38    Stripe {
39        status: u16,
40        code: Option<String>,
41        message: String,
42    },
43    #[error("transport error: {0}")]
44    Transport(String),
45    #[error("{0}")]
46    Webhook(StripeErrorCode),
47}
48
49impl StripeApiError {
50    pub fn code(&self) -> &str {
51        match self {
52            Self::Stripe {
53                code: Some(code), ..
54            } => code,
55            Self::Stripe { .. } => "STRIPE_API_ERROR",
56            Self::Transport(_) => "STRIPE_TRANSPORT_ERROR",
57            Self::Webhook(code) => code.code(),
58        }
59    }
60
61    pub fn is_already_scheduled_cancel(&self) -> bool {
62        match self {
63            Self::Stripe { code, message, .. } => {
64                matches!(
65                    code.as_deref(),
66                    Some(
67                        "subscription_already_canceled"
68                            | "resource_already_exists"
69                            | "invalid_request_error"
70                    )
71                ) || message.contains("already set to be canceled")
72            }
73            _ => false,
74        }
75    }
76
77    pub fn plugin_response(&self, default: StripeErrorCode) -> (StatusCode, StripeErrorCode) {
78        match self {
79            Self::Webhook(code) => (StatusCode::BAD_REQUEST, *code),
80            Self::Transport(_) => (StatusCode::BAD_GATEWAY, StripeErrorCode::FailedToFetchPlans),
81            Self::Stripe { status, code, .. } if *status >= 500 => {
82                (StatusCode::BAD_GATEWAY, StripeErrorCode::FailedToFetchPlans)
83            }
84            Self::Stripe { code, .. } => (
85                StatusCode::BAD_REQUEST,
86                map_stripe_code_to_plugin(default, code.as_deref()),
87            ),
88        }
89    }
90}
91
92fn map_stripe_code_to_plugin(
93    default: StripeErrorCode,
94    stripe_code: Option<&str>,
95) -> StripeErrorCode {
96    match (default, stripe_code) {
97        (StripeErrorCode::UnableToCreateCustomer, Some("resource_missing")) => {
98            StripeErrorCode::CustomerNotFound
99        }
100        (StripeErrorCode::UnableToCreateBillingPortal, Some("resource_missing")) => {
101            StripeErrorCode::SubscriptionNotFound
102        }
103        (StripeErrorCode::SubscriptionNotFound, Some("resource_missing")) => {
104            StripeErrorCode::SubscriptionNotFound
105        }
106        _ => default,
107    }
108}
109
110pub trait StripeTransport: Send + Sync {
111    fn send<'a>(&'a self, request: StripeRequest) -> StripeTransportFuture<'a>;
112}
113
114#[derive(Clone)]
115pub struct StripeClient {
116    secret_key: String,
117    api_base: String,
118    api_version: Option<String>,
119    transport: Arc<dyn StripeTransport>,
120}
121
122impl StripeClient {
123    pub fn new(secret_key: impl Into<String>) -> Self {
124        Self {
125            secret_key: secret_key.into(),
126            api_base: "https://api.stripe.com".to_owned(),
127            api_version: None,
128            transport: Arc::new(ReqwestStripeTransport::new("https://api.stripe.com")),
129        }
130    }
131
132    pub fn with_transport(
133        secret_key: impl Into<String>,
134        transport: Arc<dyn StripeTransport>,
135    ) -> Self {
136        Self {
137            secret_key: secret_key.into(),
138            api_base: "https://api.stripe.com".to_owned(),
139            api_version: None,
140            transport,
141        }
142    }
143
144    pub fn with_api_base(mut self, api_base: impl Into<String>) -> Self {
145        self.api_base = api_base.into();
146        self.transport = Arc::new(ReqwestStripeTransport::new(self.api_base.clone()));
147        self
148    }
149
150    pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
151        self.api_version = Some(api_version.into());
152        self
153    }
154
155    pub async fn create_customer(&self, params: Value) -> Result<Value, StripeApiError> {
156        self.post("/v1/customers", params).await
157    }
158
159    pub async fn update_customer(
160        &self,
161        customer_id: &str,
162        params: Value,
163    ) -> Result<Value, StripeApiError> {
164        self.post(&format!("/v1/customers/{customer_id}"), params)
165            .await
166    }
167
168    pub async fn retrieve_customer(&self, customer_id: &str) -> Result<Value, StripeApiError> {
169        self.get(&format!("/v1/customers/{customer_id}"), Value::Null)
170            .await
171    }
172
173    pub async fn search_customers(&self, query: &str) -> Result<Value, StripeApiError> {
174        self.search_customers_page(query, None).await
175    }
176
177    pub(crate) async fn search_customers_page(
178        &self,
179        query: &str,
180        page: Option<&str>,
181    ) -> Result<Value, StripeApiError> {
182        let mut params = serde_json::json!({
183            "query": query,
184            "limit": paginated_list::STRIPE_LIST_PAGE_LIMIT,
185        });
186        if let Some(page) = page {
187            if let Some(object) = params.as_object_mut() {
188                object.insert("page".to_owned(), serde_json::json!(page));
189            }
190        }
191        self.get("/v1/customers/search", params).await
192    }
193
194    pub async fn list_customers(&self, params: Value) -> Result<Value, StripeApiError> {
195        self.get("/v1/customers", params).await
196    }
197
198    pub async fn retrieve_price(&self, price_id: &str) -> Result<Value, StripeApiError> {
199        self.get(&format!("/v1/prices/{price_id}"), Value::Null)
200            .await
201    }
202
203    pub async fn list_prices(&self, params: Value) -> Result<Value, StripeApiError> {
204        self.get("/v1/prices", params).await
205    }
206
207    pub async fn price_by_lookup_key(&self, lookup_key: &str) -> Result<Value, StripeApiError> {
208        self.list_prices(serde_json::json!({
209            "lookup_keys": [lookup_key],
210            "active": true,
211            "limit": 1
212        }))
213        .await
214    }
215
216    pub async fn create_checkout_session(&self, params: Value) -> Result<Value, StripeApiError> {
217        self.post("/v1/checkout/sessions", params).await
218    }
219
220    pub async fn retrieve_checkout_session(
221        &self,
222        session_id: &str,
223    ) -> Result<Value, StripeApiError> {
224        self.get(&format!("/v1/checkout/sessions/{session_id}"), Value::Null)
225            .await
226    }
227
228    pub async fn create_billing_portal_session(
229        &self,
230        params: Value,
231    ) -> Result<Value, StripeApiError> {
232        self.post("/v1/billing_portal/sessions", params).await
233    }
234
235    pub async fn list_subscriptions(&self, params: Value) -> Result<Value, StripeApiError> {
236        self.get("/v1/subscriptions", params).await
237    }
238
239    pub async fn retrieve_subscription(
240        &self,
241        subscription_id: &str,
242    ) -> Result<Value, StripeApiError> {
243        self.get(&format!("/v1/subscriptions/{subscription_id}"), Value::Null)
244            .await
245    }
246
247    pub async fn update_subscription(
248        &self,
249        subscription_id: &str,
250        params: Value,
251    ) -> Result<Value, StripeApiError> {
252        self.post(&format!("/v1/subscriptions/{subscription_id}"), params)
253            .await
254    }
255
256    pub async fn create_subscription_schedule(
257        &self,
258        params: Value,
259    ) -> Result<Value, StripeApiError> {
260        self.post("/v1/subscription_schedules", params).await
261    }
262
263    pub async fn list_subscription_schedules(
264        &self,
265        params: Value,
266    ) -> Result<Value, StripeApiError> {
267        self.get("/v1/subscription_schedules", params).await
268    }
269
270    pub async fn retrieve_subscription_schedule(
271        &self,
272        schedule_id: &str,
273    ) -> Result<Value, StripeApiError> {
274        self.get(
275            &format!("/v1/subscription_schedules/{schedule_id}"),
276            Value::Null,
277        )
278        .await
279    }
280
281    pub async fn update_subscription_schedule(
282        &self,
283        schedule_id: &str,
284        params: Value,
285    ) -> Result<Value, StripeApiError> {
286        self.post(&format!("/v1/subscription_schedules/{schedule_id}"), params)
287            .await
288    }
289
290    pub async fn release_subscription_schedule(
291        &self,
292        schedule_id: &str,
293    ) -> Result<Value, StripeApiError> {
294        self.post(
295            &format!("/v1/subscription_schedules/{schedule_id}/release"),
296            Value::Object(Default::default()),
297        )
298        .await
299    }
300
301    async fn post(&self, path: &str, params: Value) -> Result<Value, StripeApiError> {
302        self.send("POST", path, params).await
303    }
304
305    async fn get(&self, path: &str, params: Value) -> Result<Value, StripeApiError> {
306        self.send("GET", path, params).await
307    }
308
309    async fn send(&self, method: &str, path: &str, params: Value) -> Result<Value, StripeApiError> {
310        let body = if params.is_null() {
311            String::new()
312        } else {
313            encode_form(&params)
314        };
315        let mut headers = BTreeMap::new();
316        headers.insert(
317            "Authorization".to_owned(),
318            format!("Bearer {}", self.secret_key),
319        );
320        headers.insert(
321            "Content-Type".to_owned(),
322            "application/x-www-form-urlencoded".to_owned(),
323        );
324        if let Some(api_version) = &self.api_version {
325            headers.insert("Stripe-Version".to_owned(), api_version.clone());
326        }
327        let request = StripeRequest {
328            method: method.to_owned(),
329            path: path.to_owned(),
330            headers,
331            body,
332        };
333        let response = self.transport.send(request).await?;
334        if (200..300).contains(&response.status) {
335            Ok(response.body)
336        } else {
337            Err(stripe_error_from_response(response))
338        }
339    }
340}
341
342pub struct ReqwestStripeTransport {
343    client: reqwest::Client,
344    api_base: String,
345}
346
347const DEFAULT_STRIPE_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
348
349impl ReqwestStripeTransport {
350    pub fn new(api_base: impl Into<String>) -> Self {
351        Self::with_timeout(api_base, DEFAULT_STRIPE_HTTP_TIMEOUT)
352    }
353
354    pub fn with_timeout(api_base: impl Into<String>, timeout: Duration) -> Self {
355        let client = reqwest::Client::builder()
356            .timeout(timeout)
357            .build()
358            .unwrap_or_else(|_| reqwest::Client::new());
359        Self {
360            client,
361            api_base: api_base.into(),
362        }
363    }
364}
365
366impl StripeTransport for ReqwestStripeTransport {
367    fn send<'a>(&'a self, request: StripeRequest) -> StripeTransportFuture<'a> {
368        Box::pin(async move {
369            let method = request
370                .method
371                .parse::<Method>()
372                .map_err(|error| StripeApiError::Transport(error.to_string()))?;
373            let url = if request.method == "GET" && !request.body.is_empty() {
374                format!("{}{}?{}", self.api_base, request.path, request.body)
375            } else {
376                format!("{}{}", self.api_base, request.path)
377            };
378            let mut builder = self.client.request(method, url);
379            for (name, value) in request.headers {
380                builder = builder.header(name, value);
381            }
382            if request.method != "GET" {
383                builder = builder.body(request.body);
384            }
385            let response = builder
386                .send()
387                .await
388                .map_err(|error| StripeApiError::Transport(error.to_string()))?;
389            let status = response.status().as_u16();
390            let body = response
391                .json::<Value>()
392                .await
393                .map_err(|error| StripeApiError::Transport(error.to_string()))?;
394            Ok(StripeResponse { status, body })
395        })
396    }
397}
398
399pub fn encode_form(value: &Value) -> String {
400    let mut pairs = Vec::new();
401    collect_form_pairs(None, value, &mut pairs);
402    pairs
403        .into_iter()
404        .map(|(key, value)| format!("{}={}", form_encode(&key), form_encode(&value)))
405        .collect::<Vec<_>>()
406        .join("&")
407}
408
409fn collect_form_pairs(prefix: Option<String>, value: &Value, pairs: &mut Vec<(String, String)>) {
410    match value {
411        Value::Object(map) => {
412            for (key, value) in map {
413                let key = match &prefix {
414                    Some(prefix) => format!("{prefix}[{key}]"),
415                    None => key.clone(),
416                };
417                collect_form_pairs(Some(key), value, pairs);
418            }
419        }
420        Value::Array(values) => {
421            for (index, value) in values.iter().enumerate() {
422                if let Some(prefix) = &prefix {
423                    collect_form_pairs(Some(format!("{prefix}[{index}]")), value, pairs);
424                }
425            }
426        }
427        Value::String(value) => {
428            if let Some(prefix) = prefix {
429                pairs.push((prefix, value.clone()));
430            }
431        }
432        Value::Number(value) => {
433            if let Some(prefix) = prefix {
434                pairs.push((prefix, value.to_string()));
435            }
436        }
437        Value::Bool(value) => {
438            if let Some(prefix) = prefix {
439                pairs.push((prefix, value.to_string()));
440            }
441        }
442        Value::Null => {}
443    }
444}
445
446fn form_encode(value: &str) -> String {
447    url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
448}
449
450pub fn verify_webhook_signature(
451    payload: &[u8],
452    signature_header: &str,
453    secret: &str,
454    tolerance_seconds: i64,
455    now_unix: i64,
456) -> Result<(), StripeApiError> {
457    let timestamp = signature_header
458        .split(',')
459        .find_map(|part| part.strip_prefix("t="))
460        .and_then(|value| value.parse::<i64>().ok())
461        .ok_or(StripeApiError::Webhook(
462            StripeErrorCode::FailedToConstructStripeEvent,
463        ))?;
464    if (now_unix - timestamp).abs() > tolerance_seconds {
465        return Err(StripeApiError::Webhook(
466            StripeErrorCode::FailedToConstructStripeEvent,
467        ));
468    }
469    let expected = webhook_signature(payload, secret, timestamp)?;
470    let verified = signature_header
471        .split(',')
472        .filter_map(|part| part.strip_prefix("v1="))
473        .filter_map(|signature| hex::decode(signature).ok())
474        .any(|candidate| candidate.ct_eq(expected.as_slice()).into());
475    if verified {
476        Ok(())
477    } else {
478        Err(StripeApiError::Webhook(
479            StripeErrorCode::FailedToConstructStripeEvent,
480        ))
481    }
482}
483
484fn webhook_signature(
485    payload: &[u8],
486    secret: &str,
487    timestamp: i64,
488) -> Result<Vec<u8>, StripeApiError> {
489    // Stripe signs with the endpoint secret used verbatim as the HMAC key,
490    // including the `whsec_` prefix. Do not strip or base64-decode it.
491    let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).map_err(|error| {
492        StripeApiError::Transport(format!("failed to initialize webhook verifier: {error}"))
493    })?;
494    mac.update(timestamp.to_string().as_bytes());
495    mac.update(b".");
496    mac.update(payload);
497    Ok(mac.finalize().into_bytes().to_vec())
498}
499
500fn stripe_error_from_response(response: StripeResponse) -> StripeApiError {
501    let error = response.body.get("error").unwrap_or(&response.body);
502    let code = error.get("code").and_then(Value::as_str).map(str::to_owned);
503    let message = error
504        .get("message")
505        .and_then(Value::as_str)
506        .unwrap_or("Stripe API request failed")
507        .to_owned();
508    StripeApiError::Stripe {
509        status: response.status,
510        code,
511        message,
512    }
513}