1use axum::body::{Body, to_bytes};
5use axum::extract::Request;
6use axum::http::{HeaderMap, StatusCode};
7use axum::response::{IntoResponse, Response};
8use base64::{Engine, engine::general_purpose::STANDARD};
9use md5::{Digest, Md5};
10use rsa::pkcs8::DecodePublicKey;
11use rsa::signature::hazmat::PrehashVerifier;
12use rsa::{RsaPublicKey, pkcs1v15};
13use std::collections::HashMap;
14use std::pin::Pin;
15use std::sync::{Arc, RwLock};
16use std::task::{Context, Poll};
17use tower::{Layer, Service};
18
19#[derive(Debug, thiserror::Error)]
20enum OssVerifyError<'a> {
21 #[error("missing required header `{0}`")]
22 MissingHeader(&'a str),
23
24 #[error("invalid header `{0}`")]
25 InvalidHeader(&'a str),
26
27 #[error("invalid oss callback signature")]
28 InvalidSignature,
29
30 #[error("failed to read request body: {0}")]
31 BodyRead(#[from] axum::Error),
32
33 #[error("http error when verifying oss public key: {0}")]
34 Http(#[from] reqwest::Error),
35
36 #[error("base64 decode error: {0}")]
37 Base64(#[from] base64::DecodeError),
38
39 #[error("utf-8 error: {0}")]
40 Utf8(#[from] std::string::FromUtf8Error),
41
42 #[error("error: {0}")]
43 Common(&'a str),
44}
45
46impl IntoResponse for OssVerifyError<'_> {
47 fn into_response(self) -> Response {
48 match self {
49 OssVerifyError::MissingHeader(name) => (
50 StatusCode::BAD_REQUEST,
51 format!("missing required header `{name}`"),
52 )
53 .into_response(),
54
55 OssVerifyError::InvalidHeader(name) => {
56 (StatusCode::BAD_REQUEST, format!("invalid header `{name}`")).into_response()
57 }
58
59 OssVerifyError::InvalidSignature => {
60 (StatusCode::BAD_REQUEST, "invalid oss callback signature").into_response()
61 }
62
63 OssVerifyError::BodyRead(e) => (
64 StatusCode::BAD_REQUEST,
65 format!("failed to read request body: {e}"),
66 )
67 .into_response(),
68
69 OssVerifyError::Http(e) => (
70 StatusCode::BAD_GATEWAY,
71 format!("http error when verifying oss public key: {e}"),
72 )
73 .into_response(),
74
75 OssVerifyError::Base64(e) => {
76 (StatusCode::BAD_REQUEST, format!("base64 decode error: {e}")).into_response()
77 }
78
79 OssVerifyError::Utf8(e) => {
80 (StatusCode::BAD_REQUEST, format!("utf-8 error: {e}")).into_response()
81 }
82
83 OssVerifyError::Common(msg) => {
84 (StatusCode::BAD_REQUEST, format!("error: {msg}")).into_response()
85 }
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
165pub struct VerifiedOssCallbackBody(pub String);
166
167#[derive(Clone)]
170pub struct OssCallbackVerifyLayer {
171 client: reqwest::Client,
172 cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
174 callback_path: String,
176}
177
178impl OssCallbackVerifyLayer {
179 pub fn new(callback_url_path: &str) -> Self {
188 Self {
189 client: reqwest::Client::new(),
190 cache: Arc::new(RwLock::new(HashMap::new())),
192 callback_path: callback_url_path.to_owned(),
193 }
194 }
195
196 }
204
205impl<S> Layer<S> for OssCallbackVerifyLayer {
206 type Service = OssCallbackVerifyService<S>;
207
208 fn layer(&self, inner: S) -> Self::Service {
209 OssCallbackVerifyService {
210 inner,
211 client: self.client.clone(),
212 cache: Arc::clone(&self.cache),
213 callback_path: self.callback_path.clone(),
214 }
215 }
216}
217
218#[derive(Clone)]
220pub struct OssCallbackVerifyService<S> {
221 inner: S,
222 client: reqwest::Client,
223 cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
224 callback_path: String,
225}
226
227impl<S> Service<Request<Body>> for OssCallbackVerifyService<S>
229where
230 S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
234 S::Error: Into<axum::BoxError>,
237 S::Future: Send + 'static,
240{
241 type Response = S::Response;
242 type Error = S::Error;
243 type Future =
245 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
246
247 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248 self.inner.poll_ready(cx)
249 }
250
251 fn call(&mut self, req: Request<Body>) -> Self::Future {
252 let clone = self.inner.clone();
254 let mut inner = std::mem::replace(&mut self.inner, clone);
255 let client = self.client.clone();
256 let cache = Arc::clone(&self.cache);
257 let callback_path = self.callback_path.clone();
258
259 Box::pin(async move {
260 match verify_oss_request(req, &client, &cache, callback_path).await {
262 Ok(verified_req) => inner.call(verified_req).await,
264 Err(resp) => Ok(resp.into_response()),
265 }
266 })
267 }
268}
269
270async fn verify_oss_request<'a>(
274 req: Request<Body>,
275 client: &reqwest::Client,
276 cache: &Arc<RwLock<HashMap<String, Vec<u8>>>>,
277 callback_path: String,
278) -> Result<Request<Body>, OssVerifyError<'a>> {
279 let (parts, body) = req.into_parts();
280 let headers = parts.headers.clone();
281 let uri = parts.uri.clone();
282
283 let body_bytes = to_bytes(body, 5 * 1024 * 1024).await?;
285 let body_str = String::from_utf8(body_bytes.to_vec())?;
286
287 let pub_key_url_b64 = header_required(&headers, "x-oss-pub-key-url")?;
289 let pub_key_url_raw = STANDARD.decode(pub_key_url_b64.as_bytes())?;
290 let pub_key_url = String::from_utf8(pub_key_url_raw)?;
291
292 if !pub_key_url.starts_with("http://gosspublic.alicdn.com/")
293 && !pub_key_url.starts_with("https://gosspublic.alicdn.com/")
294 {
295 return Err(OssVerifyError::Common("invalid oss public key url"));
296 }
297
298 let pub_key_pem = get_or_fetch_pub_key(&pub_key_url, client, cache).await?;
300 let pub_key_pem_str = String::from_utf8(pub_key_pem)?;
301
302 let auth_b64 = header_required(&headers, "authorization")?;
304 let auth_bytes = STANDARD.decode(auth_b64.as_bytes())?;
305
306 let decoded_path = callback_path;
315
316 let auth_path = match uri.query() {
317 Some(q) => format!("{}?{}", decoded_path, q),
318 None => decoded_path,
319 };
320
321 let auth_str = format!("{}\n{}", auth_path, body_str);
322
323 let mut hasher = Md5::new();
325 hasher.update(auth_str.as_bytes());
326 let digest = hasher.finalize();
327
328 let rsa_pub_key = RsaPublicKey::from_public_key_pem(&pub_key_pem_str)
330 .map_err(|_| OssVerifyError::Common("failed to parse oss public key pem"))?;
331
332 let verifying_key = pkcs1v15::VerifyingKey::<Md5>::new(rsa_pub_key);
333 let signature = pkcs1v15::Signature::try_from(auth_bytes.as_slice())
334 .map_err(|_| OssVerifyError::Common("failed to parse oss signature"))?;
335
336 verifying_key
337 .verify_prehash(&digest, &signature)
338 .map_err(|_| OssVerifyError::InvalidSignature)?;
339
340 let mut new_req = Request::from_parts(parts, Body::from(body_bytes));
342 new_req
343 .extensions_mut()
344 .insert(VerifiedOssCallbackBody(body_str));
345
346 Ok(new_req)
347}
348
349fn header_required<'a>(headers: &HeaderMap, name: &'a str) -> Result<String, OssVerifyError<'a>> {
351 let value = headers
352 .get(name)
353 .ok_or(OssVerifyError::MissingHeader(name))?;
354
355 let s = value
356 .to_str()
357 .map_err(|_| OssVerifyError::InvalidHeader(name))?;
358 Ok(s.to_owned())
359}
360
361async fn get_or_fetch_pub_key<'a>(
363 url: &str,
364 client: &reqwest::Client,
365 cache: &Arc<RwLock<HashMap<String, Vec<u8>>>>,
366) -> Result<Vec<u8>, OssVerifyError<'a>> {
367 {
371 let cache_read = cache.read().unwrap();
372 if let Some(v) = cache_read.get(url) {
373 return Ok(v.clone());
374 }
375 }
376
377 let resp = client.get(url).send().await?;
379 let bytes = resp.bytes().await?;
380
381 {
383 let mut cache_write = cache.write().unwrap();
384 if let Some(v) = cache_write.get(url) {
385 return Ok(v.clone());
387 } else {
388 cache_write.insert(url.to_string(), bytes.to_vec());
389 }
390 }
391
392 Ok(bytes.to_vec())
393}