troika_rust/
troika.rs

1/*
2 * Copyright (C) 2019 Yu-Wei Wu
3 * All Rights Reserved.
4 * This is free software; you can redistribute it and/or modify it under the
5 * terms of the MIT license. A copy of the license can be found in the file
6 * "LICENSE" at the root of this distribution.
7 */
8
9use super::macros::{
10    Trit, COLUMNS, NUM_ROUNDS, NUM_SBOXES, PADDING, ROUND_CONSTANTS, ROWS, SBOX_LOOKUP,
11    SHIFT_ROWS_LANES, SLICES, SLICESIZE, STATE_SIZE, TROIKA_RATE,
12};
13use crate::Result;
14use core::fmt;
15
16/// The Troika struct is a Sponge that uses the Troika
17/// hashing algorithm.
18/// ```rust
19/// extern crate troika_rust;
20/// use troika_rust::Troika;
21/// // Create an array of 243 1s
22/// let input = [1; 243];
23/// // Create an array of 243 0s
24/// let mut out = [0; 243];
25/// let mut troika = Troika::default();
26/// troika.absorb(&input);
27/// troika.squeeze(&mut out);
28/// ```
29#[derive(Clone, Copy)]
30pub struct Troika {
31    num_rounds: usize,
32    state: [Trit; STATE_SIZE],
33}
34
35impl Default for Troika {
36    fn default() -> Troika {
37        Troika {
38            num_rounds: NUM_ROUNDS,
39            state: [0u8; STATE_SIZE],
40        }
41    }
42}
43
44impl fmt::Debug for Troika {
45    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46        write!(
47            f,
48            "Troika: [rounds: [{}], state: {:?}",
49            self.num_rounds,
50            &self.state[..],
51        )
52    }
53}
54
55impl Troika {
56    pub fn new(num_rounds: usize) -> Result<Troika> {
57        let mut troika = Troika::default();
58        troika.num_rounds = num_rounds;
59        Ok(troika)
60    }
61
62    pub fn state(&self) -> &[Trit] {
63        &self.state
64    }
65
66    pub fn reset(&mut self) {
67        self.state = [0; STATE_SIZE];
68    }
69
70    pub fn absorb(&mut self, message: &[Trit]) {
71        let mut message_length = message.len();
72        let mut message_idx = 0;
73        let mut trit_idx = 0;
74
75        while message_length >= TROIKA_RATE {
76            // Copy message block over the state
77            for trit_idx in 0..TROIKA_RATE {
78                self.state[trit_idx] = message[message_idx + trit_idx];
79            }
80            self.permutation();
81            message_length -= TROIKA_RATE;
82            message_idx += TROIKA_RATE;
83        }
84
85        // Pad last block
86        let mut last_block = [0u8; TROIKA_RATE];
87
88        // Copy over last incomplete message block
89        for _ in 0..message_length {
90            last_block[trit_idx] = message[trit_idx];
91            trit_idx += 1;
92        }
93
94        // TODO: Check trit_idx is right here
95        // Apply padding
96        last_block[trit_idx] = PADDING;
97
98        // Insert last message block
99        for trit_idx in 0..TROIKA_RATE {
100            self.state[trit_idx] = last_block[trit_idx];
101        }
102    }
103
104    pub fn squeeze(&mut self, hash: &mut [Trit]) {
105        let mut hash_length = hash.len();
106        let mut hash_idx = 0;
107
108        while hash_length >= TROIKA_RATE {
109            self.permutation();
110            // Extract rate output
111            for trit_idx in 0..TROIKA_RATE {
112                hash[hash_idx + trit_idx] = self.state[trit_idx];
113            }
114            hash_idx += TROIKA_RATE;
115            hash_length -= TROIKA_RATE;
116        }
117
118        // Check if there is a last incomplete block
119        if hash_length % TROIKA_RATE != 0 {
120            self.permutation();
121            for trit_idx in 0..hash_length {
122                hash[trit_idx] = self.state[trit_idx];
123            }
124        }
125    }
126
127    pub fn permutation(&mut self) {
128        assert!(self.num_rounds <= NUM_ROUNDS);
129
130        for round in 0..self.num_rounds {
131            self.sub_trytes();
132            self.shift_rows_lanes();
133            self.add_column_parity();
134            self.add_round_constant(round);
135        }
136    }
137
138    fn sub_trytes(&mut self) {
139        for sbox_idx in 0..NUM_SBOXES {
140            let sbox_input = 9 * self.state[3 * sbox_idx]
141                + 3 * self.state[3 * sbox_idx + 1]
142                + self.state[3 * sbox_idx + 2];
143            let mut sbox_output = SBOX_LOOKUP[sbox_input as usize];
144            self.state[3 * sbox_idx + 2] = sbox_output % 3;
145            sbox_output /= 3;
146            self.state[3 * sbox_idx + 1] = sbox_output % 3;
147            sbox_output /= 3;
148            self.state[3 * sbox_idx] = sbox_output % 3;
149        }
150    }
151
152    fn shift_rows_lanes(&mut self) {
153        let mut new_state = [0u8; STATE_SIZE];
154        for i in 0..STATE_SIZE {
155            new_state[i] = self.state[SHIFT_ROWS_LANES[i]];
156        }
157
158        self.state = new_state;
159    }
160
161    fn add_column_parity(&mut self) {
162        let mut parity = [0u8; SLICES * COLUMNS];
163
164        // First compute parity for each column
165        for slice in 0..SLICES {
166            for col in 0..COLUMNS {
167                let mut col_sum = 0;
168                for row in 0..ROWS {
169                    col_sum += self.state[SLICESIZE * slice + COLUMNS * row + col];
170                }
171                parity[COLUMNS * slice + col] = col_sum % 3;
172            }
173        }
174
175        // Add parity
176        for slice in 0..SLICES {
177            for row in 0..ROWS {
178                for col in 0..COLUMNS {
179                    let idx = SLICESIZE * slice + COLUMNS * row + col;
180                    let sum_to_add = parity[(col + 8) % 9 + COLUMNS * slice]
181                        + parity[(col + 1) % 9 + COLUMNS * ((slice + 1) % SLICES)];
182                    self.state[idx] = (self.state[idx] + sum_to_add) % 3;
183                }
184            }
185        }
186    }
187
188    fn add_round_constant(&mut self, round: usize) {
189        for slice in 0..SLICES {
190            for col in 0..COLUMNS {
191                let idx = SLICESIZE * slice + col;
192                self.state[idx] =
193                    (self.state[idx] + ROUND_CONSTANTS[round][slice * COLUMNS + col]) % 3;
194            }
195        }
196    }
197}
198
199#[cfg(test)]
200mod test_troika {
201    use super::*;
202
203    const HASH: [u8; 243] = [
204        0, 2, 2, 1, 2, 1, 0, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2,
205        1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 0, 1, 2, 0, 2, 1, 0, 0, 2, 1, 1, 1, 1, 1, 2, 0, 1, 0, 2, 1,
206        1, 2, 0, 1, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 1, 2, 2, 0, 2, 1, 1, 2, 1,
207        1, 1, 2, 2, 1, 1, 0, 0, 0, 2, 2, 2, 0, 2, 1, 1, 1, 1, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0,
208        0, 1, 1, 1, 0, 2, 1, 1, 1, 0, 2, 0, 0, 1, 0, 1, 0, 2, 0, 2, 2, 0, 0, 2, 2, 0, 1, 2, 1, 0,
209        0, 1, 2, 1, 1, 0, 0, 1, 1, 0, 2, 1, 1, 0, 1, 2, 0, 0, 0, 1, 2, 2, 1, 1, 1, 0, 0, 2, 0, 1,
210        1, 2, 1, 1, 2, 1, 0, 1, 2, 2, 2, 2, 1, 2, 0, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 0, 2, 1,
211        0, 1, 1, 1, 0, 2, 2, 0, 0, 2, 0, 2, 0, 1, 2, 0, 0, 2, 2, 1, 1, 2, 0, 1, 0, 0, 0, 0, 2, 0,
212        2, 2, 2,
213    ];
214
215    #[test]
216    fn test_hash() {
217        let mut troika = Troika::default();
218        let mut output = [0u8; 243];
219        let input = [0u8; 243];
220        troika.absorb(&input);
221        troika.squeeze(&mut output);
222
223        assert!(
224            output.iter().zip(HASH.iter()).all(|(a, b)| a == b),
225            "Arrays are not equal"
226        );
227    }
228}