winter_math/utils/mod.rs
1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::vec::Vec;
7
8#[cfg(feature = "concurrent")]
9use utils::iterators::*;
10use utils::{batch_iter_mut, iter_mut, uninit_vector};
11
12use crate::{field::FieldElement, ExtensionOf};
13
14// MATH FUNCTIONS
15// ================================================================================================
16
17/// Returns a vector containing successive powers of a given base.
18///
19/// More precisely, for base `b`, generates a vector with values [1, b, b^2, b^3, ..., b^(n-1)].
20///
21/// When `concurrent` feature is enabled, series generation is done concurrently in multiple
22/// threads.
23///
24/// # Examples
25/// ```
26/// # use winter_math::get_power_series;
27/// # use winter_math::{fields::{f128::BaseElement}, FieldElement};
28/// let n = 2048;
29/// let b = BaseElement::from(3u8);
30///
31/// let expected = (0..n).map(|p| b.exp((p as u64).into())).collect::<Vec<_>>();
32///
33/// let actual = get_power_series(b, n);
34/// assert_eq!(expected, actual);
35/// ```
36pub fn get_power_series<E>(b: E, n: usize) -> Vec<E>
37where
38 E: FieldElement,
39{
40 let mut result = unsafe { uninit_vector(n) };
41 batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| {
42 let start = b.exp((batch_offset as u64).into());
43 fill_power_series(batch, b, start);
44 });
45 result
46}
47
48/// Returns a vector containing successive powers of a given base offset by the specified value.
49///
50/// More precisely, for base `b` and offset `s`, generates a vector with values
51/// [s, s * b, s * b^2, s * b^3, ..., s * b^(n-1)].
52///
53/// When `concurrent` feature is enabled, series generation is done concurrently in multiple
54/// threads.
55///
56/// # Examples
57/// ```
58/// # use winter_math::get_power_series_with_offset;
59/// # use winter_math::{fields::{f128::BaseElement}, FieldElement};
60/// let n = 2048;
61/// let b = BaseElement::from(3u8);
62/// let s = BaseElement::from(7u8);
63///
64/// let expected = (0..n).map(|p| s * b.exp((p as u64).into())).collect::<Vec<_>>();
65///
66/// let actual = get_power_series_with_offset(b, s, n);
67/// assert_eq!(expected, actual);
68/// ```
69pub fn get_power_series_with_offset<E>(b: E, s: E, n: usize) -> Vec<E>
70where
71 E: FieldElement,
72{
73 let mut result = unsafe { uninit_vector(n) };
74 batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| {
75 let start = s * b.exp((batch_offset as u64).into());
76 fill_power_series(batch, b, start);
77 });
78 result
79}
80
81/// Computes element-wise sum of the provided vectors, and stores the result in the first vector.
82///
83/// When `concurrent` feature is enabled, the summation is performed concurrently in multiple
84/// threads.
85///
86/// # Panics
87/// Panics if lengths of `a` and `b` vectors are not the same.
88///
89/// # Examples
90/// ```
91/// # use winter_math::add_in_place;
92/// # use winter_math::{fields::{f128::BaseElement}, FieldElement};
93/// # use rand_utils::rand_vector;
94/// let a: Vec<BaseElement> = rand_vector(2048);
95/// let b: Vec<BaseElement> = rand_vector(2048);
96///
97/// let mut c = a.clone();
98/// add_in_place(&mut c, &b);
99///
100/// for ((a, b), c) in a.into_iter().zip(b).zip(c) {
101/// assert_eq!(a + b, c);
102/// }
103/// ```
104pub fn add_in_place<E>(a: &mut [E], b: &[E])
105where
106 E: FieldElement,
107{
108 assert!(a.len() == b.len(), "number of values must be the same for both operands");
109 iter_mut!(a).zip(b).for_each(|(a, &b)| *a += b);
110}
111
112/// Multiplies a sequence of values by a scalar and accumulates the results.
113///
114/// More precisely, computes `a[i]` + `b[i]` * `c` for all `i` and saves result into `a[i]`.
115///
116/// When `concurrent` feature is enabled, the computation is performed concurrently in multiple
117/// threads.
118///
119/// # Panics
120/// Panics if lengths of `a` and `b` slices are not the same.
121///
122/// # Examples
123/// ```
124/// # use winter_math::mul_acc;
125/// # use winter_math::{fields::{f128::BaseElement}, FieldElement};
126/// # use rand_utils::rand_vector;
127/// let a: Vec<BaseElement> = rand_vector(2048);
128/// let b: Vec<BaseElement> = rand_vector(2048);
129/// let c = BaseElement::new(12345);
130///
131/// let mut d = a.clone();
132/// mul_acc(&mut d, &b, c);
133///
134/// for ((a, b), d) in a.into_iter().zip(b).zip(d) {
135/// assert_eq!(a + b * c, d);
136/// }
137/// ```
138pub fn mul_acc<F, E>(a: &mut [E], b: &[F], c: E)
139where
140 F: FieldElement,
141 E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
142{
143 assert!(a.len() == b.len(), "number of values must be the same for both slices");
144 iter_mut!(a).zip(b).for_each(|(a, &b)| *a += c.mul_base(b));
145}
146
147/// Computes a multiplicative inverse of a sequence of elements using batch inversion method.
148///
149/// Any ZEROs in the provided sequence are ignored.
150///
151/// When `concurrent` feature is enabled, the inversion is performed concurrently in multiple
152/// threads.
153///
154/// This function is significantly faster than inverting elements one-by-one because it
155/// essentially transforms `n` inversions into `3 * n` multiplications + 1 inversion.
156///
157/// # Examples
158/// ```
159/// # use winter_math::batch_inversion;
160/// # use winter_math::{fields::{f128::BaseElement}, FieldElement};
161/// # use rand_utils::rand_vector;
162/// let a: Vec<BaseElement> = rand_vector(2048);
163/// let b = batch_inversion(&a);
164///
165/// for (&a, &b) in a.iter().zip(b.iter()) {
166/// assert_eq!(a.inv(), b);
167/// }
168/// ```
169pub fn batch_inversion<E>(values: &[E]) -> Vec<E>
170where
171 E: FieldElement,
172{
173 let mut result: Vec<E> = unsafe { uninit_vector(values.len()) };
174 batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| {
175 let start = batch_offset;
176 let end = start + batch.len();
177 serial_batch_inversion(&values[start..end], batch);
178 });
179 result
180}
181
182// HELPER FUNCTIONS
183// ------------------------------------------------------------------------------------------------
184
185#[inline(always)]
186fn fill_power_series<E: FieldElement>(result: &mut [E], base: E, start: E) {
187 result[0] = start;
188 for i in 1..result.len() {
189 result[i] = result[i - 1] * base;
190 }
191}
192
193fn serial_batch_inversion<E: FieldElement>(values: &[E], result: &mut [E]) {
194 let mut last = E::ONE;
195 for (result, &value) in result.iter_mut().zip(values.iter()) {
196 *result = last;
197 if value != E::ZERO {
198 last *= value;
199 }
200 }
201
202 last = last.inv();
203
204 for i in (0..values.len()).rev() {
205 if values[i] == E::ZERO {
206 result[i] = E::ZERO;
207 } else {
208 result[i] *= last;
209 last *= values[i];
210 }
211 }
212}