Skip to main content

u_sdk/
oss_callback_verify_layer.rs

1//! OSS callback验证中间件 Layer
2//!
3//! 只适用于 tokio + axum 环境
4use axum::body::{Body, to_bytes};
5use axum::extract::Request;
6use axum::http::{HeaderMap, StatusCode};
7use axum::response::{IntoResponse, Response};
8use base64::{Engine, engine::general_purpose::STANDARD};
9use md5::{Digest, Md5};
10use rsa::pkcs8::DecodePublicKey;
11use rsa::signature::hazmat::PrehashVerifier;
12use rsa::{RsaPublicKey, pkcs1v15};
13use std::collections::HashMap;
14use std::pin::Pin;
15use std::sync::{Arc, RwLock};
16use std::task::{Context, Poll};
17use tower::{Layer, Service};
18
19#[derive(Debug, thiserror::Error)]
20enum OssVerifyError<'a> {
21    #[error("missing required header `{0}`")]
22    MissingHeader(&'a str),
23
24    #[error("invalid header `{0}`")]
25    InvalidHeader(&'a str),
26
27    #[error("invalid oss callback signature")]
28    InvalidSignature,
29
30    #[error("failed to read request body: {0}")]
31    BodyRead(#[from] axum::Error),
32
33    #[error("http error when verifying oss public key: {0}")]
34    Http(#[from] reqwest::Error),
35
36    #[error("base64 decode error: {0}")]
37    Base64(#[from] base64::DecodeError),
38
39    #[error("utf-8 error: {0}")]
40    Utf8(#[from] std::string::FromUtf8Error),
41
42    #[error("error: {0}")]
43    Common(&'a str),
44}
45
46impl IntoResponse for OssVerifyError<'_> {
47    fn into_response(self) -> Response {
48        match self {
49            OssVerifyError::MissingHeader(name) => (
50                StatusCode::BAD_REQUEST,
51                format!("missing required header `{name}`"),
52            )
53                .into_response(),
54
55            OssVerifyError::InvalidHeader(name) => {
56                (StatusCode::BAD_REQUEST, format!("invalid header `{name}`")).into_response()
57            }
58
59            OssVerifyError::InvalidSignature => {
60                (StatusCode::BAD_REQUEST, "invalid oss callback signature").into_response()
61            }
62
63            OssVerifyError::BodyRead(e) => (
64                StatusCode::BAD_REQUEST,
65                format!("failed to read request body: {e}"),
66            )
67                .into_response(),
68
69            OssVerifyError::Http(e) => (
70                StatusCode::BAD_GATEWAY,
71                format!("http error when verifying oss public key: {e}"),
72            )
73                .into_response(),
74
75            OssVerifyError::Base64(e) => {
76                (StatusCode::BAD_REQUEST, format!("base64 decode error: {e}")).into_response()
77            }
78
79            OssVerifyError::Utf8(e) => {
80                (StatusCode::BAD_REQUEST, format!("utf-8 error: {e}")).into_response()
81            }
82
83            OssVerifyError::Common(msg) => {
84                (StatusCode::BAD_REQUEST, format!("error: {msg}")).into_response()
85            }
86        }
87    }
88}
89
90/// OSS回调验证成功后的Body的数据
91///
92/// 验证成功后,存放oss发过来的body数据,为application/json或application/x-www-form-urlencoded(具体视调用callback api时的设置而定)。
93///
94/// 在axum中可以把它写为一个extractor,方便handler直接使用。下面给一个构建提取json的例子:
95/// ```rust,no_run
96/// use axum::{
97///     extract::FromRequestParts,
98///     http::{StatusCode, request::Parts},
99///     response::{IntoResponse, Response},
100/// };
101/// use serde::de::DeserializeOwned;
102///
103/// #[derive(Debug, Clone)]
104/// pub struct VerifiedOssCallbackBody(pub String);
105///
106/// #[derive(Debug)]
107/// pub struct VerifiedOssJson<T>(pub T);
108///
109/// impl<S, T> FromRequestParts<S> for VerifiedOssJson<T>
110/// where
111///     S: Send + Sync,
112///     T: DeserializeOwned,
113/// {
114///     type Rejection = Response;
115///
116///     async fn from_request_parts(
117///         parts: &mut Parts,
118///         _state: &S,
119///     ) -> Result<Self, Self::Rejection> {
120///         // 1. 从 extensions 里拿到之前中间件塞进去的 VerifiedOssCallbackBody
121///         let ext = parts
122///             .extensions
123///             .get::<VerifiedOssCallbackBody>()
124///             .ok_or_else(|| {
125///                 (
126///                     StatusCode::INTERNAL_SERVER_ERROR,
127///                     "VerifiedOssCallbackBody missing",
128///                 )
129///                     .into_response()
130///             })?;
131///
132///         // 2. 把里面的 String 按 JSON 解析成 T
133///         let value = serde_json::from_str::<T>(&ext.0).map_err(|e| {
134///             (
135///                 StatusCode::BAD_REQUEST,
136///                 format!("invalid oss callback json: {e}"),
137///             )
138///                 .into_response()
139///         })?;
140///
141///         Ok(VerifiedOssJson(value))
142///     }
143/// }
144///
145/// // 然后在handler里就可以直接用VerifiedOssJson<T>来接收解析后的数据:
146/// use serde::Deserialize;
147///
148/// #[derive(Deserialize, Debug)]
149/// struct OssCallbackPayload {
150///     // 按你的业务字段来
151///     pub user_id: String,
152///     pub filename: String,
153///     pub size: u64,
154/// }
155///
156/// // 使用 VerifiedOssJson 提取器:
157/// async fn oss_callback(
158///     VerifiedOssJson(payload): VerifiedOssJson<OssCallbackPayload>,
159/// ) -> impl IntoResponse {
160///     dbg!(&payload);
161///     "ok"
162/// }
163/// ```
164#[derive(Debug, Clone)]
165pub struct VerifiedOssCallbackBody(pub String);
166
167/// 只支持 tokio + axum
168/// 验证成功会把oss发过来的body以String形式放在extensions里:[VerifiedOssCallbackBody]
169#[derive(Clone)]
170pub struct OssCallbackVerifyLayer {
171    client: reqwest::Client,
172    // 这里用 Arc 包一层,便于不同 Service 共享缓存
173    cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
174    // TODO 因为目前配置的nginx/axum会把url路径给剥离掉,所以先用这个字段让用户自己填上回调路径
175    callback_path: String,
176}
177
178impl OssCallbackVerifyLayer {
179    /// callback_url_path: 在设置`callbackUrl`时的路径部分
180    ///
181    /// 因为如果应用部署在代理如nginx后面,nginx可能会配置把路径前缀剥离掉;
182    /// 或者axum如果是嵌套路由,那么在layer的service里看到的uri.path()也不是完整路径,
183    /// 此时需要使用(如果路径没有被nginx等剥离)[axum::extract::OriginalUri]来获取完整路径。
184    /// 为了简化起见,这里直接让用户传入callbackUrl里的路径部分
185    ///
186    /// 注意:传入的是没有经过url encode的原始路径
187    pub fn new(callback_url_path: &str) -> Self {
188        Self {
189            client: reqwest::Client::new(),
190            // 也可以直接用自己的 HashMap,这里示范复用全局缓存
191            cache: Arc::new(RwLock::new(HashMap::new())),
192            callback_path: callback_url_path.to_owned(),
193        }
194    }
195
196    // 如果你想让多个 Layer / 多 crate 共享同一份缓存,可以用这个构造
197    // pub fn with_global_cache() -> Self {
198    //     Self {
199    //         client: reqwest::Client::new(),
200    //         cache: Arc::new(RwLock::new(GLOBAL_PUB_KEY_CACHE.read().unwrap().clone())),
201    //     }
202    // }
203}
204
205impl<S> Layer<S> for OssCallbackVerifyLayer {
206    type Service = OssCallbackVerifyService<S>;
207
208    fn layer(&self, inner: S) -> Self::Service {
209        OssCallbackVerifyService {
210            inner,
211            client: self.client.clone(),
212            cache: Arc::clone(&self.cache),
213            callback_path: self.callback_path.clone(),
214        }
215    }
216}
217
218/// OssCallbackVerifyLayer对应的Service实现
219#[derive(Clone)]
220pub struct OssCallbackVerifyService<S> {
221    inner: S,
222    client: reqwest::Client,
223    cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
224    callback_path: String,
225}
226
227// 构建Service的过程tower有一个guide: https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md
228impl<S> Service<Request<Body>> for OssCallbackVerifyService<S>
229where
230    // 这里要求S: Clone是因为我们在call里需要clone它
231    // S 必须是一个处理 HTTP 请求的 Service,返回的是 axum 的 Response,
232    // 这样这个中间件才能挂在 axum 的 Router / 其它 HTTP 中间件前后
233    S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
234    // 不强制 S::Error 是具体什么类型,但要求它能转换成 axum::BoxError,
235    // 方便和 axum/tower 生态里那些统一用 BoxError 的通用组件(如 HandleErrorLayer)组合
236    S::Error: Into<axum::BoxError>,
237    // 这个必须要否则会在call返回的Future里报错,如果没有这个,即使你返回的Future是Send的,编译器也会报错
238    // axum文档中的例子也有:https://docs.rs/axum/latest/axum/middleware/index.html#towerservice-and-pinboxdyn-future
239    S::Future: Send + 'static,
240{
241    type Response = S::Response;
242    type Error = S::Error;
243    // Future 必须是 Send + 'static,因为可能会跨线程(tokio默认是多线程运行时)
244    type Future =
245        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
246
247    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248        self.inner.poll_ready(cx)
249    }
250
251    fn call(&mut self, req: Request<Body>) -> Self::Future {
252        // 不直接使用clone,而是使用mem::replace,[文档](https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services)
253        let clone = self.inner.clone();
254        let mut inner = std::mem::replace(&mut self.inner, clone);
255        let client = self.client.clone();
256        let cache = Arc::clone(&self.cache);
257        let callback_path = self.callback_path.clone();
258
259        Box::pin(async move {
260            // 先做 OSS 验签,如果失败,直接返回 400 Response
261            match verify_oss_request(req, &client, &cache, callback_path).await {
262                // 这里我们需要把inner clone出来,不能直接用self.inner,否则此时返回的Future的生命周期就不是'static了,而是和self绑定在一起了
263                Ok(verified_req) => inner.call(verified_req).await,
264                Err(resp) => Ok(resp.into_response()),
265            }
266        })
267    }
268}
269
270/// OSS 验签逻辑:
271/// - 成功:返回新的 Request(body 已重建,且 extensions 里挂了 VerifiedOssCallbackBody)
272/// - 失败:返回一个 400 Response
273async fn verify_oss_request<'a>(
274    req: Request<Body>,
275    client: &reqwest::Client,
276    cache: &Arc<RwLock<HashMap<String, Vec<u8>>>>,
277    callback_path: String,
278) -> Result<Request<Body>, OssVerifyError<'a>> {
279    let (parts, body) = req.into_parts();
280    let headers = parts.headers.clone();
281    let uri = parts.uri.clone();
282
283    // 1. 读 body
284    let body_bytes = to_bytes(body, 5 * 1024 * 1024).await?;
285    let body_str = String::from_utf8(body_bytes.to_vec())?;
286
287    // 2. 拿 x-oss-pub-key-url 并解码
288    let pub_key_url_b64 = header_required(&headers, "x-oss-pub-key-url")?;
289    let pub_key_url_raw = STANDARD.decode(pub_key_url_b64.as_bytes())?;
290    let pub_key_url = String::from_utf8(pub_key_url_raw)?;
291
292    if !pub_key_url.starts_with("http://gosspublic.alicdn.com/")
293        && !pub_key_url.starts_with("https://gosspublic.alicdn.com/")
294    {
295        return Err(OssVerifyError::Common("invalid oss public key url"));
296    }
297
298    // 3. 公钥 PEM:先查缓存,再必要时 HTTP 拉取
299    let pub_key_pem = get_or_fetch_pub_key(&pub_key_url, client, cache).await?;
300    let pub_key_pem_str = String::from_utf8(pub_key_pem)?;
301
302    // 4. authorization (签名)Base64 解码
303    let auth_b64 = header_required(&headers, "authorization")?;
304    let auth_bytes = STANDARD.decode(auth_b64.as_bytes())?;
305
306    // 5. 组装 sign_str = url_decode(path) [+ query] + '\n' + body
307    // 下面的注释代码先保留,这是最初的实现方式,但是由于部署时可能会有nginx等代理剥离路径前缀的问题,
308    // 所以改成直接用用户传入的callback_path
309    // let raw_path = uri.path();
310    // let decoded_path = percent_decode_str(raw_path)
311    //     .decode_utf8()
312    //     .map_err(|_| OssVerifyError::Common("failed to percent-decode uri path"))?
313    //     .into_owned();
314    let decoded_path = callback_path;
315
316    let auth_path = match uri.query() {
317        Some(q) => format!("{}?{}", decoded_path, q),
318        None => decoded_path,
319    };
320
321    let auth_str = format!("{}\n{}", auth_path, body_str);
322
323    // 6. MD5(auth_str)
324    let mut hasher = Md5::new();
325    hasher.update(auth_str.as_bytes());
326    let digest = hasher.finalize();
327
328    // 7. RSA(PKCS#1 v1.5, MD5) 验签
329    let rsa_pub_key = RsaPublicKey::from_public_key_pem(&pub_key_pem_str)
330        .map_err(|_| OssVerifyError::Common("failed to parse oss public key pem"))?;
331
332    let verifying_key = pkcs1v15::VerifyingKey::<Md5>::new(rsa_pub_key);
333    let signature = pkcs1v15::Signature::try_from(auth_bytes.as_slice())
334        .map_err(|_| OssVerifyError::Common("failed to parse oss signature"))?;
335
336    verifying_key
337        .verify_prehash(&digest, &signature)
338        .map_err(|_| OssVerifyError::InvalidSignature)?;
339
340    // 8. 验签通过:重建 Request,把 body 塞回去,并在 extensions 里挂一份解析好的 body
341    let mut new_req = Request::from_parts(parts, Body::from(body_bytes));
342    new_req
343        .extensions_mut()
344        .insert(VerifiedOssCallbackBody(body_str));
345
346    Ok(new_req)
347}
348
349/// 取必需 header
350fn header_required<'a>(headers: &HeaderMap, name: &'a str) -> Result<String, OssVerifyError<'a>> {
351    let value = headers
352        .get(name)
353        .ok_or(OssVerifyError::MissingHeader(name))?;
354
355    let s = value
356        .to_str()
357        .map_err(|_| OssVerifyError::InvalidHeader(name))?;
358    Ok(s.to_owned())
359}
360
361/// 按“公钥 URL -> PEM”缓存
362async fn get_or_fetch_pub_key<'a>(
363    url: &str,
364    client: &reqwest::Client,
365    cache: &Arc<RwLock<HashMap<String, Vec<u8>>>>,
366) -> Result<Vec<u8>, OssVerifyError<'a>> {
367    // 这个方法的代码有个问题:在并发场景下,如果url未命中缓存,会对同一个url发起多次HTTP请求
368
369    // 先查缓存
370    {
371        let cache_read = cache.read().unwrap();
372        if let Some(v) = cache_read.get(url) {
373            return Ok(v.clone());
374        }
375    }
376
377    // 缓存未命中,走 HTTP
378    let resp = client.get(url).send().await?;
379    let bytes = resp.bytes().await?;
380
381    // 写回缓存
382    {
383        let mut cache_write = cache.write().unwrap();
384        if let Some(v) = cache_write.get(url) {
385            // Another thread inserted it while we were fetching; use theirs.
386            return Ok(v.clone());
387        } else {
388            cache_write.insert(url.to_string(), bytes.to_vec());
389        }
390    }
391
392    Ok(bytes.to_vec())
393}