tower_sesh_core/
key.rs

1//! `SessionKey` and related items.
2
3use std::{error::Error as StdError, fmt, num::NonZeroU128};
4
5use base64::Engine;
6use rand::TryCryptoRng;
7
8/// A 128-bit session identifier.
9// `NonZeroU128` is used so that `Option<SessionKey>` has the same size as
10// `SessionKey`
11#[derive(Clone, Hash, PartialEq, Eq)]
12pub struct SessionKey(NonZeroU128);
13
14/// Debug implementation does not leak secret
15impl fmt::Debug for SessionKey {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        f.write_str("SessionKey(..)")
18    }
19}
20
21impl SessionKey {
22    const BASE64_ENGINE: base64::engine::GeneralPurpose =
23        base64::engine::general_purpose::URL_SAFE_NO_PAD;
24
25    /// Length of a Base64 string returned by the [`encode`] method.
26    ///
27    /// [`encode`]: SessionKey::encode
28    pub const ENCODED_LEN: usize = 22;
29
30    /// Length of output from decoding a Base64-encoded session key string with
31    /// the [`decode`] method.
32    ///
33    /// [`decode`]: SessionKey::decode
34    const DECODED_LEN: usize = 16;
35
36    /// Returns a random [`SessionKey`], generated from [`ThreadRng`].
37    ///
38    /// Alternatively, you may wish to use [`generate_from_rng`] and pass your
39    /// own CSPRNG. See `ThreadRng`'s documentation for notes on security.
40    ///
41    /// [`ThreadRng`]: rand::rngs::ThreadRng
42    /// [`generate_from_rng`]: SessionKey::generate_from_rng
43    #[must_use]
44    pub fn generate() -> SessionKey {
45        SessionKey::generate_from_rng(&mut rand::rng())
46    }
47
48    /// Returns a random [`SessionKey`], generated from `rng`.
49    ///
50    /// Alternatively, you may wish to use [`generate`]. See its documentation
51    /// for more.
52    ///
53    /// # Panics
54    ///
55    /// If the RNG passed is [fallible] and yields an error, this function will
56    /// panic.
57    ///
58    /// [`generate`]: SessionKey::generate
59    /// [fallible]: rand::TryRngCore
60    #[must_use]
61    pub fn generate_from_rng<R: TryCryptoRng>(rng: &mut R) -> SessionKey {
62        fn generate_u128<R: TryCryptoRng>(rng: &mut R) -> u128 {
63            let x = u128::from(rng.try_next_u64().unwrap());
64            let y = u128::from(rng.try_next_u64().unwrap());
65            (y << 64) | x
66        }
67
68        loop {
69            if let Some(n) = NonZeroU128::new(generate_u128(rng)) {
70                return SessionKey(n);
71            }
72        }
73    }
74
75    /// Encodes this session key as a URL-safe Base64 string with no padding.
76    ///
77    /// The returned string uses the URL-safe and filename-safe alphabet (with
78    /// `-` and `_`) specified in [RFC 4648].
79    ///
80    /// [RFC 4648]: https://datatracker.ietf.org/doc/html/rfc4648#section-5
81    #[must_use]
82    pub fn encode(&self) -> String {
83        SessionKey::BASE64_ENGINE.encode(self.0.get().to_le_bytes())
84    }
85
86    /// Decodes a session key string encoded with the URL-safe Base64 alphabet
87    /// specified in [RFC 4648]. There must be no padding present in the input.
88    ///
89    /// [RFC 4648]: https://datatracker.ietf.org/doc/html/rfc4648#section-5
90    pub fn decode<B: AsRef<[u8]>>(b: B) -> Result<SessionKey, DecodeSessionKeyError> {
91        fn _decode(b: &[u8]) -> Result<SessionKey, DecodeSessionKeyError> {
92            use base64::DecodeError;
93
94            let mut buf = [0; const { SessionKey::DECODED_LEN }];
95            SessionKey::BASE64_ENGINE
96                .decode_slice(b, &mut buf)
97                .and_then(|decoded_len| {
98                    if decoded_len == SessionKey::DECODED_LEN {
99                        Ok(())
100                    } else {
101                        Err(DecodeError::InvalidLength(decoded_len).into())
102                    }
103                })?;
104
105            match u128::from_le_bytes(buf).try_into() {
106                Ok(v) => Ok(SessionKey(v)),
107                Err(_) => Err(DecodeSessionKeyError::Zero),
108            }
109        }
110
111        _decode(b.as_ref())
112    }
113}
114
115impl SessionKey {
116    // Not public API. Only tests use this.
117    #[doc(hidden)]
118    #[inline]
119    pub fn from_non_zero_u128(value: NonZeroU128) -> SessionKey {
120        SessionKey(value)
121    }
122
123    // Not public API. Only tests use this.
124    #[doc(hidden)]
125    #[inline]
126    pub fn try_from_u128(value: u128) -> Result<SessionKey, std::num::TryFromIntError> {
127        value.try_into().map(SessionKey::from_non_zero_u128)
128    }
129}
130
131/// The error type returned when decoding a session key fails.
132#[derive(Debug)]
133pub enum DecodeSessionKeyError {
134    Base64(base64::DecodeSliceError),
135    Zero,
136}
137
138impl StdError for DecodeSessionKeyError {
139    fn source(&self) -> Option<&(dyn StdError + 'static)> {
140        match self {
141            DecodeSessionKeyError::Base64(err) => Some(err),
142            DecodeSessionKeyError::Zero => None,
143        }
144    }
145}
146
147impl fmt::Display for DecodeSessionKeyError {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        match self {
150            DecodeSessionKeyError::Base64(_err) => f.write_str("failed to parse base64 string"),
151            DecodeSessionKeyError::Zero => f.write_str("session id must be non-zero"),
152        }
153    }
154}
155
156impl From<base64::DecodeSliceError> for DecodeSessionKeyError {
157    fn from(value: base64::DecodeSliceError) -> Self {
158        DecodeSessionKeyError::Base64(value)
159    }
160}
161
162#[cfg(test)]
163mod test {
164    use super::*;
165    use quickcheck::{quickcheck, Arbitrary};
166
167    #[test]
168    fn parse_error_zero() {
169        const INPUT: &str = "AAAAAAAAAAAAAAAAAAAAAA";
170        let result = SessionKey::decode(INPUT);
171        assert!(
172            matches!(result, Err(DecodeSessionKeyError::Zero)),
173            "expected decoding to fail"
174        );
175    }
176
177    impl Arbitrary for SessionKey {
178        fn arbitrary(g: &mut quickcheck::Gen) -> Self {
179            SessionKey::from_non_zero_u128(NonZeroU128::arbitrary(g))
180        }
181    }
182
183    quickcheck! {
184        fn encode_decode(id: SessionKey) -> bool {
185            let encoded = id.encode();
186            let decoded = SessionKey::decode(&encoded).unwrap();
187            id == decoded
188        }
189    }
190
191    #[test]
192    fn debug_redacts_content() {
193        let s = SessionKey::generate();
194        assert_eq!(format!("{:?}", s), "SessionKey(..)");
195    }
196}