volli_core/
token.rs

1use base64::{Engine as _, engine::general_purpose};
2use eyre::Report;
3use hmac::{Hmac, Mac};
4use serde::{Deserialize, Serialize};
5use sha2::Sha256;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tracing::info;
8
9type HmacSha256 = Hmac<Sha256>;
10
11#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
12pub struct TokenPayload {
13    pub tenant: String,
14    pub cluster: String,
15    pub agent_id: String,
16    pub iat: u64,
17    pub exp: u64,
18}
19
20#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
21pub struct Token {
22    pub payload: TokenPayload,
23    pub sig: Vec<u8>,
24}
25
26pub fn issue_token(
27    key: &[u8; 32],
28    tenant: &str,
29    cluster: &str,
30    agent_id: &str,
31    ttl_secs: u64,
32) -> Result<Token, Report> {
33    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
34    let payload = TokenPayload {
35        tenant: tenant.to_string(),
36        cluster: cluster.to_string(),
37        agent_id: agent_id.to_string(),
38        iat: now,
39        exp: now + ttl_secs,
40    };
41    let encoded = bincode::serialize(&payload)?;
42    let mut mac = HmacSha256::new_from_slice(key)?;
43    mac.update(&encoded);
44    let sig = mac.finalize().into_bytes().to_vec();
45    Ok(Token { payload, sig })
46}
47
48pub fn verify_token(token: &Token, key: &[u8; 32]) -> Result<(), Report> {
49    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
50    if token.payload.exp < now {
51        info!("Token expired");
52        return Err(eyre::eyre!("token expired"));
53    }
54    let encoded = bincode::serialize(&token.payload)?;
55    info!("Encoded token: {:?}", encoded);
56    let mut mac = HmacSha256::new_from_slice(key)?;
57    mac.update(&encoded);
58    mac.verify_slice(&token.sig)?;
59    Ok(())
60}
61
62pub fn refresh_token(token: &Token, key: &[u8; 32], ttl_secs: u64) -> Result<Token, Report> {
63    verify_token(token, key)?;
64    issue_token(
65        key,
66        &token.payload.tenant,
67        &token.payload.cluster,
68        &token.payload.agent_id,
69        ttl_secs,
70    )
71}
72
73pub fn encode_token(token: &Token) -> Result<String, Report> {
74    let bytes = bincode::serialize(token)?;
75    Ok(general_purpose::STANDARD_NO_PAD.encode(bytes))
76}
77
78pub fn decode_token(encoded: &str) -> Result<Token, Report> {
79    let bytes = general_purpose::STANDARD_NO_PAD.decode(encoded)?;
80    Ok(bincode::deserialize(&bytes)?)
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn token_issue_verify_roundtrip() {
89        let key = [1u8; 32];
90        let token = issue_token(&key, "t", "c", "aid", 60).unwrap();
91        verify_token(&token, &key).unwrap();
92    }
93
94    #[test]
95    fn token_refresh_extends_expiry() {
96        let key = [2u8; 32];
97        let token = issue_token(&key, "t", "c", "a", 2).unwrap();
98        verify_token(&token, &key).unwrap();
99        std::thread::sleep(std::time::Duration::from_secs(1));
100        let new_token = refresh_token(&token, &key, 5).unwrap();
101        verify_token(&new_token, &key).unwrap();
102        assert!(new_token.payload.exp > token.payload.exp);
103    }
104
105    #[test]
106    fn encode_decode_roundtrip() {
107        let key = [3u8; 32];
108        let token = issue_token(&key, "t", "c", "aid", 30).unwrap();
109        let encoded = encode_token(&token).unwrap();
110        let decoded = decode_token(&encoded).unwrap();
111        assert_eq!(decoded, token);
112        verify_token(&decoded, &key).unwrap();
113    }
114}