Skip to main content

vcl_protocol/
packet.rs

1use sha2::{Sha256, Digest};
2use ed25519_dalek::{SigningKey, VerifyingKey, Signature, Signer, Verifier};
3use serde::{Serialize, Deserialize};
4use crate::error::VCLError;
5
6/// Type of a VCL packet.
7/// Determines how the connection layer routes the packet after decryption.
8#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
9pub enum PacketType {
10    Data,
11    Ping,
12    Pong,
13    KeyRotation,
14}
15
16#[derive(Serialize, Deserialize, Clone, Debug)]
17pub struct VCLPacket {
18    pub version: u8,
19    pub packet_type: PacketType,
20    pub sequence: u64,
21    pub prev_hash: Vec<u8>,
22    pub nonce: [u8; 24],
23    pub payload: Vec<u8>,
24    pub signature: Vec<u8>,
25}
26
27impl VCLPacket {
28    /// Create a Data packet (default, used by send())
29    pub fn new(sequence: u64, prev_hash: Vec<u8>, payload: Vec<u8>, nonce: [u8; 24]) -> Self {
30        Self::new_typed(sequence, prev_hash, payload, nonce, PacketType::Data)
31    }
32
33    /// Create a packet of a specific type (used internally)
34    pub fn new_typed(
35        sequence: u64,
36        prev_hash: Vec<u8>,
37        payload: Vec<u8>,
38        nonce: [u8; 24],
39        packet_type: PacketType,
40    ) -> Self {
41        VCLPacket {
42            version: 2,
43            packet_type,
44            sequence,
45            prev_hash,
46            nonce,
47            payload,
48            signature: Vec::new(),
49        }
50    }
51
52    pub fn compute_hash(&self) -> Vec<u8> {
53        let mut hasher = Sha256::new();
54        hasher.update(self.version.to_be_bytes());
55        hasher.update(self.sequence.to_be_bytes());
56        hasher.update(&self.prev_hash);
57        hasher.update(&self.nonce);
58        hasher.update(&self.payload);
59        hasher.finalize().to_vec()
60    }
61
62    pub fn sign(&mut self, private_key: &[u8]) -> Result<(), VCLError> {
63        let key_bytes: &[u8; 32] = private_key
64            .try_into()
65            .map_err(|_| VCLError::InvalidKey("Private key must be 32 bytes".to_string()))?;
66        let signing_key = SigningKey::from_bytes(key_bytes);
67        let hash = self.compute_hash();
68        let signature: Signature = signing_key.sign(&hash);
69        self.signature = signature.to_bytes().to_vec();
70        Ok(())
71    }
72
73    pub fn verify(&self, public_key: &[u8]) -> Result<bool, VCLError> {
74        if self.signature.len() != 64 {
75            return Ok(false);
76        }
77        let key_bytes: &[u8; 32] = public_key
78            .try_into()
79            .map_err(|_| VCLError::InvalidKey("Public key must be 32 bytes".to_string()))?;
80        let verifying_key = VerifyingKey::from_bytes(key_bytes)
81            .map_err(|e| VCLError::InvalidKey(format!("Invalid public key: {}", e)))?;
82        let sig_bytes: &[u8; 64] = self.signature
83            .as_slice()
84            .try_into()
85            .map_err(|_| VCLError::InvalidKey("Signature must be 64 bytes".to_string()))?;
86        let signature = Signature::from_bytes(sig_bytes);
87        let hash = self.compute_hash();
88        Ok(verifying_key.verify(&hash, &signature).is_ok())
89    }
90
91    pub fn validate_chain(&self, expected_prev_hash: &[u8]) -> bool {
92        self.prev_hash == expected_prev_hash
93    }
94
95    pub fn serialize(&self) -> Vec<u8> {
96        bincode::serialize(self).unwrap()
97    }
98
99    pub fn deserialize(data: &[u8]) -> Result<Self, VCLError> {
100        bincode::deserialize(data).map_err(|e| VCLError::SerializationError(e.to_string()))
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use crate::crypto::KeyPair;
107    use super::*;
108
109    fn test_keypair() -> KeyPair {
110        KeyPair::generate()
111    }
112
113    #[test]
114    fn test_packet_new() {
115        let packet = VCLPacket::new(1, vec![0; 32], b"test".to_vec(), [0; 24]);
116        assert_eq!(packet.version, 2);
117        assert_eq!(packet.sequence, 1);
118        assert_eq!(packet.payload, b"test");
119        assert_eq!(packet.packet_type, PacketType::Data);
120    }
121
122    #[test]
123    fn test_compute_hash() {
124        let p1 = VCLPacket::new(1, vec![0; 32], b"A".to_vec(), [0; 24]);
125        let p2 = VCLPacket::new(1, vec![0; 32], b"B".to_vec(), [0; 24]);
126        assert_ne!(p1.compute_hash(), p2.compute_hash());
127    }
128
129    #[test]
130    fn test_sign_verify() {
131        let kp = test_keypair();
132        let mut packet = VCLPacket::new(1, vec![0; 32], b"test".to_vec(), [0; 24]);
133        packet.sign(&kp.private_key).unwrap();
134        assert!(packet.verify(&kp.public_key).unwrap());
135    }
136
137    #[test]
138    fn test_verify_wrong_key_fails() {
139        let kp1 = test_keypair();
140        let kp2 = test_keypair();
141        let mut packet = VCLPacket::new(1, vec![0; 32], b"test".to_vec(), [0; 24]);
142        packet.sign(&kp1.private_key).unwrap();
143        assert!(!packet.verify(&kp2.public_key).unwrap());
144    }
145
146    #[test]
147    fn test_validate_chain() {
148        let prev = vec![1, 2, 3];
149        let packet = VCLPacket::new(1, prev.clone(), b"test".to_vec(), [0; 24]);
150        assert!(packet.validate_chain(&prev));
151        assert!(!packet.validate_chain(&[4, 5, 6]));
152    }
153
154    #[test]
155    fn test_serialize_deserialize() {
156        let original = VCLPacket::new(42, vec![9; 32], b"payload".to_vec(), [7; 24]);
157        let bytes = original.serialize();
158        let restored = VCLPacket::deserialize(&bytes).unwrap();
159        assert_eq!(original.sequence, restored.sequence);
160        assert_eq!(original.payload, restored.payload);
161        assert_eq!(original.nonce, restored.nonce);
162        assert_eq!(restored.packet_type, PacketType::Data);
163    }
164
165    #[test]
166    fn test_packet_types() {
167        let ping = VCLPacket::new_typed(0, vec![0; 32], vec![], [0; 24], PacketType::Ping);
168        let pong = VCLPacket::new_typed(0, vec![0; 32], vec![], [0; 24], PacketType::Pong);
169        let rot  = VCLPacket::new_typed(0, vec![0; 32], vec![], [0; 24], PacketType::KeyRotation);
170        assert_eq!(ping.packet_type, PacketType::Ping);
171        assert_eq!(pong.packet_type, PacketType::Pong);
172        assert_eq!(rot.packet_type,  PacketType::KeyRotation);
173    }
174}