1use std::borrow::Cow;
2
3use base64::engine::general_purpose::STANDARD as BASE64;
4use base64::Engine;
5use rand::distributions::{Distribution, Uniform};
6use rand::{rngs::OsRng, Rng};
7use ring::digest::SHA256_OUTPUT_LEN;
8use ring::hmac;
9
10use error::{Error, Field, Kind};
11use utils::find_proofs;
12use NONCE_LENGTH;
13
14pub struct ScramServer<P: AuthenticationProvider> {
17 provider: P,
19}
20
21pub struct PasswordInfo {
24 hashed_password: Vec<u8>,
25 salt: Vec<u8>,
26 iterations: u16,
27}
28
29#[derive(Clone, Copy, PartialEq, Debug)]
31pub enum AuthenticationStatus {
32 Authenticated,
34 NotAuthenticated,
36 NotAuthorized,
39}
40
41impl PasswordInfo {
42 pub fn new(hashed_password: Vec<u8>, iterations: u16, salt: Vec<u8>) -> Self {
45 PasswordInfo {
46 hashed_password,
47 iterations,
48 salt,
49 }
50 }
51}
52
53pub trait AuthenticationProvider {
60 fn get_password_for(&self, username: &str) -> Option<PasswordInfo>;
62
63 fn authorize(&self, authcid: &str, authzid: &str) -> bool {
67 authcid == authzid
68 }
69}
70
71fn parse_client_first(data: &str) -> Result<(&str, Option<&str>, &str), Error> {
74 let mut parts = data.split(',');
75
76 if let Some(part) = parts.next() {
78 if let Some(cb) = part.chars().next() {
79 if cb == 'p' {
80 return Err(Error::UnsupportedExtension);
81 }
82 if cb != 'n' && cb != 'y' || part.len() > 1 {
83 return Err(Error::Protocol(Kind::InvalidField(Field::ChannelBinding)));
84 }
85 } else {
86 return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
87 }
88 } else {
89 return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
90 }
91
92 let authzid = if let Some(part) = parts.next() {
94 if part.is_empty() {
95 None
96 } else if part.len() < 2 || &part.as_bytes()[..2] != b"a=" {
97 return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
98 } else {
99 Some(&part[2..])
100 }
101 } else {
102 return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
103 };
104
105 let authcid = parse_part!(parts, Authcid, b"n=");
107
108 let nonce = match parts.next() {
110 Some(part) if &part.as_bytes()[..2] == b"r=" => &part[2..],
111 _ => {
112 return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
113 }
114 };
115 Ok((authcid, authzid, nonce))
116}
117
118fn parse_client_final(data: &str) -> Result<(&str, &str, &str), Error> {
120 let mut parts = data.split(',');
122 let gs2header = parse_part!(parts, GS2Header, b"c=");
123 let nonce = parse_part!(parts, Nonce, b"r=");
124 let proof = parse_part!(parts, Proof, b"p=");
125 Ok((gs2header, nonce, proof))
126}
127
128impl<P: AuthenticationProvider> ScramServer<P> {
129 pub fn new(provider: P) -> Self {
131 ScramServer { provider }
132 }
133
134 pub fn handle_client_first<'a>(
138 &'a self,
139 client_first: &'a str,
140 ) -> Result<ServerFirst<'a, P>, Error> {
141 let (authcid, authzid, client_nonce) = parse_client_first(client_first)?;
142 let password_info = self
143 .provider
144 .get_password_for(authcid)
145 .ok_or_else(|| Error::InvalidUser(authcid.to_string()))?;
146 Ok(ServerFirst {
147 client_nonce,
148 authcid,
149 authzid,
150 provider: &self.provider,
151 password_info,
152 })
153 }
154}
155
156pub struct ServerFirst<'a, P: 'a + AuthenticationProvider> {
159 client_nonce: &'a str,
160 authcid: &'a str,
161 authzid: Option<&'a str>,
162 provider: &'a P,
163 password_info: PasswordInfo,
164}
165
166impl<'a, P: AuthenticationProvider> ServerFirst<'a, P> {
167 pub fn server_first(self) -> (ClientFinal<'a, P>, String) {
173 self.server_first_with_rng(&mut OsRng)
174 }
175
176 pub fn server_first_with_rng<R: Rng>(self, rng: &mut R) -> (ClientFinal<'a, P>, String) {
181 let mut nonce = String::with_capacity(self.client_nonce.len() + NONCE_LENGTH);
182 nonce.push_str(self.client_nonce);
183 nonce.extend(
184 Uniform::from(33..125)
185 .sample_iter(rng)
186 .map(|x: u8| if x > 43 { (x + 1) as char } else { x as char })
187 .take(NONCE_LENGTH),
188 );
189
190 let gs2header: Cow<'static, str> = match self.authzid {
191 Some(authzid) => format!("n,a={},", authzid).into(),
192 None => "n,,".into(),
193 };
194 let client_first_bare: Cow<'static, str> =
195 format!("n={},r={}", self.authcid, self.client_nonce).into();
196 let server_first: Cow<'static, str> = format!(
197 "r={},s={},i={}",
198 nonce,
199 BASE64.encode(self.password_info.salt.as_slice()),
200 self.password_info.iterations
201 )
202 .into();
203 (
204 ClientFinal {
205 hashed_password: self.password_info.hashed_password,
206 nonce,
207 gs2header,
208 client_first_bare,
209 server_first: server_first.clone(),
210 authcid: self.authcid,
211 authzid: self.authzid,
212 provider: self.provider,
213 },
214 server_first.into_owned(),
215 )
216 }
217}
218
219pub struct ClientFinal<'a, P: 'a + AuthenticationProvider> {
222 hashed_password: Vec<u8>,
223 nonce: String,
224 gs2header: Cow<'static, str>,
225 client_first_bare: Cow<'static, str>,
226 server_first: Cow<'static, str>,
227 authcid: &'a str,
228 authzid: Option<&'a str>,
229 provider: &'a P,
230}
231
232impl<'a, P: AuthenticationProvider> ClientFinal<'a, P> {
233 pub fn handle_client_final(self, client_final: &str) -> Result<ServerFinal, Error> {
240 let (gs2header_enc, nonce, proof) = parse_client_final(client_final)?;
241 if !self.verify_header(gs2header_enc) {
242 return Err(Error::Protocol(Kind::InvalidField(Field::GS2Header)));
243 }
244 if !self.verify_nonce(nonce) {
245 return Err(Error::Protocol(Kind::InvalidField(Field::Nonce)));
246 }
247 if let Some(signature) = self.verify_proof(proof)? {
248 if let Some(authzid) = self.authzid {
249 if self.provider.authorize(self.authcid, authzid) {
250 Ok(ServerFinal {
251 status: AuthenticationStatus::Authenticated,
252 signature,
253 })
254 } else {
255 Ok(ServerFinal {
256 status: AuthenticationStatus::NotAuthorized,
257 signature: format!(
258 "e=User '{}' not authorized to act as '{}'",
259 self.authcid, authzid
260 ),
261 })
262 }
263 } else {
264 Ok(ServerFinal {
265 status: AuthenticationStatus::Authenticated,
266 signature,
267 })
268 }
269 } else {
270 Ok(ServerFinal {
271 status: AuthenticationStatus::NotAuthenticated,
272 signature: "e=Invalid Password".to_string(),
273 })
274 }
275 }
276
277 fn verify_header(&self, gs2header: &str) -> bool {
279 let server_gs2header = BASE64.encode(self.gs2header.as_bytes());
280 server_gs2header == gs2header
281 }
282
283 fn verify_nonce(&self, nonce: &str) -> bool {
285 nonce == self.nonce
286 }
287
288 fn verify_proof(&self, proof: &str) -> Result<Option<String>, Error> {
290 let (client_proof, server_signature): ([u8; SHA256_OUTPUT_LEN], hmac::Tag) = find_proofs(
291 &self.gs2header,
292 &self.client_first_bare,
293 &self.server_first,
294 self.hashed_password.as_slice(),
295 &self.nonce,
296 );
297 let proof = if let Ok(proof) = BASE64.decode(proof.as_bytes()) {
298 proof
299 } else {
300 return Err(Error::Protocol(Kind::InvalidField(Field::Proof)));
301 };
302 if proof != client_proof {
303 return Ok(None);
304 }
305
306 let server_signature_string = format!("v={}", BASE64.encode(server_signature.as_ref()));
307 Ok(Some(server_signature_string))
308 }
309}
310
311pub struct ServerFinal {
314 status: AuthenticationStatus,
315 signature: String,
316}
317
318impl ServerFinal {
319 pub fn server_final(self) -> (AuthenticationStatus, String) {
322 (self.status, self.signature)
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::super::{Error, Field, Kind};
329 use super::{parse_client_final, parse_client_first};
330
331 #[test]
332 fn test_parse_client_first_success() {
333 let (authcid, authzid, nonce) = parse_client_first("n,,n=user,r=abcdefghijk").unwrap();
334 assert_eq!(authcid, "user");
335 assert!(authzid.is_none());
336 assert_eq!(nonce, "abcdefghijk");
337
338 let (authcid, authzid, nonce) =
339 parse_client_first("y,a=other user,n=user,r=abcdef=hijk").unwrap();
340 assert_eq!(authcid, "user");
341 assert_eq!(authzid, Some("other user"));
342 assert_eq!(nonce, "abcdef=hijk");
343
344 let (authcid, authzid, nonce) = parse_client_first("n,,n=,r=").unwrap();
345 assert_eq!(authcid, "");
346 assert!(authzid.is_none());
347 assert_eq!(nonce, "");
348 }
349
350 #[test]
351 fn test_parse_client_first_missing_fields() {
352 assert_eq!(
353 parse_client_first("n,,n=user").unwrap_err(),
354 Error::Protocol(Kind::ExpectedField(Field::Nonce))
355 );
356 assert_eq!(
357 parse_client_first("n,,r=user").unwrap_err(),
358 Error::Protocol(Kind::ExpectedField(Field::Authcid))
359 );
360 assert_eq!(
361 parse_client_first("n,n=user,r=abc").unwrap_err(),
362 Error::Protocol(Kind::ExpectedField(Field::Authzid))
363 );
364 assert_eq!(
365 parse_client_first(",,n=user,r=abc").unwrap_err(),
366 Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
367 );
368 assert_eq!(
369 parse_client_first("").unwrap_err(),
370 Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
371 );
372 assert_eq!(
373 parse_client_first(",,,").unwrap_err(),
374 Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
375 );
376 }
377 #[test]
378 fn test_parse_client_first_invalid_data() {
379 assert_eq!(
380 parse_client_first("a,,n=user,r=abc").unwrap_err(),
381 Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
382 );
383 assert_eq!(
384 parse_client_first("p,,n=user,r=abc").unwrap_err(),
385 Error::UnsupportedExtension
386 );
387 assert_eq!(
388 parse_client_first("nn,,n=user,r=abc").unwrap_err(),
389 Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
390 );
391 assert_eq!(
392 parse_client_first("n,,n,r=abc").unwrap_err(),
393 Error::Protocol(Kind::ExpectedField(Field::Authcid))
394 );
395 }
396
397 #[test]
398 fn test_parse_client_final_success() {
399 let (gs2head, nonce, proof) = parse_client_final("c=abc,r=abcefg,p=783232").unwrap();
400 assert_eq!(gs2head, "abc");
401 assert_eq!(nonce, "abcefg");
402 assert_eq!(proof, "783232");
403
404 let (gs2head, nonce, proof) = parse_client_final("c=,r=,p=").unwrap();
405 assert_eq!(gs2head, "");
406 assert_eq!(nonce, "");
407 assert_eq!(proof, "");
408 }
409
410 #[test]
411 fn test_parse_client_final_missing_fields() {
412 assert_eq!(
413 parse_client_final("c=whatever,r=something").unwrap_err(),
414 Error::Protocol(Kind::ExpectedField(Field::Proof))
415 );
416 assert_eq!(
417 parse_client_final("c=whatever,p=words").unwrap_err(),
418 Error::Protocol(Kind::ExpectedField(Field::Nonce))
419 );
420 assert_eq!(
421 parse_client_final("c=whatever").unwrap_err(),
422 Error::Protocol(Kind::ExpectedField(Field::Nonce))
423 );
424 assert_eq!(
425 parse_client_final("c=").unwrap_err(),
426 Error::Protocol(Kind::ExpectedField(Field::Nonce))
427 );
428 assert_eq!(
429 parse_client_final("").unwrap_err(),
430 Error::Protocol(Kind::ExpectedField(Field::GS2Header))
431 );
432 assert_eq!(
433 parse_client_final("r=anonce").unwrap_err(),
434 Error::Protocol(Kind::ExpectedField(Field::GS2Header))
435 );
436 }
437}