vitaminc_permutation/
elementwise.rs1use crate::{private::IsPermutable, PermutationKey};
2use subtle::{ConditionallySelectable, ConstantTimeEq};
3use vitaminc_protected::{Controlled, Zeroed};
4use zeroize::{Zeroize, Zeroizing};
5
6pub trait Permute<T> {
9 fn permute(&self, input: T) -> T;
10}
11
12pub trait Depermute<T> {
13 fn depermute(&self, input: T) -> T;
14}
15
16impl<const N: usize, T> Permute<[T; N]> for PermutationKey<N>
18where
19 T: Zeroize + Default + Copy + ConditionallySelectable,
20 [T; N]: IsPermutable + Zeroed,
21{
22 fn permute(&self, input: [T; N]) -> [T; N] {
23 permute_array(self, input)
24 }
25}
26
27impl<const N: usize, T> Depermute<[T; N]> for PermutationKey<N>
28where
29 T: Zeroize + Default + Copy + ConditionallySelectable,
30 [T; N]: IsPermutable + Zeroed,
31{
32 fn depermute(&self, input: [T; N]) -> [T; N] {
33 depermute_array(self, input)
34 }
35}
36
37#[inline]
38pub fn permute_array<const N: usize, T>(key: &PermutationKey<N>, input: [T; N]) -> [T; N]
39where
40 [T; N]: IsPermutable + Zeroed,
41 T: Zeroize + Copy + ConditionallySelectable,
42{
43 const { assert!(N <= 256, "permutation length must fit in u8") };
47
48 let input = Zeroizing::new(input);
55 let mut out: Zeroizing<[T; N]> = Zeroizing::new(Zeroed::zeroed());
56 for (i, k) in key.iter().enumerate() {
57 let kv = Zeroizing::new(k.risky_unwrap());
58 let mut selected = Zeroizing::new(input[0]);
59 for (j, src) in input.iter().enumerate().skip(1) {
60 let mask = (j as u8).ct_eq(&*kv);
61 selected.conditional_assign(src, mask);
62 }
63 out[i] = *selected;
64 }
65 core::mem::replace(&mut *out, Zeroed::zeroed())
68}
69
70#[inline]
71pub fn depermute_array<const N: usize, T>(key: &PermutationKey<N>, input: [T; N]) -> [T; N]
72where
73 [T; N]: IsPermutable + Zeroed,
74 T: Zeroize + Copy + ConditionallySelectable,
75{
76 const { assert!(N <= 256, "permutation length must fit in u8") };
78
79 let input = Zeroizing::new(input);
84 let mut out: Zeroizing<[T; N]> = Zeroizing::new(Zeroed::zeroed());
85 for (i, k) in key.iter().enumerate() {
86 let kv = Zeroizing::new(k.risky_unwrap());
87 let src = Zeroizing::new(input[i]);
88 for (j, dst) in out.iter_mut().enumerate() {
89 let mask = (j as u8).ct_eq(&*kv);
90 dst.conditional_assign(&*src, mask);
91 }
92 }
93 core::mem::replace(&mut *out, Zeroed::zeroed())
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::tests;
100 use crate::{Depermute, PermutationKey, Permute};
101 use vitaminc_random::{Generatable, SafeRand};
102
103 fn test_permute<const N: usize>() -> Result<(), Box<dyn std::error::Error>>
104 where
105 [u8; N]: IsPermutable + Zeroed,
106 {
107 let mut rng = SafeRand::from_entropy()?;
108 let input: [u8; N] = Generatable::random(&mut rng)?;
109 let key: PermutationKey<N> = tests::gen_rand_key()?;
110 let output = key.permute(input);
111 assert_ne!(output, input);
113 Ok(())
114 }
115
116 fn test_depermute<const N: usize>() -> Result<(), Box<dyn std::error::Error>>
117 where
118 [u8; N]: IsPermutable + Zeroed,
119 {
120 let mut rng = SafeRand::from_entropy()?;
121 let input: [u8; N] = Generatable::random(&mut rng)?;
122 let key: PermutationKey<N> = tests::gen_rand_key()?;
123 let output = key.permute(input);
124 let depermuted = key.depermute(output);
125 assert_eq!(depermuted, input);
126 Ok(())
127 }
128
129 fn test_associativity<const N: usize>() -> Result<(), Box<dyn std::error::Error>>
130 where
131 [u8; N]: IsPermutable,
132 {
133 let mut rng = SafeRand::from_entropy()?;
134 let key_1 = tests::gen_key([0; 32]);
135 let key_2 = tests::gen_key([1; 32]);
136 let input: [u8; N] = Generatable::random(&mut rng)?;
137
138 let output_1 = key_2.permute(key_1.permute(input));
140
141 let output_2 = key_2.permute(key_1).permute(input);
143
144 assert_eq!(output_1, output_2);
145 Ok(())
146 }
147
148 #[test]
149 fn permute_case() -> Result<(), Box<dyn std::error::Error>> {
150 test_permute::<8>()?;
151 test_permute::<16>()?;
152 test_permute::<32>()?;
153 test_permute::<64>()?;
154 Ok(())
155 }
156
157 #[test]
158 fn depermutation_case() -> Result<(), Box<dyn std::error::Error>> {
159 test_depermute::<8>()?;
160 test_depermute::<16>()?;
161 test_depermute::<32>()?;
162 test_depermute::<64>()?;
163 Ok(())
164 }
165
166 #[test]
167 fn associativity_case() -> Result<(), Box<dyn std::error::Error>> {
168 test_associativity::<8>()?;
169 test_associativity::<16>()?;
170 test_associativity::<32>()?;
171 test_associativity::<64>()?;
172 Ok(())
173 }
174}