1use std::{error::Error as StdError, fmt, num::NonZeroU128};
4
5use base64::Engine;
6use rand::TryCryptoRng;
7
8#[derive(Clone, Hash, PartialEq, Eq)]
12pub struct SessionKey(NonZeroU128);
13
14impl fmt::Debug for SessionKey {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 f.write_str("SessionKey(..)")
18 }
19}
20
21impl SessionKey {
22 const BASE64_ENGINE: base64::engine::GeneralPurpose =
23 base64::engine::general_purpose::URL_SAFE_NO_PAD;
24
25 pub const ENCODED_LEN: usize = 22;
29
30 const DECODED_LEN: usize = 16;
35
36 #[must_use]
44 pub fn generate() -> SessionKey {
45 SessionKey::generate_from_rng(&mut rand::rng())
46 }
47
48 #[must_use]
61 pub fn generate_from_rng<R: TryCryptoRng>(rng: &mut R) -> SessionKey {
62 fn generate_u128<R: TryCryptoRng>(rng: &mut R) -> u128 {
63 let x = u128::from(rng.try_next_u64().unwrap());
64 let y = u128::from(rng.try_next_u64().unwrap());
65 (y << 64) | x
66 }
67
68 loop {
69 if let Some(n) = NonZeroU128::new(generate_u128(rng)) {
70 return SessionKey(n);
71 }
72 }
73 }
74
75 #[must_use]
82 pub fn encode(&self) -> String {
83 SessionKey::BASE64_ENGINE.encode(self.0.get().to_le_bytes())
84 }
85
86 pub fn decode<B: AsRef<[u8]>>(b: B) -> Result<SessionKey, DecodeSessionKeyError> {
91 fn _decode(b: &[u8]) -> Result<SessionKey, DecodeSessionKeyError> {
92 use base64::DecodeError;
93
94 let mut buf = [0; const { SessionKey::DECODED_LEN }];
95 SessionKey::BASE64_ENGINE
96 .decode_slice(b, &mut buf)
97 .and_then(|decoded_len| {
98 if decoded_len == SessionKey::DECODED_LEN {
99 Ok(())
100 } else {
101 Err(DecodeError::InvalidLength(decoded_len).into())
102 }
103 })?;
104
105 match u128::from_le_bytes(buf).try_into() {
106 Ok(v) => Ok(SessionKey(v)),
107 Err(_) => Err(DecodeSessionKeyError::Zero),
108 }
109 }
110
111 _decode(b.as_ref())
112 }
113}
114
115impl SessionKey {
116 #[doc(hidden)]
118 #[inline]
119 pub fn from_non_zero_u128(value: NonZeroU128) -> SessionKey {
120 SessionKey(value)
121 }
122
123 #[doc(hidden)]
125 #[inline]
126 pub fn try_from_u128(value: u128) -> Result<SessionKey, std::num::TryFromIntError> {
127 value.try_into().map(SessionKey::from_non_zero_u128)
128 }
129}
130
131#[derive(Debug)]
133pub enum DecodeSessionKeyError {
134 Base64(base64::DecodeSliceError),
135 Zero,
136}
137
138impl StdError for DecodeSessionKeyError {
139 fn source(&self) -> Option<&(dyn StdError + 'static)> {
140 match self {
141 DecodeSessionKeyError::Base64(err) => Some(err),
142 DecodeSessionKeyError::Zero => None,
143 }
144 }
145}
146
147impl fmt::Display for DecodeSessionKeyError {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 match self {
150 DecodeSessionKeyError::Base64(_err) => f.write_str("failed to parse base64 string"),
151 DecodeSessionKeyError::Zero => f.write_str("session id must be non-zero"),
152 }
153 }
154}
155
156impl From<base64::DecodeSliceError> for DecodeSessionKeyError {
157 fn from(value: base64::DecodeSliceError) -> Self {
158 DecodeSessionKeyError::Base64(value)
159 }
160}
161
162#[cfg(test)]
163mod test {
164 use super::*;
165 use quickcheck::{quickcheck, Arbitrary};
166
167 #[test]
168 fn parse_error_zero() {
169 const INPUT: &str = "AAAAAAAAAAAAAAAAAAAAAA";
170 let result = SessionKey::decode(INPUT);
171 assert!(
172 matches!(result, Err(DecodeSessionKeyError::Zero)),
173 "expected decoding to fail"
174 );
175 }
176
177 impl Arbitrary for SessionKey {
178 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
179 SessionKey::from_non_zero_u128(NonZeroU128::arbitrary(g))
180 }
181 }
182
183 quickcheck! {
184 fn encode_decode(id: SessionKey) -> bool {
185 let encoded = id.encode();
186 let decoded = SessionKey::decode(&encoded).unwrap();
187 id == decoded
188 }
189 }
190
191 #[test]
192 fn debug_redacts_content() {
193 let s = SessionKey::generate();
194 assert_eq!(format!("{:?}", s), "SessionKey(..)");
195 }
196}