Skip to main content

rustrails_record/
secure_token.rs

1use std::collections::HashSet;
2
3use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
5/// Metadata describing a generated secure token.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct SecureTokenConfig {
8    /// The token field name.
9    pub field: String,
10    /// The generated token length.
11    pub length: usize,
12}
13
14impl SecureTokenConfig {
15    /// Creates token metadata for `field`.
16    #[must_use]
17    pub fn new(field: &str) -> Self {
18        Self {
19            field: field.to_owned(),
20            length: 24,
21        }
22    }
23
24    /// Overrides the generated token length.
25    #[must_use]
26    pub fn length(mut self, length: usize) -> Self {
27        self.length = length.max(1);
28        self
29    }
30}
31
32/// Errors returned by secure-token helpers.
33#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
34pub enum SecureTokenError {
35    /// The requested token field is unknown.
36    #[error("unknown secure token field: {0}")]
37    UnknownField(String),
38    /// Token generation exhausted uniqueness attempts.
39    #[error("could not generate a unique token for {0}")]
40    ExhaustedAttempts(String),
41    /// A provided token collides with an existing token.
42    #[error("token for {0} must be unique")]
43    DuplicateToken(String),
44}
45
46/// Declares a secure token for `field`.
47#[must_use]
48pub fn has_secure_token(field: &str) -> SecureTokenConfig {
49    SecureTokenConfig::new(field)
50}
51
52/// Trait implemented by records that expose secure-token fields.
53pub trait SecureToken {
54    /// Returns secure-token metadata for the record type.
55    fn secure_token_configurations() -> &'static [SecureTokenConfig];
56    /// Reads the current token value for `field`.
57    fn get_secure_token(&self, field: &str) -> Option<&str>;
58    /// Stores a token value for `field`.
59    fn set_secure_token(&mut self, field: &str, token: String);
60
61    /// Generates a cryptographically random token string.
62    fn generate_token() -> String {
63        generate_token_with_length(32)
64    }
65
66    /// Ensures every declared secure-token field has a unique value.
67    fn ensure_secure_tokens(
68        &mut self,
69        existing_tokens: &HashSet<String>,
70    ) -> Result<(), SecureTokenError> {
71        let mut reserved = existing_tokens.clone();
72        for config in Self::secure_token_configurations() {
73            match self.get_secure_token(&config.field) {
74                Some(token) if reserved.contains(token) => {
75                    return Err(SecureTokenError::DuplicateToken(config.field.clone()));
76                }
77                Some(token) => {
78                    reserved.insert(token.to_owned());
79                }
80                None => {
81                    let token = generate_unique_token(config.length, &reserved, &config.field)?;
82                    reserved.insert(token.clone());
83                    self.set_secure_token(&config.field, token);
84                }
85            }
86        }
87        Ok(())
88    }
89
90    /// Replaces the token stored in `field` and returns the new value.
91    fn regenerate_token(
92        &mut self,
93        field: &str,
94        existing_tokens: &HashSet<String>,
95    ) -> Result<String, SecureTokenError> {
96        let config = Self::secure_token_configurations()
97            .iter()
98            .find(|config| config.field == field)
99            .ok_or_else(|| SecureTokenError::UnknownField(field.to_owned()))?;
100
101        let mut reserved = existing_tokens.clone();
102        if let Some(current) = self.get_secure_token(field) {
103            reserved.remove(current);
104        }
105
106        let token = generate_unique_token(config.length, &reserved, field)?;
107        self.set_secure_token(field, token.clone());
108        Ok(token)
109    }
110}
111
112fn generate_unique_token(
113    length: usize,
114    existing_tokens: &HashSet<String>,
115    field: &str,
116) -> Result<String, SecureTokenError> {
117    for _ in 0..32 {
118        let token = generate_token_with_length(length);
119        if !existing_tokens.contains(&token) {
120            return Ok(token);
121        }
122    }
123
124    Err(SecureTokenError::ExhaustedAttempts(field.to_owned()))
125}
126
127fn generate_token_with_length(length: usize) -> String {
128    let mut token = String::new();
129    while token.len() < length {
130        let bytes: [u8; 24] = rand::random();
131        token.push_str(&URL_SAFE_NO_PAD.encode(bytes));
132    }
133    token.truncate(length);
134    token
135}
136
137#[cfg(test)]
138mod tests {
139    use std::collections::HashSet;
140    use std::sync::LazyLock;
141
142    use super::{SecureToken, SecureTokenConfig, SecureTokenError, has_secure_token};
143
144    #[derive(Debug, Default)]
145    struct ApiKeyRecord {
146        token: Option<String>,
147        recovery_token: Option<String>,
148    }
149
150    static TOKEN_CONFIGS: LazyLock<Vec<SecureTokenConfig>> = LazyLock::new(|| {
151        vec![
152            has_secure_token("token"),
153            has_secure_token("recovery_token").length(12),
154        ]
155    });
156
157    impl SecureToken for ApiKeyRecord {
158        fn secure_token_configurations() -> &'static [SecureTokenConfig] {
159            TOKEN_CONFIGS.as_slice()
160        }
161
162        fn get_secure_token(&self, field: &str) -> Option<&str> {
163            match field {
164                "token" => self.token.as_deref(),
165                "recovery_token" => self.recovery_token.as_deref(),
166                _ => None,
167            }
168        }
169
170        fn set_secure_token(&mut self, field: &str, token: String) {
171            match field {
172                "token" => self.token = Some(token),
173                "recovery_token" => self.recovery_token = Some(token),
174                _ => {}
175            }
176        }
177    }
178
179    #[test]
180    fn ensure_secure_tokens_generates_missing_tokens() {
181        let mut record = ApiKeyRecord::default();
182        record
183            .ensure_secure_tokens(&HashSet::new())
184            .expect("tokens should be generated");
185
186        assert!(record.token.is_some());
187        assert!(record.recovery_token.is_some());
188        assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
189    }
190
191    #[test]
192    fn ensure_secure_tokens_preserves_unique_existing_tokens() {
193        let mut record = ApiKeyRecord {
194            token: Some("existing-token".to_owned()),
195            recovery_token: None,
196        };
197        record
198            .ensure_secure_tokens(&HashSet::new())
199            .expect("existing token should be preserved");
200
201        assert_eq!(record.token.as_deref(), Some("existing-token"));
202        assert!(record.recovery_token.is_some());
203    }
204
205    #[test]
206    fn ensure_secure_tokens_rejects_duplicate_existing_tokens() {
207        let mut record = ApiKeyRecord {
208            token: Some("taken".to_owned()),
209            recovery_token: None,
210        };
211        let existing = HashSet::from(["taken".to_owned()]);
212
213        assert_eq!(
214            record.ensure_secure_tokens(&existing),
215            Err(SecureTokenError::DuplicateToken("token".to_owned()))
216        );
217    }
218
219    #[test]
220    fn regenerate_token_replaces_existing_value() {
221        let mut record = ApiKeyRecord {
222            token: Some("current".to_owned()),
223            recovery_token: None,
224        };
225        let token = record
226            .regenerate_token("token", &HashSet::new())
227            .expect("token should regenerate");
228
229        assert_eq!(record.token.as_deref(), Some(token.as_str()));
230        assert_ne!(token, "current");
231    }
232
233    #[test]
234    fn regenerate_token_rejects_unknown_fields() {
235        let mut record = ApiKeyRecord::default();
236        assert_eq!(
237            record.regenerate_token("missing", &HashSet::new()),
238            Err(SecureTokenError::UnknownField("missing".to_owned()))
239        );
240    }
241
242    #[test]
243    fn metadata_builder_preserves_length_overrides() {
244        let config = has_secure_token("auth_token").length(10);
245        assert_eq!(config.field, "auth_token");
246        assert_eq!(config.length, 10);
247    }
248    #[test]
249    fn generate_token_returns_minimum_length() {
250        let token = <ApiKeyRecord as SecureToken>::generate_token();
251
252        assert!(token.len() >= 32);
253    }
254
255    #[test]
256    fn ensure_secure_tokens_still_honors_length_overrides() {
257        let mut record = ApiKeyRecord::default();
258        record
259            .ensure_secure_tokens(&HashSet::new())
260            .expect("tokens should be generated");
261
262        assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
263    }
264}