shield_core/
exchange.rs

1//! Key exchange without public-key cryptography.
2//!
3//! Provides PAKE, QR exchange, and key splitting.
4
5use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
6use ring::rand::{SecureRandom, SystemRandom};
7use serde::{Deserialize, Serialize};
8use std::num::NonZeroU32;
9
10use crate::error::{Result, ShieldError};
11
12/// Password-Authenticated Key Exchange.
13pub struct PAKEExchange;
14
15impl PAKEExchange {
16    /// Default PBKDF2 iterations.
17    pub const ITERATIONS: u32 = 200_000;
18
19    /// Derive key contribution from password.
20    #[must_use]
21    pub fn derive(password: &str, salt: &[u8], role: &str, iterations: Option<u32>) -> [u8; 32] {
22        let iters = iterations.unwrap_or(Self::ITERATIONS);
23
24        let mut base_key = [0u8; 32];
25        ring::pbkdf2::derive(
26            ring::pbkdf2::PBKDF2_HMAC_SHA256,
27            NonZeroU32::new(iters).unwrap(),
28            salt,
29            password.as_bytes(),
30            &mut base_key,
31        );
32
33        let mut data = Vec::with_capacity(32 + role.len());
34        data.extend_from_slice(&base_key);
35        data.extend_from_slice(role.as_bytes());
36
37        let hash = ring::digest::digest(&ring::digest::SHA256, &data);
38        let mut result = [0u8; 32];
39        result.copy_from_slice(hash.as_ref());
40        result
41    }
42
43    /// Combine key contributions into session key.
44    #[must_use]
45    pub fn combine(contributions: &[[u8; 32]]) -> [u8; 32] {
46        let mut sorted: Vec<&[u8; 32]> = contributions.iter().collect();
47        sorted.sort();
48
49        let mut combined = Vec::with_capacity(contributions.len() * 32);
50        for c in sorted {
51            combined.extend_from_slice(c);
52        }
53
54        let hash = ring::digest::digest(&ring::digest::SHA256, &combined);
55        let mut result = [0u8; 32];
56        result.copy_from_slice(hash.as_ref());
57        result
58    }
59
60    /// Generate random salt.
61    pub fn generate_salt() -> Result<[u8; 16]> {
62        let rng = SystemRandom::new();
63        let mut salt = [0u8; 16];
64        rng.fill(&mut salt).map_err(|_| ShieldError::RandomFailed)?;
65        Ok(salt)
66    }
67}
68
69/// Key exchange via QR codes or manual transfer.
70pub struct QRExchange;
71
72#[derive(Serialize, Deserialize)]
73struct ExchangeData {
74    v: u8,
75    k: String,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    m: Option<serde_json::Value>,
78}
79
80impl QRExchange {
81    /// Encode key for QR code.
82    #[must_use]
83    pub fn encode(key: &[u8]) -> String {
84        URL_SAFE_NO_PAD.encode(key)
85    }
86
87    /// Decode key from QR code.
88    pub fn decode(encoded: &str) -> Result<Vec<u8>> {
89        URL_SAFE_NO_PAD
90            .decode(encoded)
91            .map_err(|_| ShieldError::InvalidFormat)
92    }
93
94    /// Generate complete exchange data with metadata.
95    #[must_use]
96    pub fn generate_exchange_data(key: &[u8], metadata: Option<serde_json::Value>) -> String {
97        let data = ExchangeData {
98            v: 1,
99            k: URL_SAFE_NO_PAD.encode(key),
100            m: metadata,
101        };
102        serde_json::to_string(&data).unwrap()
103    }
104
105    /// Parse exchange data.
106    pub fn parse_exchange_data(data: &str) -> Result<(Vec<u8>, Option<serde_json::Value>)> {
107        let parsed: ExchangeData =
108            serde_json::from_str(data).map_err(|_| ShieldError::InvalidFormat)?;
109        let key = URL_SAFE_NO_PAD
110            .decode(&parsed.k)
111            .map_err(|_| ShieldError::InvalidFormat)?;
112        Ok((key, parsed.m))
113    }
114}
115
116/// Split keys into shares (all required to reconstruct).
117pub struct KeySplitter;
118
119impl KeySplitter {
120    /// Split key into shares.
121    pub fn split(key: &[u8], num_shares: usize) -> Result<Vec<Vec<u8>>> {
122        if num_shares < 2 {
123            return Err(ShieldError::InvalidShareCount);
124        }
125
126        let rng = SystemRandom::new();
127        let mut shares = Vec::with_capacity(num_shares);
128
129        for _ in 0..num_shares - 1 {
130            let mut share = vec![0u8; key.len()];
131            rng.fill(&mut share)
132                .map_err(|_| ShieldError::RandomFailed)?;
133            shares.push(share);
134        }
135
136        // Final share = XOR of key with all others
137        let mut final_share = key.to_vec();
138        for share in &shares {
139            for (i, &b) in share.iter().enumerate() {
140                final_share[i] ^= b;
141            }
142        }
143        shares.push(final_share);
144
145        Ok(shares)
146    }
147
148    /// Combine shares to recover key.
149    pub fn combine(shares: &[Vec<u8>]) -> Result<Vec<u8>> {
150        if shares.len() < 2 {
151            return Err(ShieldError::InvalidShareCount);
152        }
153
154        let len = shares[0].len();
155        let mut result = vec![0u8; len];
156
157        for share in shares {
158            if share.len() != len {
159                return Err(ShieldError::InvalidFormat);
160            }
161            for (i, &b) in share.iter().enumerate() {
162                result[i] ^= b;
163            }
164        }
165
166        Ok(result)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_pake_derive() {
176        let salt = PAKEExchange::generate_salt().unwrap();
177        let key = PAKEExchange::derive("password", &salt, "client", None);
178        assert_eq!(key.len(), 32);
179    }
180
181    #[test]
182    fn test_pake_deterministic() {
183        let salt = PAKEExchange::generate_salt().unwrap();
184        let key1 = PAKEExchange::derive("password", &salt, "client", None);
185        let key2 = PAKEExchange::derive("password", &salt, "client", None);
186        assert_eq!(key1, key2);
187    }
188
189    #[test]
190    fn test_pake_combine_order_independent() {
191        let salt = PAKEExchange::generate_salt().unwrap();
192        let client = PAKEExchange::derive("password", &salt, "client", None);
193        let server = PAKEExchange::derive("password", &salt, "server", None);
194
195        let shared1 = PAKEExchange::combine(&[client, server]);
196        let shared2 = PAKEExchange::combine(&[server, client]);
197        assert_eq!(shared1, shared2);
198    }
199
200    #[test]
201    fn test_qr_roundtrip() {
202        let key = [42u8; 32];
203        let encoded = QRExchange::encode(&key);
204        let decoded = QRExchange::decode(&encoded).unwrap();
205        assert_eq!(key.as_slice(), decoded.as_slice());
206    }
207
208    #[test]
209    fn test_qr_exchange_data() {
210        let key = [1u8; 32];
211        let metadata = serde_json::json!({"name": "test"});
212        let data = QRExchange::generate_exchange_data(&key, Some(metadata.clone()));
213        let (parsed_key, parsed_meta) = QRExchange::parse_exchange_data(&data).unwrap();
214        assert_eq!(key.as_slice(), parsed_key.as_slice());
215        assert_eq!(parsed_meta, Some(metadata));
216    }
217
218    #[test]
219    fn test_key_splitter() {
220        let key = [42u8; 32];
221        let shares = KeySplitter::split(&key, 3).unwrap();
222        assert_eq!(shares.len(), 3);
223
224        let recovered = KeySplitter::combine(&shares).unwrap();
225        assert_eq!(key.as_slice(), recovered.as_slice());
226    }
227
228    #[test]
229    fn test_key_splitter_partial() {
230        let key = [42u8; 32];
231        let shares = KeySplitter::split(&key, 3).unwrap();
232
233        // Partial shares don't recover key
234        let partial = KeySplitter::combine(&shares[..2]).unwrap();
235        assert_ne!(key.as_slice(), partial.as_slice());
236    }
237
238    #[test]
239    fn test_key_splitter_min_shares() {
240        let key = [42u8; 32];
241        assert!(KeySplitter::split(&key, 1).is_err());
242    }
243}