1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
6use ring::rand::{SecureRandom, SystemRandom};
7use serde::{Deserialize, Serialize};
8use std::num::NonZeroU32;
9
10use crate::error::{Result, ShieldError};
11
12pub struct PAKEExchange;
14
15impl PAKEExchange {
16 pub const ITERATIONS: u32 = 200_000;
18
19 #[must_use]
21 pub fn derive(password: &str, salt: &[u8], role: &str, iterations: Option<u32>) -> [u8; 32] {
22 let iters = iterations.unwrap_or(Self::ITERATIONS);
23
24 let mut base_key = [0u8; 32];
25 ring::pbkdf2::derive(
26 ring::pbkdf2::PBKDF2_HMAC_SHA256,
27 NonZeroU32::new(iters).unwrap(),
28 salt,
29 password.as_bytes(),
30 &mut base_key,
31 );
32
33 let mut data = Vec::with_capacity(32 + role.len());
34 data.extend_from_slice(&base_key);
35 data.extend_from_slice(role.as_bytes());
36
37 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
38 let mut result = [0u8; 32];
39 result.copy_from_slice(hash.as_ref());
40 result
41 }
42
43 #[must_use]
45 pub fn combine(contributions: &[[u8; 32]]) -> [u8; 32] {
46 let mut sorted: Vec<&[u8; 32]> = contributions.iter().collect();
47 sorted.sort();
48
49 let mut combined = Vec::with_capacity(contributions.len() * 32);
50 for c in sorted {
51 combined.extend_from_slice(c);
52 }
53
54 let hash = ring::digest::digest(&ring::digest::SHA256, &combined);
55 let mut result = [0u8; 32];
56 result.copy_from_slice(hash.as_ref());
57 result
58 }
59
60 pub fn generate_salt() -> Result<[u8; 16]> {
62 let rng = SystemRandom::new();
63 let mut salt = [0u8; 16];
64 rng.fill(&mut salt).map_err(|_| ShieldError::RandomFailed)?;
65 Ok(salt)
66 }
67}
68
69pub struct QRExchange;
71
72#[derive(Serialize, Deserialize)]
73struct ExchangeData {
74 v: u8,
75 k: String,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 m: Option<serde_json::Value>,
78}
79
80impl QRExchange {
81 #[must_use]
83 pub fn encode(key: &[u8]) -> String {
84 URL_SAFE_NO_PAD.encode(key)
85 }
86
87 pub fn decode(encoded: &str) -> Result<Vec<u8>> {
89 URL_SAFE_NO_PAD
90 .decode(encoded)
91 .map_err(|_| ShieldError::InvalidFormat)
92 }
93
94 #[must_use]
96 pub fn generate_exchange_data(key: &[u8], metadata: Option<serde_json::Value>) -> String {
97 let data = ExchangeData {
98 v: 1,
99 k: URL_SAFE_NO_PAD.encode(key),
100 m: metadata,
101 };
102 serde_json::to_string(&data).unwrap()
103 }
104
105 pub fn parse_exchange_data(data: &str) -> Result<(Vec<u8>, Option<serde_json::Value>)> {
107 let parsed: ExchangeData =
108 serde_json::from_str(data).map_err(|_| ShieldError::InvalidFormat)?;
109 let key = URL_SAFE_NO_PAD
110 .decode(&parsed.k)
111 .map_err(|_| ShieldError::InvalidFormat)?;
112 Ok((key, parsed.m))
113 }
114}
115
116pub struct KeySplitter;
118
119impl KeySplitter {
120 pub fn split(key: &[u8], num_shares: usize) -> Result<Vec<Vec<u8>>> {
122 if num_shares < 2 {
123 return Err(ShieldError::InvalidShareCount);
124 }
125
126 let rng = SystemRandom::new();
127 let mut shares = Vec::with_capacity(num_shares);
128
129 for _ in 0..num_shares - 1 {
130 let mut share = vec![0u8; key.len()];
131 rng.fill(&mut share)
132 .map_err(|_| ShieldError::RandomFailed)?;
133 shares.push(share);
134 }
135
136 let mut final_share = key.to_vec();
138 for share in &shares {
139 for (i, &b) in share.iter().enumerate() {
140 final_share[i] ^= b;
141 }
142 }
143 shares.push(final_share);
144
145 Ok(shares)
146 }
147
148 pub fn combine(shares: &[Vec<u8>]) -> Result<Vec<u8>> {
150 if shares.len() < 2 {
151 return Err(ShieldError::InvalidShareCount);
152 }
153
154 let len = shares[0].len();
155 let mut result = vec![0u8; len];
156
157 for share in shares {
158 if share.len() != len {
159 return Err(ShieldError::InvalidFormat);
160 }
161 for (i, &b) in share.iter().enumerate() {
162 result[i] ^= b;
163 }
164 }
165
166 Ok(result)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_pake_derive() {
176 let salt = PAKEExchange::generate_salt().unwrap();
177 let key = PAKEExchange::derive("password", &salt, "client", None);
178 assert_eq!(key.len(), 32);
179 }
180
181 #[test]
182 fn test_pake_deterministic() {
183 let salt = PAKEExchange::generate_salt().unwrap();
184 let key1 = PAKEExchange::derive("password", &salt, "client", None);
185 let key2 = PAKEExchange::derive("password", &salt, "client", None);
186 assert_eq!(key1, key2);
187 }
188
189 #[test]
190 fn test_pake_combine_order_independent() {
191 let salt = PAKEExchange::generate_salt().unwrap();
192 let client = PAKEExchange::derive("password", &salt, "client", None);
193 let server = PAKEExchange::derive("password", &salt, "server", None);
194
195 let shared1 = PAKEExchange::combine(&[client, server]);
196 let shared2 = PAKEExchange::combine(&[server, client]);
197 assert_eq!(shared1, shared2);
198 }
199
200 #[test]
201 fn test_qr_roundtrip() {
202 let key = [42u8; 32];
203 let encoded = QRExchange::encode(&key);
204 let decoded = QRExchange::decode(&encoded).unwrap();
205 assert_eq!(key.as_slice(), decoded.as_slice());
206 }
207
208 #[test]
209 fn test_qr_exchange_data() {
210 let key = [1u8; 32];
211 let metadata = serde_json::json!({"name": "test"});
212 let data = QRExchange::generate_exchange_data(&key, Some(metadata.clone()));
213 let (parsed_key, parsed_meta) = QRExchange::parse_exchange_data(&data).unwrap();
214 assert_eq!(key.as_slice(), parsed_key.as_slice());
215 assert_eq!(parsed_meta, Some(metadata));
216 }
217
218 #[test]
219 fn test_key_splitter() {
220 let key = [42u8; 32];
221 let shares = KeySplitter::split(&key, 3).unwrap();
222 assert_eq!(shares.len(), 3);
223
224 let recovered = KeySplitter::combine(&shares).unwrap();
225 assert_eq!(key.as_slice(), recovered.as_slice());
226 }
227
228 #[test]
229 fn test_key_splitter_partial() {
230 let key = [42u8; 32];
231 let shares = KeySplitter::split(&key, 3).unwrap();
232
233 let partial = KeySplitter::combine(&shares[..2]).unwrap();
235 assert_ne!(key.as_slice(), partial.as_slice());
236 }
237
238 #[test]
239 fn test_key_splitter_min_shares() {
240 let key = [42u8; 32];
241 assert!(KeySplitter::split(&key, 1).is_err());
242 }
243}