1use anyhow::Result;
7use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
8use ring::rand::{SecureRandom, SystemRandom};
9use sha2::{Digest, Sha256};
10
11const CODE_VERIFIER_LENGTH: usize = 64;
13
14const CODE_VERIFIER_CHARSET: &[u8] =
16 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
17
18#[derive(Debug, Clone)]
20pub struct PkceChallenge {
21 pub code_verifier: String,
23 pub code_challenge: String,
25 pub code_challenge_method: String,
27}
28
29impl PkceChallenge {
30 pub fn from_verifier(code_verifier: String) -> Result<Self> {
32 let code_challenge = compute_s256_challenge(&code_verifier)?;
33 Ok(Self {
34 code_verifier,
35 code_challenge,
36 code_challenge_method: "S256".to_string(),
37 })
38 }
39}
40
41pub fn generate_pkce_challenge() -> Result<PkceChallenge> {
55 let code_verifier = generate_code_verifier()?;
56 PkceChallenge::from_verifier(code_verifier)
57}
58
59fn generate_code_verifier() -> Result<String> {
64 let rng = SystemRandom::new();
65 let charset_len = CODE_VERIFIER_CHARSET.len() as u8;
66 let max_valid = (256u16 - 256u16 % charset_len as u16) as u8;
67 let mut verifier = String::with_capacity(CODE_VERIFIER_LENGTH);
68 let mut buf = [0u8; 1];
69
70 while verifier.len() < CODE_VERIFIER_LENGTH {
71 rng.fill(&mut buf)
72 .map_err(|_| anyhow::anyhow!("failed to read from OS random source"))?;
73 if buf[0] < max_valid {
75 let idx = (buf[0] % charset_len) as usize;
76 verifier.push(CODE_VERIFIER_CHARSET[idx] as char);
77 }
78 }
79
80 Ok(verifier)
81}
82
83fn compute_s256_challenge(code_verifier: &str) -> Result<String> {
87 let mut hasher = Sha256::new();
88 hasher.update(code_verifier.as_bytes());
89 let hash = hasher.finalize();
90
91 Ok(URL_SAFE_NO_PAD.encode(hash))
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn test_generate_pkce_challenge() {
100 let challenge = generate_pkce_challenge().unwrap();
101
102 assert_eq!(challenge.code_verifier.len(), CODE_VERIFIER_LENGTH);
104
105 for c in challenge.code_verifier.chars() {
107 assert!(
108 CODE_VERIFIER_CHARSET.contains(&(c as u8)),
109 "Invalid character in verifier: {}",
110 c
111 );
112 }
113
114 assert_eq!(challenge.code_challenge_method, "S256");
116
117 assert_eq!(challenge.code_challenge.len(), 43);
119 }
120
121 #[test]
122 fn test_deterministic_challenge() {
123 let verifier = "test_verifier_string_for_deterministic_test";
125 let challenge1 = PkceChallenge::from_verifier(verifier.to_string()).unwrap();
126 let challenge2 = PkceChallenge::from_verifier(verifier.to_string()).unwrap();
127
128 assert_eq!(challenge1.code_challenge, challenge2.code_challenge);
129 }
130
131 #[test]
132 fn test_unique_verifiers() {
133 let c1 = generate_pkce_challenge().unwrap();
135 let c2 = generate_pkce_challenge().unwrap();
136
137 assert_ne!(c1.code_verifier, c2.code_verifier);
138 }
139}