rustrails_record/
secure_token.rs1use std::collections::HashSet;
2
3use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct SecureTokenConfig {
8 pub field: String,
10 pub length: usize,
12}
13
14impl SecureTokenConfig {
15 #[must_use]
17 pub fn new(field: &str) -> Self {
18 Self {
19 field: field.to_owned(),
20 length: 24,
21 }
22 }
23
24 #[must_use]
26 pub fn length(mut self, length: usize) -> Self {
27 self.length = length.max(1);
28 self
29 }
30}
31
32#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
34pub enum SecureTokenError {
35 #[error("unknown secure token field: {0}")]
37 UnknownField(String),
38 #[error("could not generate a unique token for {0}")]
40 ExhaustedAttempts(String),
41 #[error("token for {0} must be unique")]
43 DuplicateToken(String),
44}
45
46#[must_use]
48pub fn has_secure_token(field: &str) -> SecureTokenConfig {
49 SecureTokenConfig::new(field)
50}
51
52pub trait SecureToken {
54 fn secure_token_configurations() -> &'static [SecureTokenConfig];
56 fn get_secure_token(&self, field: &str) -> Option<&str>;
58 fn set_secure_token(&mut self, field: &str, token: String);
60
61 fn generate_token() -> String {
63 generate_token_with_length(32)
64 }
65
66 fn ensure_secure_tokens(
68 &mut self,
69 existing_tokens: &HashSet<String>,
70 ) -> Result<(), SecureTokenError> {
71 let mut reserved = existing_tokens.clone();
72 for config in Self::secure_token_configurations() {
73 match self.get_secure_token(&config.field) {
74 Some(token) if reserved.contains(token) => {
75 return Err(SecureTokenError::DuplicateToken(config.field.clone()));
76 }
77 Some(token) => {
78 reserved.insert(token.to_owned());
79 }
80 None => {
81 let token = generate_unique_token(config.length, &reserved, &config.field)?;
82 reserved.insert(token.clone());
83 self.set_secure_token(&config.field, token);
84 }
85 }
86 }
87 Ok(())
88 }
89
90 fn regenerate_token(
92 &mut self,
93 field: &str,
94 existing_tokens: &HashSet<String>,
95 ) -> Result<String, SecureTokenError> {
96 let config = Self::secure_token_configurations()
97 .iter()
98 .find(|config| config.field == field)
99 .ok_or_else(|| SecureTokenError::UnknownField(field.to_owned()))?;
100
101 let mut reserved = existing_tokens.clone();
102 if let Some(current) = self.get_secure_token(field) {
103 reserved.remove(current);
104 }
105
106 let token = generate_unique_token(config.length, &reserved, field)?;
107 self.set_secure_token(field, token.clone());
108 Ok(token)
109 }
110}
111
112fn generate_unique_token(
113 length: usize,
114 existing_tokens: &HashSet<String>,
115 field: &str,
116) -> Result<String, SecureTokenError> {
117 for _ in 0..32 {
118 let token = generate_token_with_length(length);
119 if !existing_tokens.contains(&token) {
120 return Ok(token);
121 }
122 }
123
124 Err(SecureTokenError::ExhaustedAttempts(field.to_owned()))
125}
126
127fn generate_token_with_length(length: usize) -> String {
128 let mut token = String::new();
129 while token.len() < length {
130 let bytes: [u8; 24] = rand::random();
131 token.push_str(&URL_SAFE_NO_PAD.encode(bytes));
132 }
133 token.truncate(length);
134 token
135}
136
137#[cfg(test)]
138mod tests {
139 use std::collections::HashSet;
140 use std::sync::LazyLock;
141
142 use super::{SecureToken, SecureTokenConfig, SecureTokenError, has_secure_token};
143
144 #[derive(Debug, Default)]
145 struct ApiKeyRecord {
146 token: Option<String>,
147 recovery_token: Option<String>,
148 }
149
150 static TOKEN_CONFIGS: LazyLock<Vec<SecureTokenConfig>> = LazyLock::new(|| {
151 vec![
152 has_secure_token("token"),
153 has_secure_token("recovery_token").length(12),
154 ]
155 });
156
157 impl SecureToken for ApiKeyRecord {
158 fn secure_token_configurations() -> &'static [SecureTokenConfig] {
159 TOKEN_CONFIGS.as_slice()
160 }
161
162 fn get_secure_token(&self, field: &str) -> Option<&str> {
163 match field {
164 "token" => self.token.as_deref(),
165 "recovery_token" => self.recovery_token.as_deref(),
166 _ => None,
167 }
168 }
169
170 fn set_secure_token(&mut self, field: &str, token: String) {
171 match field {
172 "token" => self.token = Some(token),
173 "recovery_token" => self.recovery_token = Some(token),
174 _ => {}
175 }
176 }
177 }
178
179 #[test]
180 fn ensure_secure_tokens_generates_missing_tokens() {
181 let mut record = ApiKeyRecord::default();
182 record
183 .ensure_secure_tokens(&HashSet::new())
184 .expect("tokens should be generated");
185
186 assert!(record.token.is_some());
187 assert!(record.recovery_token.is_some());
188 assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
189 }
190
191 #[test]
192 fn ensure_secure_tokens_preserves_unique_existing_tokens() {
193 let mut record = ApiKeyRecord {
194 token: Some("existing-token".to_owned()),
195 recovery_token: None,
196 };
197 record
198 .ensure_secure_tokens(&HashSet::new())
199 .expect("existing token should be preserved");
200
201 assert_eq!(record.token.as_deref(), Some("existing-token"));
202 assert!(record.recovery_token.is_some());
203 }
204
205 #[test]
206 fn ensure_secure_tokens_rejects_duplicate_existing_tokens() {
207 let mut record = ApiKeyRecord {
208 token: Some("taken".to_owned()),
209 recovery_token: None,
210 };
211 let existing = HashSet::from(["taken".to_owned()]);
212
213 assert_eq!(
214 record.ensure_secure_tokens(&existing),
215 Err(SecureTokenError::DuplicateToken("token".to_owned()))
216 );
217 }
218
219 #[test]
220 fn regenerate_token_replaces_existing_value() {
221 let mut record = ApiKeyRecord {
222 token: Some("current".to_owned()),
223 recovery_token: None,
224 };
225 let token = record
226 .regenerate_token("token", &HashSet::new())
227 .expect("token should regenerate");
228
229 assert_eq!(record.token.as_deref(), Some(token.as_str()));
230 assert_ne!(token, "current");
231 }
232
233 #[test]
234 fn regenerate_token_rejects_unknown_fields() {
235 let mut record = ApiKeyRecord::default();
236 assert_eq!(
237 record.regenerate_token("missing", &HashSet::new()),
238 Err(SecureTokenError::UnknownField("missing".to_owned()))
239 );
240 }
241
242 #[test]
243 fn metadata_builder_preserves_length_overrides() {
244 let config = has_secure_token("auth_token").length(10);
245 assert_eq!(config.field, "auth_token");
246 assert_eq!(config.length, 10);
247 }
248 #[test]
249 fn generate_token_returns_minimum_length() {
250 let token = <ApiKeyRecord as SecureToken>::generate_token();
251
252 assert!(token.len() >= 32);
253 }
254
255 #[test]
256 fn ensure_secure_tokens_still_honors_length_overrides() {
257 let mut record = ApiKeyRecord::default();
258 record
259 .ensure_secure_tokens(&HashSet::new())
260 .expect("tokens should be generated");
261
262 assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
263 }
264}