1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
6use ring::hmac;
7use serde::{Deserialize, Serialize};
8use std::num::NonZeroU32;
9
10use crate::error::{Result, ShieldError};
11
12pub struct PAKEExchange;
14
15impl PAKEExchange {
16 pub const ITERATIONS: u32 = 200_000;
18
19 #[must_use]
21 pub fn derive(password: &str, salt: &[u8], role: &str, iterations: Option<u32>) -> [u8; 32] {
22 let iters = iterations.unwrap_or(Self::ITERATIONS);
23
24 let mut base_key = [0u8; 32];
25 ring::pbkdf2::derive(
26 ring::pbkdf2::PBKDF2_HMAC_SHA256,
27 NonZeroU32::new(iters).unwrap_or(NonZeroU32::new(Self::ITERATIONS).unwrap()),
28 salt,
29 password.as_bytes(),
30 &mut base_key,
31 );
32
33 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &base_key);
35 let tag = hmac::sign(&hmac_key, role.as_bytes());
36 let mut result = [0u8; 32];
37 result.copy_from_slice(&tag.as_ref()[..32]);
38 result
39 }
40
41 #[must_use]
43 pub fn combine(contributions: &[[u8; 32]]) -> [u8; 32] {
44 let mut sorted: Vec<&[u8; 32]> = contributions.iter().collect();
45 sorted.sort();
46
47 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, sorted[0]);
49 let mut data = Vec::with_capacity((sorted.len() - 1) * 32);
50 for c in &sorted[1..] {
51 data.extend_from_slice(*c);
52 }
53
54 let tag = hmac::sign(&hmac_key, &data);
55 let mut result = [0u8; 32];
56 result.copy_from_slice(&tag.as_ref()[..32]);
57 result
58 }
59
60 pub fn generate_salt() -> Result<[u8; 16]> {
62 crate::random::random_bytes()
63 }
64}
65
66pub struct QRExchange;
68
69#[derive(Serialize, Deserialize)]
70struct ExchangeData {
71 v: u8,
72 k: String,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 m: Option<serde_json::Value>,
75}
76
77impl QRExchange {
78 #[must_use]
80 pub fn encode(key: &[u8]) -> String {
81 URL_SAFE_NO_PAD.encode(key)
82 }
83
84 pub fn decode(encoded: &str) -> Result<Vec<u8>> {
86 URL_SAFE_NO_PAD
87 .decode(encoded)
88 .map_err(|_| ShieldError::InvalidFormat)
89 }
90
91 #[must_use]
93 pub fn generate_exchange_data(key: &[u8], metadata: Option<serde_json::Value>) -> String {
94 let data = ExchangeData {
95 v: 1,
96 k: URL_SAFE_NO_PAD.encode(key),
97 m: metadata,
98 };
99 serde_json::to_string(&data).unwrap_or_default()
100 }
101
102 pub fn parse_exchange_data(data: &str) -> Result<(Vec<u8>, Option<serde_json::Value>)> {
104 let parsed: ExchangeData =
105 serde_json::from_str(data).map_err(|_| ShieldError::InvalidFormat)?;
106 let key = URL_SAFE_NO_PAD
107 .decode(&parsed.k)
108 .map_err(|_| ShieldError::InvalidFormat)?;
109 Ok((key, parsed.m))
110 }
111}
112
113pub struct KeySplitter;
115
116impl KeySplitter {
117 pub fn split(key: &[u8], num_shares: usize) -> Result<Vec<Vec<u8>>> {
119 if num_shares < 2 {
120 return Err(ShieldError::InvalidShareCount);
121 }
122
123 let mut shares = Vec::with_capacity(num_shares);
124
125 for _ in 0..num_shares - 1 {
126 let share = crate::random::random_vec(key.len())?;
127 shares.push(share);
128 }
129
130 let mut final_share = key.to_vec();
132 for share in &shares {
133 for (i, &b) in share.iter().enumerate() {
134 final_share[i] ^= b;
135 }
136 }
137 shares.push(final_share);
138
139 Ok(shares)
140 }
141
142 pub fn combine(shares: &[Vec<u8>]) -> Result<Vec<u8>> {
144 if shares.len() < 2 {
145 return Err(ShieldError::InvalidShareCount);
146 }
147
148 let len = shares[0].len();
149 let mut result = vec![0u8; len];
150
151 for share in shares {
152 if share.len() != len {
153 return Err(ShieldError::InvalidFormat);
154 }
155 for (i, &b) in share.iter().enumerate() {
156 result[i] ^= b;
157 }
158 }
159
160 Ok(result)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_pake_derive() {
170 let salt = PAKEExchange::generate_salt().unwrap();
171 let key = PAKEExchange::derive("password", &salt, "client", None);
172 assert_eq!(key.len(), 32);
173 }
174
175 #[test]
176 fn test_pake_deterministic() {
177 let salt = PAKEExchange::generate_salt().unwrap();
178 let key1 = PAKEExchange::derive("password", &salt, "client", None);
179 let key2 = PAKEExchange::derive("password", &salt, "client", None);
180 assert_eq!(key1, key2);
181 }
182
183 #[test]
184 fn test_pake_combine_order_independent() {
185 let salt = PAKEExchange::generate_salt().unwrap();
186 let client = PAKEExchange::derive("password", &salt, "client", None);
187 let server = PAKEExchange::derive("password", &salt, "server", None);
188
189 let shared1 = PAKEExchange::combine(&[client, server]);
190 let shared2 = PAKEExchange::combine(&[server, client]);
191 assert_eq!(shared1, shared2);
192 }
193
194 #[test]
195 fn test_qr_roundtrip() {
196 let key = [42u8; 32];
197 let encoded = QRExchange::encode(&key);
198 let decoded = QRExchange::decode(&encoded).unwrap();
199 assert_eq!(key.as_slice(), decoded.as_slice());
200 }
201
202 #[test]
203 fn test_qr_exchange_data() {
204 let key = [1u8; 32];
205 let metadata = serde_json::json!({"name": "test"});
206 let data = QRExchange::generate_exchange_data(&key, Some(metadata.clone()));
207 let (parsed_key, parsed_meta) = QRExchange::parse_exchange_data(&data).unwrap();
208 assert_eq!(key.as_slice(), parsed_key.as_slice());
209 assert_eq!(parsed_meta, Some(metadata));
210 }
211
212 #[test]
213 fn test_key_splitter() {
214 let key = [42u8; 32];
215 let shares = KeySplitter::split(&key, 3).unwrap();
216 assert_eq!(shares.len(), 3);
217
218 let recovered = KeySplitter::combine(&shares).unwrap();
219 assert_eq!(key.as_slice(), recovered.as_slice());
220 }
221
222 #[test]
223 fn test_key_splitter_partial() {
224 let key = [42u8; 32];
225 let shares = KeySplitter::split(&key, 3).unwrap();
226
227 let partial = KeySplitter::combine(&shares[..2]).unwrap();
229 assert_ne!(key.as_slice(), partial.as_slice());
230 }
231
232 #[test]
233 fn test_key_splitter_min_shares() {
234 let key = [42u8; 32];
235 assert!(KeySplitter::split(&key, 1).is_err());
236 }
237}