1use super::macros::{
10 Trit, COLUMNS, NUM_ROUNDS, NUM_SBOXES, PADDING, ROUND_CONSTANTS, ROWS, SBOX_LOOKUP,
11 SHIFT_ROWS_LANES, SLICES, SLICESIZE, STATE_SIZE, TROIKA_RATE,
12};
13use crate::Result;
14use core::fmt;
15
16#[derive(Clone, Copy)]
30pub struct Troika {
31 num_rounds: usize,
32 state: [Trit; STATE_SIZE],
33}
34
35impl Default for Troika {
36 fn default() -> Troika {
37 Troika {
38 num_rounds: NUM_ROUNDS,
39 state: [0u8; STATE_SIZE],
40 }
41 }
42}
43
44impl fmt::Debug for Troika {
45 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46 write!(
47 f,
48 "Troika: [rounds: [{}], state: {:?}",
49 self.num_rounds,
50 &self.state[..],
51 )
52 }
53}
54
55impl Troika {
56 pub fn new(num_rounds: usize) -> Result<Troika> {
57 let mut troika = Troika::default();
58 troika.num_rounds = num_rounds;
59 Ok(troika)
60 }
61
62 pub fn state(&self) -> &[Trit] {
63 &self.state
64 }
65
66 pub fn reset(&mut self) {
67 self.state = [0; STATE_SIZE];
68 }
69
70 pub fn absorb(&mut self, message: &[Trit]) {
71 let mut message_length = message.len();
72 let mut message_idx = 0;
73 let mut trit_idx = 0;
74
75 while message_length >= TROIKA_RATE {
76 for trit_idx in 0..TROIKA_RATE {
78 self.state[trit_idx] = message[message_idx + trit_idx];
79 }
80 self.permutation();
81 message_length -= TROIKA_RATE;
82 message_idx += TROIKA_RATE;
83 }
84
85 let mut last_block = [0u8; TROIKA_RATE];
87
88 for _ in 0..message_length {
90 last_block[trit_idx] = message[trit_idx];
91 trit_idx += 1;
92 }
93
94 last_block[trit_idx] = PADDING;
97
98 for trit_idx in 0..TROIKA_RATE {
100 self.state[trit_idx] = last_block[trit_idx];
101 }
102 }
103
104 pub fn squeeze(&mut self, hash: &mut [Trit]) {
105 let mut hash_length = hash.len();
106 let mut hash_idx = 0;
107
108 while hash_length >= TROIKA_RATE {
109 self.permutation();
110 for trit_idx in 0..TROIKA_RATE {
112 hash[hash_idx + trit_idx] = self.state[trit_idx];
113 }
114 hash_idx += TROIKA_RATE;
115 hash_length -= TROIKA_RATE;
116 }
117
118 if hash_length % TROIKA_RATE != 0 {
120 self.permutation();
121 for trit_idx in 0..hash_length {
122 hash[trit_idx] = self.state[trit_idx];
123 }
124 }
125 }
126
127 pub fn permutation(&mut self) {
128 assert!(self.num_rounds <= NUM_ROUNDS);
129
130 for round in 0..self.num_rounds {
131 self.sub_trytes();
132 self.shift_rows_lanes();
133 self.add_column_parity();
134 self.add_round_constant(round);
135 }
136 }
137
138 fn sub_trytes(&mut self) {
139 for sbox_idx in 0..NUM_SBOXES {
140 let sbox_input = 9 * self.state[3 * sbox_idx]
141 + 3 * self.state[3 * sbox_idx + 1]
142 + self.state[3 * sbox_idx + 2];
143 let mut sbox_output = SBOX_LOOKUP[sbox_input as usize];
144 self.state[3 * sbox_idx + 2] = sbox_output % 3;
145 sbox_output /= 3;
146 self.state[3 * sbox_idx + 1] = sbox_output % 3;
147 sbox_output /= 3;
148 self.state[3 * sbox_idx] = sbox_output % 3;
149 }
150 }
151
152 fn shift_rows_lanes(&mut self) {
153 let mut new_state = [0u8; STATE_SIZE];
154 for i in 0..STATE_SIZE {
155 new_state[i] = self.state[SHIFT_ROWS_LANES[i]];
156 }
157
158 self.state = new_state;
159 }
160
161 fn add_column_parity(&mut self) {
162 let mut parity = [0u8; SLICES * COLUMNS];
163
164 for slice in 0..SLICES {
166 for col in 0..COLUMNS {
167 let mut col_sum = 0;
168 for row in 0..ROWS {
169 col_sum += self.state[SLICESIZE * slice + COLUMNS * row + col];
170 }
171 parity[COLUMNS * slice + col] = col_sum % 3;
172 }
173 }
174
175 for slice in 0..SLICES {
177 for row in 0..ROWS {
178 for col in 0..COLUMNS {
179 let idx = SLICESIZE * slice + COLUMNS * row + col;
180 let sum_to_add = parity[(col + 8) % 9 + COLUMNS * slice]
181 + parity[(col + 1) % 9 + COLUMNS * ((slice + 1) % SLICES)];
182 self.state[idx] = (self.state[idx] + sum_to_add) % 3;
183 }
184 }
185 }
186 }
187
188 fn add_round_constant(&mut self, round: usize) {
189 for slice in 0..SLICES {
190 for col in 0..COLUMNS {
191 let idx = SLICESIZE * slice + col;
192 self.state[idx] =
193 (self.state[idx] + ROUND_CONSTANTS[round][slice * COLUMNS + col]) % 3;
194 }
195 }
196 }
197}
198
199#[cfg(test)]
200mod test_troika {
201 use super::*;
202
203 const HASH: [u8; 243] = [
204 0, 2, 2, 1, 2, 1, 0, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2,
205 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 0, 1, 2, 0, 2, 1, 0, 0, 2, 1, 1, 1, 1, 1, 2, 0, 1, 0, 2, 1,
206 1, 2, 0, 1, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 1, 2, 2, 0, 2, 1, 1, 2, 1,
207 1, 1, 2, 2, 1, 1, 0, 0, 0, 2, 2, 2, 0, 2, 1, 1, 1, 1, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0,
208 0, 1, 1, 1, 0, 2, 1, 1, 1, 0, 2, 0, 0, 1, 0, 1, 0, 2, 0, 2, 2, 0, 0, 2, 2, 0, 1, 2, 1, 0,
209 0, 1, 2, 1, 1, 0, 0, 1, 1, 0, 2, 1, 1, 0, 1, 2, 0, 0, 0, 1, 2, 2, 1, 1, 1, 0, 0, 2, 0, 1,
210 1, 2, 1, 1, 2, 1, 0, 1, 2, 2, 2, 2, 1, 2, 0, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 0, 2, 1,
211 0, 1, 1, 1, 0, 2, 2, 0, 0, 2, 0, 2, 0, 1, 2, 0, 0, 2, 2, 1, 1, 2, 0, 1, 0, 0, 0, 0, 2, 0,
212 2, 2, 2,
213 ];
214
215 #[test]
216 fn test_hash() {
217 let mut troika = Troika::default();
218 let mut output = [0u8; 243];
219 let input = [0u8; 243];
220 troika.absorb(&input);
221 troika.squeeze(&mut output);
222
223 assert!(
224 output.iter().zip(HASH.iter()).all(|(a, b)| a == b),
225 "Arrays are not equal"
226 );
227 }
228}