winter_air/proof/
context.rs1use alloc::{string::ToString, vec::Vec};
7
8use math::{StarkField, ToElements};
9use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
10
11use crate::{ProofOptions, TraceInfo};
12
13#[derive(Debug, Clone, Eq, PartialEq)]
17pub struct Context {
18 trace_info: TraceInfo,
19 field_modulus_bytes: Vec<u8>,
20 options: ProofOptions,
21 num_constraints: usize,
22}
23
24impl Context {
25 pub fn new<B: StarkField>(
36 trace_info: TraceInfo,
37 options: ProofOptions,
38 num_constraints: usize,
39 ) -> Self {
40 let trace_length = trace_info.length();
43 assert!(trace_length <= u32::MAX as usize, "trace length too big");
44
45 let lde_domain_size = trace_length * options.blowup_factor();
46 assert!(lde_domain_size <= u32::MAX as usize, "LDE domain size too big");
47
48 assert!(num_constraints > 0, "number of constraints should be greater than zero");
49 assert!(num_constraints <= u32::MAX as usize, "number of constraints is too big");
50
51 Context {
52 trace_info,
53 field_modulus_bytes: B::get_modulus_le_bytes(),
54 options,
55 num_constraints,
56 }
57 }
58
59 pub fn trace_info(&self) -> &TraceInfo {
64 &self.trace_info
65 }
66
67 pub fn lde_domain_size(&self) -> usize {
69 self.trace_info.length() * self.options.blowup_factor()
70 }
71
72 pub fn field_modulus_bytes(&self) -> &[u8] {
74 &self.field_modulus_bytes
75 }
76
77 pub fn num_modulus_bits(&self) -> u32 {
82 let mut num_bits = self.field_modulus_bytes.len() as u32 * 8;
83 for &byte in self.field_modulus_bytes.iter().rev() {
84 if byte != 0 {
85 num_bits -= byte.leading_zeros();
86 return num_bits;
87 }
88 num_bits -= 8;
89 }
90
91 0
92 }
93
94 pub fn options(&self) -> &ProofOptions {
96 &self.options
97 }
98
99 pub fn num_constraints(&self) -> usize {
101 self.num_constraints
102 }
103}
104
105impl<E: StarkField> ToElements<E> for Context {
106 fn to_elements(&self) -> Vec<E> {
120 let mut result = self.trace_info.to_elements();
122
123 let num_modulus_bytes = self.field_modulus_bytes.len();
125 let (m1, m2) = self.field_modulus_bytes.split_at(num_modulus_bytes / 2);
126 result.push(E::from_bytes_with_padding(m1));
127 result.push(E::from_bytes_with_padding(m2));
128
129 result.push(E::from(self.num_constraints as u32));
131
132 result.append(&mut self.options.to_elements());
134
135 result
136 }
137}
138
139impl Serializable for Context {
143 fn write_into<W: ByteWriter>(&self, target: &mut W) {
145 self.trace_info.write_into(target);
146 assert!(self.field_modulus_bytes.len() < u8::MAX as usize);
147 target.write_u8(self.field_modulus_bytes.len() as u8);
148 target.write_bytes(&self.field_modulus_bytes);
149 self.options.write_into(target);
150 self.num_constraints.write_into(target);
151 }
152}
153
154impl Deserializable for Context {
155 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
160 let trace_info = TraceInfo::read_from(source)?;
162
163 let num_modulus_bytes = source.read_u8()? as usize;
165 if num_modulus_bytes == 0 {
166 return Err(DeserializationError::InvalidValue(
167 "field modulus cannot be an empty value".to_string(),
168 ));
169 }
170 let field_modulus_bytes = source.read_vec(num_modulus_bytes)?;
171
172 let options = ProofOptions::read_from(source)?;
174
175 let num_constraints = source.read_usize()?;
177
178 Ok(Context {
179 trace_info,
180 field_modulus_bytes,
181 options,
182 num_constraints,
183 })
184 }
185}
186
187#[cfg(test)]
191mod tests {
192 use math::fields::f64::BaseElement;
193
194 use super::{Context, Deserializable, ProofOptions, Serializable, ToElements, TraceInfo};
195 use crate::{options::BatchingMethod, FieldExtension};
196
197 #[test]
198 fn context_to_elements() {
199 let field_extension = FieldExtension::None;
200 let fri_folding_factor = 8;
201 let fri_remainder_max_degree = 127;
202 let grinding_factor = 20;
203 let blowup_factor = 8;
204 let num_queries = 30;
205 let num_constraints = 128;
206 let batching_constraints = BatchingMethod::Linear;
207 let batching_deep = BatchingMethod::Linear;
208
209 let main_width = 20;
210 let aux_width = 9;
211 let aux_rands = 12;
212 let trace_length = 4096;
213
214 let ext_fri = u32::from_le_bytes([
215 blowup_factor as u8,
216 fri_remainder_max_degree,
217 fri_folding_factor,
218 field_extension as u8,
219 ]);
220
221 let expected = {
222 let trace_info = TraceInfo::new_multi_segment(
223 main_width,
224 aux_width,
225 aux_rands,
226 trace_length,
227 vec![],
228 );
229
230 let mut expected = trace_info.to_elements();
231 expected.extend(vec![
232 BaseElement::from(1_u32), BaseElement::from(u32::MAX), BaseElement::from(num_constraints as u32),
235 BaseElement::from(ext_fri),
236 BaseElement::from(grinding_factor),
237 BaseElement::from(num_queries as u32),
238 ]);
239
240 expected
241 };
242
243 let options = ProofOptions::new(
244 num_queries,
245 blowup_factor,
246 grinding_factor,
247 field_extension,
248 fri_folding_factor as usize,
249 fri_remainder_max_degree as usize,
250 batching_constraints,
251 batching_deep,
252 );
253 let trace_info =
254 TraceInfo::new_multi_segment(main_width, aux_width, aux_rands, trace_length, vec![]);
255 let context = Context::new::<BaseElement>(trace_info, options, num_constraints);
256 assert_eq!(expected, context.to_elements());
257 }
258
259 #[test]
260 fn context_serialization() {
261 use math::fields::f64::BaseElement as DummyField;
262
263 let context = Context::new::<DummyField>(
264 TraceInfo::new(1, 8),
265 ProofOptions::new(
266 1,
267 2,
268 2,
269 FieldExtension::None,
270 8,
271 1,
272 BatchingMethod::Linear,
273 BatchingMethod::Linear,
274 ),
275 100,
276 );
277
278 let bytes = context.to_bytes();
279
280 let deserialized_context = Context::read_from_bytes(&bytes).unwrap();
281
282 assert_eq!(context, deserialized_context);
283 }
284}