1use std::borrow::Cow;
2use std::num::NonZeroU32;
3
4use base64::engine::general_purpose::STANDARD as BASE64;
5use base64::Engine;
6use rand::distributions::{Distribution, Uniform};
7use rand::{rngs::OsRng, Rng};
8use ring::digest::SHA256_OUTPUT_LEN;
9use ring::hmac;
10
11use error::{Error, Field, Kind};
12use utils::{find_proofs, hash_password};
13use NONCE_LENGTH;
14
15#[deprecated(
16 since = "0.2.0",
17 note = "Please use `ScramClient` instead. (exported at crate root)"
18)]
19pub type ClientFirst<'a> = ScramClient<'a>;
20
21fn parse_server_first(data: &str) -> Result<(&str, Vec<u8>, NonZeroU32), Error> {
23 if data.len() < 2 {
24 return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
25 }
26 let mut parts = data.split(',').peekable();
27 match parts.peek() {
28 Some(part) if &part.as_bytes()[..2] == b"m=" => {
29 return Err(Error::UnsupportedExtension);
30 }
31 Some(_) => {}
32 None => {
33 return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
34 }
35 }
36 let nonce = match parts.next() {
37 Some(part) if &part.as_bytes()[..2] == b"r=" => &part[2..],
38 _ => {
39 return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
40 }
41 };
42 let salt = match parts.next() {
43 Some(part) if &part.as_bytes()[..2] == b"s=" => BASE64.decode(part[2..].as_bytes())
44 .map_err(|_| Error::Protocol(Kind::InvalidField(Field::Salt)))?,
45 _ => {
46 return Err(Error::Protocol(Kind::ExpectedField(Field::Salt)));
47 }
48 };
49 let iterations = match parts.next() {
50 Some(part) if &part.as_bytes()[..2] == b"i=" => part[2..]
51 .parse()
52 .map_err(|_| Error::Protocol(Kind::InvalidField(Field::Iterations)))?,
53 _ => {
54 return Err(Error::Protocol(Kind::ExpectedField(Field::Iterations)));
55 }
56 };
57 Ok((nonce, salt, iterations))
58}
59
60fn parse_server_final(data: &str) -> Result<Vec<u8>, Error> {
61 if data.len() < 2 {
62 return Err(Error::Protocol(Kind::ExpectedField(Field::VerifyOrError)));
63 }
64 match &data[..2] {
65 "v=" => BASE64.decode(&data.as_bytes()[2..])
66 .map_err(|_| Error::Protocol(Kind::InvalidField(Field::VerifyOrError))),
67 "e=" => Err(Error::Authentication(data[2..].to_string())),
68 _ => Err(Error::Protocol(Kind::ExpectedField(Field::VerifyOrError))),
69 }
70}
71
72#[derive(Debug)]
74pub struct ScramClient<'a> {
75 gs2header: Cow<'static, str>,
76 password: &'a str,
77 nonce: String,
78 authcid: &'a str,
79}
80
81impl<'a> ScramClient<'a> {
82 pub fn new(authcid: &'a str, password: &'a str, authzid: Option<&'a str>) -> Self {
96 Self::with_rng(authcid, password, authzid, &mut OsRng)
97 }
98
99 pub fn with_rng<R: Rng + ?Sized>(
112 authcid: &'a str,
113 password: &'a str,
114 authzid: Option<&'a str>,
115 rng: &mut R,
116 ) -> Self {
117 let gs2header: Cow<'static, str> = match authzid {
118 Some(authzid) => format!("n,a={},", authzid).into(),
119 None => "n,,".into(),
120 };
121 let nonce: String = Uniform::from(33..125)
122 .sample_iter(rng)
123 .map(|x: u8| if x > 43 { (x + 1) as char } else { x as char })
124 .take(NONCE_LENGTH)
125 .collect();
126 ScramClient {
127 gs2header,
128 password,
129 authcid,
130 nonce,
131 }
132 }
133
134 pub fn client_first(self) -> (ServerFirst<'a>, String) {
138 let escaped_authcid: Cow<'a, str> =
139 if self.authcid.chars().any(|chr| chr == ',' || chr == '=') {
140 self.authcid.into()
141 } else {
142 self.authcid.replace(',', "=2C").replace('=', "=3D").into()
143 };
144 let client_first_bare = format!("n={},r={}", escaped_authcid, self.nonce);
145 let client_first = format!("{}{}", self.gs2header, client_first_bare);
146 let server_first = ServerFirst {
147 gs2header: self.gs2header,
148 password: self.password,
149 client_nonce: self.nonce,
150 client_first_bare,
151 };
152 (server_first, client_first)
153 }
154}
155
156#[derive(Debug)]
158pub struct ServerFirst<'a> {
159 gs2header: Cow<'static, str>,
160 password: &'a str,
161 client_nonce: String,
162 client_first_bare: String,
163}
164
165impl<'a> ServerFirst<'a> {
166 pub fn handle_server_first(self, server_first: &str) -> Result<ClientFinal, Error> {
178 let (nonce, salt, iterations) = parse_server_first(server_first)?;
179 if !nonce.starts_with(&self.client_nonce) {
180 return Err(Error::Protocol(Kind::InvalidNonce));
181 }
182 let salted_password = hash_password(self.password, iterations, &salt);
183 let (client_proof, server_signature): ([u8; SHA256_OUTPUT_LEN], hmac::Tag) = find_proofs(
184 &self.gs2header,
185 &self.client_first_bare,
186 server_first,
187 &salted_password,
188 nonce,
189 );
190 let client_final = format!(
191 "c={},r={},p={}",
192 BASE64.encode(self.gs2header.as_bytes()),
193 nonce,
194 BASE64.encode(client_proof)
195 );
196 Ok(ClientFinal {
197 server_signature,
198 client_final,
199 })
200 }
201}
202
203#[derive(Debug)]
206pub struct ClientFinal {
207 server_signature: hmac::Tag,
208 client_final: String,
209}
210
211impl ClientFinal {
212 #[inline]
217 pub fn client_final(self) -> (ServerFinal, String) {
218 let server_final = ServerFinal {
219 server_signature: self.server_signature,
220 };
221 (server_final, self.client_final)
222 }
223}
224
225#[derive(Debug)]
227pub struct ServerFinal {
228 server_signature: hmac::Tag,
229}
230
231impl ServerFinal {
232 pub fn handle_server_final(self, server_final: &str) -> Result<(), Error> {
244 if self.server_signature.as_ref() == &*parse_server_final(server_final)? {
245 Ok(())
246 } else {
247 Err(Error::InvalidServer)
248 }
249 }
250}