sputnikvm_precompiled_modexp/
lib.rs

1extern crate bigint;
2extern crate num_bigint;
3extern crate sputnikvm;
4
5#[cfg(test)]
6extern crate hexutil;
7
8use std::rc::Rc;
9use bigint::{Gas, U256};
10
11use sputnikvm::Precompiled;
12use sputnikvm::errors::{OnChainError, RuntimeError, NotSupportedError};
13
14pub static MODEXP_PRECOMPILED: ModexpPrecompiled = ModexpPrecompiled;
15
16pub struct ModexpPrecompiled;
17impl Precompiled for ModexpPrecompiled {
18    fn gas_and_step(&self, data: &[u8], gas_limit: Gas) -> Result<(Gas, Rc<Vec<u8>>), RuntimeError> {
19        use std::cmp;
20        use num_bigint::BigUint;
21
22        fn adjusted_exponent_length(exponent_length: U256, base_length: U256, data: &[u8]) -> U256 {
23            let mut exp32_arr = Vec::new();
24            for i in 0..32 {
25                if U256::from(96) + base_length + U256::from(i) >= U256::from(data.len()) {
26                    exp32_arr.push(0u8);
27                } else {
28                    let base_length_usize: usize = base_length.as_usize();
29                    let data_i: usize = 96 + base_length_usize + i;
30                    exp32_arr.push(data[data_i]);
31                }
32            }
33            let exp32 = U256::from(exp32_arr.as_slice());
34
35            if exponent_length <= U256::from(32) && exp32 == U256::zero() {
36                U256::zero()
37            } else if exponent_length <= U256::from(32) {
38                U256::from(exp32.bits())
39            } else {
40                U256::from(8) * (exponent_length - U256::from(32)) + U256::from(exp32.bits())
41            }
42        }
43
44        fn mult_complexity(x: U256) -> Result<U256, RuntimeError> {
45            if x <= U256::from(64) {
46                Ok(x * x)
47            } else if x <= U256::from(1024) {
48                Ok(x * x / U256::from(4) + U256::from(96) * x - U256::from(3072))
49            } else {
50                let (sqr, o) = x.overflowing_mul(x);
51                if o {
52                    Err(RuntimeError::OnChain(OnChainError::EmptyGas))
53                } else {
54                    Ok(sqr / U256::from(16) + U256::from(480) * x - U256::from(199680))
55                }
56            }
57        }
58
59        // Padding data to be at least 32 * 3 bytes.
60        let mut data: Vec<u8> = data.into();
61        while data.len() < 32 * 3 {
62            data.push(0);
63        }
64
65        let base_length = U256::from(&data[0..32]);
66        let exponent_length = U256::from(&data[32..64]);
67        let modulus_length = U256::from(&data[64..96]);
68
69        let op1 = mult_complexity(cmp::max(modulus_length, base_length))?;
70        let op2 = cmp::max(adjusted_exponent_length(exponent_length, base_length, &data), U256::from(1)) / U256::from(20);
71        let (r, o) = op1.overflowing_mul(op2);
72        if o {
73            return Err(RuntimeError::OnChain(OnChainError::EmptyGas));
74        }
75        let gas: Gas = r.into();
76
77        if gas > gas_limit {
78            return Err(RuntimeError::OnChain(OnChainError::EmptyGas));
79        }
80
81        if base_length > U256::from(usize::max_value()) ||
82            exponent_length > U256::from(usize::max_value()) ||
83            modulus_length > U256::from(usize::max_value())
84        {
85            return Err(RuntimeError::NotSupported(NotSupportedError::MemoryIndexNotSupported));
86        }
87
88        let base_length: usize = base_length.as_usize();
89        let exponent_length: usize = exponent_length.as_usize();
90        let modulus_length: usize = modulus_length.as_usize();
91
92        let mut base_arr = Vec::new();
93        let mut exponent_arr = Vec::new();
94        let mut modulus_arr = Vec::new();
95
96        for i in 0..base_length {
97            if 96 + i >= data.len() {
98                base_arr.push(0u8);
99            } else {
100                base_arr.push(data[96 + i]);
101            }
102        }
103        for i in 0..exponent_length {
104            if 96 + base_length + i >= data.len() {
105                exponent_arr.push(0u8);
106            } else {
107                exponent_arr.push(data[96 + base_length + i]);
108            }
109        }
110        for i in 0..modulus_length {
111            if 96 + base_length + exponent_length + i >= data.len() {
112                modulus_arr.push(0u8);
113            } else {
114                modulus_arr.push(data[96 + base_length + exponent_length + i]);
115            }
116        }
117
118        let base = BigUint::from_bytes_be(&base_arr);
119        let exponent = BigUint::from_bytes_be(&exponent_arr);
120        let modulus = BigUint::from_bytes_be(&modulus_arr);
121
122        let mut result = base.modpow(&exponent, &modulus).to_bytes_be();
123        assert!(result.len() <= modulus_length);
124        while result.len() < modulus_length {
125            result.insert(0, 0u8);
126        }
127
128        Ok((gas, Rc::new(result)))
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use ::*;
135    use hexutil::*;
136
137    #[test]
138    fn spec_test1() {
139        let input = read_hex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000002003fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2efffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f").unwrap();
140        let (_, output) = MODEXP_PRECOMPILED.gas_and_step(&input, Gas::from(10000000usize)).unwrap();
141        let expected = read_hex("0000000000000000000000000000000000000000000000000000000000000001").unwrap();
142        assert_eq!(expected, Rc::try_unwrap(output).unwrap());
143    }
144
145    #[test]
146    fn spec_test2() {
147        let input = read_hex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000020fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2efffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f").unwrap();
148        let (_, output) = MODEXP_PRECOMPILED.gas_and_step(&input, Gas::from(10000000usize)).unwrap();
149        let expected = read_hex("0000000000000000000000000000000000000000000000000000000000000000").unwrap();
150        assert_eq!(expected, Rc::try_unwrap(output).unwrap());
151    }
152
153    #[test]
154    fn spec_test3() {
155        let input = read_hex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000020fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd").unwrap();
156        match MODEXP_PRECOMPILED.gas_and_step(&input, Gas::from(10000000usize)) {
157            Ok(_) => panic!(),
158            Err(_) => (),
159        }
160    }
161
162    #[test]
163    fn spec_test4() {
164        let input = read_hex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000002003ffff800000000000000000000000000000000000000000000000000000000000000007").unwrap();
165        let (_, output) = MODEXP_PRECOMPILED.gas_and_step(&input, Gas::from(10000000usize)).unwrap();
166        let expected = read_hex("3b01b01ac41f2d6e917c6d6a221ce793802469026d9ab7578fa2e79e4da6aaab").unwrap();
167        assert_eq!(expected, Rc::try_unwrap(output).unwrap());
168    }
169
170    #[test]
171    fn sepc_test5() {
172        let input = read_hex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000002003ffff80").unwrap();
173        let (_, output) = MODEXP_PRECOMPILED.gas_and_step(&input, Gas::from(10000000usize)).unwrap();
174        let expected = read_hex("3b01b01ac41f2d6e917c6d6a221ce793802469026d9ab7578fa2e79e4da6aaab").unwrap();
175        assert_eq!(expected, Rc::try_unwrap(output).unwrap());
176    }
177}