zino_auth/
session_id.rs

1use self::ParseSessionIdError::*;
2use hmac::digest::{Digest, FixedOutput, HashMarker, Update};
3use serde::{Deserialize, Serialize};
4use std::{error, fmt};
5use zino_core::{SharedString, encoding::base64, error::Error, validation::Validation};
6
7/// Session Identification URI.
8/// See [the spec](https://www.w3.org/TR/WD-session-id).
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SessionId {
11    /// Specifies the realm within which linkage of the identifier is possible.
12    /// Realms have the same format as DNS names.
13    realm: SharedString,
14    /// Unstructured random integer specific to realm generated using a procedure with
15    /// a negligible probability of collision. The identifier is encoded using base64.
16    identifier: String,
17    /// Optional extension of identifier field used to differentiate concurrent uses of
18    /// the same session identifier. The thread field is an integer encoded in hexadecimal.
19    thread: u8,
20    /// Optional Hexadecimal encoded integer containing a monotonically increasing counter value.
21    /// A client should increment the count field after each operation.
22    count: u8,
23}
24
25impl SessionId {
26    /// Creates a new instance.
27    #[inline]
28    pub fn new<D>(realm: impl Into<SharedString>, key: impl AsRef<[u8]>) -> Self
29    where
30        D: Default + FixedOutput + HashMarker + Update,
31    {
32        fn inner<D>(realm: SharedString, key: &[u8]) -> SessionId
33        where
34            D: Default + FixedOutput + HashMarker + Update,
35        {
36            let data = [realm.as_ref().as_bytes(), key].concat();
37            let mut hasher = D::new();
38            hasher.update(data.as_ref());
39
40            let identifier = base64::encode(hasher.finalize().as_slice());
41            SessionId {
42                realm,
43                identifier,
44                thread: 0,
45                count: 0,
46            }
47        }
48        inner::<D>(realm.into(), key.as_ref())
49    }
50
51    /// Validates the session identifier using the realm and the key.
52    pub fn validate_with<D>(&self, realm: &str, key: impl AsRef<[u8]>) -> Validation
53    where
54        D: Default + FixedOutput + HashMarker + Update,
55    {
56        fn inner<D>(session_id: &SessionId, realm: &str, key: &[u8]) -> Validation
57        where
58            D: Default + FixedOutput + HashMarker + Update,
59        {
60            let mut validation = Validation::new();
61            let identifier = &session_id.identifier;
62            match base64::decode(identifier) {
63                Ok(hash) => {
64                    let data = [realm.as_bytes(), key].concat();
65                    let mut hasher = D::new();
66                    hasher.update(data.as_ref());
67
68                    if hasher.finalize().as_slice() != hash {
69                        validation.record("identifier", "invalid session identifier");
70                    }
71                }
72                Err(err) => {
73                    validation.record_fail("identifier", err);
74                }
75            }
76            validation
77        }
78        inner::<D>(self, realm, key.as_ref())
79    }
80
81    /// Returns `true` if the given `SessionId` can be accepted by `self`.
82    pub fn accepts(&self, session_id: &SessionId) -> bool {
83        if self.identifier() != session_id.identifier() {
84            return false;
85        }
86
87        let realm = self.realm();
88        let domain = session_id.realm();
89        if domain == realm {
90            self.count() <= session_id.count()
91        } else {
92            let remainder = if realm.len() > domain.len() {
93                realm.strip_suffix(domain)
94            } else {
95                domain.strip_suffix(realm)
96            };
97            remainder.is_some_and(|s| s.ends_with('.'))
98        }
99    }
100
101    /// Sets the thread used to differentiate concurrent uses of the same session identifier.
102    #[inline]
103    pub fn set_thread(&mut self, thread: u8) {
104        self.thread = thread;
105    }
106
107    /// Increments the count used to prevent replay attacks.
108    #[inline]
109    pub fn increment_count(&mut self) {
110        self.count = self.count.saturating_add(1);
111    }
112
113    /// Returns the realm as `&str`.
114    #[inline]
115    pub fn realm(&self) -> &str {
116        self.realm.as_ref()
117    }
118
119    /// Returns the identifier as `&str`.
120    #[inline]
121    pub fn identifier(&self) -> &str {
122        self.identifier.as_ref()
123    }
124
125    /// Returns the thread.
126    #[inline]
127    pub fn thread(&self) -> u8 {
128        self.thread
129    }
130
131    /// Returns the count.
132    #[inline]
133    pub fn count(&self) -> u8 {
134        self.count
135    }
136
137    /// Parses the `SessionId`.
138    pub fn parse(s: &str) -> Result<SessionId, ParseSessionIdError> {
139        if let Some(s) = s.strip_prefix("SID:ANON:")
140            && let Some((realm, s)) = s.split_once(':')
141        {
142            if let Some((identifier, s)) = s.split_once('-') {
143                if let Some((thread, count)) = s.split_once(':') {
144                    return u8::from_str_radix(thread, 16)
145                        .map_err(|err| ParseThreadError(err.into()))
146                        .and_then(|thread| {
147                            u8::from_str_radix(count, 16)
148                                .map_err(|err| ParseCountError(err.into()))
149                                .map(|count| Self {
150                                    realm: realm.to_owned().into(),
151                                    identifier: identifier.to_owned(),
152                                    thread,
153                                    count,
154                                })
155                        });
156                } else {
157                    return u8::from_str_radix(s, 16)
158                        .map_err(|err| ParseThreadError(err.into()))
159                        .map(|thread| Self {
160                            realm: realm.to_owned().into(),
161                            identifier: identifier.to_owned(),
162                            thread,
163                            count: 0,
164                        });
165                }
166            } else if let Some((identifier, count)) = s.split_once(':') {
167                return u8::from_str_radix(count, 16)
168                    .map_err(|err| ParseCountError(err.into()))
169                    .map(|count| Self {
170                        realm: realm.to_owned().into(),
171                        identifier: identifier.to_owned(),
172                        thread: 0,
173                        count,
174                    });
175            } else {
176                return Ok(Self {
177                    realm: realm.to_owned().into(),
178                    identifier: s.to_owned(),
179                    thread: 0,
180                    count: 0,
181                });
182            }
183        }
184        Err(InvalidFormat)
185    }
186}
187
188impl fmt::Display for SessionId {
189    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190        let realm = &self.realm;
191        let identifier = &self.identifier;
192        let thread = self.thread;
193        let count = self.count;
194        if thread > 0 {
195            if count > 0 {
196                write!(f, "SID:ANON:{realm}:{identifier}-{thread:x}:{count:x}")
197            } else {
198                write!(f, "SID:ANON:{realm}:{identifier}-{thread:x}")
199            }
200        } else if count > 0 {
201            write!(f, "SID:ANON:{realm}:{identifier}:{count:x}")
202        } else {
203            write!(f, "SID:ANON:{realm}:{identifier}")
204        }
205    }
206}
207
208/// An error which can be returned when parsing a `SessionId`.
209#[derive(Debug)]
210pub enum ParseSessionIdError {
211    /// An error that can occur when parsing thread.
212    ParseThreadError(Error),
213    /// An error that can occur when parsing count.
214    ParseCountError(Error),
215    /// Invalid format.
216    InvalidFormat,
217}
218
219impl fmt::Display for ParseSessionIdError {
220    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221        match self {
222            ParseThreadError(err) => write!(f, "fail to parse thread: {err}"),
223            ParseCountError(err) => write!(f, "fail to parse count: {err}"),
224            InvalidFormat => write!(f, "invalid format"),
225        }
226    }
227}
228
229impl error::Error for ParseSessionIdError {}