swink_agent_eval/
audit.rs1use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5
6use crate::types::Invocation;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AuditedInvocation {
18 pub invocation: Invocation,
20 pub turn_hashes: Vec<String>,
22 pub chain_hash: String,
24}
25
26impl AuditedInvocation {
27 #[must_use]
29 pub fn from_invocation(invocation: Invocation) -> Self {
30 let turn_hashes: Vec<String> = invocation
31 .turns
32 .iter()
33 .map(|turn| {
34 let json = serde_json::to_string(turn).expect("TurnRecord is serializable");
35 hex_sha256(json.as_bytes())
36 })
37 .collect();
38
39 let chain_hash = compute_chain_hash(&turn_hashes);
40
41 Self {
42 invocation,
43 turn_hashes,
44 chain_hash,
45 }
46 }
47
48 #[must_use]
50 pub fn verify(&self) -> bool {
51 if self.turn_hashes.len() != self.invocation.turns.len() {
52 return false;
53 }
54
55 for (turn, stored_hash) in self.invocation.turns.iter().zip(&self.turn_hashes) {
56 let json = serde_json::to_string(turn).expect("TurnRecord is serializable");
57 let computed = hex_sha256(json.as_bytes());
58 if &computed != stored_hash {
59 return false;
60 }
61 }
62
63 let computed_chain = compute_chain_hash(&self.turn_hashes);
64 computed_chain == self.chain_hash
65 }
66}
67
68fn hex_sha256(data: &[u8]) -> String {
69 let mut hasher = Sha256::new();
70 hasher.update(data);
71 let hash = hasher.finalize();
72 let mut out = String::with_capacity(hash.len() * 2);
73 for byte in hash {
74 use std::fmt::Write as _;
75 let _ = write!(&mut out, "{byte:02x}");
76 }
77 out
78}
79
80fn compute_chain_hash(turn_hashes: &[String]) -> String {
81 let concatenated: String = turn_hashes.concat();
82 hex_sha256(concatenated.as_bytes())
83}
84
85#[cfg(test)]
86mod tests {
87 use std::time::Duration;
88
89 use swink_agent::{AssistantMessage, Cost, ModelSpec, StopReason, Usage};
90
91 use super::*;
92 use crate::types::TurnRecord;
93
94 fn minimal_invocation(num_turns: usize) -> Invocation {
95 let turns = (0..num_turns)
96 .map(|i| TurnRecord {
97 turn_index: i,
98 assistant_message: AssistantMessage {
99 content: vec![],
100 provider: "test".to_string(),
101 model_id: "test-model".to_string(),
102 usage: Usage::default(),
103 cost: Cost::default(),
104 stop_reason: StopReason::Stop,
105 error_message: None,
106 error_kind: None,
107 timestamp: 0,
108 cache_hint: None,
109 },
110 tool_calls: vec![],
111 tool_results: vec![],
112 duration: Duration::from_millis(10),
113 })
114 .collect();
115
116 Invocation {
117 turns,
118 total_usage: Usage::default(),
119 total_cost: Cost::default(),
120 total_duration: Duration::from_millis(10 * num_turns as u64),
121 final_response: None,
122 stop_reason: StopReason::Stop,
123 model: ModelSpec::new("test", "test-model"),
124 }
125 }
126
127 #[test]
128 fn roundtrip_verify() {
129 let inv = minimal_invocation(3);
130 let audited = AuditedInvocation::from_invocation(inv);
131
132 assert!(audited.verify());
133 assert_eq!(audited.turn_hashes.len(), 3);
134 for hash in &audited.turn_hashes {
135 assert_eq!(hash.len(), 64);
136 }
137 assert_eq!(audited.chain_hash.len(), 64);
138 }
139
140 #[test]
141 fn tampered_turn_fails_verify() {
142 let inv = minimal_invocation(2);
143 let mut audited = AuditedInvocation::from_invocation(inv);
144
145 audited.turn_hashes[0] = "0".repeat(64);
146
147 assert!(!audited.verify());
148 }
149
150 #[test]
151 fn empty_invocation() {
152 let inv = minimal_invocation(0);
153 let audited = AuditedInvocation::from_invocation(inv);
154
155 assert!(audited.verify());
156 assert!(audited.turn_hashes.is_empty());
157 assert_eq!(audited.chain_hash.len(), 64);
158 }
159}