1use crate::AppStateError;
8use crate::decode::{Mutation, decode_record};
9use crate::hash::{HashState, generate_patch_mac};
10use crate::keys::ExpandedAppStateKeys;
11use log::{debug, trace};
12use serde::{Deserialize, Serialize};
13use waproto::whatsapp as wa;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AppStateMutationMAC {
17 pub index_mac: Vec<u8>,
18 pub value_mac: Vec<u8>,
19}
20
21#[derive(Debug, Clone)]
23pub struct ProcessedSnapshot {
24 pub state: HashState,
26 pub mutations: Vec<Mutation>,
28 pub mutation_macs: Vec<AppStateMutationMAC>,
30}
31
32#[derive(Debug, Clone)]
34pub struct PatchProcessingResult {
35 pub state: HashState,
37 pub mutations: Vec<Mutation>,
39 pub added_macs: Vec<AppStateMutationMAC>,
41 pub removed_index_macs: Vec<Vec<u8>>,
43}
44
45pub fn process_snapshot<F>(
60 snapshot: &wa::SyncdSnapshot,
61 initial_state: &mut HashState,
62 mut get_keys: F,
63 validate_macs: bool,
64 collection_name: &str,
65) -> Result<ProcessedSnapshot, AppStateError>
66where
67 F: FnMut(&[u8]) -> Result<ExpandedAppStateKeys, AppStateError>,
68{
69 let version = snapshot
70 .version
71 .as_ref()
72 .and_then(|v| v.version)
73 .unwrap_or(0);
74 initial_state.version = version;
75
76 initial_state.update_hash_from_records(&snapshot.records);
78
79 debug!(
80 target: "AppState",
81 "Snapshot {} v{}: {} records, ltHash ends with ...{}",
82 collection_name,
83 version,
84 snapshot.records.len(),
85 hex::encode(&initial_state.hash[120..])
86 );
87
88 if validate_macs
90 && let (Some(mac_expected), Some(key_id)) = (
91 snapshot.mac.as_ref(),
92 snapshot.key_id.as_ref().and_then(|k| k.id.as_ref()),
93 )
94 {
95 let keys = get_keys(key_id)?;
96 let computed = initial_state.generate_snapshot_mac(collection_name, &keys.snapshot_mac);
97 trace!(
98 target: "AppState",
99 "Snapshot {} v{} MAC validation: computed={}, expected={}",
100 collection_name,
101 version,
102 hex::encode(&computed),
103 hex::encode(mac_expected)
104 );
105 if computed != *mac_expected {
106 return Err(AppStateError::SnapshotMACMismatch);
107 }
108 }
109
110 let mut mutations = Vec::with_capacity(snapshot.records.len());
112 let mut mutation_macs = Vec::with_capacity(snapshot.records.len());
113
114 for rec in &snapshot.records {
115 let key_id = rec
116 .key_id
117 .as_ref()
118 .and_then(|k| k.id.as_ref())
119 .ok_or(AppStateError::MissingKeyId)?;
120 let keys = get_keys(key_id)?;
121
122 let mutation = decode_record(
123 wa::syncd_mutation::SyncdOperation::Set,
124 rec,
125 &keys,
126 key_id,
127 validate_macs,
128 )?;
129
130 mutation_macs.push(AppStateMutationMAC {
131 index_mac: mutation.index_mac.clone(),
132 value_mac: mutation.value_mac.clone(),
133 });
134
135 mutations.push(mutation);
136 }
137
138 Ok(ProcessedSnapshot {
139 state: initial_state.clone(),
140 mutations,
141 mutation_macs,
142 })
143}
144
145pub fn process_patch<F, G>(
161 patch: &wa::SyncdPatch,
162 state: &mut HashState,
163 mut get_keys: F,
164 mut get_prev_value_mac: G,
165 validate_macs: bool,
166 collection_name: &str,
167) -> Result<PatchProcessingResult, AppStateError>
168where
169 F: FnMut(&[u8]) -> Result<ExpandedAppStateKeys, AppStateError>,
170 G: FnMut(&[u8]) -> Result<Option<Vec<u8>>, AppStateError>,
171{
172 let original_version = state.version;
177 let original_hash_is_empty = state.hash == [0u8; 128];
178 let had_no_prior_state = original_version == 0 && original_hash_is_empty;
179
180 state.version = patch.version.as_ref().and_then(|v| v.version).unwrap_or(0);
181
182 let (hash_update_result, result) = state.update_hash(&patch.mutations, |index_mac, idx| {
184 for prev in patch.mutations[..idx].iter().rev() {
186 if let Some(rec) = &prev.record
187 && let Some(ind) = &rec.index
188 && let Some(b) = &ind.blob
189 && b == index_mac
190 && let Some(val) = &rec.value
191 && let Some(vb) = &val.blob
192 && vb.len() >= 32
193 {
194 return Ok(Some(vb[vb.len() - 32..].to_vec()));
195 }
196 }
197 get_prev_value_mac(index_mac).map_err(|e| anyhow::anyhow!(e))
199 });
200 result.map_err(|_| AppStateError::MismatchingLTHash)?;
201
202 debug!(
203 target: "AppState",
204 "Patch {} v{}: {} mutations, ltHash ends with ...{}, hasMissingRemove={}",
205 collection_name,
206 state.version,
207 patch.mutations.len(),
208 hex::encode(&state.hash[120..]),
209 hash_update_result.has_missing_remove
210 );
211
212 if validate_macs && let Some(key_id) = patch.key_id.as_ref().and_then(|k| k.id.as_ref()) {
214 let keys = get_keys(key_id)?;
215 validate_patch_macs(
216 patch,
217 state,
218 &keys,
219 collection_name,
220 had_no_prior_state,
221 hash_update_result.has_missing_remove,
222 )?;
223 }
224
225 let mut mutations = Vec::with_capacity(patch.mutations.len());
227 let mut added_macs = Vec::new();
228 let mut removed_index_macs = Vec::new();
229
230 for m in &patch.mutations {
231 if let Some(rec) = &m.record {
232 let op = wa::syncd_mutation::SyncdOperation::try_from(m.operation.unwrap_or(0))
233 .unwrap_or(wa::syncd_mutation::SyncdOperation::Set);
234
235 let key_id = rec
236 .key_id
237 .as_ref()
238 .and_then(|k| k.id.as_ref())
239 .ok_or(AppStateError::MissingKeyId)?;
240 let keys = get_keys(key_id)?;
241
242 let mutation = decode_record(op, rec, &keys, key_id, validate_macs)?;
243
244 match op {
245 wa::syncd_mutation::SyncdOperation::Set => {
246 added_macs.push(AppStateMutationMAC {
247 index_mac: mutation.index_mac.clone(),
248 value_mac: mutation.value_mac.clone(),
249 });
250 }
251 wa::syncd_mutation::SyncdOperation::Remove => {
252 removed_index_macs.push(mutation.index_mac.clone());
253 }
254 }
255
256 mutations.push(mutation);
257 }
258 }
259
260 Ok(PatchProcessingResult {
261 state: state.clone(),
262 mutations,
263 added_macs,
264 removed_index_macs,
265 })
266}
267
268pub fn validate_patch_macs(
286 patch: &wa::SyncdPatch,
287 state: &HashState,
288 keys: &ExpandedAppStateKeys,
289 collection_name: &str,
290 had_no_prior_state: bool,
291 has_missing_remove: bool,
292) -> Result<(), AppStateError> {
293 if had_no_prior_state {
300 return Ok(());
301 }
302
303 if let Some(snap_mac) = patch.snapshot_mac.as_ref() {
304 let computed_snap = state.generate_snapshot_mac(collection_name, &keys.snapshot_mac);
305 trace!(
306 target: "AppState",
307 "Patch {} v{} snapshotMAC: computed={}, expected={}",
308 collection_name,
309 state.version,
310 hex::encode(&computed_snap),
311 hex::encode(snap_mac)
312 );
313 if computed_snap != *snap_mac {
314 if has_missing_remove {
317 log::warn!(
318 target: "AppState",
319 "Patch {} v{} snapshotMAC mismatch (expected due to hasMissingRemove=true), continuing",
320 collection_name,
321 state.version
322 );
323 } else {
325 debug!(
326 target: "AppState",
327 "Patch {} v{} snapshotMAC MISMATCH! ltHash=...{}",
328 collection_name,
329 state.version,
330 hex::encode(&state.hash[120..])
331 );
332 return Err(AppStateError::PatchSnapshotMACMismatch);
333 }
334 }
335 }
336
337 if let Some(patch_mac) = patch.patch_mac.as_ref() {
338 let version = patch.version.as_ref().and_then(|v| v.version).unwrap_or(0);
339 let computed_patch = generate_patch_mac(patch, collection_name, &keys.patch_mac, version);
340 if computed_patch != *patch_mac {
341 if has_missing_remove {
343 log::warn!(
344 target: "AppState",
345 "Patch {} v{} patchMAC mismatch (expected due to hasMissingRemove=true), continuing",
346 collection_name,
347 state.version
348 );
349 } else {
350 return Err(AppStateError::PatchMACMismatch);
351 }
352 }
353 }
354
355 Ok(())
356}
357
358pub fn validate_snapshot_mac(
362 snapshot: &wa::SyncdSnapshot,
363 state: &HashState,
364 keys: &ExpandedAppStateKeys,
365 collection_name: &str,
366) -> Result<(), AppStateError> {
367 if let Some(mac_expected) = snapshot.mac.as_ref() {
368 let computed = state.generate_snapshot_mac(collection_name, &keys.snapshot_mac);
369 if computed != *mac_expected {
370 return Err(AppStateError::SnapshotMACMismatch);
371 }
372 }
373 Ok(())
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::hash::generate_content_mac;
380 use crate::keys::expand_app_state_keys;
381 use crate::lthash::WAPATCH_INTEGRITY;
382 use prost::Message;
383 use wacore_libsignal::crypto::aes_256_cbc_encrypt_into;
384
385 fn create_encrypted_record(
386 op: wa::syncd_mutation::SyncdOperation,
387 index_mac: &[u8],
388 keys: &ExpandedAppStateKeys,
389 key_id: &[u8],
390 timestamp: i64,
391 ) -> wa::SyncdRecord {
392 let action_data = wa::SyncActionData {
393 value: Some(wa::SyncActionValue {
394 timestamp: Some(timestamp),
395 ..Default::default()
396 }),
397 ..Default::default()
398 };
399 let plaintext = action_data.encode_to_vec();
400
401 let iv = vec![0u8; 16];
402 let mut ciphertext = Vec::new();
403 aes_256_cbc_encrypt_into(&plaintext, &keys.value_encryption, &iv, &mut ciphertext)
404 .expect("test data should be valid");
405
406 let mut value_with_iv = iv;
407 value_with_iv.extend_from_slice(&ciphertext);
408 let value_mac = generate_content_mac(op, &value_with_iv, key_id, &keys.value_mac);
409 let mut value_blob = value_with_iv;
410 value_blob.extend_from_slice(&value_mac);
411
412 wa::SyncdRecord {
413 index: Some(wa::SyncdIndex {
414 blob: Some(index_mac.to_vec()),
415 }),
416 value: Some(wa::SyncdValue {
417 blob: Some(value_blob),
418 }),
419 key_id: Some(wa::KeyId {
420 id: Some(key_id.to_vec()),
421 }),
422 }
423 }
424
425 #[test]
426 fn test_process_snapshot_basic() {
427 let master_key = [7u8; 32];
428 let keys = expand_app_state_keys(&master_key);
429 let key_id = b"test_key_id".to_vec();
430 let index_mac = vec![1; 32];
431
432 let record = create_encrypted_record(
433 wa::syncd_mutation::SyncdOperation::Set,
434 &index_mac,
435 &keys,
436 &key_id,
437 1234567890,
438 );
439
440 let snapshot = wa::SyncdSnapshot {
441 version: Some(wa::SyncdVersion { version: Some(1) }),
442 records: vec![record],
443 key_id: Some(wa::KeyId {
444 id: Some(key_id.clone()),
445 }),
446 ..Default::default()
447 };
448
449 let get_keys = |_: &[u8]| Ok(keys.clone());
450
451 let mut state = HashState::default();
452 let result = process_snapshot(&snapshot, &mut state, get_keys, false, "regular")
453 .expect("test data should be valid");
454
455 assert_eq!(result.state.version, 1);
456 assert_eq!(result.mutations.len(), 1);
457 assert_eq!(result.mutation_macs.len(), 1);
458 assert_eq!(
459 result.mutations[0]
460 .action_value
461 .as_ref()
462 .and_then(|v| v.timestamp),
463 Some(1234567890)
464 );
465 }
466
467 #[test]
468 fn test_process_patch_basic() {
469 let master_key = [7u8; 32];
470 let keys = expand_app_state_keys(&master_key);
471 let key_id = b"test_key_id".to_vec();
472 let index_mac = vec![1; 32];
473
474 let record = create_encrypted_record(
475 wa::syncd_mutation::SyncdOperation::Set,
476 &index_mac,
477 &keys,
478 &key_id,
479 1234567890,
480 );
481
482 let patch = wa::SyncdPatch {
483 version: Some(wa::SyncdVersion { version: Some(2) }),
484 mutations: vec![wa::SyncdMutation {
485 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
486 record: Some(record),
487 }],
488 key_id: Some(wa::KeyId {
489 id: Some(key_id.clone()),
490 }),
491 ..Default::default()
492 };
493
494 let get_keys = |_: &[u8]| Ok(keys.clone());
495 let get_prev = |_: &[u8]| Ok(None);
496
497 let mut state = HashState::default();
498 let result = process_patch(&patch, &mut state, get_keys, get_prev, false, "regular")
499 .expect("test data should be valid");
500
501 assert_eq!(result.state.version, 2);
502 assert_eq!(result.mutations.len(), 1);
503 assert_eq!(result.added_macs.len(), 1);
504 assert!(result.removed_index_macs.is_empty());
505 }
506
507 #[test]
508 fn test_process_patch_with_overwrite() {
509 let master_key = [7u8; 32];
510 let keys = expand_app_state_keys(&master_key);
511 let key_id = b"test_key_id".to_vec();
512 let index_mac = vec![1; 32];
513
514 let initial_record = create_encrypted_record(
516 wa::syncd_mutation::SyncdOperation::Set,
517 &index_mac,
518 &keys,
519 &key_id,
520 1000,
521 );
522 let initial_value_blob = initial_record
523 .value
524 .as_ref()
525 .expect("test data should be valid")
526 .blob
527 .as_ref()
528 .expect("test data should be valid");
529 let initial_value_mac = initial_value_blob[initial_value_blob.len() - 32..].to_vec();
530
531 let snapshot = wa::SyncdSnapshot {
533 version: Some(wa::SyncdVersion { version: Some(1) }),
534 records: vec![initial_record],
535 key_id: Some(wa::KeyId {
536 id: Some(key_id.clone()),
537 }),
538 ..Default::default()
539 };
540
541 let get_keys = |_: &[u8]| Ok(keys.clone());
542 let mut snapshot_state = HashState::default();
543 let snapshot_result =
544 process_snapshot(&snapshot, &mut snapshot_state, get_keys, false, "regular")
545 .expect("test data should be valid");
546
547 let overwrite_record = create_encrypted_record(
549 wa::syncd_mutation::SyncdOperation::Set,
550 &index_mac,
551 &keys,
552 &key_id,
553 2000,
554 );
555
556 let patch = wa::SyncdPatch {
557 version: Some(wa::SyncdVersion { version: Some(2) }),
558 mutations: vec![wa::SyncdMutation {
559 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
560 record: Some(overwrite_record.clone()),
561 }],
562 key_id: Some(wa::KeyId {
563 id: Some(key_id.clone()),
564 }),
565 ..Default::default()
566 };
567
568 let get_keys = |_: &[u8]| Ok(keys.clone());
569 let get_prev = |idx: &[u8]| {
571 if idx == index_mac.as_slice() {
572 Ok(Some(initial_value_mac.clone()))
573 } else {
574 Ok(None)
575 }
576 };
577
578 let mut patch_state = snapshot_result.state.clone();
579 let result = process_patch(
580 &patch,
581 &mut patch_state,
582 get_keys,
583 get_prev,
584 false,
585 "regular",
586 )
587 .expect("test data should be valid");
588
589 assert_eq!(result.state.version, 2);
590 assert_eq!(result.mutations.len(), 1);
591 assert_eq!(
592 result.mutations[0]
593 .action_value
594 .as_ref()
595 .and_then(|v| v.timestamp),
596 Some(2000)
597 );
598
599 let new_value_blob = overwrite_record
601 .value
602 .expect("test data should be valid")
603 .blob
604 .expect("test data should be valid");
605 let new_value_mac = new_value_blob[new_value_blob.len() - 32..].to_vec();
606
607 let expected_hash = WAPATCH_INTEGRITY.subtract_then_add(
608 &snapshot_result.state.hash,
609 &[initial_value_mac],
610 &[new_value_mac],
611 );
612
613 assert_eq!(result.state.hash.as_slice(), expected_hash.as_slice());
614 }
615}