1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
// Copyright (C) 2019-2023 Aleo Systems Inc.
// This file is part of the snarkVM library.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::traits::{FftParameters, Field};

/// The interface for fields that are able to be used in FFTs.
pub trait FftField: Field + From<u128> + From<u64> + From<u32> + From<u16> + From<u8> {
    type FftParameters: FftParameters;

    /// Returns the 2^s root of unity.
    fn two_adic_root_of_unity() -> Self;

    /// Returns the 2^s * small_subgroup_base^small_subgroup_base_adicity root of unity
    /// if a small subgroup is defined.
    fn large_subgroup_root_of_unity() -> Option<Self>;

    /// Returns the multiplicative generator of `char()` - 1 order.
    fn multiplicative_generator() -> Self;

    /// Returns the root of unity of order n, if one exists.
    /// If no small multiplicative subgroup is defined, this is the 2-adic root of unity of order n
    /// (for n a power of 2).
    /// If a small multiplicative subgroup is defined, this is the root of unity of order n for
    /// the larger subgroup generated by `FftParams::LARGE_SUBGROUP_ROOT_OF_UNITY`
    /// (for n = 2^i * FftParams::SMALL_SUBGROUP_BASE^j for some i, j).
    fn get_root_of_unity(n: usize) -> Option<Self> {
        let mut omega: Self;
        if let Some(large_subgroup_root_of_unity) = Self::large_subgroup_root_of_unity() {
            let q = Self::FftParameters::SMALL_SUBGROUP_BASE
                .expect("LARGE_SUBGROUP_ROOT_OF_UNITY should only be set in conjunction with SMALL_SUBGROUP_BASE")
                as usize;
            let small_subgroup_base_adicity = Self::FftParameters::SMALL_SUBGROUP_BASE_ADICITY.expect(
                "LARGE_SUBGROUP_ROOT_OF_UNITY should only be set in conjunction with SMALL_SUBGROUP_BASE_ADICITY",
            );

            let q_adicity = Self::k_adicity(q, n);
            let q_part = q.pow(q_adicity);

            let two_adicity = Self::k_adicity(2, n);
            let two_part = 1 << two_adicity;

            if n != two_part * q_part
                || (two_adicity > Self::FftParameters::TWO_ADICITY)
                || (q_adicity > small_subgroup_base_adicity)
            {
                return None;
            }

            omega = large_subgroup_root_of_unity;
            for _ in q_adicity..small_subgroup_base_adicity {
                omega = omega.pow([q as u64]);
            }

            for _ in two_adicity..Self::FftParameters::TWO_ADICITY {
                omega.square_in_place();
            }
        } else {
            // Compute the next power of 2.
            let size = n.checked_next_power_of_two()? as u64;
            let log_size_of_group = size.trailing_zeros();

            if n != size as usize || log_size_of_group > Self::FftParameters::TWO_ADICITY {
                return None;
            }

            // Compute the generator for the multiplicative subgroup.
            // It should be 2^(log_size_of_group) root of unity.
            omega = Self::two_adic_root_of_unity();
            for _ in log_size_of_group..Self::FftParameters::TWO_ADICITY {
                omega.square_in_place();
            }
        }
        Some(omega)
    }

    /// Calculates the k-adicity of n, i.e., the number of trailing 0s in a base-k
    /// representation.
    fn k_adicity(k: usize, mut n: usize) -> u32 {
        let mut r = 0;
        while n > 1 {
            if n % k == 0 {
                r += 1;
                n /= k;
            } else {
                return r;
            }
        }
        r
    }
}