Skip to main content

vitaminc_permutation/
elementwise.rs

1use crate::{private::IsPermutable, PermutationKey};
2use subtle::{ConditionallySelectable, ConstantTimeEq};
3use vitaminc_protected::{Controlled, Zeroed};
4use zeroize::{Zeroize, Zeroizing};
5
6// TODO: Make this a private trait
7// FIXME: This trait is backwards - self should be T and the argument should be a key
8pub trait Permute<T> {
9    fn permute(&self, input: T) -> T;
10}
11
12pub trait Depermute<T> {
13    fn depermute(&self, input: T) -> T;
14}
15
16/// Implement permutation for Protected type containing a permutable array.
17impl<const N: usize, T> Permute<[T; N]> for PermutationKey<N>
18where
19    T: Zeroize + Default + Copy + ConditionallySelectable,
20    [T; N]: IsPermutable + Zeroed,
21{
22    fn permute(&self, input: [T; N]) -> [T; N] {
23        permute_array(self, input)
24    }
25}
26
27impl<const N: usize, T> Depermute<[T; N]> for PermutationKey<N>
28where
29    T: Zeroize + Default + Copy + ConditionallySelectable,
30    [T; N]: IsPermutable + Zeroed,
31{
32    fn depermute(&self, input: [T; N]) -> [T; N] {
33        depermute_array(self, input)
34    }
35}
36
37#[inline]
38pub fn permute_array<const N: usize, T>(key: &PermutationKey<N>, input: [T; N]) -> [T; N]
39where
40    [T; N]: IsPermutable + Zeroed,
41    T: Zeroize + Copy + ConditionallySelectable,
42{
43    // Key bytes are u8, so an `N > 256` permutation could never be expressed by
44    // the key — and `(j as u8)` below would silently wrap, breaking the
45    // ct_eq comparison. Surface the limit at compile time.
46    const { assert!(N <= 256, "permutation length must fit in u8") };
47
48    // Constant-time scan: for each output position `i`, the key byte `kv`
49    // selects which input element to copy. We scan all `j` in 0..N and use
50    // `ConditionallySelectable` so the access pattern is independent of `kv`,
51    // preventing cache-line timing leaks of the secret key bytes. Secret
52    // locals — including the partially-populated `out` — are wrapped in
53    // `Zeroizing` so they are wiped on any unwind path.
54    let input = Zeroizing::new(input);
55    let mut out: Zeroizing<[T; N]> = Zeroizing::new(Zeroed::zeroed());
56    for (i, k) in key.iter().enumerate() {
57        let kv = Zeroizing::new(k.risky_unwrap());
58        let mut selected = Zeroizing::new(input[0]);
59        for (j, src) in input.iter().enumerate().skip(1) {
60            let mask = (j as u8).ct_eq(&*kv);
61            selected.conditional_assign(src, mask);
62        }
63        out[i] = *selected;
64    }
65    // Move the populated array out, leaving a fresh zeroed array for the
66    // `Zeroizing` Drop to clean (a no-op wipe in the success path).
67    core::mem::replace(&mut *out, Zeroed::zeroed())
68}
69
70#[inline]
71pub fn depermute_array<const N: usize, T>(key: &PermutationKey<N>, input: [T; N]) -> [T; N]
72where
73    [T; N]: IsPermutable + Zeroed,
74    T: Zeroize + Copy + ConditionallySelectable,
75{
76    // See `permute_array` — same u8-fit constraint applies here.
77    const { assert!(N <= 256, "permutation length must fit in u8") };
78
79    // Constant-time scatter: for each (i, kv), write `input[i]` to `out[kv]`
80    // by scanning every output slot and conditionally assigning when `j == kv`.
81    // Secret locals — including the partially-populated `out` — are wrapped
82    // in `Zeroizing` for unwind safety.
83    let input = Zeroizing::new(input);
84    let mut out: Zeroizing<[T; N]> = Zeroizing::new(Zeroed::zeroed());
85    for (i, k) in key.iter().enumerate() {
86        let kv = Zeroizing::new(k.risky_unwrap());
87        let src = Zeroizing::new(input[i]);
88        for (j, dst) in out.iter_mut().enumerate() {
89            let mask = (j as u8).ct_eq(&*kv);
90            dst.conditional_assign(&*src, mask);
91        }
92    }
93    core::mem::replace(&mut *out, Zeroed::zeroed())
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::tests;
100    use crate::{Depermute, PermutationKey, Permute};
101    use vitaminc_random::{Generatable, SafeRand};
102
103    fn test_permute<const N: usize>() -> Result<(), Box<dyn std::error::Error>>
104    where
105        [u8; N]: IsPermutable + Zeroed,
106    {
107        let mut rng = SafeRand::from_entropy()?;
108        let input: [u8; N] = Generatable::random(&mut rng)?;
109        let key: PermutationKey<N> = tests::gen_rand_key()?;
110        let output = key.permute(input);
111        // Note that this may fail for some inputs
112        assert_ne!(output, input);
113        Ok(())
114    }
115
116    fn test_depermute<const N: usize>() -> Result<(), Box<dyn std::error::Error>>
117    where
118        [u8; N]: IsPermutable + Zeroed,
119    {
120        let mut rng = SafeRand::from_entropy()?;
121        let input: [u8; N] = Generatable::random(&mut rng)?;
122        let key: PermutationKey<N> = tests::gen_rand_key()?;
123        let output = key.permute(input);
124        let depermuted = key.depermute(output);
125        assert_eq!(depermuted, input);
126        Ok(())
127    }
128
129    fn test_associativity<const N: usize>() -> Result<(), Box<dyn std::error::Error>>
130    where
131        [u8; N]: IsPermutable,
132    {
133        let mut rng = SafeRand::from_entropy()?;
134        let key_1 = tests::gen_key([0; 32]);
135        let key_2 = tests::gen_key([1; 32]);
136        let input: [u8; N] = Generatable::random(&mut rng)?;
137
138        // p_2(p_1(input))
139        let output_1 = key_2.permute(key_1.permute(input));
140
141        // p_2(p_1)(input)
142        let output_2 = key_2.permute(key_1).permute(input);
143
144        assert_eq!(output_1, output_2);
145        Ok(())
146    }
147
148    #[test]
149    fn permute_case() -> Result<(), Box<dyn std::error::Error>> {
150        test_permute::<8>()?;
151        test_permute::<16>()?;
152        test_permute::<32>()?;
153        test_permute::<64>()?;
154        Ok(())
155    }
156
157    #[test]
158    fn depermutation_case() -> Result<(), Box<dyn std::error::Error>> {
159        test_depermute::<8>()?;
160        test_depermute::<16>()?;
161        test_depermute::<32>()?;
162        test_depermute::<64>()?;
163        Ok(())
164    }
165
166    #[test]
167    fn associativity_case() -> Result<(), Box<dyn std::error::Error>> {
168        test_associativity::<8>()?;
169        test_associativity::<16>()?;
170        test_associativity::<32>()?;
171        test_associativity::<64>()?;
172        Ok(())
173    }
174}