snarkvm_circuit_types_scalar/
lib.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#![forbid(unsafe_code)]
17#![allow(clippy::too_many_arguments)]
18#![cfg_attr(test, allow(clippy::assertions_on_result_states))]
19
20mod helpers;
21
22pub mod add;
23pub mod compare;
24pub mod equal;
25pub mod ternary;
26
27#[cfg(test)]
28use console::{TestRng, Uniform};
29#[cfg(test)]
30use snarkvm_circuit_environment::{assert_count, assert_output_mode, assert_scope, count, output_mode};
31
32use snarkvm_circuit_environment::prelude::*;
33use snarkvm_circuit_types_boolean::Boolean;
34use snarkvm_circuit_types_field::Field;
35
36#[derive(Clone)]
37pub struct Scalar<E: Environment> {
38    /// The primary representation of the scalar element.
39    field: Field<E>,
40    /// An optional secondary representation in little-endian bits is provided,
41    /// so that calls to `ToBits` only incur constraint costs once.
42    bits_le: OnceCell<Vec<Boolean<E>>>,
43}
44
45impl<E: Environment> ScalarTrait for Scalar<E> {}
46
47#[cfg(feature = "console")]
48impl<E: Environment> Inject for Scalar<E> {
49    type Primitive = console::Scalar<E::Network>;
50
51    /// Initializes a scalar circuit from a console scalar.
52    fn new(mode: Mode, scalar: Self::Primitive) -> Self {
53        // Note: We are reconstituting the scalar field into a base field.
54        // This is safe as the scalar field modulus is less than the base field modulus,
55        // and thus will always fit within a single base field element.
56        debug_assert!(console::Scalar::<E::Network>::size_in_bits() < console::Field::<E::Network>::size_in_bits());
57
58        // Initialize the scalar as a field element.
59        match console::ToField::to_field(&scalar) {
60            Ok(field) => Self { field: Field::new(mode, field), bits_le: OnceCell::new() },
61            Err(error) => E::halt(format!("Unable to initialize a scalar circuit as a field element: {error}")),
62        }
63    }
64}
65
66#[cfg(feature = "console")]
67impl<E: Environment> Eject for Scalar<E> {
68    type Primitive = console::Scalar<E::Network>;
69
70    /// Ejects the mode of the scalar.
71    fn eject_mode(&self) -> Mode {
72        self.field.eject_mode()
73    }
74
75    /// Ejects the scalar circuit as a console scalar.
76    fn eject_value(&self) -> Self::Primitive {
77        match console::Scalar::<E::Network>::from_bits_le(&self.field.eject_value().to_bits_le()) {
78            Ok(scalar) => scalar,
79            Err(error) => E::halt(format!("Failed to eject scalar value: {error}")),
80        }
81    }
82}
83
84#[cfg(feature = "console")]
85impl<E: Environment> Parser for Scalar<E> {
86    /// Parses a string into a scalar circuit.
87    #[inline]
88    fn parse(string: &str) -> ParserResult<Self> {
89        // Parse the scalar from the string.
90        let (string, scalar) = console::Scalar::parse(string)?;
91        // Parse the mode from the string.
92        let (string, mode) = opt(pair(tag("."), Mode::parse))(string)?;
93
94        match mode {
95            Some((_, mode)) => Ok((string, Scalar::new(mode, scalar))),
96            None => Ok((string, Scalar::new(Mode::Constant, scalar))),
97        }
98    }
99}
100
101#[cfg(feature = "console")]
102impl<E: Environment> FromStr for Scalar<E> {
103    type Err = Error;
104
105    /// Parses a string into a scalar circuit.
106    #[inline]
107    fn from_str(string: &str) -> Result<Self> {
108        match Self::parse(string) {
109            Ok((remainder, object)) => {
110                // Ensure the remainder is empty.
111                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
112                // Return the object.
113                Ok(object)
114            }
115            Err(error) => bail!("Failed to parse string. {error}"),
116        }
117    }
118}
119
120#[cfg(feature = "console")]
121impl<E: Environment> TypeName for Scalar<E> {
122    /// Returns the type name of the circuit as a string.
123    #[inline]
124    fn type_name() -> &'static str {
125        console::Scalar::<E::Network>::type_name()
126    }
127}
128
129#[cfg(feature = "console")]
130impl<E: Environment> Debug for Scalar<E> {
131    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
132        Display::fmt(self, f)
133    }
134}
135
136#[cfg(feature = "console")]
137impl<E: Environment> Display for Scalar<E> {
138    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
139        write!(f, "{}.{}", self.eject_value(), self.eject_mode())
140    }
141}
142
143impl<E: Environment> From<Scalar<E>> for LinearCombination<E::BaseField> {
144    fn from(scalar: Scalar<E>) -> Self {
145        From::from(&scalar)
146    }
147}
148
149impl<E: Environment> From<&Scalar<E>> for LinearCombination<E::BaseField> {
150    fn from(scalar: &Scalar<E>) -> Self {
151        scalar.to_field().into()
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use snarkvm_circuit_environment::Circuit;
159
160    use core::str::FromStr;
161
162    const ITERATIONS: u64 = 250;
163
164    fn check_new(
165        name: &str,
166        expected: console::Scalar<<Circuit as Environment>::Network>,
167        mode: Mode,
168        num_constants: u64,
169        num_public: u64,
170        num_private: u64,
171        num_constraints: u64,
172    ) {
173        Circuit::scope(name, || {
174            let candidate = Scalar::<Circuit>::new(mode, expected);
175            assert_eq!(expected, candidate.eject_value());
176            assert_scope!(num_constants, num_public, num_private, num_constraints);
177        });
178    }
179
180    /// Attempts to construct a field from the given element and mode,
181    /// format it in display mode, and recover a field from it.
182    fn check_display(mode: Mode, element: console::Scalar<<Circuit as Environment>::Network>) -> Result<()> {
183        let candidate = Scalar::<Circuit>::new(mode, element);
184        assert_eq!(format!("{element}.{mode}"), format!("{candidate}"));
185
186        let candidate_recovered = Scalar::<Circuit>::from_str(&format!("{candidate}"))?;
187        assert_eq!(candidate.eject_value(), candidate_recovered.eject_value());
188        Ok(())
189    }
190
191    #[test]
192    fn test_new_constant() {
193        let expected = Uniform::rand(&mut TestRng::default());
194        check_new("Constant", expected, Mode::Constant, 1, 0, 0, 0);
195    }
196
197    #[test]
198    fn test_new_public() {
199        let expected = Uniform::rand(&mut TestRng::default());
200        check_new("Public", expected, Mode::Public, 0, 1, 0, 0);
201    }
202
203    #[test]
204    fn test_new_private() {
205        let expected = Uniform::rand(&mut TestRng::default());
206        check_new("Private", expected, Mode::Private, 0, 0, 1, 0);
207    }
208
209    #[test]
210    fn test_display() -> Result<()> {
211        let mut rng = TestRng::default();
212
213        for _ in 0..ITERATIONS {
214            let element = Uniform::rand(&mut rng);
215
216            // Constant
217            check_display(Mode::Constant, element)?;
218            // Public
219            check_display(Mode::Public, element)?;
220            // Private
221            check_display(Mode::Private, element)?;
222        }
223        Ok(())
224    }
225
226    #[test]
227    fn test_display_zero() {
228        let zero = console::Scalar::<<Circuit as Environment>::Network>::zero();
229
230        // Constant
231        let candidate = Scalar::<Circuit>::new(Mode::Constant, zero);
232        assert_eq!("0scalar.constant", &format!("{candidate}"));
233
234        // Public
235        let candidate = Scalar::<Circuit>::new(Mode::Public, zero);
236        assert_eq!("0scalar.public", &format!("{candidate}"));
237
238        // Private
239        let candidate = Scalar::<Circuit>::new(Mode::Private, zero);
240        assert_eq!("0scalar.private", &format!("{candidate}"));
241    }
242
243    #[test]
244    fn test_display_one() {
245        let one = console::Scalar::<<Circuit as Environment>::Network>::one();
246
247        // Constant
248        let candidate = Scalar::<Circuit>::new(Mode::Constant, one);
249        assert_eq!("1scalar.constant", &format!("{candidate}"));
250
251        // Public
252        let candidate = Scalar::<Circuit>::new(Mode::Public, one);
253        assert_eq!("1scalar.public", &format!("{candidate}"));
254
255        // Private
256        let candidate = Scalar::<Circuit>::new(Mode::Private, one);
257        assert_eq!("1scalar.private", &format!("{candidate}"));
258    }
259
260    #[test]
261    fn test_display_two() {
262        let one = console::Scalar::<<Circuit as Environment>::Network>::one();
263        let two = one + one;
264
265        // Constant
266        let candidate = Scalar::<Circuit>::new(Mode::Constant, two);
267        assert_eq!("2scalar.constant", &format!("{candidate}"));
268
269        // Public
270        let candidate = Scalar::<Circuit>::new(Mode::Public, two);
271        assert_eq!("2scalar.public", &format!("{candidate}"));
272
273        // Private
274        let candidate = Scalar::<Circuit>::new(Mode::Private, two);
275        assert_eq!("2scalar.private", &format!("{candidate}"));
276    }
277
278    #[test]
279    fn test_parser() {
280        type Primitive = console::Scalar<<Circuit as Environment>::Network>;
281
282        // Constant
283
284        let (_, candidate) = Scalar::<Circuit>::parse("5scalar").unwrap();
285        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
286        assert!(candidate.is_constant());
287
288        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar").unwrap();
289        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
290        assert!(candidate.is_constant());
291
292        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar").unwrap();
293        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
294        assert!(candidate.is_constant());
295
296        let (_, candidate) = Scalar::<Circuit>::parse("5scalar.constant").unwrap();
297        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
298        assert!(candidate.is_constant());
299
300        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar.constant").unwrap();
301        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
302        assert!(candidate.is_constant());
303
304        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar.constant").unwrap();
305        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
306        assert!(candidate.is_constant());
307
308        // Public
309
310        let (_, candidate) = Scalar::<Circuit>::parse("5scalar.public").unwrap();
311        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
312        assert!(candidate.is_public());
313
314        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar.public").unwrap();
315        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
316        assert!(candidate.is_public());
317
318        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar.public").unwrap();
319        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
320        assert!(candidate.is_public());
321
322        // Private
323
324        let (_, candidate) = Scalar::<Circuit>::parse("5scalar.private").unwrap();
325        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
326        assert!(candidate.is_private());
327
328        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar.private").unwrap();
329        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
330        assert!(candidate.is_private());
331
332        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar.private").unwrap();
333        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
334        assert!(candidate.is_private());
335
336        // Random
337
338        let mut rng = TestRng::default();
339
340        for mode in [Mode::Constant, Mode::Public, Mode::Private] {
341            for _ in 0..ITERATIONS {
342                let value = Uniform::rand(&mut rng);
343                let expected = Scalar::<Circuit>::new(mode, value);
344
345                let (_, candidate) = Scalar::<Circuit>::parse(&format!("{expected}")).unwrap();
346                assert_eq!(expected.eject_value(), candidate.eject_value());
347                assert_eq!(mode, candidate.eject_mode());
348            }
349        }
350    }
351}