1use aes_gcm::{Aes256Gcm, KeyInit as AesKeyInit, Nonce as AesNonce, aead::Aead as AesAead};
7use base64::Engine;
8use base64::engine::general_purpose::STANDARD as BASE64;
9use chacha20poly1305::{ChaCha20Poly1305, Nonce};
10use getrandom::getrandom;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13use thiserror::Error;
14
15const AES256_NONCE_LEN: usize = 12;
16const CHACHA20_NONCE_LEN: usize = 12;
17const MAX_FRAME_BYTES: usize = 10 * 1024 * 1024;
18const MAX_PAYLOAD_BYTES: usize = MAX_FRAME_BYTES - 6;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[repr(u8)]
23pub enum MessageType {
24 Api = 0x00,
26 Event = 0x01,
28}
29
30impl TryFrom<u8> for MessageType {
31 type Error = ProtocolError;
32
33 fn try_from(value: u8) -> Result<Self, Self::Error> {
34 match value {
35 0x00 => Ok(Self::Api),
36 0x01 => Ok(Self::Event),
37 _ => Err(ProtocolError::UnknownMessageType(value)),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44#[repr(u8)]
45pub enum EncryptionKind {
46 None = 0x00,
48 ChaCha20 = 0x01,
50 Aes256 = 0x02,
52}
53
54impl TryFrom<u8> for EncryptionKind {
55 type Error = ProtocolError;
56
57 fn try_from(value: u8) -> Result<Self, Self::Error> {
58 match value {
59 0x00 => Ok(Self::None),
60 0x01 => Ok(Self::ChaCha20),
61 0x02 => Ok(Self::Aes256),
62 _ => Err(ProtocolError::UnknownEncryption(value)),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct FileAttachment {
70 pub id: String,
72 pub name: String,
74 pub content_type: String,
76 pub encoding: String,
78 pub data: String,
80 pub size: usize,
82}
83
84impl FileAttachment {
85 pub fn inline_text(
87 id: impl Into<String>,
88 name: impl Into<String>,
89 content_type: impl Into<String>,
90 text: impl AsRef<str>,
91 ) -> Self {
92 Self::inline_bytes(id, name, content_type, text.as_ref().as_bytes().to_vec())
93 }
94
95 pub fn inline_bytes(
97 id: impl Into<String>,
98 name: impl Into<String>,
99 content_type: impl Into<String>,
100 bytes: Vec<u8>,
101 ) -> Self {
102 let size = bytes.len();
103 Self {
104 id: id.into(),
105 name: name.into(),
106 content_type: content_type.into(),
107 encoding: "base64".to_string(),
108 data: BASE64.encode(bytes),
109 size,
110 }
111 }
112
113 pub fn decode_bytes(&self) -> Result<Vec<u8>, ProtocolError> {
115 BASE64
116 .decode(self.data.as_bytes())
117 .map_err(|source| ProtocolError::InvalidAttachmentEncoding(source.to_string()))
118 }
119
120 pub fn param_ref(id: impl Into<String>) -> Value {
122 json!({ "$file": id.into() })
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ErrorPayload {
129 pub code: String,
130 pub message: String,
131 pub status: u16,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub details: Option<Value>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138#[serde(tag = "kind", rename_all = "snake_case")]
139pub enum PacketBody {
140 ApiRequest {
142 request_id: String,
143 route: String,
144 params: Value,
145 attachments: Vec<FileAttachment>,
146 metadata: Value,
147 },
148 ApiResponse {
150 request_id: String,
151 ok: bool,
152 status: u16,
153 data: Value,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 error: Option<ErrorPayload>,
156 metadata: Value,
157 },
158 EventEmit {
160 event_id: String,
161 name: String,
162 data: Value,
163 attachments: Vec<FileAttachment>,
164 metadata: Value,
165 expect_ack: bool,
166 },
167 EventAck {
169 event_id: String,
170 ok: bool,
171 receipt: Value,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 error: Option<ErrorPayload>,
174 },
175}
176
177impl PacketBody {
178 pub fn message_type(&self) -> MessageType {
179 match self {
180 Self::ApiRequest { .. } | Self::ApiResponse { .. } => MessageType::Api,
181 Self::EventEmit { .. } | Self::EventAck { .. } => MessageType::Event,
182 }
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct PacketEnvelope {
189 pub message_type: MessageType,
191 pub encryption: EncryptionKind,
193 pub body: PacketBody,
195}
196
197impl PacketEnvelope {
198 pub fn new(body: PacketBody) -> Self {
200 Self {
201 message_type: body.message_type(),
202 encryption: EncryptionKind::None,
203 body,
204 }
205 }
206
207 pub fn with_encryption(body: PacketBody, encryption: EncryptionKind) -> Self {
209 Self {
210 message_type: body.message_type(),
211 encryption,
212 body,
213 }
214 }
215}
216
217#[derive(Debug, Clone, Default)]
219pub struct FrameCodec {
220 aes256_key: Option<[u8; 32]>,
221 chacha20_key: Option<[u8; 32]>,
222}
223
224impl FrameCodec {
225 pub fn plaintext() -> Self {
227 Self::default()
228 }
229
230 pub fn with_chacha20_key(mut self, key: [u8; 32]) -> Self {
232 self.chacha20_key = Some(key);
233 self
234 }
235
236 pub fn with_aes256_key(mut self, key: [u8; 32]) -> Self {
238 self.aes256_key = Some(key);
239 self
240 }
241
242 pub fn encode(&self, packet: &PacketEnvelope) -> Result<Vec<u8>, ProtocolError> {
244 let payload = serde_json::to_vec(&packet.body)?;
245 let payload = match packet.encryption {
246 EncryptionKind::None => payload,
247 EncryptionKind::ChaCha20 => self.encrypt_chacha20(&payload)?,
248 EncryptionKind::Aes256 => self.encrypt_aes256(&payload)?,
249 };
250
251 if payload.len() > MAX_PAYLOAD_BYTES {
252 return Err(ProtocolError::PayloadTooLarge {
253 actual: payload.len(),
254 max: MAX_PAYLOAD_BYTES,
255 });
256 }
257
258 let frame_len = 2 + payload.len();
259 let mut frame = Vec::with_capacity(4 + frame_len);
260 frame.extend_from_slice(&(frame_len as u32).to_be_bytes());
261 frame.push(packet.message_type as u8);
262 frame.push(packet.encryption as u8);
263 frame.extend_from_slice(&payload);
264 Ok(frame)
265 }
266
267 pub fn decode(&self, frame: &[u8]) -> Result<PacketEnvelope, ProtocolError> {
269 if frame.len() < 6 {
270 return Err(ProtocolError::FrameTooShort);
271 }
272
273 let declared = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
274 let actual = frame.len() - 4;
275 if declared != actual {
276 return Err(ProtocolError::FrameLengthMismatch { declared, actual });
277 }
278
279 let payload_len = actual - 2;
280 if payload_len > MAX_PAYLOAD_BYTES {
281 return Err(ProtocolError::PayloadTooLarge {
282 actual: payload_len,
283 max: MAX_PAYLOAD_BYTES,
284 });
285 }
286
287 let message_type = MessageType::try_from(frame[4])?;
288 let encryption = EncryptionKind::try_from(frame[5])?;
289 let payload = match encryption {
290 EncryptionKind::None => frame[6..].to_vec(),
291 EncryptionKind::ChaCha20 => self.decrypt_chacha20(&frame[6..])?,
292 EncryptionKind::Aes256 => self.decrypt_aes256(&frame[6..])?,
293 };
294
295 let body: PacketBody = serde_json::from_slice(&payload)?;
296 if body.message_type() != message_type {
297 return Err(ProtocolError::MessageTypeMismatch);
298 }
299
300 Ok(PacketEnvelope {
301 message_type,
302 encryption,
303 body,
304 })
305 }
306
307 fn encrypt_chacha20(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
308 let key = self
309 .chacha20_key
310 .ok_or(ProtocolError::MissingEncryptionKey("chacha20"))?;
311 let cipher = ChaCha20Poly1305::new_from_slice(&key)
312 .map_err(|_| ProtocolError::InvalidEncryptionKey("chacha20"))?;
313 let mut nonce_bytes = [0_u8; CHACHA20_NONCE_LEN];
314 getrandom(&mut nonce_bytes).map_err(|source| ProtocolError::Random(source.to_string()))?;
315 let ciphertext = cipher
316 .encrypt(Nonce::from_slice(&nonce_bytes), payload)
317 .map_err(|_| ProtocolError::EncryptionFailed("chacha20"))?;
318
319 let mut encoded = Vec::with_capacity(CHACHA20_NONCE_LEN + ciphertext.len());
320 encoded.extend_from_slice(&nonce_bytes);
321 encoded.extend_from_slice(&ciphertext);
322 Ok(encoded)
323 }
324
325 fn decrypt_chacha20(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
326 if payload.len() < CHACHA20_NONCE_LEN {
327 return Err(ProtocolError::EncryptedPayloadTooShort {
328 algorithm: "chacha20",
329 expected_min: CHACHA20_NONCE_LEN,
330 actual: payload.len(),
331 });
332 }
333
334 let key = self
335 .chacha20_key
336 .ok_or(ProtocolError::MissingEncryptionKey("chacha20"))?;
337 let cipher = ChaCha20Poly1305::new_from_slice(&key)
338 .map_err(|_| ProtocolError::InvalidEncryptionKey("chacha20"))?;
339 let (nonce_bytes, ciphertext) = payload.split_at(CHACHA20_NONCE_LEN);
340 cipher
341 .decrypt(Nonce::from_slice(nonce_bytes), ciphertext)
342 .map_err(|_| ProtocolError::DecryptionFailed("chacha20"))
343 }
344
345 fn encrypt_aes256(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
346 let key = self
347 .aes256_key
348 .ok_or(ProtocolError::MissingEncryptionKey("aes256"))?;
349 let cipher = Aes256Gcm::new_from_slice(&key)
350 .map_err(|_| ProtocolError::InvalidEncryptionKey("aes256"))?;
351 let mut nonce_bytes = [0_u8; AES256_NONCE_LEN];
352 getrandom(&mut nonce_bytes).map_err(|source| ProtocolError::Random(source.to_string()))?;
353 let ciphertext = cipher
354 .encrypt(AesNonce::from_slice(&nonce_bytes), payload)
355 .map_err(|_| ProtocolError::EncryptionFailed("aes256"))?;
356
357 let mut encoded = Vec::with_capacity(AES256_NONCE_LEN + ciphertext.len());
358 encoded.extend_from_slice(&nonce_bytes);
359 encoded.extend_from_slice(&ciphertext);
360 Ok(encoded)
361 }
362
363 fn decrypt_aes256(&self, payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
364 if payload.len() < AES256_NONCE_LEN {
365 return Err(ProtocolError::EncryptedPayloadTooShort {
366 algorithm: "aes256",
367 expected_min: AES256_NONCE_LEN,
368 actual: payload.len(),
369 });
370 }
371
372 let key = self
373 .aes256_key
374 .ok_or(ProtocolError::MissingEncryptionKey("aes256"))?;
375 let cipher = Aes256Gcm::new_from_slice(&key)
376 .map_err(|_| ProtocolError::InvalidEncryptionKey("aes256"))?;
377 let (nonce_bytes, ciphertext) = payload.split_at(AES256_NONCE_LEN);
378 cipher
379 .decrypt(AesNonce::from_slice(nonce_bytes), ciphertext)
380 .map_err(|_| ProtocolError::DecryptionFailed("aes256"))
381 }
382}
383
384pub fn encode_frame(packet: &PacketEnvelope) -> Result<Vec<u8>, ProtocolError> {
386 FrameCodec::plaintext().encode(packet)
387}
388
389pub fn decode_frame(frame: &[u8]) -> Result<PacketEnvelope, ProtocolError> {
391 FrameCodec::plaintext().decode(frame)
392}
393
394#[derive(Debug, Error)]
396pub enum ProtocolError {
397 #[error("frame too short")]
398 FrameTooShort,
399 #[error("frame length mismatch: declared={declared}, actual={actual}")]
400 FrameLengthMismatch { declared: usize, actual: usize },
401 #[error("payload too large: actual={actual}, max={max}")]
402 PayloadTooLarge { actual: usize, max: usize },
403 #[error("unknown message type: {0:#x}")]
404 UnknownMessageType(u8),
405 #[error("unknown encryption kind: {0:#x}")]
406 UnknownEncryption(u8),
407 #[error("unsupported encryption kind: {0:#x}")]
408 UnsupportedEncryption(u8),
409 #[error("missing encryption key for {0}")]
410 MissingEncryptionKey(&'static str),
411 #[error("invalid encryption key for {0}")]
412 InvalidEncryptionKey(&'static str),
413 #[error("secure random generation failed: {0}")]
414 Random(String),
415 #[error(
416 "encrypted payload too short for {algorithm}: expected at least {expected_min}, actual={actual}"
417 )]
418 EncryptedPayloadTooShort {
419 algorithm: &'static str,
420 expected_min: usize,
421 actual: usize,
422 },
423 #[error("encryption failed for {0}")]
424 EncryptionFailed(&'static str),
425 #[error("decryption failed for {0}")]
426 DecryptionFailed(&'static str),
427 #[error("message type does not match packet body")]
428 MessageTypeMismatch,
429 #[error("invalid attachment encoding: {0}")]
430 InvalidAttachmentEncoding(String),
431 #[error("json error: {0}")]
432 Json(#[from] serde_json::Error),
433}
434
435#[cfg(test)]
436mod tests {
437 use super::{
438 EncryptionKind, FrameCodec, MAX_PAYLOAD_BYTES, MessageType, PacketBody, PacketEnvelope,
439 ProtocolError, decode_frame, encode_frame,
440 };
441 use serde_json::json;
442
443 const TEST_KEY: [u8; 32] = [0x11; 32];
444
445 #[test]
446 fn plaintext_helpers_still_work() {
447 let packet = PacketEnvelope::new(PacketBody::EventAck {
448 event_id: "evt-1".to_string(),
449 ok: true,
450 receipt: json!({ "ok": true }),
451 error: None,
452 });
453
454 let encoded = encode_frame(&packet).expect("encode plaintext");
455 let decoded = decode_frame(&encoded).expect("decode plaintext");
456 assert!(matches!(decoded.encryption, EncryptionKind::None));
457 }
458
459 #[test]
460 fn aes256_roundtrip_works() {
461 let codec = FrameCodec::plaintext().with_aes256_key(TEST_KEY);
462 let packet = PacketEnvelope::with_encryption(
463 PacketBody::ApiResponse {
464 request_id: "req-1".to_string(),
465 ok: true,
466 status: 200,
467 data: json!({ "message": "encrypted" }),
468 error: None,
469 metadata: json!({}),
470 },
471 EncryptionKind::Aes256,
472 );
473
474 let encoded = codec.encode(&packet).expect("encode aes256");
475 let decoded = codec.decode(&encoded).expect("decode aes256");
476 assert!(matches!(decoded.encryption, EncryptionKind::Aes256));
477 }
478
479 #[test]
480 fn encode_rejects_payloads_over_limit() {
481 let codec = FrameCodec::plaintext();
482 let packet = PacketEnvelope::new(PacketBody::ApiResponse {
483 request_id: "req-oversize".to_string(),
484 ok: true,
485 status: 200,
486 data: json!({ "blob": "a".repeat(10 * 1024 * 1024) }),
487 error: None,
488 metadata: json!({}),
489 });
490
491 let error = codec
492 .encode(&packet)
493 .expect_err("oversized payload should fail");
494 assert!(matches!(error, ProtocolError::PayloadTooLarge { .. }));
495 }
496
497 #[test]
498 fn decode_rejects_payloads_over_limit() {
499 let payload = vec![0_u8; MAX_PAYLOAD_BYTES + 1];
500 let frame_len = 2 + payload.len();
501 let mut frame = Vec::with_capacity(4 + frame_len);
502 frame.extend_from_slice(&(frame_len as u32).to_be_bytes());
503 frame.push(MessageType::Api as u8);
504 frame.push(EncryptionKind::None as u8);
505 frame.extend_from_slice(&payload);
506
507 let error = FrameCodec::plaintext()
508 .decode(&frame)
509 .expect_err("oversized payload should fail");
510 assert!(matches!(error, ProtocolError::PayloadTooLarge { .. }));
511 }
512}