strike_rs/
webhooks.rs

1//! Strike Webhooks
2
3use anyhow::anyhow;
4use axum::body::{self, BoxBody, Full};
5use axum::extract::State;
6use axum::http::request::Request;
7use axum::http::StatusCode;
8use axum::middleware::{self, Next};
9use axum::response::{IntoResponse, Response};
10use axum::routing::post;
11use axum::{Json, Router};
12use ring::hmac::{self, Key};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tower::ServiceBuilder;
16use tower_http::ServiceBuilderExt;
17
18use crate::{hex, Strike};
19
20/// Webhook state
21#[derive(Debug, Clone)]
22pub struct WebhookState {
23    /// Webhook secret
24    pub webhook_secret: String,
25    /// Sender
26    pub sender: tokio::sync::mpsc::Sender<String>,
27}
28
29/// Webhook data
30#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct WebhookRequest {
33    /// Webhook url
34    pub webhook_url: String,
35    /// Webhook version
36    pub webhook_version: String,
37    /// Secret
38    pub secret: String,
39    /// Enabled
40    pub enabled: bool,
41    /// Event Types
42    pub event_types: Vec<String>,
43}
44
45/// Webhook response
46#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48pub struct WebhookInfoResponse {
49    /// Webhook id
50    pub id: String,
51    /// Webhook url
52    pub webhook_url: String,
53    /// Webhook Version
54    pub webhook_version: String,
55    /// Enabled
56    pub enabled: bool,
57    /// Event types
58    pub event_types: Vec<String>,
59}
60
61impl Strike {
62    /// Create invoice webhook
63    pub async fn create_invoice_webhook_router(
64        &self,
65        webhook_endpoint: &str,
66        sender: tokio::sync::mpsc::Sender<String>,
67    ) -> anyhow::Result<Router> {
68        let state = WebhookState {
69            sender,
70            webhook_secret: self.webhook_secret.clone(),
71        };
72
73        let router = Router::new()
74            .route(webhook_endpoint, post(handle_invoice))
75            .layer(ServiceBuilder::new().map_request_body(body::boxed).layer(
76                middleware::from_fn_with_state(state.clone(), verify_request_body),
77            ))
78            .with_state(state);
79
80        Ok(router)
81    }
82
83    /// Subscribe to invoice webhook
84    pub async fn subscribe_to_invoice_webhook(&self, webhook_url: String) -> anyhow::Result<()> {
85        let url = self.base_url.join("/v1/subscriptions")?;
86
87        let subscription = WebhookRequest {
88            webhook_url,
89            webhook_version: "v1".to_string(),
90            secret: self.webhook_secret.clone(),
91            enabled: true,
92            event_types: vec!["invoice.updated".to_string()],
93        };
94
95        let res = self
96            .make_post(url, Some(serde_json::to_value(subscription)?))
97            .await?;
98
99        log::debug!("Created Webhook subscription: {}", res);
100
101        Ok(())
102    }
103
104    /// Get current subscriptions
105    pub async fn get_current_subscriptions(&self) -> anyhow::Result<Vec<WebhookInfoResponse>> {
106        let url = self.base_url.join("/v1/subscriptions")?;
107
108        let res = self.make_get(url).await?;
109
110        let webhooks: Vec<WebhookInfoResponse> = serde_json::from_value(res)?;
111
112        Ok(webhooks)
113    }
114
115    /// Delete subscription
116    pub async fn delete_subscription(&self, webhook_id: &str) -> anyhow::Result<()> {
117        let url = self
118            .base_url
119            .join(&format!("/v1/subscriptions/{}", webhook_id))?;
120
121        self.make_delete(url).await
122    }
123}
124
125// middleware to consume the request body upfront
126async fn verify_request_body(
127    State(state): State<WebhookState>,
128    request: Request<BoxBody>,
129    next: Next<BoxBody>,
130) -> Result<impl IntoResponse, Response> {
131    let request = buffer_request_body(request, &state.webhook_secret).await?;
132
133    Ok(next.run(request).await)
134}
135
136// take the request apart, buffer the body,
137// veridy signature, then put the request back together
138async fn buffer_request_body(
139    request: Request<BoxBody>,
140    secret: &str,
141) -> Result<Request<BoxBody>, Response> {
142    let (parts, body) = request.into_parts();
143
144    let bytes = hyper::body::to_bytes(body)
145        .await
146        .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?;
147
148    let headers = parts.headers.clone();
149
150    let signature = headers
151        .get("X-Webhook-Signature")
152        .ok_or_else(|| {
153            log::warn!("Post to webhook did not include signature");
154            StatusCode::UNAUTHORIZED.into_response()
155        })?
156        .to_str()
157        .map_err(|_| {
158            log::warn!("Webhook signature is not a valid string");
159            StatusCode::UNAUTHORIZED.into_response()
160        })?;
161
162    verify_request_signature(signature, &bytes, secret.as_bytes())
163        .map_err(|_| StatusCode::UNAUTHORIZED)
164        .into_response();
165
166    Ok(Request::from_parts(parts, body::boxed(Full::from(bytes))))
167}
168
169/// Webhook data
170#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "camelCase")]
172pub struct WebHookData {
173    /// Entity Id
174    entity_id: String,
175    /// Changes
176    changes: Vec<String>,
177}
178
179/// Webhook Response
180#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
181#[serde(rename_all = "camelCase")]
182struct WebHookResponse {
183    /// Webhook id
184    id: String,
185    /// Event type
186    event_type: String,
187    /// Webhook version
188    webhook_version: String,
189    /// Webhook data
190    data: WebHookData,
191    /// Created
192    created: String,
193    /// Delivery Success
194    delivery_success: Option<bool>,
195}
196// Function to compute HMAC SHA-256
197fn compute_hmac(content: &[u8], key: &Key) -> Vec<u8> {
198    let tag = hmac::sign(key, content);
199    tag.as_ref().to_vec()
200}
201
202// Function to verify request signature
203fn verify_request_signature(
204    request_signature: &str,
205    body: &[u8],
206    secret: &[u8],
207) -> anyhow::Result<()> {
208    let key = hmac::Key::new(hmac::HMAC_SHA256, secret);
209
210    let body = serde_json::from_slice(body)?;
211    let content_signature = compute_hmac(body, &key);
212
213    hmac::verify(&key, &hex::decode(request_signature)?, &content_signature).map_err(|_| {
214        log::warn!("Request did not have a valid signature");
215
216        anyhow!("Invalid signature")
217    })
218}
219
220async fn handle_invoice(
221    State(state): State<WebhookState>,
222    Json(payload): Json<Value>,
223) -> Result<StatusCode, StatusCode> {
224    let webhook_response: WebHookResponse = serde_json::from_value(payload).map_err(|_err| {
225        log::warn!("Got an invalid payload on webhook");
226
227        StatusCode::UNPROCESSABLE_ENTITY
228    })?;
229
230    log::debug!(
231        "Received webhook update for: {}",
232        webhook_response.data.entity_id
233    );
234
235    if let Err(err) = state.sender.send(webhook_response.data.entity_id).await {
236        log::warn!("Could not send on channel: {}", err);
237    }
238    Ok(StatusCode::OK)
239}