1#[cfg(feature = "zk-vm")]
2mod methods {
3 include!(concat!(env!("OUT_DIR"), "/methods.rs"));
4}
5
6use risc0_zkp::core::digest::Digest;
7use risc0_zkvm::{ExecutorEnv, default_prover};
8use zkcg_common::types::ZkVmInput;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum ZkVmProverError {
12 PolicyViolation,
13 ExecutionFailed,
14}
15
16fn classify_prover_failure(message: &str) -> ZkVmProverError {
17 if message.contains("Guest panicked") {
18 ZkVmProverError::PolicyViolation
19 } else {
20 ZkVmProverError::ExecutionFailed
21 }
22}
23
24#[cfg(feature = "zk-vm")]
25pub fn prove(
26 score: u64,
27 threshold: u64,
28 old_state_root: [u8; 32],
29 nonce: u64,
30) -> Result<Vec<u8>, ZkVmProverError> {
31 let mut builder = ExecutorEnv::builder();
32 builder
33 .write(&ZkVmInput {
34 score,
35 threshold,
36 old_state_root,
37 nonce,
38 })
39 .map_err(|_| ZkVmProverError::ExecutionFailed)?;
40
41 let env = builder
42 .build()
43 .map_err(|_| ZkVmProverError::ExecutionFailed)?;
44
45 let prove_info = default_prover()
46 .prove(env, elf())
47 .map_err(|err| classify_prover_failure(&format!("{err:#}")))?;
48
49 bincode::serialize(&prove_info.receipt).map_err(|_| ZkVmProverError::ExecutionFailed)
50}
51
52#[cfg(feature = "zk-vm")]
53pub fn method_id() -> Digest {
54 Digest::from(methods::ZKCG_ZKVM_GUEST_ID)
55}
56
57#[cfg(feature = "zk-vm")]
58pub fn elf() -> &'static [u8] {
59 methods::ZKCG_ZKVM_GUEST_ELF
60}
61
62#[cfg(test)]
63mod tests {
64 use super::{ZkVmProverError, classify_prover_failure};
65
66 #[test]
67 fn guest_panic_maps_to_policy_violation() {
68 let err = classify_prover_failure("Guest panicked: score exceeds threshold");
69
70 assert_eq!(err, ZkVmProverError::PolicyViolation);
71 }
72
73 #[test]
74 fn non_guest_failure_maps_to_execution_failed() {
75 let err = classify_prover_failure("failed to start prover");
76
77 assert_eq!(err, ZkVmProverError::ExecutionFailed);
78 }
79}