snarkvm_circuit_types_scalar/
lib.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#![forbid(unsafe_code)]
17#![allow(clippy::too_many_arguments)]
18#![cfg_attr(test, allow(clippy::assertions_on_result_states))]
19
20extern crate snarkvm_console_types_scalar as console;
21
22mod helpers;
23
24pub mod add;
25pub mod compare;
26pub mod equal;
27pub mod ternary;
28
29#[cfg(test)]
30use console::{TestRng, Uniform};
31#[cfg(test)]
32use snarkvm_circuit_environment::{assert_count, assert_output_mode, assert_scope, count, output_mode};
33
34use snarkvm_circuit_environment::prelude::*;
35use snarkvm_circuit_types_boolean::Boolean;
36use snarkvm_circuit_types_field::Field;
37
38use std::cell::OnceCell;
39
40#[derive(Clone)]
41pub struct Scalar<E: Environment> {
42    /// The primary representation of the scalar element.
43    field: Field<E>,
44    /// An optional secondary representation in little-endian bits is provided,
45    /// so that calls to `ToBits` only incur constraint costs once.
46    bits_le: OnceCell<Vec<Boolean<E>>>,
47}
48
49impl<E: Environment> ScalarTrait for Scalar<E> {}
50
51impl<E: Environment> Inject for Scalar<E> {
52    type Primitive = console::Scalar<E::Network>;
53
54    /// Initializes a scalar circuit from a console scalar.
55    fn new(mode: Mode, scalar: Self::Primitive) -> Self {
56        debug_assert!(console::Scalar::<E::Network>::size_in_bits() < console::Field::<E::Network>::size_in_bits());
57
58        // Initialize the scalar as a field element.
59        match console::ToField::to_field(&scalar) {
60            Ok(field) => Self { field: Field::new(mode, field), bits_le: OnceCell::new() },
61            Err(error) => E::halt(format!("Unable to initialize a scalar circuit as a field element: {error}")),
62        }
63    }
64}
65
66impl<E: Environment> Eject for Scalar<E> {
67    type Primitive = console::Scalar<E::Network>;
68
69    /// Ejects the mode of the scalar.
70    fn eject_mode(&self) -> Mode {
71        self.field.eject_mode()
72    }
73
74    /// Ejects the scalar circuit as a console scalar.
75    fn eject_value(&self) -> Self::Primitive {
76        match console::Scalar::<E::Network>::from_bits_le(&self.field.eject_value().to_bits_le()) {
77            Ok(scalar) => scalar,
78            Err(error) => E::halt(format!("Failed to eject scalar value: {error}")),
79        }
80    }
81}
82
83impl<E: Environment> Parser for Scalar<E> {
84    /// Parses a string into a scalar circuit.
85    #[inline]
86    fn parse(string: &str) -> ParserResult<Self> {
87        // Parse the scalar from the string.
88        let (string, scalar) = console::Scalar::parse(string)?;
89        // Parse the mode from the string.
90        let (string, mode) = opt(pair(tag("."), Mode::parse))(string)?;
91
92        match mode {
93            Some((_, mode)) => Ok((string, Scalar::new(mode, scalar))),
94            None => Ok((string, Scalar::new(Mode::Constant, scalar))),
95        }
96    }
97}
98
99impl<E: Environment> FromStr for Scalar<E> {
100    type Err = Error;
101
102    /// Parses a string into a scalar circuit.
103    #[inline]
104    fn from_str(string: &str) -> Result<Self> {
105        match Self::parse(string) {
106            Ok((remainder, object)) => {
107                // Ensure the remainder is empty.
108                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
109                // Return the object.
110                Ok(object)
111            }
112            Err(error) => bail!("Failed to parse string. {error}"),
113        }
114    }
115}
116
117impl<E: Environment> TypeName for Scalar<E> {
118    /// Returns the type name of the circuit as a string.
119    #[inline]
120    fn type_name() -> &'static str {
121        console::Scalar::<E::Network>::type_name()
122    }
123}
124
125impl<E: Environment> Debug for Scalar<E> {
126    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
127        Display::fmt(self, f)
128    }
129}
130
131impl<E: Environment> Display for Scalar<E> {
132    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
133        write!(f, "{}.{}", self.eject_value(), self.eject_mode())
134    }
135}
136
137impl<E: Environment> From<Scalar<E>> for LinearCombination<E::BaseField> {
138    fn from(scalar: Scalar<E>) -> Self {
139        From::from(&scalar)
140    }
141}
142
143impl<E: Environment> From<&Scalar<E>> for LinearCombination<E::BaseField> {
144    fn from(scalar: &Scalar<E>) -> Self {
145        scalar.to_field().into()
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use snarkvm_circuit_environment::Circuit;
153
154    use core::str::FromStr;
155
156    const ITERATIONS: u64 = 250;
157
158    fn check_new(
159        name: &str,
160        expected: console::Scalar<<Circuit as Environment>::Network>,
161        mode: Mode,
162        num_constants: u64,
163        num_public: u64,
164        num_private: u64,
165        num_constraints: u64,
166    ) {
167        Circuit::scope(name, || {
168            let candidate = Scalar::<Circuit>::new(mode, expected);
169            assert_eq!(expected, candidate.eject_value());
170            assert_scope!(num_constants, num_public, num_private, num_constraints);
171        });
172    }
173
174    /// Attempts to construct a field from the given element and mode,
175    /// format it in display mode, and recover a field from it.
176    fn check_display(mode: Mode, element: console::Scalar<<Circuit as Environment>::Network>) -> Result<()> {
177        let candidate = Scalar::<Circuit>::new(mode, element);
178        assert_eq!(format!("{element}.{mode}"), format!("{candidate}"));
179
180        let candidate_recovered = Scalar::<Circuit>::from_str(&format!("{candidate}"))?;
181        assert_eq!(candidate.eject_value(), candidate_recovered.eject_value());
182        Ok(())
183    }
184
185    #[test]
186    fn test_new_constant() {
187        let expected = Uniform::rand(&mut TestRng::default());
188        check_new("Constant", expected, Mode::Constant, 1, 0, 0, 0);
189    }
190
191    #[test]
192    fn test_new_public() {
193        let expected = Uniform::rand(&mut TestRng::default());
194        check_new("Public", expected, Mode::Public, 0, 1, 0, 0);
195    }
196
197    #[test]
198    fn test_new_private() {
199        let expected = Uniform::rand(&mut TestRng::default());
200        check_new("Private", expected, Mode::Private, 0, 0, 1, 0);
201    }
202
203    #[test]
204    fn test_display() -> Result<()> {
205        let mut rng = TestRng::default();
206
207        for _ in 0..ITERATIONS {
208            let element = Uniform::rand(&mut rng);
209
210            // Constant
211            check_display(Mode::Constant, element)?;
212            // Public
213            check_display(Mode::Public, element)?;
214            // Private
215            check_display(Mode::Private, element)?;
216        }
217        Ok(())
218    }
219
220    #[test]
221    fn test_display_zero() {
222        let zero = console::Scalar::<<Circuit as Environment>::Network>::zero();
223
224        // Constant
225        let candidate = Scalar::<Circuit>::new(Mode::Constant, zero);
226        assert_eq!("0scalar.constant", &format!("{candidate}"));
227
228        // Public
229        let candidate = Scalar::<Circuit>::new(Mode::Public, zero);
230        assert_eq!("0scalar.public", &format!("{candidate}"));
231
232        // Private
233        let candidate = Scalar::<Circuit>::new(Mode::Private, zero);
234        assert_eq!("0scalar.private", &format!("{candidate}"));
235    }
236
237    #[test]
238    fn test_display_one() {
239        let one = console::Scalar::<<Circuit as Environment>::Network>::one();
240
241        // Constant
242        let candidate = Scalar::<Circuit>::new(Mode::Constant, one);
243        assert_eq!("1scalar.constant", &format!("{candidate}"));
244
245        // Public
246        let candidate = Scalar::<Circuit>::new(Mode::Public, one);
247        assert_eq!("1scalar.public", &format!("{candidate}"));
248
249        // Private
250        let candidate = Scalar::<Circuit>::new(Mode::Private, one);
251        assert_eq!("1scalar.private", &format!("{candidate}"));
252    }
253
254    #[test]
255    fn test_display_two() {
256        let one = console::Scalar::<<Circuit as Environment>::Network>::one();
257        let two = one + one;
258
259        // Constant
260        let candidate = Scalar::<Circuit>::new(Mode::Constant, two);
261        assert_eq!("2scalar.constant", &format!("{candidate}"));
262
263        // Public
264        let candidate = Scalar::<Circuit>::new(Mode::Public, two);
265        assert_eq!("2scalar.public", &format!("{candidate}"));
266
267        // Private
268        let candidate = Scalar::<Circuit>::new(Mode::Private, two);
269        assert_eq!("2scalar.private", &format!("{candidate}"));
270    }
271
272    #[test]
273    fn test_parser() {
274        type Primitive = console::Scalar<<Circuit as Environment>::Network>;
275
276        // Constant
277
278        let (_, candidate) = Scalar::<Circuit>::parse("5scalar").unwrap();
279        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
280        assert!(candidate.is_constant());
281
282        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar").unwrap();
283        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
284        assert!(candidate.is_constant());
285
286        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar").unwrap();
287        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
288        assert!(candidate.is_constant());
289
290        let (_, candidate) = Scalar::<Circuit>::parse("5scalar.constant").unwrap();
291        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
292        assert!(candidate.is_constant());
293
294        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar.constant").unwrap();
295        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
296        assert!(candidate.is_constant());
297
298        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar.constant").unwrap();
299        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
300        assert!(candidate.is_constant());
301
302        // Public
303
304        let (_, candidate) = Scalar::<Circuit>::parse("5scalar.public").unwrap();
305        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
306        assert!(candidate.is_public());
307
308        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar.public").unwrap();
309        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
310        assert!(candidate.is_public());
311
312        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar.public").unwrap();
313        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
314        assert!(candidate.is_public());
315
316        // Private
317
318        let (_, candidate) = Scalar::<Circuit>::parse("5scalar.private").unwrap();
319        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
320        assert!(candidate.is_private());
321
322        let (_, candidate) = Scalar::<Circuit>::parse("5_scalar.private").unwrap();
323        assert_eq!(Primitive::from_str("5scalar").unwrap(), candidate.eject_value());
324        assert!(candidate.is_private());
325
326        let (_, candidate) = Scalar::<Circuit>::parse("1_5_scalar.private").unwrap();
327        assert_eq!(Primitive::from_str("15scalar").unwrap(), candidate.eject_value());
328        assert!(candidate.is_private());
329
330        // Random
331
332        let mut rng = TestRng::default();
333
334        for mode in [Mode::Constant, Mode::Public, Mode::Private] {
335            for _ in 0..ITERATIONS {
336                let value = Uniform::rand(&mut rng);
337                let expected = Scalar::<Circuit>::new(mode, value);
338
339                let (_, candidate) = Scalar::<Circuit>::parse(&format!("{expected}")).unwrap();
340                assert_eq!(expected.eject_value(), candidate.eject_value());
341                assert_eq!(mode, candidate.eject_mode());
342            }
343        }
344    }
345}