winter_air/proof/
context.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::{string::ToString, vec::Vec};
7
8use math::{StarkField, ToElements};
9use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
10
11use crate::{ProofOptions, TraceInfo};
12
13// PROOF CONTEXT
14// ================================================================================================
15/// Basic metadata about a specific execution of a computation.
16#[derive(Debug, Clone, Eq, PartialEq)]
17pub struct Context {
18    trace_info: TraceInfo,
19    field_modulus_bytes: Vec<u8>,
20    options: ProofOptions,
21    num_constraints: usize,
22}
23
24impl Context {
25    // CONSTRUCTOR
26    // --------------------------------------------------------------------------------------------
27    /// Creates a new context for a computation described by the specified field, trace info, proof
28    /// options, and total number of constraints.
29    ///
30    /// # Panics
31    /// - If either trace length or the LDE domain size implied by the trace length and the blowup
32    ///   factor is greater then [u32::MAX].
33    /// - If the number of constraints is not greater than 0.
34    /// - If the number of constraints is greater than [u32::MAX].
35    pub fn new<B: StarkField>(
36        trace_info: TraceInfo,
37        options: ProofOptions,
38        num_constraints: usize,
39    ) -> Self {
40        // TODO: return errors instead of panicking?
41
42        let trace_length = trace_info.length();
43        assert!(trace_length <= u32::MAX as usize, "trace length too big");
44
45        let lde_domain_size = trace_length * options.blowup_factor();
46        assert!(lde_domain_size <= u32::MAX as usize, "LDE domain size too big");
47
48        assert!(num_constraints > 0, "number of constraints should be greater than zero");
49        assert!(num_constraints <= u32::MAX as usize, "number of constraints is too big");
50
51        Context {
52            trace_info,
53            field_modulus_bytes: B::get_modulus_le_bytes(),
54            options,
55            num_constraints,
56        }
57    }
58
59    // PUBLIC ACCESSORS
60    // --------------------------------------------------------------------------------------------
61
62    /// Returns execution trace info for the computation described by this context.
63    pub fn trace_info(&self) -> &TraceInfo {
64        &self.trace_info
65    }
66
67    /// Returns the size of the LDE domain for the computation described by this context.
68    pub fn lde_domain_size(&self) -> usize {
69        self.trace_info.length() * self.options.blowup_factor()
70    }
71
72    /// Returns modulus of the field for the computation described by this context.
73    pub fn field_modulus_bytes(&self) -> &[u8] {
74        &self.field_modulus_bytes
75    }
76
77    /// Returns number of bits in the base field modulus for the computation described by this
78    /// context.
79    ///
80    /// The modulus is assumed to be encoded in little-endian byte order.
81    pub fn num_modulus_bits(&self) -> u32 {
82        let mut num_bits = self.field_modulus_bytes.len() as u32 * 8;
83        for &byte in self.field_modulus_bytes.iter().rev() {
84            if byte != 0 {
85                num_bits -= byte.leading_zeros();
86                return num_bits;
87            }
88            num_bits -= 8;
89        }
90
91        0
92    }
93
94    /// Returns proof options which were used to a proof in this context.
95    pub fn options(&self) -> &ProofOptions {
96        &self.options
97    }
98
99    /// Returns the total number of constraints.
100    pub fn num_constraints(&self) -> usize {
101        self.num_constraints
102    }
103}
104
105impl<E: StarkField> ToElements<E> for Context {
106    /// Converts this [Context] into a vector of field elements.
107    ///
108    /// The elements are laid out as follows:
109    /// - trace info:
110    ///   - trace segment widths and the number of aux random values [1 element].
111    ///   - trace length [1 element].
112    ///   - trace metadata [0 or more elements].
113    /// - field modulus bytes [2 field elements].
114    /// - number of constraints (1 element).
115    /// - proof options:
116    ///   - field extension, FRI parameters, and grinding factor [1 element].
117    ///   - blowup factor [1 element].
118    ///   - number of queries [1 element].
119    fn to_elements(&self) -> Vec<E> {
120        // convert trace layout
121        let mut result = self.trace_info.to_elements();
122
123        // convert field modulus bytes into 2 elements
124        let num_modulus_bytes = self.field_modulus_bytes.len();
125        let (m1, m2) = self.field_modulus_bytes.split_at(num_modulus_bytes / 2);
126        result.push(E::from_bytes_with_padding(m1));
127        result.push(E::from_bytes_with_padding(m2));
128
129        // convert the number of constraints
130        result.push(E::from(self.num_constraints as u32));
131
132        // convert proof options to elements
133        result.append(&mut self.options.to_elements());
134
135        result
136    }
137}
138
139// SERIALIZATION
140// ================================================================================================
141
142impl Serializable for Context {
143    /// Serializes `self` and writes the resulting bytes into the `target`.
144    fn write_into<W: ByteWriter>(&self, target: &mut W) {
145        self.trace_info.write_into(target);
146        assert!(self.field_modulus_bytes.len() < u8::MAX as usize);
147        target.write_u8(self.field_modulus_bytes.len() as u8);
148        target.write_bytes(&self.field_modulus_bytes);
149        self.options.write_into(target);
150        self.num_constraints.write_into(target);
151    }
152}
153
154impl Deserializable for Context {
155    /// Reads proof context from the specified `source` and returns the result.
156    ///
157    /// # Errors
158    /// Returns an error of a valid Context struct could not be read from the specified `source`.
159    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
160        // read and validate trace info
161        let trace_info = TraceInfo::read_from(source)?;
162
163        // read and validate field modulus bytes
164        let num_modulus_bytes = source.read_u8()? as usize;
165        if num_modulus_bytes == 0 {
166            return Err(DeserializationError::InvalidValue(
167                "field modulus cannot be an empty value".to_string(),
168            ));
169        }
170        let field_modulus_bytes = source.read_vec(num_modulus_bytes)?;
171
172        // read options
173        let options = ProofOptions::read_from(source)?;
174
175        // read total number of constraints
176        let num_constraints = source.read_usize()?;
177
178        Ok(Context {
179            trace_info,
180            field_modulus_bytes,
181            options,
182            num_constraints,
183        })
184    }
185}
186
187// TESTS
188// ================================================================================================
189
190#[cfg(test)]
191mod tests {
192    use math::fields::f64::BaseElement;
193
194    use super::{Context, Deserializable, ProofOptions, Serializable, ToElements, TraceInfo};
195    use crate::{options::BatchingMethod, FieldExtension};
196
197    #[test]
198    fn context_to_elements() {
199        let field_extension = FieldExtension::None;
200        let fri_folding_factor = 8;
201        let fri_remainder_max_degree = 127;
202        let grinding_factor = 20;
203        let blowup_factor = 8;
204        let num_queries = 30;
205        let num_constraints = 128;
206        let batching_constraints = BatchingMethod::Linear;
207        let batching_deep = BatchingMethod::Linear;
208
209        let main_width = 20;
210        let aux_width = 9;
211        let aux_rands = 12;
212        let trace_length = 4096;
213
214        let ext_fri = u32::from_le_bytes([
215            blowup_factor as u8,
216            fri_remainder_max_degree,
217            fri_folding_factor,
218            field_extension as u8,
219        ]);
220
221        let expected = {
222            let trace_info = TraceInfo::new_multi_segment(
223                main_width,
224                aux_width,
225                aux_rands,
226                trace_length,
227                vec![],
228            );
229
230            let mut expected = trace_info.to_elements();
231            expected.extend(vec![
232                BaseElement::from(1_u32),    // lower bits of field modulus
233                BaseElement::from(u32::MAX), // upper bits of field modulus
234                BaseElement::from(num_constraints as u32),
235                BaseElement::from(ext_fri),
236                BaseElement::from(grinding_factor),
237                BaseElement::from(num_queries as u32),
238            ]);
239
240            expected
241        };
242
243        let options = ProofOptions::new(
244            num_queries,
245            blowup_factor,
246            grinding_factor,
247            field_extension,
248            fri_folding_factor as usize,
249            fri_remainder_max_degree as usize,
250            batching_constraints,
251            batching_deep,
252        );
253        let trace_info =
254            TraceInfo::new_multi_segment(main_width, aux_width, aux_rands, trace_length, vec![]);
255        let context = Context::new::<BaseElement>(trace_info, options, num_constraints);
256        assert_eq!(expected, context.to_elements());
257    }
258
259    #[test]
260    fn context_serialization() {
261        use math::fields::f64::BaseElement as DummyField;
262
263        let context = Context::new::<DummyField>(
264            TraceInfo::new(1, 8),
265            ProofOptions::new(
266                1,
267                2,
268                2,
269                FieldExtension::None,
270                8,
271                1,
272                BatchingMethod::Linear,
273                BatchingMethod::Linear,
274            ),
275            100,
276        );
277
278        let bytes = context.to_bytes();
279
280        let deserialized_context = Context::read_from_bytes(&bytes).unwrap();
281
282        assert_eq!(context, deserialized_context);
283    }
284}