scram_2/
server.rs

1use std::borrow::Cow;
2
3use base64::engine::general_purpose::STANDARD as BASE64;
4use base64::Engine;
5use rand::distributions::{Distribution, Uniform};
6use rand::{rngs::OsRng, Rng};
7use ring::digest::SHA256_OUTPUT_LEN;
8use ring::hmac;
9
10use error::{Error, Field, Kind};
11use utils::find_proofs;
12use NONCE_LENGTH;
13
14/// Responds to client authentication challenges. It's the entrypoint for the SCRAM server side
15/// implementation.
16pub struct ScramServer<P: AuthenticationProvider> {
17    /// The ['AuthenticationProvider'] that will find passwords and check authorization.
18    provider: P,
19}
20
21/// Contains information about stored passwords. In particular, it stores the password that has been
22/// salted and hashed, the salt that was used, and the number of iterations of the hashing algorithm
23pub struct PasswordInfo {
24    hashed_password: Vec<u8>,
25    salt: Vec<u8>,
26    iterations: u16,
27}
28
29/// The status of authentication after the final client message has been received by the server.
30#[derive(Clone, Copy, PartialEq, Debug)]
31pub enum AuthenticationStatus {
32    /// The client has correctly authenticated, and has been authorized.
33    Authenticated,
34    /// The client was not correctly authenticated, meaning they supplied an incorrect password.
35    NotAuthenticated,
36    /// The client authenticated correctly, but was not authorized for the alternate user they
37    /// requested.
38    NotAuthorized,
39}
40
41impl PasswordInfo {
42    /// Create a new `PasswordInfo` from the given information. The password is assumed to have
43    /// already been hashed using the given salt and iterations.
44    pub fn new(hashed_password: Vec<u8>, iterations: u16, salt: Vec<u8>) -> Self {
45        PasswordInfo {
46            hashed_password,
47            iterations,
48            salt,
49        }
50    }
51}
52
53/// An `AuthenticationProvider` looks up password information for a given user, and also checks if a
54/// user is authorized to act on another user's behalf. The authorization component is optional, and
55/// if not implemented will simply allow users to act on their own behalf, and no one else's.
56///
57/// To ensure the password is hashed correctly, cleartext passwords can be hased using the
58/// [`hash_password`](crate::utils::hash_password) function provided in the crate root.
59pub trait AuthenticationProvider {
60    /// Gets the [`PasswordInfo`] for the given user.
61    fn get_password_for(&self, username: &str) -> Option<PasswordInfo>;
62
63    /// Checks to see if the user given by `authcid` is authorized to act as the user given by
64    /// `authzid.` Implementors do not need to implement this method. The default implementation
65    /// just checks if the two are equal
66    fn authorize(&self, authcid: &str, authzid: &str) -> bool {
67        authcid == authzid
68    }
69}
70
71/// Parses a client's first message by splitting it on commas and analyzing each part. Gives an
72/// error if the data was malformed in any way
73fn parse_client_first(data: &str) -> Result<(&str, Option<&str>, &str), Error> {
74    let mut parts = data.split(',');
75
76    // Channel binding
77    if let Some(part) = parts.next() {
78        if let Some(cb) = part.chars().next() {
79            if cb == 'p' {
80                return Err(Error::UnsupportedExtension);
81            }
82            if cb != 'n' && cb != 'y' || part.len() > 1 {
83                return Err(Error::Protocol(Kind::InvalidField(Field::ChannelBinding)));
84            }
85        } else {
86            return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
87        }
88    } else {
89        return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
90    }
91
92    // Authzid
93    let authzid = if let Some(part) = parts.next() {
94        if part.is_empty() {
95            None
96        } else if part.len() < 2 || &part.as_bytes()[..2] != b"a=" {
97            return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
98        } else {
99            Some(&part[2..])
100        }
101    } else {
102        return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
103    };
104
105    // Authcid
106    let authcid = parse_part!(parts, Authcid, b"n=");
107
108    // Nonce
109    let nonce = match parts.next() {
110        Some(part) if &part.as_bytes()[..2] == b"r=" => &part[2..],
111        _ => {
112            return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
113        }
114    };
115    Ok((authcid, authzid, nonce))
116}
117
118/// Parses the client's final message. Gives an error if the data was malformed.
119fn parse_client_final(data: &str) -> Result<(&str, &str, &str), Error> {
120    // 6 is the length of the required parts of the message
121    let mut parts = data.split(',');
122    let gs2header = parse_part!(parts, GS2Header, b"c=");
123    let nonce = parse_part!(parts, Nonce, b"r=");
124    let proof = parse_part!(parts, Proof, b"p=");
125    Ok((gs2header, nonce, proof))
126}
127
128impl<P: AuthenticationProvider> ScramServer<P> {
129    /// Creates a new `ScramServer` using the given authentication provider.
130    pub fn new(provider: P) -> Self {
131        ScramServer { provider }
132    }
133
134    /// Handle a challenge message sent by the client to the server. If the message is well formed,
135    /// and the requested user exists, then this will progress to the next stage of the
136    /// authentication process, [`ServerFirst`]. Otherwise, it will return an error.
137    pub fn handle_client_first<'a>(
138        &'a self,
139        client_first: &'a str,
140    ) -> Result<ServerFirst<'a, P>, Error> {
141        let (authcid, authzid, client_nonce) = parse_client_first(client_first)?;
142        let password_info = self
143            .provider
144            .get_password_for(authcid)
145            .ok_or_else(|| Error::InvalidUser(authcid.to_string()))?;
146        Ok(ServerFirst {
147            client_nonce,
148            authcid,
149            authzid,
150            provider: &self.provider,
151            password_info,
152        })
153    }
154}
155
156/// Represents the first stage in the authentication process, after the client has submitted their
157/// first message. This struct is responsible for responding to the message
158pub struct ServerFirst<'a, P: 'a + AuthenticationProvider> {
159    client_nonce: &'a str,
160    authcid: &'a str,
161    authzid: Option<&'a str>,
162    provider: &'a P,
163    password_info: PasswordInfo,
164}
165
166impl<'a, P: AuthenticationProvider> ServerFirst<'a, P> {
167    /// Creates the server's first message in response to the client's first message. By default,
168    /// this method uses [`OsRng`] as its source of randomness for the nonce. To specify the
169    /// randomness source, use [`server_first_with_rng`](Self::server_first_with_rng). This method
170    /// will return an error when it cannot initialize the OS's randomness source. See the
171    /// documentation on `OsRng` for more information.
172    pub fn server_first(self) -> (ClientFinal<'a, P>, String) {
173        self.server_first_with_rng(&mut OsRng)
174    }
175
176    /// Creates the server's first message in response to the client's first message, with the
177    /// given source of randomness used for the server's nonce. The randomness is assigned here
178    /// instead of universally in [`ScramServer`] for increased flexibility, and also to keep
179    /// `ScramServer` immutable.
180    pub fn server_first_with_rng<R: Rng>(self, rng: &mut R) -> (ClientFinal<'a, P>, String) {
181        let mut nonce = String::with_capacity(self.client_nonce.len() + NONCE_LENGTH);
182        nonce.push_str(self.client_nonce);
183        nonce.extend(
184            Uniform::from(33..125)
185                .sample_iter(rng)
186                .map(|x: u8| if x > 43 { (x + 1) as char } else { x as char })
187                .take(NONCE_LENGTH),
188        );
189
190        let gs2header: Cow<'static, str> = match self.authzid {
191            Some(authzid) => format!("n,a={},", authzid).into(),
192            None => "n,,".into(),
193        };
194        let client_first_bare: Cow<'static, str> =
195            format!("n={},r={}", self.authcid, self.client_nonce).into();
196        let server_first: Cow<'static, str> = format!(
197            "r={},s={},i={}",
198            nonce,
199            BASE64.encode(self.password_info.salt.as_slice()),
200            self.password_info.iterations
201        )
202        .into();
203        (
204            ClientFinal {
205                hashed_password: self.password_info.hashed_password,
206                nonce,
207                gs2header,
208                client_first_bare,
209                server_first: server_first.clone(),
210                authcid: self.authcid,
211                authzid: self.authzid,
212                provider: self.provider,
213            },
214            server_first.into_owned(),
215        )
216    }
217}
218
219/// Represents the stage after the server has generated its first response to the client. This
220/// struct is responsible for handling the client's final message.
221pub struct ClientFinal<'a, P: 'a + AuthenticationProvider> {
222    hashed_password: Vec<u8>,
223    nonce: String,
224    gs2header: Cow<'static, str>,
225    client_first_bare: Cow<'static, str>,
226    server_first: Cow<'static, str>,
227    authcid: &'a str,
228    authzid: Option<&'a str>,
229    provider: &'a P,
230}
231
232impl<'a, P: AuthenticationProvider> ClientFinal<'a, P> {
233    /// Handle the final client message. If the message is not well formed, or the authorization
234    /// header is invalid, then this will return an error. In all other cases (including when
235    /// authentication or authorization has failed), this will return `Ok` along with a message to
236    /// send the client. In cases where authentication or authorization has failed, the message will
237    /// contain error information for the client. To check if authentication and authorization have
238    /// succeeded, use [`server_final`](ServerFinal::server_final) on the return value.
239    pub fn handle_client_final(self, client_final: &str) -> Result<ServerFinal, Error> {
240        let (gs2header_enc, nonce, proof) = parse_client_final(client_final)?;
241        if !self.verify_header(gs2header_enc) {
242            return Err(Error::Protocol(Kind::InvalidField(Field::GS2Header)));
243        }
244        if !self.verify_nonce(nonce) {
245            return Err(Error::Protocol(Kind::InvalidField(Field::Nonce)));
246        }
247        if let Some(signature) = self.verify_proof(proof)? {
248            if let Some(authzid) = self.authzid {
249                if self.provider.authorize(self.authcid, authzid) {
250                    Ok(ServerFinal {
251                        status: AuthenticationStatus::Authenticated,
252                        signature,
253                    })
254                } else {
255                    Ok(ServerFinal {
256                        status: AuthenticationStatus::NotAuthorized,
257                        signature: format!(
258                            "e=User '{}' not authorized to act as '{}'",
259                            self.authcid, authzid
260                        ),
261                    })
262                }
263            } else {
264                Ok(ServerFinal {
265                    status: AuthenticationStatus::Authenticated,
266                    signature,
267                })
268            }
269        } else {
270            Ok(ServerFinal {
271                status: AuthenticationStatus::NotAuthenticated,
272                signature: "e=Invalid Password".to_string(),
273            })
274        }
275    }
276
277    /// Checks that the gs2header received from the client is the same as the one we've stored
278    fn verify_header(&self, gs2header: &str) -> bool {
279        let server_gs2header = BASE64.encode(self.gs2header.as_bytes());
280        server_gs2header == gs2header
281    }
282
283    /// Checks that the client has sent the same nonce
284    fn verify_nonce(&self, nonce: &str) -> bool {
285        nonce == self.nonce
286    }
287
288    /// Checks that the proof from the client matches our saved credentials
289    fn verify_proof(&self, proof: &str) -> Result<Option<String>, Error> {
290        let (client_proof, server_signature): ([u8; SHA256_OUTPUT_LEN], hmac::Tag) = find_proofs(
291            &self.gs2header,
292            &self.client_first_bare,
293            &self.server_first,
294            self.hashed_password.as_slice(),
295            &self.nonce,
296        );
297        let proof = if let Ok(proof) = BASE64.decode(proof.as_bytes()) {
298            proof
299        } else {
300            return Err(Error::Protocol(Kind::InvalidField(Field::Proof)));
301        };
302        if proof != client_proof {
303            return Ok(None);
304        }
305
306        let server_signature_string = format!("v={}", BASE64.encode(server_signature.as_ref()));
307        Ok(Some(server_signature_string))
308    }
309}
310
311/// Represents the final stage of authentication, after we have generated the final server message
312/// to send to the client
313pub struct ServerFinal {
314    status: AuthenticationStatus,
315    signature: String,
316}
317
318impl ServerFinal {
319    /// Get the [`AuthenticationStatus`] of the exchange. This status can be successful, failed
320    /// because of invalid authentication or failed because of invalid authorization.
321    pub fn server_final(self) -> (AuthenticationStatus, String) {
322        (self.status, self.signature)
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::super::{Error, Field, Kind};
329    use super::{parse_client_final, parse_client_first};
330
331    #[test]
332    fn test_parse_client_first_success() {
333        let (authcid, authzid, nonce) = parse_client_first("n,,n=user,r=abcdefghijk").unwrap();
334        assert_eq!(authcid, "user");
335        assert!(authzid.is_none());
336        assert_eq!(nonce, "abcdefghijk");
337
338        let (authcid, authzid, nonce) =
339            parse_client_first("y,a=other user,n=user,r=abcdef=hijk").unwrap();
340        assert_eq!(authcid, "user");
341        assert_eq!(authzid, Some("other user"));
342        assert_eq!(nonce, "abcdef=hijk");
343
344        let (authcid, authzid, nonce) = parse_client_first("n,,n=,r=").unwrap();
345        assert_eq!(authcid, "");
346        assert!(authzid.is_none());
347        assert_eq!(nonce, "");
348    }
349
350    #[test]
351    fn test_parse_client_first_missing_fields() {
352        assert_eq!(
353            parse_client_first("n,,n=user").unwrap_err(),
354            Error::Protocol(Kind::ExpectedField(Field::Nonce))
355        );
356        assert_eq!(
357            parse_client_first("n,,r=user").unwrap_err(),
358            Error::Protocol(Kind::ExpectedField(Field::Authcid))
359        );
360        assert_eq!(
361            parse_client_first("n,n=user,r=abc").unwrap_err(),
362            Error::Protocol(Kind::ExpectedField(Field::Authzid))
363        );
364        assert_eq!(
365            parse_client_first(",,n=user,r=abc").unwrap_err(),
366            Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
367        );
368        assert_eq!(
369            parse_client_first("").unwrap_err(),
370            Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
371        );
372        assert_eq!(
373            parse_client_first(",,,").unwrap_err(),
374            Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
375        );
376    }
377    #[test]
378    fn test_parse_client_first_invalid_data() {
379        assert_eq!(
380            parse_client_first("a,,n=user,r=abc").unwrap_err(),
381            Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
382        );
383        assert_eq!(
384            parse_client_first("p,,n=user,r=abc").unwrap_err(),
385            Error::UnsupportedExtension
386        );
387        assert_eq!(
388            parse_client_first("nn,,n=user,r=abc").unwrap_err(),
389            Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
390        );
391        assert_eq!(
392            parse_client_first("n,,n,r=abc").unwrap_err(),
393            Error::Protocol(Kind::ExpectedField(Field::Authcid))
394        );
395    }
396
397    #[test]
398    fn test_parse_client_final_success() {
399        let (gs2head, nonce, proof) = parse_client_final("c=abc,r=abcefg,p=783232").unwrap();
400        assert_eq!(gs2head, "abc");
401        assert_eq!(nonce, "abcefg");
402        assert_eq!(proof, "783232");
403
404        let (gs2head, nonce, proof) = parse_client_final("c=,r=,p=").unwrap();
405        assert_eq!(gs2head, "");
406        assert_eq!(nonce, "");
407        assert_eq!(proof, "");
408    }
409
410    #[test]
411    fn test_parse_client_final_missing_fields() {
412        assert_eq!(
413            parse_client_final("c=whatever,r=something").unwrap_err(),
414            Error::Protocol(Kind::ExpectedField(Field::Proof))
415        );
416        assert_eq!(
417            parse_client_final("c=whatever,p=words").unwrap_err(),
418            Error::Protocol(Kind::ExpectedField(Field::Nonce))
419        );
420        assert_eq!(
421            parse_client_final("c=whatever").unwrap_err(),
422            Error::Protocol(Kind::ExpectedField(Field::Nonce))
423        );
424        assert_eq!(
425            parse_client_final("c=").unwrap_err(),
426            Error::Protocol(Kind::ExpectedField(Field::Nonce))
427        );
428        assert_eq!(
429            parse_client_final("").unwrap_err(),
430            Error::Protocol(Kind::ExpectedField(Field::GS2Header))
431        );
432        assert_eq!(
433            parse_client_final("r=anonce").unwrap_err(),
434            Error::Protocol(Kind::ExpectedField(Field::GS2Header))
435        );
436    }
437}