1use num_bigint::{BigInt, Sign, ToBigInt};
2use sha1::{Digest, Sha1};
3
4#[derive(Debug)]
5pub struct Srp {
6 pub session_key: Vec<u8>,
7 modulus: BigInt,
8 generator: BigInt,
9 private_ephemeral: BigInt,
10 public_ephemeral: BigInt,
11 server_ephemeral: BigInt,
12 salt: [u8; 32],
13 client_proof: Option<[u8; 20]>,
14}
15
16impl Srp {
18 pub fn new(n: &[u8], g: &[u8], server_ephemeral: &[u8; 32], salt: [u8; 32]) -> Self {
19 let private_ephemeral: [u8; 19] = rand::random();
20
21 let modulus = BigInt::from_bytes_le(Sign::Plus, n);
22 let generator = BigInt::from_bytes_le(Sign::Plus, g);
23
24 let private_ephemeral = BigInt::from_bytes_le(Sign::Plus, &private_ephemeral);
25 let public_ephemeral = generator.modpow(&private_ephemeral, &modulus);
26 let server_ephemeral = BigInt::from_bytes_le(Sign::Plus, server_ephemeral);
27
28 Self {
29 session_key: Vec::new(),
30 modulus,
31 generator,
32 private_ephemeral,
33 public_ephemeral,
34 server_ephemeral,
35 salt,
36 client_proof: None,
37 }
38 }
39
40 pub fn public_ephemeral(&mut self) -> [u8; 32] {
41 Self::pad_to_32_bytes(self.public_ephemeral.to_bytes_le().1)
42 }
43
44 pub fn session_key(&mut self) -> Vec<u8> {
45 self.session_key.to_vec()
46 }
47
48 pub fn calculate_proof(&mut self, account: &str) -> [u8; 20] {
49 let result = Sha1::new()
50 .chain(self.calculate_xor_hash())
51 .chain(Self::calculate_account_hash(account))
52 .chain(self.salt)
53 .chain(self.public_ephemeral.to_bytes_le().1)
54 .chain(self.server_ephemeral.to_bytes_le().1)
55 .chain(&self.session_key)
56 .finalize()
57 .to_vec();
58
59 let mut output = [0u8; 20];
60 output.copy_from_slice(&result);
61
62 self.client_proof = Some(output);
63
64 output
65 }
66
67 pub fn calculate_session_key(&mut self, account: &str, password: &str) {
68 let salt = self.salt;
69 let x = self.calculate_x(account, password, &salt);
70 let verifier = self.generator.modpow(
71 &x,
72 &self.modulus,
73 );
74
75 let mut session_key = Self::calculate_interleaved(
76 self.calculate_s(x, verifier)
77 );
78
79 while let Some(&0) = session_key.last() {
82 session_key.truncate(session_key.len() - 1);
83 }
84
85 self.session_key = session_key;
86 }
87
88 pub fn validate_proof(&mut self, server_proof: [u8; 20]) -> bool {
89 let client_proof = {
90 let hasher = Sha1::new();
91
92 let result = hasher
93 .chain(self.public_ephemeral())
94 .chain(self.client_proof.unwrap())
95 .chain(self.session_key.clone())
96 .finalize();
97
98 let mut hashed_proof = [0u8; 20];
99 hashed_proof.copy_from_slice(&result);
100 hashed_proof
101 };
102
103 client_proof == server_proof
104 }
105}
106
107impl Srp {
109 fn calculate_account_hash(account: &str) -> Vec<u8> {
110 Sha1::new()
111 .chain(account.as_bytes())
112 .finalize()
113 .to_vec()
114 }
115
116 fn calculate_xor_hash(&mut self) -> Vec<u8> {
117 let n_hash = Sha1::new().chain(self.modulus.to_bytes_le().1).finalize();
118 let g_hash = Sha1::new().chain(self.generator.to_bytes_le().1).finalize();
119
120 let mut xor_hash = Vec::new();
121 for (index, value) in g_hash.iter().enumerate() {
122 xor_hash.push(value ^ n_hash[index]);
123 }
124
125 xor_hash
126 }
127
128 fn calculate_x(&mut self, account: &str, password: &str, salt: &[u8]) -> BigInt {
129 let identity_hash = Sha1::new()
130 .chain(format!("{}:{}", account, password).as_bytes())
131 .finalize()
132 .to_vec();
133
134 let x = Sha1::new()
135 .chain(salt)
136 .chain(identity_hash)
137 .finalize()
138 .to_vec();
139
140 BigInt::from_bytes_le(Sign::Plus, &x)
141 }
142
143 fn calculate_u(&mut self) -> BigInt {
144 let u = Sha1::new()
145 .chain(self.public_ephemeral.to_bytes_le().1)
146 .chain(self.server_ephemeral.to_bytes_le().1)
147 .finalize()
148 .to_vec();
149
150 BigInt::from_bytes_le(Sign::Plus, &u)
151 }
152
153 fn calculate_s(&mut self, x: BigInt, verifier: BigInt) -> BigInt {
154 const K: u8 = 3;
155 let u = self.calculate_u();
156 let mut s = &self.server_ephemeral - K.to_bigint().unwrap() * verifier;
157 s = s.modpow(
158 &(&self.private_ephemeral + u * x),
159 &self.modulus,
160 );
161 s
162 }
163
164 fn calculate_interleaved(s: BigInt) -> Vec<u8> {
165 let (even, odd): (Vec<_>, Vec<_>) =
166 Self::pad_to_32_bytes(s.to_bytes_le().1)
167 .into_iter()
168 .enumerate()
169 .partition(|(i, _)| i % 2 == 0);
170
171 let part1 = even.iter().map(|(_, v)| *v).collect::<Vec<u8>>();
172 let part2 = odd.iter().map(|(_, v)| *v).collect::<Vec<u8>>();
173
174 let hashed1 = Sha1::new().chain(part1).finalize();
175 let hashed2 = Sha1::new().chain(part2).finalize();
176
177 let mut session_key = Vec::new();
178 for (index, _) in hashed1.iter().enumerate() {
179 session_key.push(hashed1[index]);
180 session_key.push(hashed2[index]);
181 }
182
183 session_key
184 }
185
186 fn pad_to_32_bytes(bytes: Vec<u8>) -> [u8; 32] {
187 let mut buffer = [0u8; 32];
188 buffer[..bytes.len()].copy_from_slice(&bytes);
189 buffer
190 }
191}