Skip to main content

thru_base/
tn_state_proof.rs

1//! State proof structures and utilities
2//!
3//! Rust equivalent of the C tn_state_proof.h structures
4
5use crate::txn_lib::{
6    TN_STATE_PROOF_TYPE_CREATION, TN_STATE_PROOF_TYPE_EXISTING, TN_STATE_PROOF_TYPE_UPDATING,
7    TnHash, TnPubkey,
8};
9
10/// Maximum number of keys in a state proof
11pub const TN_STATE_PROOF_KEYS_MAX: usize = 256;
12
13/// State proof type enumeration
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum StateProofType {
16    Existing = 0,
17    Updating = 1,
18    Creation = 2,
19}
20
21impl StateProofType {
22    /// Convert from u64 value
23    pub fn from_u64(value: u64) -> Option<Self> {
24        match value {
25            TN_STATE_PROOF_TYPE_EXISTING => Some(Self::Existing),
26            TN_STATE_PROOF_TYPE_UPDATING => Some(Self::Updating),
27            TN_STATE_PROOF_TYPE_CREATION => Some(Self::Creation),
28            _ => None,
29        }
30    }
31
32    /// Convert to u64 value
33    pub fn to_u64(self) -> u64 {
34        match self {
35            Self::Existing => TN_STATE_PROOF_TYPE_EXISTING,
36            Self::Updating => TN_STATE_PROOF_TYPE_UPDATING,
37            Self::Creation => TN_STATE_PROOF_TYPE_CREATION,
38        }
39    }
40}
41
42/// State proof header
43#[derive(Debug, Clone)]
44pub struct StateProofHeader {
45    pub proof_type: StateProofType,
46    pub slot: u64,
47    pub path_bitset: TnHash,
48}
49
50impl StateProofHeader {
51    /// Create a new state proof header
52    pub fn new(proof_type: StateProofType, slot: u64, path_bitset: TnHash) -> Self {
53        Self {
54            proof_type,
55            slot,
56            path_bitset,
57        }
58    }
59
60    /// Encode type and slot into type_slot field
61    pub fn encode_type_slot(&self) -> u64 {
62        self.slot | (self.proof_type.to_u64() << 62)
63    }
64
65    /// Decode type_slot field into type and slot
66    pub fn decode_type_slot(type_slot: u64) -> (StateProofType, u64) {
67        let proof_type =
68            StateProofType::from_u64((type_slot >> 62) & 0x3).unwrap_or(StateProofType::Existing);
69        let slot = type_slot & 0x3FFFFFFFFFFFFFFF; // Extract low 62 bits
70        (proof_type, slot)
71    }
72
73    /// Serialize header to bytes
74    pub fn to_wire(&self) -> Vec<u8> {
75        let mut result = Vec::with_capacity(40);
76        result.extend_from_slice(&self.encode_type_slot().to_le_bytes());
77        result.extend_from_slice(&self.path_bitset);
78        result
79    }
80
81    /// Deserialize header from bytes
82    pub fn from_wire(bytes: &[u8]) -> Option<Self> {
83        if bytes.len() < 40 {
84            return None;
85        }
86
87        let type_slot = u64::from_le_bytes([
88            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
89        ]);
90
91        let (proof_type, slot) = Self::decode_type_slot(type_slot);
92
93        let mut path_bitset = [0u8; 32];
94        path_bitset.copy_from_slice(&bytes[8..40]);
95
96        Some(Self {
97            proof_type,
98            slot,
99            path_bitset,
100        })
101    }
102}
103
104/// State proof body variants
105#[derive(Debug, Clone)]
106pub enum StateProofBody {
107    /// For existing entries - just sibling hashes
108    Existing { sibling_hashes: Vec<TnHash> },
109    /// For updating entries - existing leaf hash + sibling hashes
110    Updating {
111        existing_leaf_hash: TnHash,
112        sibling_hashes: Vec<TnHash>,
113    },
114    /// For creation entries - existing leaf pubkey and hash + sibling hashes
115    Creation {
116        existing_leaf_pubkey: TnPubkey,
117        existing_leaf_hash: TnHash,
118        sibling_hashes: Vec<TnHash>,
119    },
120}
121
122impl StateProofBody {
123    /// Get the number of sibling hashes
124    pub fn sibling_hash_count(&self) -> usize {
125        match self {
126            StateProofBody::Existing { sibling_hashes } => sibling_hashes.len(),
127            StateProofBody::Updating { sibling_hashes, .. } => sibling_hashes.len(),
128            StateProofBody::Creation { sibling_hashes, .. } => sibling_hashes.len(),
129        }
130    }
131
132    /// Calculate the number of hashes this body contains (for footprint calculation)
133    pub fn hash_count(&self) -> usize {
134        match self {
135            StateProofBody::Existing { sibling_hashes } => sibling_hashes.len(),
136            StateProofBody::Updating { sibling_hashes, .. } => 1 + sibling_hashes.len(),
137            StateProofBody::Creation { sibling_hashes, .. } => 2 + sibling_hashes.len(),
138        }
139    }
140
141    /// Serialize body to bytes
142    pub fn to_wire(&self) -> Vec<u8> {
143        let mut result = Vec::new();
144
145        match self {
146            StateProofBody::Existing { sibling_hashes } => {
147                for hash in sibling_hashes {
148                    result.extend_from_slice(hash);
149                }
150            }
151            StateProofBody::Updating {
152                existing_leaf_hash,
153                sibling_hashes,
154            } => {
155                result.extend_from_slice(existing_leaf_hash);
156                for hash in sibling_hashes {
157                    result.extend_from_slice(hash);
158                }
159            }
160            StateProofBody::Creation {
161                existing_leaf_pubkey,
162                existing_leaf_hash,
163                sibling_hashes,
164            } => {
165                result.extend_from_slice(existing_leaf_pubkey);
166                result.extend_from_slice(existing_leaf_hash);
167                for hash in sibling_hashes {
168                    result.extend_from_slice(hash);
169                }
170            }
171        }
172
173        result
174    }
175
176    /// Deserialize body from bytes given the proof type and sibling hash count
177    pub fn from_wire(
178        bytes: &[u8],
179        proof_type: StateProofType,
180        sibling_hash_count: usize,
181    ) -> Option<Self> {
182        let mut offset = 0;
183
184        match proof_type {
185            StateProofType::Existing => {
186                if bytes.len() < sibling_hash_count * 32 {
187                    return None;
188                }
189
190                let mut sibling_hashes = Vec::with_capacity(sibling_hash_count);
191                for _ in 0..sibling_hash_count {
192                    let mut hash = [0u8; 32];
193                    hash.copy_from_slice(&bytes[offset..offset + 32]);
194                    sibling_hashes.push(hash);
195                    offset += 32;
196                }
197
198                Some(StateProofBody::Existing { sibling_hashes })
199            }
200            StateProofType::Updating => {
201                if bytes.len() < 32 + sibling_hash_count * 32 {
202                    return None;
203                }
204
205                let mut existing_leaf_hash = [0u8; 32];
206                existing_leaf_hash.copy_from_slice(&bytes[offset..offset + 32]);
207                offset += 32;
208
209                let mut sibling_hashes = Vec::with_capacity(sibling_hash_count);
210                for _ in 0..sibling_hash_count {
211                    let mut hash = [0u8; 32];
212                    hash.copy_from_slice(&bytes[offset..offset + 32]);
213                    sibling_hashes.push(hash);
214                    offset += 32;
215                }
216
217                Some(StateProofBody::Updating {
218                    existing_leaf_hash,
219                    sibling_hashes,
220                })
221            }
222            StateProofType::Creation => {
223                if bytes.len() < 64 + sibling_hash_count * 32 {
224                    return None;
225                }
226
227                let mut existing_leaf_pubkey = [0u8; 32];
228                existing_leaf_pubkey.copy_from_slice(&bytes[offset..offset + 32]);
229                offset += 32;
230
231                let mut existing_leaf_hash = [0u8; 32];
232                existing_leaf_hash.copy_from_slice(&bytes[offset..offset + 32]);
233                offset += 32;
234
235                let mut sibling_hashes = Vec::with_capacity(sibling_hash_count);
236                for _ in 0..sibling_hash_count {
237                    let mut hash = [0u8; 32];
238                    hash.copy_from_slice(&bytes[offset..offset + 32]);
239                    sibling_hashes.push(hash);
240                    offset += 32;
241                }
242
243                Some(StateProofBody::Creation {
244                    existing_leaf_pubkey,
245                    existing_leaf_hash,
246                    sibling_hashes,
247                })
248            }
249        }
250    }
251}
252
253/// Complete state proof structure
254#[derive(Debug, Clone)]
255pub struct StateProof {
256    pub header: StateProofHeader,
257    pub body: StateProofBody,
258}
259
260impl StateProof {
261    /// Create a new state proof
262    pub fn new(header: StateProofHeader, body: StateProofBody) -> Self {
263        Self { header, body }
264    }
265
266    /// Create a zeroed creation state proof
267    pub fn zero_creation(slot: u64) -> Self {
268        let header = StateProofHeader::new(StateProofType::Creation, slot, [0u8; 32]);
269        let body = StateProofBody::Creation {
270            existing_leaf_pubkey: [0u8; 32],
271            existing_leaf_hash: [0u8; 32],
272            sibling_hashes: vec![],
273        };
274        Self { header, body }
275    }
276
277    /// Create an existing state proof
278    pub fn existing(slot: u64, path_bitset: TnHash, sibling_hashes: Vec<TnHash>) -> Self {
279        let header = StateProofHeader::new(StateProofType::Existing, slot, path_bitset);
280        let body = StateProofBody::Existing { sibling_hashes };
281        Self { header, body }
282    }
283
284    /// Create an updating state proof
285    pub fn updating(
286        slot: u64,
287        path_bitset: TnHash,
288        existing_leaf_hash: TnHash,
289        sibling_hashes: Vec<TnHash>,
290    ) -> Self {
291        let header = StateProofHeader::new(StateProofType::Updating, slot, path_bitset);
292        let body = StateProofBody::Updating {
293            existing_leaf_hash,
294            sibling_hashes,
295        };
296        Self { header, body }
297    }
298
299    /// Create a creation state proof
300    pub fn creation(
301        slot: u64,
302        path_bitset: TnHash,
303        existing_leaf_pubkey: TnPubkey,
304        existing_leaf_hash: TnHash,
305        sibling_hashes: Vec<TnHash>,
306    ) -> Self {
307        let header = StateProofHeader::new(StateProofType::Creation, slot, path_bitset);
308        let body = StateProofBody::Creation {
309            existing_leaf_pubkey,
310            existing_leaf_hash,
311            sibling_hashes,
312        };
313        Self { header, body }
314    }
315
316    /// Calculate the footprint (size in bytes) when serialized
317    pub fn footprint(&self) -> usize {
318        // Header is always 40 bytes (8 bytes type_slot + 32 bytes path_bitset)
319        let header_size = 40;
320        // Body size is number of hashes * 32 bytes per hash
321        let body_size = self.body.hash_count() * 32;
322        header_size + body_size
323    }
324
325    /// Calculate footprint from proof type and sibling hash count
326    pub fn footprint_from_counts(proof_type: StateProofType, sibling_hash_count: usize) -> usize {
327        let header_size = 40;
328        let body_hash_count = match proof_type {
329            StateProofType::Existing => sibling_hash_count,
330            StateProofType::Updating => 1 + sibling_hash_count,
331            StateProofType::Creation => 2 + sibling_hash_count,
332        };
333        header_size + body_hash_count * 32
334    }
335
336    /// Calculate footprint from path bitset (count set bits for sibling hashes)
337    pub fn footprint_from_header(header: &StateProofHeader) -> usize {
338        let sibling_hash_count = count_set_bits(&header.path_bitset);
339        Self::footprint_from_counts(header.proof_type, sibling_hash_count)
340    }
341
342    /// Serialize to wire format
343    pub fn to_wire(&self) -> Vec<u8> {
344        let mut result = self.header.to_wire();
345        result.extend_from_slice(&self.body.to_wire());
346        result
347    }
348
349    /// Deserialize from wire format
350    pub fn from_wire(bytes: &[u8]) -> Option<Self> {
351        if bytes.len() < 40 {
352            return None;
353        }
354
355        let header = StateProofHeader::from_wire(&bytes[0..40])?;
356        let sibling_hash_count = count_set_bits(&header.path_bitset);
357
358        let body_bytes = &bytes[40..];
359        let body = StateProofBody::from_wire(body_bytes, header.proof_type, sibling_hash_count)?;
360
361        Some(Self { header, body })
362    }
363
364    /// Get the proof type
365    pub fn proof_type(&self) -> StateProofType {
366        self.header.proof_type
367    }
368
369    /// Get the slot
370    pub fn slot(&self) -> u64 {
371        self.header.slot
372    }
373
374    /// Get the path bitset
375    pub fn path_bitset(&self) -> &TnHash {
376        &self.header.path_bitset
377    }
378}
379
380/// Count the number of set bits in a hash (used for calculating sibling hash count)
381fn count_set_bits(hash: &TnHash) -> usize {
382    let mut count = 0;
383    for i in 0..4 {
384        let start = i * 8;
385        let word = u64::from_le_bytes([
386            hash[start],
387            hash[start + 1],
388            hash[start + 2],
389            hash[start + 3],
390            hash[start + 4],
391            hash[start + 5],
392            hash[start + 6],
393            hash[start + 7],
394        ]);
395        count += word.count_ones() as usize;
396    }
397    count
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_state_proof_type_conversion() {
406        assert_eq!(
407            StateProofType::Existing.to_u64(),
408            TN_STATE_PROOF_TYPE_EXISTING
409        );
410        assert_eq!(
411            StateProofType::Updating.to_u64(),
412            TN_STATE_PROOF_TYPE_UPDATING
413        );
414        assert_eq!(
415            StateProofType::Creation.to_u64(),
416            TN_STATE_PROOF_TYPE_CREATION
417        );
418
419        assert_eq!(
420            StateProofType::from_u64(TN_STATE_PROOF_TYPE_EXISTING),
421            Some(StateProofType::Existing)
422        );
423        assert_eq!(
424            StateProofType::from_u64(TN_STATE_PROOF_TYPE_UPDATING),
425            Some(StateProofType::Updating)
426        );
427        assert_eq!(
428            StateProofType::from_u64(TN_STATE_PROOF_TYPE_CREATION),
429            Some(StateProofType::Creation)
430        );
431        assert_eq!(StateProofType::from_u64(999), None);
432    }
433
434    #[test]
435    fn test_header_type_slot_encoding() {
436        let header = StateProofHeader::new(StateProofType::Creation, 0x1FFFFFFFFFFFFFFF, [0u8; 32]);
437        let encoded = header.encode_type_slot();
438        let (decoded_type, decoded_slot) = StateProofHeader::decode_type_slot(encoded);
439
440        assert_eq!(decoded_type, StateProofType::Creation);
441        assert_eq!(decoded_slot, 0x1FFFFFFFFFFFFFFF);
442    }
443
444    #[test]
445    fn test_header_serialization() {
446        let path_bitset = [1u8; 32];
447        let header = StateProofHeader::new(StateProofType::Updating, 12345, path_bitset);
448
449        let serialized = header.to_wire();
450        assert_eq!(serialized.len(), 40);
451
452        let deserialized = StateProofHeader::from_wire(&serialized).unwrap();
453        assert_eq!(deserialized.proof_type, StateProofType::Updating);
454        assert_eq!(deserialized.slot, 12345);
455        assert_eq!(deserialized.path_bitset, path_bitset);
456    }
457
458    #[test]
459    fn test_existing_proof_serialization() {
460        let sibling_hashes = vec![[1u8; 32], [2u8; 32]];
461        // Create a path_bitset with 2 bits set to match the 2 sibling hashes
462        let mut path_bitset = [0u8; 32];
463        path_bitset[0] = 0b11; // Set first 2 bits
464        let proof = StateProof::existing(100, path_bitset, sibling_hashes.clone());
465
466        assert_eq!(proof.footprint(), 40 + 2 * 32); // header + 2 hashes
467
468        let serialized = proof.to_wire();
469        let deserialized = StateProof::from_wire(&serialized).unwrap();
470
471        assert_eq!(deserialized.proof_type(), StateProofType::Existing);
472        assert_eq!(deserialized.slot(), 100);
473
474        if let StateProofBody::Existing {
475            sibling_hashes: deser_hashes,
476        } = deserialized.body
477        {
478            assert_eq!(deser_hashes, sibling_hashes);
479        } else {
480            panic!("Expected Existing proof body");
481        }
482    }
483
484    #[test]
485    fn test_creation_proof_serialization() {
486        let existing_leaf_pubkey = [3u8; 32];
487        let existing_leaf_hash = [4u8; 32];
488        let sibling_hashes = vec![[5u8; 32]];
489
490        // Create a path_bitset with 1 bit set to match the 1 sibling hash
491        let mut path_bitset = [0u8; 32];
492        path_bitset[0] = 0b1; // Set first bit
493
494        let proof = StateProof::creation(
495            200,
496            path_bitset,
497            existing_leaf_pubkey,
498            existing_leaf_hash,
499            sibling_hashes.clone(),
500        );
501
502        assert_eq!(proof.footprint(), 40 + 3 * 32); // header + pubkey + hash + 1 sibling
503
504        let serialized = proof.to_wire();
505        let deserialized = StateProof::from_wire(&serialized).unwrap();
506
507        assert_eq!(deserialized.proof_type(), StateProofType::Creation);
508        assert_eq!(deserialized.slot(), 200);
509
510        if let StateProofBody::Creation {
511            existing_leaf_pubkey: deser_pubkey,
512            existing_leaf_hash: deser_hash,
513            sibling_hashes: deser_hashes,
514        } = deserialized.body
515        {
516            assert_eq!(deser_pubkey, existing_leaf_pubkey);
517            assert_eq!(deser_hash, existing_leaf_hash);
518            assert_eq!(deser_hashes, sibling_hashes);
519        } else {
520            panic!("Expected Creation proof body");
521        }
522    }
523
524    #[test]
525    fn test_count_set_bits() {
526        let mut hash = [0u8; 32];
527        assert_eq!(count_set_bits(&hash), 0);
528
529        hash[0] = 0b10101010; // 4 bits set
530        hash[1] = 0b11110000; // 4 bits set
531        assert_eq!(count_set_bits(&hash), 8);
532    }
533
534    #[test]
535    fn test_footprint_calculation() {
536        assert_eq!(
537            StateProof::footprint_from_counts(StateProofType::Existing, 5),
538            40 + 5 * 32
539        );
540        assert_eq!(
541            StateProof::footprint_from_counts(StateProofType::Updating, 3),
542            40 + 4 * 32
543        );
544        assert_eq!(
545            StateProof::footprint_from_counts(StateProofType::Creation, 2),
546            40 + 4 * 32
547        );
548    }
549}