Skip to main content

wae_authentication/jwt/
codec.rs

1//! JWT 编解码模块
2
3use base64::{Engine as _, engine::general_purpose};
4use hmac::{Hmac, Mac as _};
5use serde::{Deserialize, Serialize};
6use sha2::{Sha256, Sha384, Sha512};
7use std::fmt;
8use wae_types::{WaeError, WaeErrorKind};
9
10/// JWT 编解码错误
11#[derive(Debug)]
12pub enum JwtCodecError {
13    /// 无效的令牌格式
14    InvalidFormat,
15
16    /// Base64 解码失败
17    Base64Error(base64::DecodeError),
18
19    /// JSON 序列化/反序列化失败
20    JsonError(serde_json::Error),
21
22    /// 无效的签名
23    InvalidSignature,
24
25    /// 无效的算法
26    InvalidAlgorithm,
27
28    /// 密钥错误
29    KeyError,
30}
31
32impl fmt::Display for JwtCodecError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            JwtCodecError::InvalidFormat => write!(f, "invalid token format"),
36            JwtCodecError::Base64Error(e) => write!(f, "base64 decode error: {}", e),
37            JwtCodecError::JsonError(e) => write!(f, "json error: {}", e),
38            JwtCodecError::InvalidSignature => write!(f, "invalid signature"),
39            JwtCodecError::InvalidAlgorithm => write!(f, "invalid algorithm"),
40            JwtCodecError::KeyError => write!(f, "key error"),
41        }
42    }
43}
44
45impl std::error::Error for JwtCodecError {}
46
47impl From<base64::DecodeError> for JwtCodecError {
48    fn from(err: base64::DecodeError) -> Self {
49        JwtCodecError::Base64Error(err)
50    }
51}
52
53impl From<serde_json::Error> for JwtCodecError {
54    fn from(err: serde_json::Error) -> Self {
55        JwtCodecError::JsonError(err)
56    }
57}
58
59impl From<JwtCodecError> for WaeError {
60    fn from(err: JwtCodecError) -> Self {
61        match err {
62            JwtCodecError::InvalidFormat => WaeError::invalid_token("malformed token"),
63            JwtCodecError::Base64Error(_) => WaeError::invalid_token("invalid base64"),
64            JwtCodecError::JsonError(_) => WaeError::invalid_token("invalid json"),
65            JwtCodecError::InvalidSignature => WaeError::invalid_signature(),
66            JwtCodecError::InvalidAlgorithm => WaeError::new(WaeErrorKind::InvalidAlgorithm),
67            JwtCodecError::KeyError => WaeError::new(WaeErrorKind::KeyError),
68        }
69    }
70}
71
72/// JWT 结果类型
73pub type JwtCodecResult<T> = Result<T, JwtCodecError>;
74
75/// JWT 头部
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct JwtHeader {
78    /// 算法
79    pub alg: String,
80    /// 类型
81    pub typ: String,
82}
83
84impl JwtHeader {
85    /// 创建新的 JWT 头部
86    pub fn new(alg: impl Into<String>) -> Self {
87        Self { alg: alg.into(), typ: "JWT".to_string() }
88    }
89}
90
91/// Base64URL 编码(URL_SAFE_NO_PAD)
92pub fn base64url_encode(input: &[u8]) -> String {
93    general_purpose::URL_SAFE_NO_PAD.encode(input)
94}
95
96/// Base64URL 解码(URL_SAFE_NO_PAD)
97pub fn base64url_decode(input: &str) -> JwtCodecResult<Vec<u8>> {
98    Ok(general_purpose::URL_SAFE_NO_PAD.decode(input)?)
99}
100
101/// 计算 HMAC 签名
102pub fn hmac_sign(algorithm: &str, secret: &[u8], data: &[u8]) -> JwtCodecResult<Vec<u8>> {
103    match algorithm {
104        "HS256" => {
105            let mut mac = Hmac::<Sha256>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
106            mac.update(data);
107            Ok(mac.finalize().into_bytes().to_vec())
108        }
109        "HS384" => {
110            let mut mac = Hmac::<Sha384>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
111            mac.update(data);
112            Ok(mac.finalize().into_bytes().to_vec())
113        }
114        "HS512" => {
115            let mut mac = Hmac::<Sha512>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
116            mac.update(data);
117            Ok(mac.finalize().into_bytes().to_vec())
118        }
119        _ => Err(JwtCodecError::InvalidAlgorithm),
120    }
121}
122
123/// 验证 HMAC 签名
124pub fn hmac_verify(algorithm: &str, secret: &[u8], data: &[u8], signature: &[u8]) -> JwtCodecResult<bool> {
125    match algorithm {
126        "HS256" => {
127            let mut mac = Hmac::<Sha256>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
128            mac.update(data);
129            mac.verify_slice(signature).map_err(|_| JwtCodecError::InvalidSignature)?;
130            Ok(true)
131        }
132        "HS384" => {
133            let mut mac = Hmac::<Sha384>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
134            mac.update(data);
135            mac.verify_slice(signature).map_err(|_| JwtCodecError::InvalidSignature)?;
136            Ok(true)
137        }
138        "HS512" => {
139            let mut mac = Hmac::<Sha512>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
140            mac.update(data);
141            mac.verify_slice(signature).map_err(|_| JwtCodecError::InvalidSignature)?;
142            Ok(true)
143        }
144        _ => Err(JwtCodecError::InvalidAlgorithm),
145    }
146}
147
148/// 编码 JWT 令牌
149pub fn encode_jwt<T: Serialize>(header: &JwtHeader, claims: &T, secret: &[u8]) -> JwtCodecResult<String> {
150    let header_json = serde_json::to_string(header)?;
151    let claims_json = serde_json::to_string(claims)?;
152
153    let header_b64 = base64url_encode(header_json.as_bytes());
154    let claims_b64 = base64url_encode(claims_json.as_bytes());
155
156    let message = format!("{}.{}", header_b64, claims_b64);
157    let signature = hmac_sign(&header.alg, secret, message.as_bytes())?;
158    let signature_b64 = base64url_encode(&signature);
159
160    Ok(format!("{}.{}", message, signature_b64))
161}
162
163/// 解码 JWT 令牌
164pub fn decode_jwt<T: for<'de> Deserialize<'de>>(token: &str, secret: &[u8], validate_signature: bool) -> JwtCodecResult<T> {
165    let parts: Vec<&str> = token.split('.').collect();
166    if parts.len() != 3 {
167        return Err(JwtCodecError::InvalidFormat);
168    }
169
170    let header_b64 = parts[0];
171    let claims_b64 = parts[1];
172    let signature_b64 = parts[2];
173
174    let header_bytes = base64url_decode(header_b64)?;
175    let header: JwtHeader = serde_json::from_slice(&header_bytes)?;
176
177    let claims_bytes = base64url_decode(claims_b64)?;
178    let claims: T = serde_json::from_slice(&claims_bytes)?;
179
180    if validate_signature {
181        let message = format!("{}.{}", header_b64, claims_b64);
182        let signature = base64url_decode(signature_b64)?;
183        hmac_verify(&header.alg, secret, message.as_bytes(), &signature)?;
184    }
185
186    Ok(claims)
187}