1use pyo3::exceptions::{PyRuntimeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::{PyBytes, PyDict};
9use uuid::Uuid;
10
11use ves_stark_air::Policy as RustPolicy;
12use ves_stark_primitives::{
13 CommerceAuthorizationReceipt, CompliancePublicInputs as RustCompliancePublicInputs,
14 PayloadAmountBinding, PolicyParams,
15};
16use ves_stark_prover::{ComplianceProver, ComplianceWitness};
17use ves_stark_verifier::{
18 verify_agent_authorization_proof_auto_with_amount_binding, verify_compliance_proof_auto_bound,
19 verify_compliance_proof_auto_with_amount_binding, VerifierError,
20};
21
22#[pyclass]
24#[derive(Clone)]
25pub struct Policy {
26 inner: RustPolicy,
27}
28
29#[pymethods]
30impl Policy {
31 #[staticmethod]
41 pub fn aml_threshold(threshold: u64) -> Self {
42 Self {
43 inner: RustPolicy::aml_threshold(threshold),
44 }
45 }
46
47 #[staticmethod]
57 pub fn order_total_cap(cap: u64) -> Self {
58 Self {
59 inner: RustPolicy::order_total_cap(cap),
60 }
61 }
62
63 #[staticmethod]
75 pub fn agent_authorization(max_total: u64, intent_hash: String) -> PyResult<Self> {
76 let inner = RustPolicy::agent_authorization(max_total, intent_hash).map_err(|e| {
77 PyValueError::new_err(format!("Invalid agent authorization policy: {}", e))
78 })?;
79 Ok(Self { inner })
80 }
81
82 #[getter]
84 pub fn policy_id(&self) -> &'static str {
85 self.inner.policy_id()
86 }
87
88 #[getter]
90 pub fn limit(&self) -> u64 {
91 self.inner.limit()
92 }
93
94 fn __repr__(&self) -> String {
95 format!("Policy({}, limit={})", self.policy_id(), self.limit())
96 }
97}
98
99#[pyclass]
101#[derive(Clone)]
102pub struct CompliancePublicInputs {
103 #[pyo3(get, set)]
105 pub event_id: String,
106 #[pyo3(get, set)]
108 pub tenant_id: String,
109 #[pyo3(get, set)]
111 pub store_id: String,
112 #[pyo3(get, set)]
114 pub sequence_number: u64,
115 #[pyo3(get, set)]
117 pub payload_kind: u32,
118 #[pyo3(get, set)]
120 pub payload_plain_hash: String,
121 #[pyo3(get, set)]
123 pub payload_cipher_hash: String,
124 #[pyo3(get, set)]
126 pub event_signing_hash: String,
127 #[pyo3(get, set)]
129 pub policy_id: String,
130 policy_params_json: String,
132 #[pyo3(get, set)]
134 pub policy_hash: String,
135 #[pyo3(get, set)]
137 pub witness_commitment: Option<String>,
138 #[pyo3(get, set)]
140 pub authorization_receipt_hash: Option<String>,
141 #[pyo3(get, set)]
143 pub amount_binding_hash: Option<String>,
144}
145
146#[pymethods]
147impl CompliancePublicInputs {
148 #[new]
166 #[allow(clippy::too_many_arguments)]
167 #[pyo3(signature = (event_id, tenant_id, store_id, sequence_number, payload_kind, payload_plain_hash, payload_cipher_hash, event_signing_hash, policy_id, policy_params, policy_hash, witness_commitment=None, authorization_receipt_hash=None, amount_binding_hash=None))]
168 pub fn new(
169 event_id: String,
170 tenant_id: String,
171 store_id: String,
172 sequence_number: u64,
173 payload_kind: u32,
174 payload_plain_hash: String,
175 payload_cipher_hash: String,
176 event_signing_hash: String,
177 policy_id: String,
178 policy_params: &Bound<'_, PyDict>,
179 policy_hash: String,
180 witness_commitment: Option<String>,
181 authorization_receipt_hash: Option<String>,
182 amount_binding_hash: Option<String>,
183 ) -> PyResult<Self> {
184 let policy_params_json = Python::with_gil(|py| {
186 let json = py.import("json")?;
187 let dumps = json.getattr("dumps")?;
188 dumps.call1((policy_params,))?.extract::<String>()
189 })?;
190
191 Ok(Self {
192 event_id,
193 tenant_id,
194 store_id,
195 sequence_number,
196 payload_kind,
197 payload_plain_hash,
198 payload_cipher_hash,
199 event_signing_hash,
200 policy_id,
201 policy_params_json,
202 policy_hash,
203 witness_commitment,
204 authorization_receipt_hash,
205 amount_binding_hash,
206 })
207 }
208
209 #[getter]
211 pub fn policy_params(&self, py: Python<'_>) -> PyResult<PyObject> {
212 let json = py.import("json")?;
213 let loads = json.getattr("loads")?;
214 let result = loads.call1((&self.policy_params_json,))?;
215 Ok(result.into())
216 }
217
218 #[setter]
220 pub fn set_policy_params(&mut self, value: &Bound<'_, PyDict>) -> PyResult<()> {
221 let policy_params_json = Python::with_gil(|py| {
222 let json = py.import("json")?;
223 let dumps = json.getattr("dumps")?;
224 dumps.call1((value,))?.extract::<String>()
225 })?;
226 self.policy_params_json = policy_params_json;
227 Ok(())
228 }
229
230 fn __repr__(&self) -> String {
231 format!(
232 "CompliancePublicInputs(event_id='{}', policy_id='{}')",
233 self.event_id, self.policy_id
234 )
235 }
236}
237
238fn verifier_error_to_py(err: VerifierError) -> PyErr {
239 let message = format!("Verification error: {}", err);
240 match err {
241 VerifierError::PublicInputMismatch(_)
242 | VerifierError::InvalidHexFormat { .. }
243 | VerifierError::DeserializationError(_)
244 | VerifierError::InvalidPolicyHash { .. }
245 | VerifierError::PolicyMismatch { .. }
246 | VerifierError::LimitMismatch { .. }
247 | VerifierError::PayloadAmountBindingRequired(_)
248 | VerifierError::WitnessCommitmentMismatch
249 | VerifierError::ProofTooLarge { .. }
250 | VerifierError::UnsupportedProofVersion { .. } => PyValueError::new_err(message),
251 VerifierError::InvalidProofStructure(_)
252 | VerifierError::FriVerificationFailed(_)
253 | VerifierError::ConstraintCheckFailed(_)
254 | VerifierError::VerificationFailed(_) => PyRuntimeError::new_err(message),
255 }
256}
257
258impl CompliancePublicInputs {
259 fn to_rust(&self) -> PyResult<RustCompliancePublicInputs> {
260 let event_id = Uuid::parse_str(&self.event_id)
261 .map_err(|e| PyValueError::new_err(format!("Invalid event_id UUID: {}", e)))?;
262 let tenant_id = Uuid::parse_str(&self.tenant_id)
263 .map_err(|e| PyValueError::new_err(format!("Invalid tenant_id UUID: {}", e)))?;
264 let store_id = Uuid::parse_str(&self.store_id)
265 .map_err(|e| PyValueError::new_err(format!("Invalid store_id UUID: {}", e)))?;
266
267 let policy_params: serde_json::Value = serde_json::from_str(&self.policy_params_json)
268 .map_err(|e| PyValueError::new_err(format!("Invalid policy_params JSON: {}", e)))?;
269
270 Ok(RustCompliancePublicInputs {
271 event_id,
272 tenant_id,
273 store_id,
274 sequence_number: self.sequence_number,
275 payload_kind: self.payload_kind,
276 payload_plain_hash: self.payload_plain_hash.clone(),
277 payload_cipher_hash: self.payload_cipher_hash.clone(),
278 event_signing_hash: self.event_signing_hash.clone(),
279 policy_id: self.policy_id.clone(),
280 policy_params: PolicyParams(policy_params),
281 policy_hash: self.policy_hash.clone(),
282 witness_commitment: self.witness_commitment.clone(),
283 authorization_receipt_hash: self.authorization_receipt_hash.clone(),
284 amount_binding_hash: self.amount_binding_hash.clone(),
285 })
286 }
287}
288
289fn bind_public_inputs_to_commitment(
290 public_inputs: &CompliancePublicInputs,
291 witness_commitment: &[u64; 4],
292) -> PyResult<RustCompliancePublicInputs> {
293 let rust_inputs = public_inputs.to_rust()?;
294 rust_inputs
295 .bind_witness_commitment(witness_commitment)
296 .map_err(|e| PyValueError::new_err(format!("Failed to bind witness commitment: {}", e)))
297}
298
299fn receipt_from_py_dict(value: &Bound<'_, PyDict>) -> PyResult<CommerceAuthorizationReceipt> {
300 let receipt_json = Python::with_gil(|py| {
301 let json = py.import("json")?;
302 let dumps = json.getattr("dumps")?;
303 dumps.call1((value,))?.extract::<String>()
304 })?;
305 serde_json::from_str(&receipt_json)
306 .map_err(|e| PyValueError::new_err(format!("Invalid authorization receipt dict: {}", e)))
307}
308
309fn payload_amount_binding_from_py_dict(
310 value: &Bound<'_, PyDict>,
311) -> PyResult<PayloadAmountBinding> {
312 let binding_json = Python::with_gil(|py| {
313 let json = py.import("json")?;
314 let dumps = json.getattr("dumps")?;
315 dumps.call1((value,))?.extract::<String>()
316 })?;
317 serde_json::from_str(&binding_json)
318 .map_err(|e| PyValueError::new_err(format!("Invalid payload amount binding dict: {}", e)))
319}
320
321#[pyclass]
323pub struct ComplianceProof {
324 proof_bytes_vec: Vec<u8>,
326 #[pyo3(get)]
328 pub proof_hash: String,
329 #[pyo3(get)]
331 pub proving_time_ms: u64,
332 #[pyo3(get)]
334 pub proof_size: usize,
335 witness_commitment_vec: Vec<u64>,
337 #[pyo3(get)]
339 pub witness_commitment_hex: String,
340}
341
342#[pymethods]
343impl ComplianceProof {
344 #[getter]
346 pub fn proof_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
347 PyBytes::new(py, &self.proof_bytes_vec)
348 }
349
350 #[getter]
352 pub fn witness_commitment(&self) -> Vec<u64> {
353 self.witness_commitment_vec.clone()
354 }
355
356 fn __repr__(&self) -> String {
357 format!(
358 "ComplianceProof(proof_size={}, proving_time_ms={})",
359 self.proof_size, self.proving_time_ms
360 )
361 }
362}
363
364#[pyclass]
366pub struct VerificationResult {
367 #[pyo3(get)]
369 pub valid: bool,
370 #[pyo3(get)]
372 pub verification_time_ms: u64,
373 #[pyo3(get)]
375 pub error: Option<String>,
376 #[pyo3(get)]
378 pub policy_id: String,
379 #[pyo3(get)]
381 pub policy_limit: u64,
382}
383
384#[pymethods]
385impl VerificationResult {
386 fn __repr__(&self) -> String {
387 format!(
388 "VerificationResult(valid={}, policy_id='{}', policy_limit={})",
389 self.valid, self.policy_id, self.policy_limit
390 )
391 }
392
393 fn __bool__(&self) -> bool {
394 self.valid
395 }
396}
397
398#[pyfunction]
422pub fn prove(
423 amount: u64,
424 public_inputs: &CompliancePublicInputs,
425 policy: &Policy,
426) -> PyResult<ComplianceProof> {
427 let rust_inputs = public_inputs.to_rust()?;
429
430 let witness = ComplianceWitness::try_new(amount, rust_inputs)
432 .map_err(|e| PyValueError::new_err(format!("Invalid witness/public inputs: {}", e)))?;
433
434 let prover = ComplianceProver::with_policy(policy.inner.clone());
436 let proof = prover
437 .prove(&witness)
438 .map_err(|e| PyRuntimeError::new_err(format!("Proof generation failed: {}", e)))?;
439
440 Ok(ComplianceProof {
441 proof_bytes_vec: proof.proof_bytes,
442 proof_hash: proof.proof_hash,
443 proving_time_ms: proof.metadata.proving_time_ms,
444 proof_size: proof.metadata.proof_size,
445 witness_commitment_vec: proof.witness_commitment.to_vec(),
446 witness_commitment_hex: proof.witness_commitment_hex.ok_or_else(|| {
447 PyRuntimeError::new_err("Missing witness_commitment_hex in proof".to_string())
448 })?,
449 })
450}
451
452#[pyfunction]
470pub fn verify(
471 proof_bytes: &[u8],
472 public_inputs: &CompliancePublicInputs,
473 witness_commitment: Vec<u64>,
474) -> PyResult<VerificationResult> {
475 if witness_commitment.len() != 4 {
477 return Err(PyValueError::new_err(format!(
478 "Witness commitment must have exactly 4 elements, got {}",
479 witness_commitment.len()
480 )));
481 }
482 let commitment: [u64; 4] = [
483 witness_commitment[0],
484 witness_commitment[1],
485 witness_commitment[2],
486 witness_commitment[3],
487 ];
488 let rust_inputs = bind_public_inputs_to_commitment(public_inputs, &commitment)?;
489
490 let result = verify_compliance_proof_auto_bound(proof_bytes, &rust_inputs);
492
493 match result {
494 Ok(verification) => Ok(VerificationResult {
495 valid: verification.valid,
496 verification_time_ms: verification.verification_time_ms,
497 error: verification.error,
498 policy_id: verification.policy_id,
499 policy_limit: verification.policy_limit,
500 }),
501 Err(e) => Err(verifier_error_to_py(e)),
502 }
503}
504
505#[pyfunction]
507pub fn verify_with_amount_binding(
508 proof_bytes: &[u8],
509 public_inputs: &CompliancePublicInputs,
510 amount_binding: &Bound<'_, PyDict>,
511) -> PyResult<VerificationResult> {
512 let rust_inputs = public_inputs.to_rust()?;
513 let binding = payload_amount_binding_from_py_dict(amount_binding)?;
514
515 let result =
516 verify_compliance_proof_auto_with_amount_binding(proof_bytes, &rust_inputs, &binding);
517
518 match result {
519 Ok(verification) => Ok(VerificationResult {
520 valid: verification.valid,
521 verification_time_ms: verification.verification_time_ms,
522 error: verification.error,
523 policy_id: verification.policy_id,
524 policy_limit: verification.policy_limit,
525 }),
526 Err(e) => Err(verifier_error_to_py(e)),
527 }
528}
529
530#[pyfunction]
532pub fn verify_agent_authorization(
533 proof_bytes: &[u8],
534 public_inputs: &CompliancePublicInputs,
535 witness_commitment: Vec<u64>,
536 receipt: &Bound<'_, PyDict>,
537) -> PyResult<VerificationResult> {
538 if witness_commitment.len() != 4 {
539 return Err(PyValueError::new_err(format!(
540 "Witness commitment must have exactly 4 elements, got {}",
541 witness_commitment.len()
542 )));
543 }
544 let commitment: [u64; 4] = [
545 witness_commitment[0],
546 witness_commitment[1],
547 witness_commitment[2],
548 witness_commitment[3],
549 ];
550 let rust_inputs = bind_public_inputs_to_commitment(public_inputs, &commitment)?;
551 let receipt = receipt_from_py_dict(receipt)?;
552 let binding = rust_inputs
553 .payload_amount_binding(receipt.amount)
554 .map_err(|e| verifier_error_to_py(VerifierError::PublicInputMismatch(format!("{e}"))))?;
555
556 let result = verify_agent_authorization_proof_auto_with_amount_binding(
557 proof_bytes,
558 &rust_inputs,
559 &binding,
560 &receipt,
561 );
562
563 match result {
564 Ok(verification) => Ok(VerificationResult {
565 valid: verification.valid,
566 verification_time_ms: verification.verification_time_ms,
567 error: verification.error,
568 policy_id: verification.policy_id,
569 policy_limit: verification.policy_limit,
570 }),
571 Err(e) => Err(verifier_error_to_py(e)),
572 }
573}
574
575#[pyfunction]
578pub fn verify_agent_authorization_with_amount_binding(
579 proof_bytes: &[u8],
580 public_inputs: &CompliancePublicInputs,
581 amount_binding: &Bound<'_, PyDict>,
582 receipt: &Bound<'_, PyDict>,
583) -> PyResult<VerificationResult> {
584 let rust_inputs = public_inputs.to_rust()?;
585 let binding = payload_amount_binding_from_py_dict(amount_binding)?;
586 let receipt = receipt_from_py_dict(receipt)?;
587
588 let result = verify_agent_authorization_proof_auto_with_amount_binding(
589 proof_bytes,
590 &rust_inputs,
591 &binding,
592 &receipt,
593 );
594
595 match result {
596 Ok(verification) => Ok(VerificationResult {
597 valid: verification.valid,
598 verification_time_ms: verification.verification_time_ms,
599 error: verification.error,
600 policy_id: verification.policy_id,
601 policy_limit: verification.policy_limit,
602 }),
603 Err(e) => Err(verifier_error_to_py(e)),
604 }
605}
606
607#[pyfunction]
622pub fn compute_policy_hash(policy_id: &str, policy_params: &Bound<'_, PyDict>) -> PyResult<String> {
623 let params_json = Python::with_gil(|py| {
625 let json = py.import("json")?;
626 let dumps = json.getattr("dumps")?;
627 dumps.call1((policy_params,))?.extract::<String>()
628 })?;
629
630 let params: serde_json::Value = serde_json::from_str(¶ms_json)
631 .map_err(|e| PyValueError::new_err(format!("Invalid policy_params JSON: {}", e)))?;
632
633 let hash = ves_stark_primitives::compute_policy_hash(policy_id, &PolicyParams(params))
634 .map_err(|e| PyRuntimeError::new_err(format!("Failed to compute policy hash: {}", e)))?;
635
636 Ok(hash.to_hex())
637}
638
639#[pyfunction]
641pub fn create_payload_amount_binding(
642 py: Python<'_>,
643 public_inputs: &CompliancePublicInputs,
644 amount: u64,
645) -> PyResult<PyObject> {
646 let rust_inputs = public_inputs.to_rust()?;
647
648 let binding = rust_inputs.payload_amount_binding(amount).map_err(|e| {
649 PyValueError::new_err(format!("Invalid payload amount binding inputs: {}", e))
650 })?;
651
652 let binding_json = serde_json::to_string(&binding).map_err(|e| {
653 PyRuntimeError::new_err(format!("Failed to serialize payload amount binding: {}", e))
654 })?;
655 let json = py.import("json")?;
656 let loads = json.getattr("loads")?;
657 Ok(loads.call1((binding_json,))?.into())
658}
659
660#[pymodule]
672fn ves_stark(m: &Bound<'_, PyModule>) -> PyResult<()> {
673 m.add_class::<Policy>()?;
674 m.add_class::<CompliancePublicInputs>()?;
675 m.add_class::<ComplianceProof>()?;
676 m.add_class::<VerificationResult>()?;
677 m.add_function(wrap_pyfunction!(prove, m)?)?;
678 m.add_function(wrap_pyfunction!(verify, m)?)?;
679 m.add_function(wrap_pyfunction!(verify_with_amount_binding, m)?)?;
680 m.add_function(wrap_pyfunction!(verify_agent_authorization, m)?)?;
681 m.add_function(wrap_pyfunction!(
682 verify_agent_authorization_with_amount_binding,
683 m
684 )?)?;
685 m.add_function(wrap_pyfunction!(compute_policy_hash, m)?)?;
686 m.add_function(wrap_pyfunction!(create_payload_amount_binding, m)?)?;
687 Ok(())
688}