use crate::septic_curve::SepticCurve;
use crate::septic_digest::SepticDigest;
use crate::septic_extension::SepticExtension;
use crate::{air::InteractionScope, AirOpenedValues, ChipOpenedValues, ShardOpenedValues};
use core::fmt::Display;
use itertools::Itertools;
use p3_air::Air;
use p3_challenger::{CanObserve, FieldChallenger};
use p3_commit::{Pcs, PolynomialSpace};
use p3_field::{AbstractExtensionField, AbstractField, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_maybe_rayon::prelude::*;
use p3_uni_stark::SymbolicAirBuilder;
use p3_util::log2_strict_usize;
use serde::{de::DeserializeOwned, Serialize};
use std::{cmp::Reverse, error::Error, time::Instant};
use super::{
quotient_values, Com, OpeningProof, StarkGenericConfig, StarkMachine, StarkProvingKey, Val,
VerifierConstraintFolder,
};
use crate::{
air::MachineAir, lookup::InteractionBuilder, opts::SP1CoreOpts, record::MachineRecord,
Challenger, DebugConstraintBuilder, MachineChip, MachineProof, PackedChallenge, PcsProverData,
ProverConstraintFolder, ShardCommitment, ShardMainData, ShardProof, StarkVerifyingKey,
};
pub trait MachineProver<SC: StarkGenericConfig, A: MachineAir<SC::Val>>:
'static + Send + Sync
{
type DeviceMatrix: Matrix<SC::Val>;
type DeviceProverData;
type DeviceProvingKey: MachineProvingKey<SC>;
type Error: Error + Send + Sync;
fn new(machine: StarkMachine<SC, A>) -> Self;
fn machine(&self) -> &StarkMachine<SC, A>;
fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>);
fn pk_from_vk(
&self,
program: &A::Program,
vk: &StarkVerifyingKey<SC>,
) -> Self::DeviceProvingKey;
fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey;
fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC>;
fn generate_traces(&self, record: &A::Record) -> Vec<(String, RowMajorMatrix<Val<SC>>)> {
let shard_chips = self.shard_chips(record).collect::<Vec<_>>();
let parent_span = tracing::debug_span!("generate traces for shard");
parent_span.in_scope(|| {
shard_chips
.par_iter()
.map(|chip| {
let chip_name = chip.name();
let begin = Instant::now();
let trace = chip.generate_trace(record, &mut A::Record::default());
tracing::debug!(
parent: &parent_span,
"generated trace for chip {} in {:?}",
chip_name,
begin.elapsed()
);
(chip_name, trace)
})
.collect::<Vec<_>>()
})
}
fn commit(
&self,
record: &A::Record,
traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>;
fn observe(
&self,
challenger: &mut SC::Challenger,
commitment: Com<SC>,
public_values: &[SC::Val],
) {
challenger.observe(commitment);
challenger.observe_slice(public_values);
}
fn open(
&self,
pk: &Self::DeviceProvingKey,
data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
challenger: &mut SC::Challenger,
) -> Result<ShardProof<SC>, Self::Error>;
fn prove(
&self,
pk: &Self::DeviceProvingKey,
records: Vec<A::Record>,
challenger: &mut SC::Challenger,
opts: <A::Record as MachineRecord>::Config,
) -> Result<MachineProof<SC>, Self::Error>
where
A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>;
fn config(&self) -> &SC {
self.machine().config()
}
fn num_pv_elts(&self) -> usize {
self.machine().num_pv_elts()
}
fn shard_chips<'a, 'b>(
&'a self,
record: &'b A::Record,
) -> impl Iterator<Item = &'b MachineChip<SC, A>>
where
'a: 'b,
SC: 'b,
{
self.machine().shard_chips(record)
}
fn debug_constraints(
&self,
pk: &StarkProvingKey<SC>,
records: Vec<A::Record>,
challenger: &mut SC::Challenger,
) where
SC::Val: PrimeField32,
A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
{
self.machine().debug_constraints(pk, records, challenger);
}
}
pub trait MachineProvingKey<SC: StarkGenericConfig>: Send + Sync {
fn preprocessed_commit(&self) -> Com<SC>;
fn pc_start(&self) -> Val<SC>;
fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>>;
fn observe_into(&self, challenger: &mut Challenger<SC>);
}
pub struct CpuProver<SC: StarkGenericConfig, A> {
machine: StarkMachine<SC, A>,
}
#[derive(Debug, Clone, Copy)]
pub struct CpuProverError;
impl<SC, A> MachineProver<SC, A> for CpuProver<SC, A>
where
SC: 'static + StarkGenericConfig + Send + Sync,
A: MachineAir<SC::Val>
+ for<'a> Air<ProverConstraintFolder<'a, SC>>
+ Air<InteractionBuilder<Val<SC>>>
+ for<'a> Air<VerifierConstraintFolder<'a, SC>>
+ for<'a> Air<SymbolicAirBuilder<Val<SC>>>,
A::Record: MachineRecord<Config = SP1CoreOpts>,
SC::Val: PrimeField32,
Com<SC>: Send + Sync,
PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
OpeningProof<SC>: Send + Sync,
SC::Challenger: Clone,
{
type DeviceMatrix = RowMajorMatrix<Val<SC>>;
type DeviceProverData = PcsProverData<SC>;
type DeviceProvingKey = StarkProvingKey<SC>;
type Error = CpuProverError;
fn new(machine: StarkMachine<SC, A>) -> Self {
Self { machine }
}
fn machine(&self) -> &StarkMachine<SC, A> {
&self.machine
}
fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>) {
self.machine().setup(program)
}
fn pk_from_vk(
&self,
program: &A::Program,
vk: &StarkVerifyingKey<SC>,
) -> Self::DeviceProvingKey {
self.machine().setup_core(program, vk.initial_global_cumulative_sum).0
}
fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey {
pk.clone()
}
fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC> {
pk.clone()
}
fn commit(
&self,
record: &A::Record,
mut named_traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData> {
named_traces.sort_by_key(|(name, trace)| (Reverse(trace.height()), name.clone()));
let pcs = self.config().pcs();
let domains_and_traces = named_traces
.iter()
.map(|(_, trace)| {
let domain = pcs.natural_domain_for_degree(trace.height());
(domain, trace.to_owned())
})
.collect::<Vec<_>>();
let (main_commit, main_data) = pcs.commit(domains_and_traces);
let chip_ordering =
named_traces.iter().enumerate().map(|(i, (name, _))| (name.to_owned(), i)).collect();
let traces = named_traces.into_iter().map(|(_, trace)| trace).collect::<Vec<_>>();
ShardMainData {
traces,
main_commit,
main_data,
chip_ordering,
public_values: record.public_values(),
}
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::redundant_closure_for_method_calls)]
#[allow(clippy::map_unwrap_or)]
fn open(
&self,
pk: &StarkProvingKey<SC>,
data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
challenger: &mut <SC as StarkGenericConfig>::Challenger,
) -> Result<ShardProof<SC>, Self::Error> {
let chips = self.machine().shard_chips_ordered(&data.chip_ordering).collect::<Vec<_>>();
let traces = data.traces;
let config = self.machine().config();
let degrees = traces.iter().map(|trace| trace.height()).collect::<Vec<_>>();
let log_degrees =
degrees.iter().map(|degree| log2_strict_usize(*degree)).collect::<Vec<_>>();
let log_quotient_degrees =
chips.iter().map(|chip| chip.log_quotient_degree()).collect::<Vec<_>>();
let pcs = config.pcs();
let trace_domains =
degrees.iter().map(|degree| pcs.natural_domain_for_degree(*degree)).collect::<Vec<_>>();
challenger.observe_slice(&data.public_values[0..self.num_pv_elts()]);
challenger.observe(data.main_commit.clone());
let mut local_permutation_challenges: Vec<SC::Challenge> = Vec::new();
for _ in 0..2 {
local_permutation_challenges.push(challenger.sample_ext_element());
}
let packed_perm_challenges = local_permutation_challenges
.iter()
.map(|c| PackedChallenge::<SC>::from_f(*c))
.collect::<Vec<_>>();
let ((permutation_traces, prep_traces), (global_cumulative_sums, local_cumulative_sums)): (
(Vec<_>, Vec<_>),
(Vec<_>, Vec<_>),
) = tracing::debug_span!("generate permutation traces").in_scope(|| {
chips
.par_iter()
.zip(traces.par_iter())
.map(|(chip, main_trace)| {
let preprocessed_trace =
pk.chip_ordering.get(&chip.name()).map(|&index| &pk.traces[index]);
let (perm_trace, local_sum) = chip.generate_permutation_trace(
preprocessed_trace,
main_trace,
&local_permutation_challenges,
);
let global_sum = if chip.commit_scope() == InteractionScope::Local {
SepticDigest::<Val<SC>>::zero()
} else {
let main_trace_size = main_trace.height() * main_trace.width();
let last_row = &main_trace.values[main_trace_size - 14..main_trace_size];
SepticDigest(SepticCurve {
x: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i]),
y: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i + 7]),
})
};
((perm_trace, preprocessed_trace), (global_sum, local_sum))
})
.unzip()
});
for i in 0..chips.len() {
let trace_width = traces[i].width();
let trace_height = traces[i].height();
let prep_width = prep_traces[i].map_or(0, |x| x.width());
let permutation_width = permutation_traces[i].width();
let total_width = trace_width
+ prep_width
+ permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D;
tracing::debug!(
"{:<15} | Main Cols = {:<5} | Pre Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<5} | Cells = {:<10}",
chips[i].name(),
trace_width,
prep_width,
permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D,
trace_height,
total_width * trace_height,
);
}
let domains_and_perm_traces =
tracing::debug_span!("flatten permutation traces and collect domains").in_scope(|| {
permutation_traces
.into_iter()
.zip(trace_domains.iter())
.map(|(perm_trace, domain)| {
let trace = perm_trace.flatten_to_base();
(*domain, trace.clone())
})
.collect::<Vec<_>>()
});
let pcs = config.pcs();
let (permutation_commit, permutation_data) =
tracing::debug_span!("commit to permutation traces")
.in_scope(|| pcs.commit(domains_and_perm_traces));
challenger.observe(permutation_commit.clone());
for (local_sum, global_sum) in
local_cumulative_sums.iter().zip(global_cumulative_sums.iter())
{
challenger.observe_slice(local_sum.as_base_slice());
challenger.observe_slice(&global_sum.0.x.0);
challenger.observe_slice(&global_sum.0.y.0);
}
let quotient_domains = trace_domains
.iter()
.zip_eq(log_degrees.iter())
.zip_eq(log_quotient_degrees.iter())
.map(|((domain, log_degree), log_quotient_degree)| {
domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree))
})
.collect::<Vec<_>>();
let alpha: SC::Challenge = challenger.sample_ext_element::<SC::Challenge>();
let parent_span = tracing::debug_span!("compute quotient values");
let quotient_values = parent_span.in_scope(|| {
quotient_domains
.into_par_iter()
.enumerate()
.map(|(i, quotient_domain)| {
tracing::debug_span!(parent: &parent_span, "compute quotient values for domain")
.in_scope(|| {
let preprocessed_trace_on_quotient_domains =
pk.chip_ordering.get(&chips[i].name()).map(|&index| {
pcs.get_evaluations_on_domain(&pk.data, index, *quotient_domain)
.to_row_major_matrix()
});
let main_trace_on_quotient_domains = pcs
.get_evaluations_on_domain(&data.main_data, i, *quotient_domain)
.to_row_major_matrix();
let permutation_trace_on_quotient_domains = pcs
.get_evaluations_on_domain(&permutation_data, i, *quotient_domain)
.to_row_major_matrix();
let chip_num_constraints =
pk.constraints_map.get(&chips[i].name()).unwrap();
let powers_of_alpha =
alpha.powers().take(*chip_num_constraints).collect::<Vec<_>>();
let mut powers_of_alpha_rev = powers_of_alpha.clone();
powers_of_alpha_rev.reverse();
quotient_values(
chips[i],
&local_cumulative_sums[i],
&global_cumulative_sums[i],
trace_domains[i],
*quotient_domain,
preprocessed_trace_on_quotient_domains,
main_trace_on_quotient_domains,
permutation_trace_on_quotient_domains,
&packed_perm_challenges,
&powers_of_alpha_rev,
&data.public_values,
)
})
})
.collect::<Vec<_>>()
});
let quotient_domains_and_chunks = quotient_domains
.into_iter()
.zip_eq(quotient_values)
.zip_eq(log_quotient_degrees.iter())
.flat_map(|((quotient_domain, quotient_values), log_quotient_degree)| {
let quotient_degree = 1 << *log_quotient_degree;
let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base();
let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat);
let qc_domains = quotient_domain.split_domains(quotient_degree);
qc_domains.into_iter().zip_eq(quotient_chunks)
})
.collect::<Vec<_>>();
let num_quotient_chunks = quotient_domains_and_chunks.len();
assert_eq!(
num_quotient_chunks,
chips.iter().map(|c| 1 << c.log_quotient_degree()).sum::<usize>()
);
let (quotient_commit, quotient_data) = tracing::debug_span!("commit to quotient traces")
.in_scope(|| pcs.commit(quotient_domains_and_chunks));
challenger.observe(quotient_commit.clone());
let zeta: SC::Challenge = challenger.sample_ext_element();
let preprocessed_opening_points =
tracing::debug_span!("compute preprocessed opening points").in_scope(|| {
pk.traces
.iter()
.zip(pk.local_only.iter())
.map(|(trace, local_only)| {
let domain = pcs.natural_domain_for_degree(trace.height());
if !local_only {
vec![zeta, domain.next_point(zeta).unwrap()]
} else {
vec![zeta]
}
})
.collect::<Vec<_>>()
});
let main_trace_opening_points = tracing::debug_span!("compute main trace opening points")
.in_scope(|| {
trace_domains
.iter()
.zip(chips.iter())
.map(|(domain, chip)| {
if !chip.local_only() {
vec![zeta, domain.next_point(zeta).unwrap()]
} else {
vec![zeta]
}
})
.collect::<Vec<_>>()
});
let permutation_trace_opening_points =
tracing::debug_span!("compute permutation trace opening points").in_scope(|| {
trace_domains
.iter()
.map(|domain| vec![zeta, domain.next_point(zeta).unwrap()])
.collect::<Vec<_>>()
});
let quotient_opening_points =
(0..num_quotient_chunks).map(|_| vec![zeta]).collect::<Vec<_>>();
let (openings, opening_proof) = tracing::debug_span!("open multi batches").in_scope(|| {
pcs.open(
vec![
(&pk.data, preprocessed_opening_points),
(&data.main_data, main_trace_opening_points.clone()),
(&permutation_data, permutation_trace_opening_points.clone()),
("ient_data, quotient_opening_points),
],
challenger,
)
});
let [preprocessed_values, main_values, permutation_values, mut quotient_values] =
openings.try_into().unwrap();
assert!(main_values.len() == chips.len());
let preprocessed_opened_values = preprocessed_values
.into_iter()
.zip(pk.local_only.iter())
.map(|(op, local_only)| {
if !local_only {
let [local, next] = op.try_into().unwrap();
AirOpenedValues { local, next }
} else {
let [local] = op.try_into().unwrap();
let width = local.len();
AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
}
})
.collect::<Vec<_>>();
let main_opened_values = main_values
.into_iter()
.zip(chips.iter())
.map(|(op, chip)| {
if !chip.local_only() {
let [local, next] = op.try_into().unwrap();
AirOpenedValues { local, next }
} else {
let [local] = op.try_into().unwrap();
let width = local.len();
AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
}
})
.collect::<Vec<_>>();
let permutation_opened_values = permutation_values
.into_iter()
.map(|op| {
let [local, next] = op.try_into().unwrap();
AirOpenedValues { local, next }
})
.collect::<Vec<_>>();
let mut quotient_opened_values = Vec::with_capacity(log_quotient_degrees.len());
for log_quotient_degree in log_quotient_degrees.iter() {
let degree = 1 << *log_quotient_degree;
let slice = quotient_values.drain(0..degree);
quotient_opened_values.push(slice.map(|mut op| op.pop().unwrap()).collect::<Vec<_>>());
}
let opened_values = main_opened_values
.into_iter()
.zip_eq(permutation_opened_values)
.zip_eq(quotient_opened_values)
.zip_eq(local_cumulative_sums)
.zip_eq(global_cumulative_sums)
.zip_eq(log_degrees.iter())
.enumerate()
.map(
|(
i,
(
(
(((main, permutation), quotient), local_cumulative_sum),
global_cumulative_sum,
),
log_degree,
),
)| {
let preprocessed = pk
.chip_ordering
.get(&chips[i].name())
.map(|&index| preprocessed_opened_values[index].clone())
.unwrap_or(AirOpenedValues { local: vec![], next: vec![] });
ChipOpenedValues {
preprocessed,
main,
permutation,
quotient,
global_cumulative_sum,
local_cumulative_sum,
log_degree: *log_degree,
}
},
)
.collect::<Vec<_>>();
Ok(ShardProof::<SC> {
commitment: ShardCommitment {
main_commit: data.main_commit.clone(),
permutation_commit,
quotient_commit,
},
opened_values: ShardOpenedValues { chips: opened_values },
opening_proof,
chip_ordering: data.chip_ordering,
public_values: data.public_values,
})
}
#[allow(clippy::needless_for_each)]
fn prove(
&self,
pk: &StarkProvingKey<SC>,
mut records: Vec<A::Record>,
challenger: &mut SC::Challenger,
opts: <A::Record as MachineRecord>::Config,
) -> Result<MachineProof<SC>, Self::Error>
where
A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
{
self.machine().generate_dependencies(&mut records, &opts, None);
pk.observe_into(challenger);
let shard_proofs = tracing::info_span!("prove_shards").in_scope(|| {
records
.into_par_iter()
.map(|record| {
let named_traces = self.generate_traces(&record);
let shard_data = self.commit(&record, named_traces);
self.open(pk, shard_data, &mut challenger.clone())
})
.collect::<Result<Vec<_>, _>>()
})?;
Ok(MachineProof { shard_proofs })
}
}
impl<SC> MachineProvingKey<SC> for StarkProvingKey<SC>
where
SC: 'static + StarkGenericConfig + Send + Sync,
PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
Com<SC>: Send + Sync,
{
fn preprocessed_commit(&self) -> Com<SC> {
self.commit.clone()
}
fn pc_start(&self) -> Val<SC> {
self.pc_start
}
fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>> {
self.initial_global_cumulative_sum
}
fn observe_into(&self, challenger: &mut Challenger<SC>) {
challenger.observe(self.commit.clone());
challenger.observe(self.pc_start);
challenger.observe_slice(&self.initial_global_cumulative_sum.0.x.0);
challenger.observe_slice(&self.initial_global_cumulative_sum.0.y.0);
let zero = Val::<SC>::zero();
challenger.observe(zero);
}
}
impl Display for CpuProverError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DefaultProverError")
}
}
impl Error for CpuProverError {}