vtcode_config/auth/
pkce.rs1use anyhow::{Context, Result};
7use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
8use sha2::{Digest, Sha256};
9
10const CODE_VERIFIER_LENGTH: usize = 64;
12
13const CODE_VERIFIER_CHARSET: &[u8] =
15 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
16
17#[derive(Debug, Clone)]
19pub struct PkceChallenge {
20 pub code_verifier: String,
22 pub code_challenge: String,
24 pub code_challenge_method: String,
26}
27
28impl PkceChallenge {
29 pub fn from_verifier(code_verifier: String) -> Result<Self> {
31 let code_challenge = compute_s256_challenge(&code_verifier)?;
32 Ok(Self {
33 code_verifier,
34 code_challenge,
35 code_challenge_method: "S256".to_string(),
36 })
37 }
38}
39
40pub fn generate_pkce_challenge() -> Result<PkceChallenge> {
54 let code_verifier = generate_code_verifier()?;
55 PkceChallenge::from_verifier(code_verifier)
56}
57
58fn generate_code_verifier() -> Result<String> {
60 use std::time::{SystemTime, UNIX_EPOCH};
61
62 let mut verifier = String::with_capacity(CODE_VERIFIER_LENGTH);
65
66 let nanos = SystemTime::now()
68 .duration_since(UNIX_EPOCH)
69 .context("System time before UNIX epoch")?
70 .as_nanos();
71
72 let pid = std::process::id() as u128;
73 let mut state = nanos.wrapping_add(pid);
74
75 for _ in 0..CODE_VERIFIER_LENGTH {
77 state ^= state << 13;
79 state ^= state >> 7;
80 state ^= state << 17;
81
82 let extra = SystemTime::now()
84 .duration_since(UNIX_EPOCH)
85 .map(|d| d.as_nanos())
86 .unwrap_or(0);
87 state = state.wrapping_add(extra);
88
89 let idx = (state % CODE_VERIFIER_CHARSET.len() as u128) as usize;
90 verifier.push(CODE_VERIFIER_CHARSET[idx] as char);
91 }
92
93 Ok(verifier)
94}
95
96fn compute_s256_challenge(code_verifier: &str) -> Result<String> {
100 let mut hasher = Sha256::new();
101 hasher.update(code_verifier.as_bytes());
102 let hash = hasher.finalize();
103
104 Ok(URL_SAFE_NO_PAD.encode(hash))
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_generate_pkce_challenge() {
113 let challenge = generate_pkce_challenge().unwrap();
114
115 assert_eq!(challenge.code_verifier.len(), CODE_VERIFIER_LENGTH);
117
118 for c in challenge.code_verifier.chars() {
120 assert!(
121 CODE_VERIFIER_CHARSET.contains(&(c as u8)),
122 "Invalid character in verifier: {}",
123 c
124 );
125 }
126
127 assert_eq!(challenge.code_challenge_method, "S256");
129
130 assert_eq!(challenge.code_challenge.len(), 43);
132 }
133
134 #[test]
135 fn test_deterministic_challenge() {
136 let verifier = "test_verifier_string_for_deterministic_test";
138 let challenge1 = PkceChallenge::from_verifier(verifier.to_string()).unwrap();
139 let challenge2 = PkceChallenge::from_verifier(verifier.to_string()).unwrap();
140
141 assert_eq!(challenge1.code_challenge, challenge2.code_challenge);
142 }
143
144 #[test]
145 fn test_unique_verifiers() {
146 let c1 = generate_pkce_challenge().unwrap();
148 let c2 = generate_pkce_challenge().unwrap();
149
150 assert_ne!(c1.code_verifier, c2.code_verifier);
151 }
152}