Skip to main content

shield_core/
exchange.rs

1//! Key exchange without public-key cryptography.
2//!
3//! Provides PAKE, QR exchange, and key splitting.
4
5use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
6use ring::hmac;
7use serde::{Deserialize, Serialize};
8use std::num::NonZeroU32;
9
10use crate::error::{Result, ShieldError};
11
12/// Password-Authenticated Key Exchange.
13pub struct PAKEExchange;
14
15impl PAKEExchange {
16    /// Default PBKDF2 iterations.
17    pub const ITERATIONS: u32 = 200_000;
18
19    /// Derive key contribution from password.
20    #[must_use]
21    pub fn derive(password: &str, salt: &[u8], role: &str, iterations: Option<u32>) -> [u8; 32] {
22        let iters = iterations.unwrap_or(Self::ITERATIONS);
23
24        let mut base_key = [0u8; 32];
25        ring::pbkdf2::derive(
26            ring::pbkdf2::PBKDF2_HMAC_SHA256,
27            NonZeroU32::new(iters).unwrap_or(NonZeroU32::new(Self::ITERATIONS).unwrap()),
28            salt,
29            password.as_bytes(),
30            &mut base_key,
31        );
32
33        // Derive role-specific key using HMAC-SHA256 (keyed PRF)
34        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &base_key);
35        let tag = hmac::sign(&hmac_key, role.as_bytes());
36        let mut result = [0u8; 32];
37        result.copy_from_slice(&tag.as_ref()[..32]);
38        result
39    }
40
41    /// Combine key contributions into session key using HMAC-SHA256.
42    #[must_use]
43    pub fn combine(contributions: &[[u8; 32]]) -> [u8; 32] {
44        let mut sorted: Vec<&[u8; 32]> = contributions.iter().collect();
45        sorted.sort();
46
47        // Use first contribution as HMAC key, remaining as data
48        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, sorted[0]);
49        let mut data = Vec::with_capacity((sorted.len() - 1) * 32);
50        for c in &sorted[1..] {
51            data.extend_from_slice(*c);
52        }
53
54        let tag = hmac::sign(&hmac_key, &data);
55        let mut result = [0u8; 32];
56        result.copy_from_slice(&tag.as_ref()[..32]);
57        result
58    }
59
60    /// Generate random salt.
61    pub fn generate_salt() -> Result<[u8; 16]> {
62        crate::random::random_bytes()
63    }
64}
65
66/// Key exchange via QR codes or manual transfer.
67pub struct QRExchange;
68
69#[derive(Serialize, Deserialize)]
70struct ExchangeData {
71    v: u8,
72    k: String,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    m: Option<serde_json::Value>,
75}
76
77impl QRExchange {
78    /// Encode key for QR code.
79    #[must_use]
80    pub fn encode(key: &[u8]) -> String {
81        URL_SAFE_NO_PAD.encode(key)
82    }
83
84    /// Decode key from QR code.
85    pub fn decode(encoded: &str) -> Result<Vec<u8>> {
86        URL_SAFE_NO_PAD
87            .decode(encoded)
88            .map_err(|_| ShieldError::InvalidFormat)
89    }
90
91    /// Generate complete exchange data with metadata.
92    #[must_use]
93    pub fn generate_exchange_data(key: &[u8], metadata: Option<serde_json::Value>) -> String {
94        let data = ExchangeData {
95            v: 1,
96            k: URL_SAFE_NO_PAD.encode(key),
97            m: metadata,
98        };
99        serde_json::to_string(&data).unwrap_or_default()
100    }
101
102    /// Parse exchange data.
103    pub fn parse_exchange_data(data: &str) -> Result<(Vec<u8>, Option<serde_json::Value>)> {
104        let parsed: ExchangeData =
105            serde_json::from_str(data).map_err(|_| ShieldError::InvalidFormat)?;
106        let key = URL_SAFE_NO_PAD
107            .decode(&parsed.k)
108            .map_err(|_| ShieldError::InvalidFormat)?;
109        Ok((key, parsed.m))
110    }
111}
112
113/// Split keys into shares (all required to reconstruct).
114pub struct KeySplitter;
115
116impl KeySplitter {
117    /// Split key into shares.
118    pub fn split(key: &[u8], num_shares: usize) -> Result<Vec<Vec<u8>>> {
119        if num_shares < 2 {
120            return Err(ShieldError::InvalidShareCount);
121        }
122
123        let mut shares = Vec::with_capacity(num_shares);
124
125        for _ in 0..num_shares - 1 {
126            let share = crate::random::random_vec(key.len())?;
127            shares.push(share);
128        }
129
130        // Final share = XOR of key with all others
131        let mut final_share = key.to_vec();
132        for share in &shares {
133            for (i, &b) in share.iter().enumerate() {
134                final_share[i] ^= b;
135            }
136        }
137        shares.push(final_share);
138
139        Ok(shares)
140    }
141
142    /// Combine shares to recover key.
143    pub fn combine(shares: &[Vec<u8>]) -> Result<Vec<u8>> {
144        if shares.len() < 2 {
145            return Err(ShieldError::InvalidShareCount);
146        }
147
148        let len = shares[0].len();
149        let mut result = vec![0u8; len];
150
151        for share in shares {
152            if share.len() != len {
153                return Err(ShieldError::InvalidFormat);
154            }
155            for (i, &b) in share.iter().enumerate() {
156                result[i] ^= b;
157            }
158        }
159
160        Ok(result)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_pake_derive() {
170        let salt = PAKEExchange::generate_salt().unwrap();
171        let key = PAKEExchange::derive("password", &salt, "client", None);
172        assert_eq!(key.len(), 32);
173    }
174
175    #[test]
176    fn test_pake_deterministic() {
177        let salt = PAKEExchange::generate_salt().unwrap();
178        let key1 = PAKEExchange::derive("password", &salt, "client", None);
179        let key2 = PAKEExchange::derive("password", &salt, "client", None);
180        assert_eq!(key1, key2);
181    }
182
183    #[test]
184    fn test_pake_combine_order_independent() {
185        let salt = PAKEExchange::generate_salt().unwrap();
186        let client = PAKEExchange::derive("password", &salt, "client", None);
187        let server = PAKEExchange::derive("password", &salt, "server", None);
188
189        let shared1 = PAKEExchange::combine(&[client, server]);
190        let shared2 = PAKEExchange::combine(&[server, client]);
191        assert_eq!(shared1, shared2);
192    }
193
194    #[test]
195    fn test_qr_roundtrip() {
196        let key = [42u8; 32];
197        let encoded = QRExchange::encode(&key);
198        let decoded = QRExchange::decode(&encoded).unwrap();
199        assert_eq!(key.as_slice(), decoded.as_slice());
200    }
201
202    #[test]
203    fn test_qr_exchange_data() {
204        let key = [1u8; 32];
205        let metadata = serde_json::json!({"name": "test"});
206        let data = QRExchange::generate_exchange_data(&key, Some(metadata.clone()));
207        let (parsed_key, parsed_meta) = QRExchange::parse_exchange_data(&data).unwrap();
208        assert_eq!(key.as_slice(), parsed_key.as_slice());
209        assert_eq!(parsed_meta, Some(metadata));
210    }
211
212    #[test]
213    fn test_key_splitter() {
214        let key = [42u8; 32];
215        let shares = KeySplitter::split(&key, 3).unwrap();
216        assert_eq!(shares.len(), 3);
217
218        let recovered = KeySplitter::combine(&shares).unwrap();
219        assert_eq!(key.as_slice(), recovered.as_slice());
220    }
221
222    #[test]
223    fn test_key_splitter_partial() {
224        let key = [42u8; 32];
225        let shares = KeySplitter::split(&key, 3).unwrap();
226
227        // Partial shares don't recover key
228        let partial = KeySplitter::combine(&shares[..2]).unwrap();
229        assert_ne!(key.as_slice(), partial.as_slice());
230    }
231
232    #[test]
233    fn test_key_splitter_min_shares() {
234        let key = [42u8; 32];
235        assert!(KeySplitter::split(&key, 1).is_err());
236    }
237}