1use self::ParseSessionIdError::*;
2use hmac::digest::{Digest, FixedOutput, HashMarker, Update};
3use serde::{Deserialize, Serialize};
4use std::{error, fmt};
5use zino_core::{SharedString, encoding::base64, error::Error, validation::Validation};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SessionId {
11 realm: SharedString,
14 identifier: String,
17 thread: u8,
20 count: u8,
23}
24
25impl SessionId {
26 #[inline]
28 pub fn new<D>(realm: impl Into<SharedString>, key: impl AsRef<[u8]>) -> Self
29 where
30 D: Default + FixedOutput + HashMarker + Update,
31 {
32 fn inner<D>(realm: SharedString, key: &[u8]) -> SessionId
33 where
34 D: Default + FixedOutput + HashMarker + Update,
35 {
36 let data = [realm.as_ref().as_bytes(), key].concat();
37 let mut hasher = D::new();
38 hasher.update(data.as_ref());
39
40 let identifier = base64::encode(hasher.finalize().as_slice());
41 SessionId {
42 realm,
43 identifier,
44 thread: 0,
45 count: 0,
46 }
47 }
48 inner::<D>(realm.into(), key.as_ref())
49 }
50
51 pub fn validate_with<D>(&self, realm: &str, key: impl AsRef<[u8]>) -> Validation
53 where
54 D: Default + FixedOutput + HashMarker + Update,
55 {
56 fn inner<D>(session_id: &SessionId, realm: &str, key: &[u8]) -> Validation
57 where
58 D: Default + FixedOutput + HashMarker + Update,
59 {
60 let mut validation = Validation::new();
61 let identifier = &session_id.identifier;
62 match base64::decode(identifier) {
63 Ok(hash) => {
64 let data = [realm.as_bytes(), key].concat();
65 let mut hasher = D::new();
66 hasher.update(data.as_ref());
67
68 if hasher.finalize().as_slice() != hash {
69 validation.record("identifier", "invalid session identifier");
70 }
71 }
72 Err(err) => {
73 validation.record_fail("identifier", err);
74 }
75 }
76 validation
77 }
78 inner::<D>(self, realm, key.as_ref())
79 }
80
81 pub fn accepts(&self, session_id: &SessionId) -> bool {
83 if self.identifier() != session_id.identifier() {
84 return false;
85 }
86
87 let realm = self.realm();
88 let domain = session_id.realm();
89 if domain == realm {
90 self.count() <= session_id.count()
91 } else {
92 let remainder = if realm.len() > domain.len() {
93 realm.strip_suffix(domain)
94 } else {
95 domain.strip_suffix(realm)
96 };
97 remainder.is_some_and(|s| s.ends_with('.'))
98 }
99 }
100
101 #[inline]
103 pub fn set_thread(&mut self, thread: u8) {
104 self.thread = thread;
105 }
106
107 #[inline]
109 pub fn increment_count(&mut self) {
110 self.count = self.count.saturating_add(1);
111 }
112
113 #[inline]
115 pub fn realm(&self) -> &str {
116 self.realm.as_ref()
117 }
118
119 #[inline]
121 pub fn identifier(&self) -> &str {
122 self.identifier.as_ref()
123 }
124
125 #[inline]
127 pub fn thread(&self) -> u8 {
128 self.thread
129 }
130
131 #[inline]
133 pub fn count(&self) -> u8 {
134 self.count
135 }
136
137 pub fn parse(s: &str) -> Result<SessionId, ParseSessionIdError> {
139 if let Some(s) = s.strip_prefix("SID:ANON:")
140 && let Some((realm, s)) = s.split_once(':')
141 {
142 if let Some((identifier, s)) = s.split_once('-') {
143 if let Some((thread, count)) = s.split_once(':') {
144 return u8::from_str_radix(thread, 16)
145 .map_err(|err| ParseThreadError(err.into()))
146 .and_then(|thread| {
147 u8::from_str_radix(count, 16)
148 .map_err(|err| ParseCountError(err.into()))
149 .map(|count| Self {
150 realm: realm.to_owned().into(),
151 identifier: identifier.to_owned(),
152 thread,
153 count,
154 })
155 });
156 } else {
157 return u8::from_str_radix(s, 16)
158 .map_err(|err| ParseThreadError(err.into()))
159 .map(|thread| Self {
160 realm: realm.to_owned().into(),
161 identifier: identifier.to_owned(),
162 thread,
163 count: 0,
164 });
165 }
166 } else if let Some((identifier, count)) = s.split_once(':') {
167 return u8::from_str_radix(count, 16)
168 .map_err(|err| ParseCountError(err.into()))
169 .map(|count| Self {
170 realm: realm.to_owned().into(),
171 identifier: identifier.to_owned(),
172 thread: 0,
173 count,
174 });
175 } else {
176 return Ok(Self {
177 realm: realm.to_owned().into(),
178 identifier: s.to_owned(),
179 thread: 0,
180 count: 0,
181 });
182 }
183 }
184 Err(InvalidFormat)
185 }
186}
187
188impl fmt::Display for SessionId {
189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190 let realm = &self.realm;
191 let identifier = &self.identifier;
192 let thread = self.thread;
193 let count = self.count;
194 if thread > 0 {
195 if count > 0 {
196 write!(f, "SID:ANON:{realm}:{identifier}-{thread:x}:{count:x}")
197 } else {
198 write!(f, "SID:ANON:{realm}:{identifier}-{thread:x}")
199 }
200 } else if count > 0 {
201 write!(f, "SID:ANON:{realm}:{identifier}:{count:x}")
202 } else {
203 write!(f, "SID:ANON:{realm}:{identifier}")
204 }
205 }
206}
207
208#[derive(Debug)]
210pub enum ParseSessionIdError {
211 ParseThreadError(Error),
213 ParseCountError(Error),
215 InvalidFormat,
217}
218
219impl fmt::Display for ParseSessionIdError {
220 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221 match self {
222 ParseThreadError(err) => write!(f, "fail to parse thread: {err}"),
223 ParseCountError(err) => write!(f, "fail to parse count: {err}"),
224 InvalidFormat => write!(f, "invalid format"),
225 }
226 }
227}
228
229impl error::Error for ParseSessionIdError {}