1use anyhow::anyhow;
4use axum::body::{self, BoxBody, Full};
5use axum::extract::State;
6use axum::http::request::Request;
7use axum::http::StatusCode;
8use axum::middleware::{self, Next};
9use axum::response::{IntoResponse, Response};
10use axum::routing::post;
11use axum::{Json, Router};
12use ring::hmac::{self, Key};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tower::ServiceBuilder;
16use tower_http::ServiceBuilderExt;
17
18use crate::{hex, Strike};
19
20#[derive(Debug, Clone)]
22pub struct WebhookState {
23 pub webhook_secret: String,
25 pub sender: tokio::sync::mpsc::Sender<String>,
27}
28
29#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct WebhookRequest {
33 pub webhook_url: String,
35 pub webhook_version: String,
37 pub secret: String,
39 pub enabled: bool,
41 pub event_types: Vec<String>,
43}
44
45#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48pub struct WebhookInfoResponse {
49 pub id: String,
51 pub webhook_url: String,
53 pub webhook_version: String,
55 pub enabled: bool,
57 pub event_types: Vec<String>,
59}
60
61impl Strike {
62 pub async fn create_invoice_webhook_router(
64 &self,
65 webhook_endpoint: &str,
66 sender: tokio::sync::mpsc::Sender<String>,
67 ) -> anyhow::Result<Router> {
68 let state = WebhookState {
69 sender,
70 webhook_secret: self.webhook_secret.clone(),
71 };
72
73 let router = Router::new()
74 .route(webhook_endpoint, post(handle_invoice))
75 .layer(ServiceBuilder::new().map_request_body(body::boxed).layer(
76 middleware::from_fn_with_state(state.clone(), verify_request_body),
77 ))
78 .with_state(state);
79
80 Ok(router)
81 }
82
83 pub async fn subscribe_to_invoice_webhook(&self, webhook_url: String) -> anyhow::Result<()> {
85 let url = self.base_url.join("/v1/subscriptions")?;
86
87 let subscription = WebhookRequest {
88 webhook_url,
89 webhook_version: "v1".to_string(),
90 secret: self.webhook_secret.clone(),
91 enabled: true,
92 event_types: vec!["invoice.updated".to_string()],
93 };
94
95 let res = self
96 .make_post(url, Some(serde_json::to_value(subscription)?))
97 .await?;
98
99 log::debug!("Created Webhook subscription: {}", res);
100
101 Ok(())
102 }
103
104 pub async fn get_current_subscriptions(&self) -> anyhow::Result<Vec<WebhookInfoResponse>> {
106 let url = self.base_url.join("/v1/subscriptions")?;
107
108 let res = self.make_get(url).await?;
109
110 let webhooks: Vec<WebhookInfoResponse> = serde_json::from_value(res)?;
111
112 Ok(webhooks)
113 }
114
115 pub async fn delete_subscription(&self, webhook_id: &str) -> anyhow::Result<()> {
117 let url = self
118 .base_url
119 .join(&format!("/v1/subscriptions/{}", webhook_id))?;
120
121 self.make_delete(url).await
122 }
123}
124
125async fn verify_request_body(
127 State(state): State<WebhookState>,
128 request: Request<BoxBody>,
129 next: Next<BoxBody>,
130) -> Result<impl IntoResponse, Response> {
131 let request = buffer_request_body(request, &state.webhook_secret).await?;
132
133 Ok(next.run(request).await)
134}
135
136async fn buffer_request_body(
139 request: Request<BoxBody>,
140 secret: &str,
141) -> Result<Request<BoxBody>, Response> {
142 let (parts, body) = request.into_parts();
143
144 let bytes = hyper::body::to_bytes(body)
145 .await
146 .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?;
147
148 let headers = parts.headers.clone();
149
150 let signature = headers
151 .get("X-Webhook-Signature")
152 .ok_or_else(|| {
153 log::warn!("Post to webhook did not include signature");
154 StatusCode::UNAUTHORIZED.into_response()
155 })?
156 .to_str()
157 .map_err(|_| {
158 log::warn!("Webhook signature is not a valid string");
159 StatusCode::UNAUTHORIZED.into_response()
160 })?;
161
162 verify_request_signature(signature, &bytes, secret.as_bytes())
163 .map_err(|_| StatusCode::UNAUTHORIZED)
164 .into_response();
165
166 Ok(Request::from_parts(parts, body::boxed(Full::from(bytes))))
167}
168
169#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "camelCase")]
172pub struct WebHookData {
173 entity_id: String,
175 changes: Vec<String>,
177}
178
179#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
181#[serde(rename_all = "camelCase")]
182struct WebHookResponse {
183 id: String,
185 event_type: String,
187 webhook_version: String,
189 data: WebHookData,
191 created: String,
193 delivery_success: Option<bool>,
195}
196fn compute_hmac(content: &[u8], key: &Key) -> Vec<u8> {
198 let tag = hmac::sign(key, content);
199 tag.as_ref().to_vec()
200}
201
202fn verify_request_signature(
204 request_signature: &str,
205 body: &[u8],
206 secret: &[u8],
207) -> anyhow::Result<()> {
208 let key = hmac::Key::new(hmac::HMAC_SHA256, secret);
209
210 let body = serde_json::from_slice(body)?;
211 let content_signature = compute_hmac(body, &key);
212
213 hmac::verify(&key, &hex::decode(request_signature)?, &content_signature).map_err(|_| {
214 log::warn!("Request did not have a valid signature");
215
216 anyhow!("Invalid signature")
217 })
218}
219
220async fn handle_invoice(
221 State(state): State<WebhookState>,
222 Json(payload): Json<Value>,
223) -> Result<StatusCode, StatusCode> {
224 let webhook_response: WebHookResponse = serde_json::from_value(payload).map_err(|_err| {
225 log::warn!("Got an invalid payload on webhook");
226
227 StatusCode::UNPROCESSABLE_ENTITY
228 })?;
229
230 log::debug!(
231 "Received webhook update for: {}",
232 webhook_response.data.entity_id
233 );
234
235 if let Err(err) = state.sender.send(webhook_response.data.entity_id).await {
236 log::warn!("Could not send on channel: {}", err);
237 }
238 Ok(StatusCode::OK)
239}