rustfs_policy/
utils.rs

1// Copyright 2024 RustFS Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header};
16use rand::{Rng, RngCore};
17use serde::{Serialize, de::DeserializeOwned};
18use std::io::{Error, Result};
19
20pub fn gen_access_key(length: usize) -> Result<String> {
21    const ALPHA_NUMERIC_TABLE: [char; 36] = [
22        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N',
23        'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
24    ];
25
26    if length < 3 {
27        return Err(Error::other("access key length is too short"));
28    }
29
30    let mut result = String::with_capacity(length);
31    let mut rng = rand::rng();
32
33    for _ in 0..length {
34        result.push(ALPHA_NUMERIC_TABLE[rng.random_range(0..ALPHA_NUMERIC_TABLE.len())]);
35    }
36
37    Ok(result)
38}
39
40pub fn gen_secret_key(length: usize) -> Result<String> {
41    use base64_simd::URL_SAFE_NO_PAD;
42
43    if length < 8 {
44        return Err(Error::other("secret key length is too short"));
45    }
46    let mut rng = rand::rng();
47
48    let mut key = vec![0u8; URL_SAFE_NO_PAD.estimated_decoded_length(length)];
49    rng.fill_bytes(&mut key);
50
51    let encoded = URL_SAFE_NO_PAD.encode_to_string(&key);
52    let key_str = encoded.replace("/", "+");
53
54    Ok(key_str)
55}
56
57pub fn generate_jwt<T: Serialize>(claims: &T, secret: &str) -> std::result::Result<String, jsonwebtoken::errors::Error> {
58    let header = Header::new(Algorithm::HS512);
59    jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(secret.as_bytes()))
60}
61
62pub fn extract_claims<T: DeserializeOwned>(
63    token: &str,
64    secret: &str,
65) -> std::result::Result<jsonwebtoken::TokenData<T>, jsonwebtoken::errors::Error> {
66    jsonwebtoken::decode::<T>(
67        token,
68        &DecodingKey::from_secret(secret.as_bytes()),
69        &jsonwebtoken::Validation::new(Algorithm::HS512),
70    )
71}
72
73#[cfg(test)]
74mod tests {
75    use super::{gen_access_key, gen_secret_key, generate_jwt};
76    use serde::{Deserialize, Serialize};
77
78    #[test]
79    fn test_gen_access_key() {
80        let a = gen_access_key(10).unwrap();
81        let b = gen_access_key(10).unwrap();
82
83        assert_eq!(a.len(), 10);
84        assert_eq!(b.len(), 10);
85        assert_ne!(a, b);
86    }
87
88    #[test]
89    fn test_gen_secret_key() {
90        let a = gen_secret_key(10).unwrap();
91        let b = gen_secret_key(10).unwrap();
92        assert_ne!(a, b);
93    }
94
95    #[derive(Debug, Serialize, Deserialize, PartialEq)]
96    struct Claims {
97        sub: String,
98        company: String,
99    }
100
101    #[test]
102    fn test_generate_jwt() {
103        let claims = Claims {
104            sub: "user1".to_string(),
105            company: "example".to_string(),
106        };
107        let secret = "my_secret";
108        let token = generate_jwt(&claims, secret).unwrap();
109
110        assert!(!token.is_empty());
111    }
112
113    // #[test]
114    // fn test_extract_claims() {
115    //     let claims = Claims {
116    //         sub: "user1".to_string(),
117    //         company: "example".to_string(),
118    //     };
119    //     let secret = "my_secret";
120    //     let token = generate_jwt(&claims, secret).unwrap();
121    //     let decoded_claims = extract_claims::<Claims>(&token, secret).unwrap();
122    //     assert_eq!(decoded_claims.claims, claims);
123    // }
124}