running_process/broker/server/handoff/
handoff_token.rs1use std::collections::HashMap;
8use std::fmt;
9use std::time::{Duration, Instant};
10
11pub const HANDOFF_TOKEN_BYTES: usize = 16;
13
14pub const DEFAULT_MAX_PENDING_HANDOFF_TOKENS: usize = 1024;
16
17pub const DEFAULT_HANDOFF_TOKEN_TTL: Duration = Duration::from_secs(30);
19
20pub const DEFAULT_HANDOFF_TOKEN_COLLISION_ATTEMPTS: usize = 16;
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash)]
26pub struct HandoffToken([u8; HANDOFF_TOKEN_BYTES]);
27
28impl HandoffToken {
29 pub fn generate() -> Result<Self, HandoffTokenError> {
31 let mut bytes = [0_u8; HANDOFF_TOKEN_BYTES];
32 getrandom::fill(&mut bytes)?;
33 Ok(Self(bytes))
34 }
35
36 pub fn from_bytes(bytes: [u8; HANDOFF_TOKEN_BYTES]) -> Self {
38 Self(bytes)
39 }
40
41 pub fn as_bytes(&self) -> &[u8; HANDOFF_TOKEN_BYTES] {
43 &self.0
44 }
45
46 pub fn into_bytes(self) -> [u8; HANDOFF_TOKEN_BYTES] {
48 self.0
49 }
50}
51
52impl fmt::Debug for HandoffToken {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.write_str("HandoffToken(<redacted>)")
55 }
56}
57
58impl From<[u8; HANDOFF_TOKEN_BYTES]> for HandoffToken {
59 fn from(value: [u8; HANDOFF_TOKEN_BYTES]) -> Self {
60 Self::from_bytes(value)
61 }
62}
63
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
66pub struct HandoffTokenStoreConfig {
67 pub max_pending_tokens: usize,
69 pub token_ttl: Duration,
71 pub collision_attempts: usize,
73}
74
75impl HandoffTokenStoreConfig {
76 pub fn new(max_pending_tokens: usize, token_ttl: Duration) -> Self {
78 Self {
79 max_pending_tokens: max_pending_tokens.max(1),
80 token_ttl: if token_ttl.is_zero() {
81 Duration::from_millis(1)
82 } else {
83 token_ttl
84 },
85 collision_attempts: DEFAULT_HANDOFF_TOKEN_COLLISION_ATTEMPTS,
86 }
87 }
88
89 pub fn with_collision_attempts(mut self, collision_attempts: usize) -> Self {
91 self.collision_attempts = collision_attempts.max(1);
92 self
93 }
94}
95
96impl Default for HandoffTokenStoreConfig {
97 fn default() -> Self {
98 Self {
99 max_pending_tokens: DEFAULT_MAX_PENDING_HANDOFF_TOKENS,
100 token_ttl: DEFAULT_HANDOFF_TOKEN_TTL,
101 collision_attempts: DEFAULT_HANDOFF_TOKEN_COLLISION_ATTEMPTS,
102 }
103 }
104}
105
106#[derive(Debug)]
108pub struct HandoffTokenStore {
109 config: HandoffTokenStoreConfig,
110 pending: HashMap<HandoffToken, PendingHandoffToken>,
111}
112
113impl HandoffTokenStore {
114 pub fn new() -> Self {
116 Self::with_config(HandoffTokenStoreConfig::default())
117 }
118
119 pub fn with_config(config: HandoffTokenStoreConfig) -> Self {
121 Self {
122 config,
123 pending: HashMap::new(),
124 }
125 }
126
127 pub fn pending_len(&self) -> usize {
129 self.pending.len()
130 }
131
132 pub fn issue(&mut self, now: Instant) -> Result<HandoffToken, HandoffTokenError> {
134 self.issue_with_random128(now, || {
135 let mut bytes = [0_u8; HANDOFF_TOKEN_BYTES];
136 getrandom::fill(&mut bytes)?;
137 Ok(bytes)
138 })
139 }
140
141 pub fn issue_with_random128<F>(
146 &mut self,
147 now: Instant,
148 mut next_random128: F,
149 ) -> Result<HandoffToken, HandoffTokenError>
150 where
151 F: FnMut() -> Result<[u8; HANDOFF_TOKEN_BYTES], HandoffTokenError>,
152 {
153 self.prune_expired(now);
154 if self.pending.len() >= self.config.max_pending_tokens {
155 return Err(HandoffTokenError::PendingLimitReached {
156 max_pending_tokens: self.config.max_pending_tokens,
157 });
158 }
159
160 for _ in 0..self.config.collision_attempts {
161 let token = HandoffToken::from_bytes(next_random128()?);
162 if self.pending.contains_key(&token) {
163 continue;
164 }
165
166 self.pending.insert(
167 token,
168 PendingHandoffToken {
169 expires_at: expires_at(now, self.config.token_ttl),
170 },
171 );
172 return Ok(token);
173 }
174
175 Err(HandoffTokenError::CollisionExhausted {
176 attempts: self.config.collision_attempts,
177 })
178 }
179
180 pub fn consume_matching(
185 &mut self,
186 expected: &HandoffToken,
187 presented: &HandoffToken,
188 now: Instant,
189 ) -> Result<(), HandoffTokenError> {
190 self.prune_expired_except(now, Some(expected));
191
192 let Some(pending) = self.pending.get(expected) else {
193 return Err(HandoffTokenError::TokenNotPending);
194 };
195 if now >= pending.expires_at {
196 self.pending.remove(expected);
197 return Err(HandoffTokenError::TokenExpired);
198 }
199 if expected != presented {
200 return Err(HandoffTokenError::TokenMismatch);
201 }
202
203 self.pending.remove(expected);
204 Ok(())
205 }
206
207 pub fn revoke(&mut self, token: &HandoffToken) -> bool {
213 self.pending.remove(token).is_some()
214 }
215
216 pub fn prune_expired(&mut self, now: Instant) -> usize {
218 self.prune_expired_except(now, None)
219 }
220
221 fn prune_expired_except(&mut self, now: Instant, except: Option<&HandoffToken>) -> usize {
222 let before = self.pending.len();
223 self.pending.retain(|token, pending| {
224 except.is_some_and(|expected| expected == token) || now < pending.expires_at
225 });
226 before - self.pending.len()
227 }
228}
229
230impl Default for HandoffTokenStore {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
238pub enum HandoffTokenError {
239 #[error("handoff token random generation failed: {0}")]
241 Random(String),
242 #[error("handoff token pending limit reached ({max_pending_tokens})")]
244 PendingLimitReached {
245 max_pending_tokens: usize,
247 },
248 #[error("handoff token allocation exhausted after {attempts} collision attempts")]
250 CollisionExhausted {
251 attempts: usize,
253 },
254 #[error("handoff token mismatch")]
256 TokenMismatch,
257 #[error("handoff token expired")]
259 TokenExpired,
260 #[error("handoff token is not pending")]
262 TokenNotPending,
263}
264
265impl From<getrandom::Error> for HandoffTokenError {
266 fn from(value: getrandom::Error) -> Self {
267 Self::Random(value.to_string())
268 }
269}
270
271#[derive(Clone, Debug)]
272struct PendingHandoffToken {
273 expires_at: Instant,
274}
275
276fn expires_at(now: Instant, ttl: Duration) -> Instant {
277 now.checked_add(ttl).unwrap_or(now)
278}