1use axum::{
2 Json,
3 extract::Request,
4 middleware::Next,
5 response::{IntoResponse, Response},
6};
7
8use crate::{
9 concepts::Facilitator,
10 seller::toolkit::{
11 extract_payment_payload, select_payment_with_payload, settle_payment,
12 update_supported_kinds, verify_payment,
13 },
14 transport::{
15 Base64EncodedHeader, FacilitatorSettleSuccess, FacilitatorVerifyValid, PaymentRequirements,
16 PaymentResponse,
17 },
18};
19
20#[derive(Debug, Clone)]
21pub struct PaymentErrorResponse(pub crate::seller::toolkit::ErrorResponse);
22
23impl From<crate::seller::toolkit::ErrorResponse> for PaymentErrorResponse {
24 fn from(err: crate::seller::toolkit::ErrorResponse) -> Self {
25 PaymentErrorResponse(err)
26 }
27}
28
29impl IntoResponse for PaymentErrorResponse {
30 fn into_response(self) -> Response {
31 (
32 self.0.status,
33 Json(self.0.into_payment_requirements_response()),
34 )
35 .into_response()
36 }
37}
38
39#[derive(Debug)]
40pub struct PaymentSuccessResponse {
41 pub response: Response,
42 pub payment_response: PaymentResponse,
43}
44
45impl IntoResponse for PaymentSuccessResponse {
46 fn into_response(self) -> Response {
47 let PaymentSuccessResponse {
48 mut response,
49 payment_response,
50 } = self;
51
52 if let Some(header) = Base64EncodedHeader::try_from(payment_response)
53 .ok()
54 .and_then(|h| h.0.parse().ok())
55 {
56 response.headers_mut().insert("X-Payment-Response", header);
57 }
58
59 response
60 }
61}
62
63pub struct PaymentHandler<F: Facilitator> {
64 pub facilitator: F,
65 pub payment_requirements: Vec<PaymentRequirements>,
66}
67
68#[derive(Debug, Clone)]
70pub enum PaymentProcessingState {
71 Verified(FacilitatorVerifyValid),
72 NotVerified,
73 Settled(FacilitatorSettleSuccess),
74}
75
76#[bon::bon]
77impl<F: Facilitator> PaymentHandler<F> {
78 pub fn builder(facilitator: F) -> PaymentHandlerBuilder<F> {
79 PaymentHandlerBuilder {
80 facilitator,
81 payment_requirements: Vec::new(),
82 }
83 }
84
85 #[builder]
86 pub async fn handle_payment(
87 self,
88 #[builder(with = || ())] no_update_supported: Option<()>,
89 #[builder(with = || ())] no_verify: Option<()>,
90 #[builder(with = || ())] settle_after_next: Option<()>,
91 mut req: Request,
92 next: Next,
93 ) -> Result<PaymentSuccessResponse, PaymentErrorResponse> {
94 let payment_requirements = if no_update_supported.is_none() {
95 update_supported_kinds(&self.facilitator, self.payment_requirements).await?
97 } else {
98 self.payment_requirements
99 };
100
101 let x_payment_header = extract_payment_payload(req.headers(), &payment_requirements)?;
102 let selected = select_payment_with_payload(&payment_requirements, &x_payment_header)?;
103
104 let verify = if no_verify.is_none() {
105 let valid = verify_payment(
107 &self.facilitator,
108 &x_payment_header,
109 &selected,
110 &payment_requirements,
111 )
112 .await?;
113
114 #[cfg(feature = "tracing")]
115 tracing::debug!("Payment verified: payer='{}'", valid.payer);
116
117 Some(valid)
118 } else {
119 None
120 };
121
122 if settle_after_next.is_none() {
123 let settled = settle_payment(
125 &self.facilitator,
126 &x_payment_header,
127 &selected,
128 &payment_requirements,
129 )
130 .await?;
131
132 #[cfg(feature = "tracing")]
133 tracing::debug!(
134 "Payment settled: payer='{}', transaction='{}', network='{}'",
135 settled.payer,
136 settled.transaction,
137 settled.network
138 );
139
140 let extension = PaymentProcessingState::Settled(settled.clone());
141 req.extensions_mut().insert(extension.clone());
142
143 #[cfg(feature = "tracing")]
144 tracing::debug!("Calling next handler with extension {:?}", extension);
145 let response = next.run(req).await;
146
147 Ok(PaymentSuccessResponse {
148 response,
149 payment_response: settled.into(),
150 })
151 } else {
152 let extension = verify
154 .map(PaymentProcessingState::Verified)
155 .unwrap_or(PaymentProcessingState::NotVerified);
156
157 req.extensions_mut().insert(extension.clone());
158
159 #[cfg(feature = "tracing")]
160 tracing::debug!("Calling next handler with extension {:?}", extension);
161 let response = next.run(req).await;
162
163 let settled = settle_payment(
164 &self.facilitator,
165 &x_payment_header,
166 &selected,
167 &payment_requirements,
168 )
169 .await?;
170
171 #[cfg(feature = "tracing")]
172 tracing::debug!(
173 "Payment settled: payer='{}', transaction='{}', network='{}'",
174 settled.payer,
175 settled.transaction,
176 settled.network
177 );
178
179 Ok(PaymentSuccessResponse {
180 response,
181 payment_response: settled.into(),
182 })
183 }
184 }
185}
186
187pub struct PaymentHandlerBuilder<F: Facilitator> {
188 pub facilitator: F,
189 pub payment_requirements: Vec<PaymentRequirements>,
190}
191
192impl<F: Facilitator> PaymentHandlerBuilder<F> {
193 pub fn add_payment(mut self, payment_requirements: impl Into<PaymentRequirements>) -> Self {
194 self.payment_requirements.push(payment_requirements.into());
195 self
196 }
197
198 pub fn build(self) -> PaymentHandler<F> {
199 PaymentHandler {
200 facilitator: self.facilitator,
201 payment_requirements: self.payment_requirements,
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use alloy_primitives::address;
209 use axum::middleware::from_fn;
210 use tower::ServiceBuilder;
211 use url_macro::url;
212
213 use crate::{
214 config::Resource, facilitator_client::RemoteFacilitatorClient,
215 networks::evm::assets::UsdcBase, schemes::exact_evm::ExactEvm,
216 };
217
218 use super::*;
219
220 async fn middleware_fn(
221 req: Request,
222 next: Next,
223 ) -> Result<PaymentSuccessResponse, PaymentErrorResponse> {
224 PaymentHandler::builder(RemoteFacilitatorClient::from_url(url!(
225 "https://facilitator.example.com"
226 )))
227 .add_payment(
228 ExactEvm::builder()
229 .asset(UsdcBase)
230 .amount(1000_000)
231 .pay_to(address!("0x17d2e11d0405fa8d0ad2dca6409c499c0132c017"))
232 .resource(
233 Resource::builder()
234 .url(url!("https://my-site.com/api"))
235 .description("")
236 .mime_type("")
237 .build(),
238 )
239 .build(),
240 )
241 .build()
242 .handle_payment()
243 .no_verify()
244 .no_update_supported()
245 .settle_after_next()
246 .req(req)
247 .next(next)
248 .call()
249 .await
250 }
251
252 #[test]
253 fn test_build_axum_middleware() {
254 let _ = ServiceBuilder::new().layer(from_fn::<_, (Request,)>(middleware_fn));
255 }
256}