1use base64::{Engine as _, engine::general_purpose};
2use eyre::Report;
3use hmac::{Hmac, Mac};
4use serde::{Deserialize, Serialize};
5use sha2::Sha256;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tracing::info;
8
9type HmacSha256 = Hmac<Sha256>;
10
11#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
12pub struct TokenPayload {
13 pub tenant: String,
14 pub cluster: String,
15 pub agent_id: String,
16 pub iat: u64,
17 pub exp: u64,
18}
19
20#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
21pub struct Token {
22 pub payload: TokenPayload,
23 pub sig: Vec<u8>,
24}
25
26pub fn issue_token(
27 key: &[u8; 32],
28 tenant: &str,
29 cluster: &str,
30 agent_id: &str,
31 ttl_secs: u64,
32) -> Result<Token, Report> {
33 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
34 let payload = TokenPayload {
35 tenant: tenant.to_string(),
36 cluster: cluster.to_string(),
37 agent_id: agent_id.to_string(),
38 iat: now,
39 exp: now + ttl_secs,
40 };
41 let encoded = bincode::serialize(&payload)?;
42 let mut mac = HmacSha256::new_from_slice(key)?;
43 mac.update(&encoded);
44 let sig = mac.finalize().into_bytes().to_vec();
45 Ok(Token { payload, sig })
46}
47
48pub fn verify_token(token: &Token, key: &[u8; 32]) -> Result<(), Report> {
49 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
50 if token.payload.exp < now {
51 info!("Token expired");
52 return Err(eyre::eyre!("token expired"));
53 }
54 let encoded = bincode::serialize(&token.payload)?;
55 info!("Encoded token: {:?}", encoded);
56 let mut mac = HmacSha256::new_from_slice(key)?;
57 mac.update(&encoded);
58 mac.verify_slice(&token.sig)?;
59 Ok(())
60}
61
62pub fn refresh_token(token: &Token, key: &[u8; 32], ttl_secs: u64) -> Result<Token, Report> {
63 verify_token(token, key)?;
64 issue_token(
65 key,
66 &token.payload.tenant,
67 &token.payload.cluster,
68 &token.payload.agent_id,
69 ttl_secs,
70 )
71}
72
73pub fn encode_token(token: &Token) -> Result<String, Report> {
74 let bytes = bincode::serialize(token)?;
75 Ok(general_purpose::STANDARD_NO_PAD.encode(bytes))
76}
77
78pub fn decode_token(encoded: &str) -> Result<Token, Report> {
79 let bytes = general_purpose::STANDARD_NO_PAD.decode(encoded)?;
80 Ok(bincode::deserialize(&bytes)?)
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn token_issue_verify_roundtrip() {
89 let key = [1u8; 32];
90 let token = issue_token(&key, "t", "c", "aid", 60).unwrap();
91 verify_token(&token, &key).unwrap();
92 }
93
94 #[test]
95 fn token_refresh_extends_expiry() {
96 let key = [2u8; 32];
97 let token = issue_token(&key, "t", "c", "a", 2).unwrap();
98 verify_token(&token, &key).unwrap();
99 std::thread::sleep(std::time::Duration::from_secs(1));
100 let new_token = refresh_token(&token, &key, 5).unwrap();
101 verify_token(&new_token, &key).unwrap();
102 assert!(new_token.payload.exp > token.payload.exp);
103 }
104
105 #[test]
106 fn encode_decode_roundtrip() {
107 let key = [3u8; 32];
108 let token = issue_token(&key, "t", "c", "aid", 30).unwrap();
109 let encoded = encode_token(&token).unwrap();
110 let decoded = decode_token(&encoded).unwrap();
111 assert_eq!(decoded, token);
112 verify_token(&decoded, &key).unwrap();
113 }
114}