use crate::{
prelude::{FromBytes, Identifier, IoResult, Network, Read, ToBytes},
synthesizer::{snark::ProvingKey, Program},
};
use anyhow::{anyhow, bail, ensure, Result};
use std::{
fs::{self, File},
io::Write,
path::Path,
};
static PROVER_FILE_EXTENSION: &str = "prover";
pub struct ProverFile<N: Network> {
function_name: Identifier<N>,
proving_key: ProvingKey<N>,
}
impl<N: Network> ProverFile<N> {
pub fn create(directory: &Path, function_name: &Identifier<N>, proving_key: ProvingKey<N>) -> Result<Self> {
ensure!(directory.exists(), "The build directory does not exist: '{}'", directory.display());
ensure!(!Program::is_reserved_keyword(function_name), "Function name is invalid (reserved): {}", function_name);
let prover_file = Self { function_name: *function_name, proving_key };
let file_name = format!("{function_name}.{PROVER_FILE_EXTENSION}");
let path = directory.join(file_name);
File::create(&path)?.write_all(&prover_file.to_bytes_le()?)?;
Self::from_filepath(&path)
}
pub fn open(directory: &Path, function_name: &Identifier<N>) -> Result<Self> {
ensure!(directory.exists(), "The build directory does not exist: '{}'", directory.display());
let file_name = format!("{function_name}.{PROVER_FILE_EXTENSION}");
let path = directory.join(file_name);
ensure!(path.exists(), "The prover file is missing: '{}'", path.display());
let prover = Self::from_filepath(&path)?;
if prover.function_name() != function_name {
bail!(
"The prover file for '{}' contains an incorrect function name of '{}'",
function_name,
prover.function_name()
);
}
Ok(prover)
}
pub fn exists_at(directory: &Path, function_name: &Identifier<N>) -> bool {
let file_name = format!("{function_name}.{PROVER_FILE_EXTENSION}");
let path = directory.join(file_name);
Self::check_path(&path).is_ok() && path.exists()
}
pub const fn function_name(&self) -> &Identifier<N> {
&self.function_name
}
pub const fn proving_key(&self) -> &ProvingKey<N> {
&self.proving_key
}
pub fn remove(&self, path: &Path) -> Result<()> {
if !path.exists() {
Ok(())
} else {
Self::check_path(path)?;
Ok(fs::remove_file(path)?)
}
}
}
impl<N: Network> ProverFile<N> {
fn check_path(path: &Path) -> Result<()> {
ensure!(path.is_file(), "The path is not a file.");
let extension = path.extension().ok_or_else(|| anyhow!("File extension not found."))?;
ensure!(extension == PROVER_FILE_EXTENSION, "File extension is incorrect.");
ensure!(path.exists(), "File does not exist: {}", path.display());
Ok(())
}
fn from_filepath(file: &Path) -> Result<Self> {
Self::check_path(file)?;
let prover = Self::from_bytes_le(&fs::read(file)?)?;
let file_stem = file
.file_stem()
.ok_or_else(|| anyhow!("File name not found."))?
.to_str()
.ok_or_else(|| anyhow!("File name not found."))?
.to_string();
ensure!(prover.function_name.to_string() == file_stem, "Function name does not match file stem.");
Ok(prover)
}
pub fn write_to(&self, path: &Path) -> Result<()> {
Self::check_path(path)?;
let file_stem = path
.file_name()
.ok_or_else(|| anyhow!("File name not found."))?
.to_str()
.ok_or_else(|| anyhow!("File name not found."))?
.to_string();
ensure!(self.function_name.to_string() == file_stem, "Function name does not match file stem.");
Ok(File::create(path)?.write_all(&self.to_bytes_le()?)?)
}
}
impl<N: Network> FromBytes for ProverFile<N> {
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
let function_name = Identifier::read_le(&mut reader)?;
let proving_key = FromBytes::read_le(&mut reader)?;
Ok(Self { function_name, proving_key })
}
}
impl<N: Network> ToBytes for ProverFile<N> {
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
self.function_name.write_le(&mut writer)?;
self.proving_key.write_le(&mut writer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prelude::{FromStr, Parser, TestRng},
synthesizer::Process,
};
type CurrentNetwork = snarkvm_console::network::Testnet3;
type CurrentAleo = snarkvm_circuit::AleoV0;
fn temp_dir() -> std::path::PathBuf {
tempfile::tempdir().expect("Failed to open temporary directory").into_path()
}
#[test]
fn test_create_and_open() {
let directory = temp_dir();
let program_string = r"
program token.aleo;
record token:
owner as address.private;
token_amount as u64.private;
function compute:
input r0 as token.record;
add r0.token_amount r0.token_amount into r1;
output r1 as u64.private;";
let (string, program) = Program::<CurrentNetwork>::parse(program_string).unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
let mut process = Process::load().unwrap();
process.add_program(&program).unwrap();
let function_name = Identifier::from_str("compute").unwrap();
process.synthesize_key::<CurrentAleo, _>(program.id(), &function_name, &mut TestRng::default()).unwrap();
let proving_key = process.get_proving_key(program.id(), function_name).unwrap();
let expected = ProverFile::create(&directory, &function_name, proving_key).unwrap();
let candidate = ProverFile::open(&directory, &function_name).unwrap();
assert_eq!(expected.to_bytes_le().unwrap(), candidate.to_bytes_le().unwrap());
}
}