use crate::traits::{
algorithms::PRFGadget,
utilities::{
alloc::AllocGadget,
bits::Xor,
boolean::Boolean,
eq::{ConditionalEqGadget, EqGadget},
integer::Integer,
select::CondSelectGadget,
uint::unsigned_integer::{UInt, UInt32, UInt8},
ToBytesGadget,
},
};
use snarkvm_algorithms::prf::Blake2s;
use snarkvm_fields::PrimeField;
use snarkvm_r1cs::{errors::SynthesisError, ConstraintSystem};
use std::borrow::Borrow;
const R1: usize = 16;
const R2: usize = 12;
const R3: usize = 8;
const R4: usize = 7;
const SIGMA: [[usize; 16]; 10] = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3],
[11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4],
[7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8],
[9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13],
[2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9],
[12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11],
[13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10],
[6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5],
[10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0],
];
#[allow(clippy::many_single_char_names)]
#[allow(clippy::too_many_arguments)]
fn mixing_g<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
v: &mut [UInt32],
a: usize,
b: usize,
c: usize,
d: usize,
x: &UInt32,
y: &UInt32,
) -> Result<(), SynthesisError> {
v[a] = UInt32::addmany(cs.ns(|| "mixing step 1"), &[v[a].clone(), v[b].clone(), x.clone()])?;
v[d] = v[d].xor(cs.ns(|| "mixing step 2"), &v[a])?.rotr(R1);
v[c] = UInt32::addmany(cs.ns(|| "mixing step 3"), &[v[c].clone(), v[d].clone()])?;
v[b] = v[b].xor(cs.ns(|| "mixing step 4"), &v[c])?.rotr(R2);
v[a] = UInt32::addmany(cs.ns(|| "mixing step 5"), &[v[a].clone(), v[b].clone(), y.clone()])?;
v[d] = v[d].xor(cs.ns(|| "mixing step 6"), &v[a])?.rotr(R3);
v[c] = UInt32::addmany(cs.ns(|| "mixing step 7"), &[v[c].clone(), v[d].clone()])?;
v[b] = v[b].xor(cs.ns(|| "mixing step 8"), &v[c])?.rotr(R4);
Ok(())
}
#[allow(clippy::many_single_char_names)]
fn blake2s_compression<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
h: &mut [UInt32],
m: &[UInt32],
t: u64,
f: bool,
) -> Result<(), SynthesisError> {
assert_eq!(h.len(), 8);
assert_eq!(m.len(), 16);
let mut v = Vec::with_capacity(16);
v.extend_from_slice(h);
v.push(UInt32::constant(0x6A09E667));
v.push(UInt32::constant(0xBB67AE85));
v.push(UInt32::constant(0x3C6EF372));
v.push(UInt32::constant(0xA54FF53A));
v.push(UInt32::constant(0x510E527F));
v.push(UInt32::constant(0x9B05688C));
v.push(UInt32::constant(0x1F83D9AB));
v.push(UInt32::constant(0x5BE0CD19));
assert_eq!(v.len(), 16);
v[12] = v[12].xor(cs.ns(|| "first xor"), &UInt32::constant(t as u32))?;
v[13] = v[13].xor(cs.ns(|| "second xor"), &UInt32::constant((t >> 32) as u32))?;
if f {
v[14] = v[14].xor(cs.ns(|| "third xor"), &UInt32::constant(u32::max_value()))?;
}
for i in 0..10 {
let mut cs = cs.ns(|| format!("round {}", i));
let s = SIGMA[i % 10];
mixing_g(cs.ns(|| "mixing invocation 1"), &mut v, 0, 4, 8, 12, &m[s[0]], &m[s[1]])?;
mixing_g(cs.ns(|| "mixing invocation 2"), &mut v, 1, 5, 9, 13, &m[s[2]], &m[s[3]])?;
mixing_g(
cs.ns(|| "mixing invocation 3"),
&mut v,
2,
6,
10,
14,
&m[s[4]],
&m[s[5]],
)?;
mixing_g(
cs.ns(|| "mixing invocation 4"),
&mut v,
3,
7,
11,
15,
&m[s[6]],
&m[s[7]],
)?;
mixing_g(
cs.ns(|| "mixing invocation 5"),
&mut v,
0,
5,
10,
15,
&m[s[8]],
&m[s[9]],
)?;
mixing_g(
cs.ns(|| "mixing invocation 6"),
&mut v,
1,
6,
11,
12,
&m[s[10]],
&m[s[11]],
)?;
mixing_g(
cs.ns(|| "mixing invocation 7"),
&mut v,
2,
7,
8,
13,
&m[s[12]],
&m[s[13]],
)?;
mixing_g(
cs.ns(|| "mixing invocation 8"),
&mut v,
3,
4,
9,
14,
&m[s[14]],
&m[s[15]],
)?;
}
for i in 0..8 {
let mut cs = cs.ns(|| format!("h[{i}] ^ v[{i}] ^ v[{i} + 8]", i = i));
h[i] = h[i].xor(cs.ns(|| "first xor"), &v[i])?;
h[i] = h[i].xor(cs.ns(|| "second xor"), &v[i + 8])?;
}
Ok(())
}
pub fn blake2s_gadget<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
input: &[Boolean],
) -> Result<Vec<UInt32>, SynthesisError> {
assert!(input.len() % 8 == 0);
let mut h = vec![
UInt32::constant(0x6A09E667 ^ 0x01010000 ^ 32),
UInt32::constant(0xBB67AE85),
UInt32::constant(0x3C6EF372),
UInt32::constant(0xA54FF53A),
UInt32::constant(0x510E527F),
UInt32::constant(0x9B05688C),
UInt32::constant(0x1F83D9AB),
UInt32::constant(0x5BE0CD19),
];
let mut blocks: Vec<Vec<UInt32>> = Vec::with_capacity(input.len() / 512);
for block in input.chunks(512) {
let mut this_block = Vec::with_capacity(16);
for word in block.chunks(32) {
let mut tmp = word.to_vec();
while tmp.len() < 32 {
tmp.push(Boolean::constant(false));
}
this_block.push(UInt32::from_bits_le(&tmp));
}
while this_block.len() < 16 {
this_block.push(UInt32::constant(0));
}
blocks.push(this_block);
}
if blocks.is_empty() {
blocks.push((0..16).map(|_| UInt32::constant(0)).collect());
}
for (i, block) in blocks[0..blocks.len() - 1].iter().enumerate() {
let cs = cs.ns(|| format!("block {}", i));
blake2s_compression(cs, &mut h, block, ((i as u64) + 1) * 64, false)?;
}
{
let cs = cs.ns(|| "final block");
blake2s_compression(cs, &mut h, &blocks[blocks.len() - 1], (input.len() / 8) as u64, true)?;
}
Ok(h)
}
pub struct Blake2sGadget;
#[derive(Clone, Debug)]
pub struct Blake2sOutputGadget(pub Vec<UInt8>);
impl PartialEq for Blake2sOutputGadget {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for Blake2sOutputGadget {}
impl<F: PrimeField> EqGadget<F> for Blake2sOutputGadget {}
impl<F: PrimeField> ConditionalEqGadget<F> for Blake2sOutputGadget {
#[inline]
fn conditional_enforce_equal<CS: ConstraintSystem<F>>(
&self,
mut cs: CS,
other: &Self,
condition: &Boolean,
) -> Result<(), SynthesisError> {
for (i, (a, b)) in self.0.iter().zip(other.0.iter()).enumerate() {
a.conditional_enforce_equal(&mut cs.ns(|| format!("blake2s_equal_{}", i)), b, condition)?;
}
Ok(())
}
fn cost() -> usize {
32 * <UInt8 as ConditionalEqGadget<F>>::cost()
}
}
impl<F: PrimeField> CondSelectGadget<F> for Blake2sOutputGadget {
fn conditionally_select<CS: ConstraintSystem<F>>(
_cs: CS,
_cond: &Boolean,
_first: &Self,
_second: &Self,
) -> Result<Self, SynthesisError> {
unimplemented!()
}
fn cost() -> usize {
unimplemented!()
}
}
impl<F: PrimeField> ToBytesGadget<F> for Blake2sOutputGadget {
#[inline]
fn to_bytes<CS: ConstraintSystem<F>>(&self, _cs: CS) -> Result<Vec<UInt8>, SynthesisError> {
Ok(self.0.clone())
}
#[inline]
fn to_bytes_strict<CS: ConstraintSystem<F>>(&self, cs: CS) -> Result<Vec<UInt8>, SynthesisError> {
self.to_bytes(cs)
}
}
impl<F: PrimeField> AllocGadget<[u8; 32], F> for Blake2sOutputGadget {
#[inline]
fn alloc<Fn: FnOnce() -> Result<T, SynthesisError>, T: Borrow<[u8; 32]>, CS: ConstraintSystem<F>>(
cs: CS,
value_gen: Fn,
) -> Result<Self, SynthesisError> {
Ok(Blake2sOutputGadget(<UInt8>::alloc_vec(cs, &match value_gen() {
Ok(val) => *(val.borrow()),
Err(_) => [0u8; 32],
})?))
}
#[inline]
fn alloc_input<Fn: FnOnce() -> Result<T, SynthesisError>, T: Borrow<[u8; 32]>, CS: ConstraintSystem<F>>(
cs: CS,
value_gen: Fn,
) -> Result<Self, SynthesisError> {
Ok(Blake2sOutputGadget(<UInt8>::alloc_input_vec_le(
cs,
&match value_gen() {
Ok(val) => *(val.borrow()),
Err(_) => [0u8; 32],
},
)?))
}
}
impl<F: PrimeField> PRFGadget<Blake2s, F> for Blake2sGadget {
type OutputGadget = Blake2sOutputGadget;
fn new_seed<CS: ConstraintSystem<F>>(mut cs: CS, seed: &[u8; 32]) -> Vec<UInt8> {
UInt8::alloc_vec(&mut cs.ns(|| "alloc_seed"), seed).unwrap()
}
fn check_evaluation_gadget<CS: ConstraintSystem<F>>(
mut cs: CS,
seed: &[UInt8],
input: &[UInt8],
) -> Result<Self::OutputGadget, SynthesisError> {
assert_eq!(seed.len(), 32);
let mut gadget_input = vec![];
for byte in seed.iter().chain(input) {
gadget_input.extend_from_slice(&byte.to_bits_le());
}
let mut result = vec![];
for (i, int) in blake2s_gadget(cs.ns(|| "blake2s_prf"), &gadget_input)?
.into_iter()
.enumerate()
{
let chunk = int.to_bytes(&mut cs.ns(|| format!("to_bytes_{}", i)))?;
result.extend_from_slice(&chunk);
}
Ok(Blake2sOutputGadget(result))
}
}