1use crate::AppStateError;
8use crate::decode::{Mutation, decode_record};
9use crate::hash::{HashState, generate_patch_mac};
10use crate::keys::ExpandedAppStateKeys;
11use log::{debug, trace};
12use serde::{Deserialize, Serialize};
13use waproto::whatsapp as wa;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AppStateMutationMAC {
17 pub index_mac: Vec<u8>,
18 pub value_mac: Vec<u8>,
19}
20
21#[derive(Debug, Clone)]
23pub struct ProcessedSnapshot {
24 pub state: HashState,
26 pub mutations: Vec<Mutation>,
28 pub mutation_macs: Vec<AppStateMutationMAC>,
30}
31
32#[derive(Debug, Clone)]
34pub struct PatchProcessingResult {
35 pub state: HashState,
37 pub mutations: Vec<Mutation>,
39 pub added_macs: Vec<AppStateMutationMAC>,
41 pub removed_index_macs: Vec<Vec<u8>>,
43}
44
45pub fn process_snapshot<F>(
60 snapshot: &wa::SyncdSnapshot,
61 initial_state: &mut HashState,
62 mut get_keys: F,
63 validate_macs: bool,
64 collection_name: &str,
65) -> Result<ProcessedSnapshot, AppStateError>
66where
67 F: FnMut(&[u8]) -> Result<ExpandedAppStateKeys, AppStateError>,
68{
69 let version = snapshot
70 .version
71 .as_ref()
72 .and_then(|v| v.version)
73 .unwrap_or(0);
74 initial_state.version = version;
75
76 initial_state.update_hash_from_records(&snapshot.records);
78
79 debug!(
80 target: "AppState",
81 "Snapshot {} v{}: {} records, ltHash ends with ...{}",
82 collection_name,
83 version,
84 snapshot.records.len(),
85 hex::encode(&initial_state.hash[120..])
86 );
87
88 if validate_macs
90 && let (Some(mac_expected), Some(key_id)) = (
91 snapshot.mac.as_ref(),
92 snapshot.key_id.as_ref().and_then(|k| k.id.as_ref()),
93 )
94 {
95 let keys = get_keys(key_id)?;
96 let computed = initial_state.generate_snapshot_mac(collection_name, &keys.snapshot_mac);
97 trace!(
98 target: "AppState",
99 "Snapshot {} v{} MAC validation: computed={}, expected={}",
100 collection_name,
101 version,
102 hex::encode(&computed),
103 hex::encode(mac_expected)
104 );
105 if computed != *mac_expected {
106 return Err(AppStateError::SnapshotMACMismatch);
107 }
108 }
109
110 let mut mutations = Vec::with_capacity(snapshot.records.len());
112 let mut mutation_macs = Vec::with_capacity(snapshot.records.len());
113
114 for rec in &snapshot.records {
115 let key_id = rec
116 .key_id
117 .as_ref()
118 .and_then(|k| k.id.as_ref())
119 .ok_or(AppStateError::MissingKeyId)?;
120 let keys = get_keys(key_id)?;
121
122 let mutation = decode_record(
123 wa::syncd_mutation::SyncdOperation::Set,
124 rec,
125 &keys,
126 key_id,
127 validate_macs,
128 )?;
129
130 mutation_macs.push(AppStateMutationMAC {
131 index_mac: mutation.index_mac.clone(),
132 value_mac: mutation.value_mac.clone(),
133 });
134
135 mutations.push(mutation);
136 }
137
138 Ok(ProcessedSnapshot {
139 state: initial_state.clone(),
140 mutations,
141 mutation_macs,
142 })
143}
144
145pub fn process_patch<F, G>(
161 patch: &wa::SyncdPatch,
162 state: &mut HashState,
163 mut get_keys: F,
164 mut get_prev_value_mac: G,
165 validate_macs: bool,
166 collection_name: &str,
167) -> Result<PatchProcessingResult, AppStateError>
168where
169 F: FnMut(&[u8]) -> Result<ExpandedAppStateKeys, AppStateError>,
170 G: FnMut(&[u8]) -> Result<Option<Vec<u8>>, AppStateError>,
171{
172 let original_version = state.version;
177 let original_hash_is_empty = state.hash == [0u8; 128];
178 let had_no_prior_state = original_version == 0 && original_hash_is_empty;
179
180 let patch_version = patch.version.as_ref().and_then(|v| v.version).unwrap_or(0);
181
182 let expected_version = original_version.saturating_add(1);
188 if !had_no_prior_state && patch_version != expected_version {
189 return Err(AppStateError::PatchVersionMismatch {
190 expected: expected_version,
191 got: patch_version,
192 });
193 }
194
195 state.version = patch_version;
196
197 let (hash_update_result, result) = state.update_hash(&patch.mutations, |index_mac, idx| {
199 for prev in patch.mutations[..idx].iter().rev() {
201 if let Some(rec) = &prev.record
202 && let Some(ind) = &rec.index
203 && let Some(b) = &ind.blob
204 && b == index_mac
205 && let Some(val) = &rec.value
206 && let Some(vb) = &val.blob
207 && vb.len() >= 32
208 {
209 return Ok(Some(vb[vb.len() - 32..].to_vec()));
210 }
211 }
212 get_prev_value_mac(index_mac).map_err(|e| anyhow::anyhow!(e))
214 });
215 result.map_err(|_| AppStateError::MismatchingLTHash)?;
216
217 debug!(
218 target: "AppState",
219 "Patch {} v{}: {} mutations, ltHash ends with ...{}, hasMissingRemove={}",
220 collection_name,
221 state.version,
222 patch.mutations.len(),
223 hex::encode(&state.hash[120..]),
224 hash_update_result.has_missing_remove
225 );
226
227 if validate_macs && let Some(key_id) = patch.key_id.as_ref().and_then(|k| k.id.as_ref()) {
229 let keys = get_keys(key_id)?;
230 validate_patch_macs(
231 patch,
232 state,
233 &keys,
234 collection_name,
235 had_no_prior_state,
236 hash_update_result.has_missing_remove,
237 )?;
238 }
239
240 let mut mutations = Vec::with_capacity(patch.mutations.len());
242 let mut added_macs = Vec::with_capacity(patch.mutations.len());
243 let mut removed_index_macs = Vec::with_capacity(patch.mutations.len());
244
245 for m in &patch.mutations {
246 if let Some(rec) = &m.record {
247 let op = wa::syncd_mutation::SyncdOperation::try_from(m.operation.unwrap_or(0))
248 .unwrap_or(wa::syncd_mutation::SyncdOperation::Set);
249
250 let key_id = rec
251 .key_id
252 .as_ref()
253 .and_then(|k| k.id.as_ref())
254 .ok_or(AppStateError::MissingKeyId)?;
255 let keys = get_keys(key_id)?;
256
257 let mutation = decode_record(op, rec, &keys, key_id, validate_macs)?;
258
259 match op {
260 wa::syncd_mutation::SyncdOperation::Set => {
261 added_macs.push(AppStateMutationMAC {
262 index_mac: mutation.index_mac.clone(),
263 value_mac: mutation.value_mac.clone(),
264 });
265 }
266 wa::syncd_mutation::SyncdOperation::Remove => {
267 removed_index_macs.push(mutation.index_mac.clone());
268 }
269 }
270
271 mutations.push(mutation);
272 }
273 }
274
275 Ok(PatchProcessingResult {
276 state: state.clone(),
277 mutations,
278 added_macs,
279 removed_index_macs,
280 })
281}
282
283pub fn validate_patch_macs(
301 patch: &wa::SyncdPatch,
302 state: &HashState,
303 keys: &ExpandedAppStateKeys,
304 collection_name: &str,
305 had_no_prior_state: bool,
306 has_missing_remove: bool,
307) -> Result<(), AppStateError> {
308 if had_no_prior_state {
315 return Ok(());
316 }
317
318 if let Some(snap_mac) = patch.snapshot_mac.as_ref() {
319 let computed_snap = state.generate_snapshot_mac(collection_name, &keys.snapshot_mac);
320 trace!(
321 target: "AppState",
322 "Patch {} v{} snapshotMAC: computed={}, expected={}",
323 collection_name,
324 state.version,
325 hex::encode(&computed_snap),
326 hex::encode(snap_mac)
327 );
328 if computed_snap != *snap_mac {
329 if has_missing_remove {
332 log::warn!(
333 target: "AppState",
334 "Patch {} v{} snapshotMAC mismatch (expected due to hasMissingRemove=true), continuing",
335 collection_name,
336 state.version
337 );
338 } else {
340 debug!(
341 target: "AppState",
342 "Patch {} v{} snapshotMAC MISMATCH! ltHash=...{}",
343 collection_name,
344 state.version,
345 hex::encode(&state.hash[120..])
346 );
347 return Err(AppStateError::PatchSnapshotMACMismatch);
348 }
349 }
350 }
351
352 if let Some(patch_mac) = patch.patch_mac.as_ref() {
353 let version = patch.version.as_ref().and_then(|v| v.version).unwrap_or(0);
354 let computed_patch = generate_patch_mac(patch, collection_name, &keys.patch_mac, version);
355 if computed_patch != *patch_mac {
356 if has_missing_remove {
358 log::warn!(
359 target: "AppState",
360 "Patch {} v{} patchMAC mismatch (expected due to hasMissingRemove=true), continuing",
361 collection_name,
362 state.version
363 );
364 } else {
365 return Err(AppStateError::PatchMACMismatch);
366 }
367 }
368 }
369
370 Ok(())
371}
372
373pub fn validate_snapshot_mac(
377 snapshot: &wa::SyncdSnapshot,
378 state: &HashState,
379 keys: &ExpandedAppStateKeys,
380 collection_name: &str,
381) -> Result<(), AppStateError> {
382 if let Some(mac_expected) = snapshot.mac.as_ref() {
383 let computed = state.generate_snapshot_mac(collection_name, &keys.snapshot_mac);
384 if computed != *mac_expected {
385 return Err(AppStateError::SnapshotMACMismatch);
386 }
387 }
388 Ok(())
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::hash::generate_content_mac;
395 use crate::keys::expand_app_state_keys;
396 use crate::lthash::WAPATCH_INTEGRITY;
397 use prost::Message;
398 use wacore_libsignal::crypto::aes_256_cbc_encrypt_into;
399
400 fn create_encrypted_record(
401 op: wa::syncd_mutation::SyncdOperation,
402 index_mac: &[u8],
403 keys: &ExpandedAppStateKeys,
404 key_id: &[u8],
405 timestamp: i64,
406 ) -> wa::SyncdRecord {
407 let action_data = wa::SyncActionData {
408 value: Some(wa::SyncActionValue {
409 timestamp: Some(timestamp),
410 ..Default::default()
411 }),
412 ..Default::default()
413 };
414 let plaintext = action_data.encode_to_vec();
415
416 let iv = vec![0u8; 16];
417 let mut ciphertext = Vec::new();
418 aes_256_cbc_encrypt_into(&plaintext, &keys.value_encryption, &iv, &mut ciphertext)
419 .expect("test data should be valid");
420
421 let mut value_with_iv = iv;
422 value_with_iv.extend_from_slice(&ciphertext);
423 let value_mac = generate_content_mac(op, &value_with_iv, key_id, &keys.value_mac);
424 let mut value_blob = value_with_iv;
425 value_blob.extend_from_slice(&value_mac);
426
427 wa::SyncdRecord {
428 index: Some(wa::SyncdIndex {
429 blob: Some(index_mac.to_vec()),
430 }),
431 value: Some(wa::SyncdValue {
432 blob: Some(value_blob),
433 }),
434 key_id: Some(wa::KeyId {
435 id: Some(key_id.to_vec()),
436 }),
437 }
438 }
439
440 #[test]
441 fn test_process_snapshot_basic() {
442 let master_key = [7u8; 32];
443 let keys = expand_app_state_keys(&master_key);
444 let key_id = b"test_key_id".to_vec();
445 let index_mac = vec![1; 32];
446
447 let record = create_encrypted_record(
448 wa::syncd_mutation::SyncdOperation::Set,
449 &index_mac,
450 &keys,
451 &key_id,
452 1234567890,
453 );
454
455 let snapshot = wa::SyncdSnapshot {
456 version: Some(wa::SyncdVersion { version: Some(1) }),
457 records: vec![record],
458 key_id: Some(wa::KeyId {
459 id: Some(key_id.clone()),
460 }),
461 ..Default::default()
462 };
463
464 let get_keys = |_: &[u8]| Ok(keys.clone());
465
466 let mut state = HashState::default();
467 let result = process_snapshot(&snapshot, &mut state, get_keys, false, "regular")
468 .expect("test data should be valid");
469
470 assert_eq!(result.state.version, 1);
471 assert_eq!(result.mutations.len(), 1);
472 assert_eq!(result.mutation_macs.len(), 1);
473 assert_eq!(
474 result.mutations[0]
475 .action_value
476 .as_ref()
477 .and_then(|v| v.timestamp),
478 Some(1234567890)
479 );
480 }
481
482 #[test]
483 fn test_process_patch_basic() {
484 let master_key = [7u8; 32];
485 let keys = expand_app_state_keys(&master_key);
486 let key_id = b"test_key_id".to_vec();
487 let index_mac = vec![1; 32];
488
489 let record = create_encrypted_record(
490 wa::syncd_mutation::SyncdOperation::Set,
491 &index_mac,
492 &keys,
493 &key_id,
494 1234567890,
495 );
496
497 let patch = wa::SyncdPatch {
498 version: Some(wa::SyncdVersion { version: Some(2) }),
499 mutations: vec![wa::SyncdMutation {
500 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
501 record: Some(record),
502 }],
503 key_id: Some(wa::KeyId {
504 id: Some(key_id.clone()),
505 }),
506 ..Default::default()
507 };
508
509 let get_keys = |_: &[u8]| Ok(keys.clone());
510 let get_prev = |_: &[u8]| Ok(None);
511
512 let mut state = HashState::default();
513 let result = process_patch(&patch, &mut state, get_keys, get_prev, false, "regular")
514 .expect("test data should be valid");
515
516 assert_eq!(result.state.version, 2);
517 assert_eq!(result.mutations.len(), 1);
518 assert_eq!(result.added_macs.len(), 1);
519 assert!(result.removed_index_macs.is_empty());
520 }
521
522 #[test]
523 fn test_process_patch_with_overwrite() {
524 let master_key = [7u8; 32];
525 let keys = expand_app_state_keys(&master_key);
526 let key_id = b"test_key_id".to_vec();
527 let index_mac = vec![1; 32];
528
529 let initial_record = create_encrypted_record(
531 wa::syncd_mutation::SyncdOperation::Set,
532 &index_mac,
533 &keys,
534 &key_id,
535 1000,
536 );
537 let initial_value_blob = initial_record
538 .value
539 .as_ref()
540 .expect("test data should be valid")
541 .blob
542 .as_ref()
543 .expect("test data should be valid");
544 let initial_value_mac = initial_value_blob[initial_value_blob.len() - 32..].to_vec();
545
546 let snapshot = wa::SyncdSnapshot {
548 version: Some(wa::SyncdVersion { version: Some(1) }),
549 records: vec![initial_record],
550 key_id: Some(wa::KeyId {
551 id: Some(key_id.clone()),
552 }),
553 ..Default::default()
554 };
555
556 let get_keys = |_: &[u8]| Ok(keys.clone());
557 let mut snapshot_state = HashState::default();
558 let snapshot_result =
559 process_snapshot(&snapshot, &mut snapshot_state, get_keys, false, "regular")
560 .expect("test data should be valid");
561
562 let overwrite_record = create_encrypted_record(
564 wa::syncd_mutation::SyncdOperation::Set,
565 &index_mac,
566 &keys,
567 &key_id,
568 2000,
569 );
570
571 let patch = wa::SyncdPatch {
572 version: Some(wa::SyncdVersion { version: Some(2) }),
573 mutations: vec![wa::SyncdMutation {
574 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
575 record: Some(overwrite_record.clone()),
576 }],
577 key_id: Some(wa::KeyId {
578 id: Some(key_id.clone()),
579 }),
580 ..Default::default()
581 };
582
583 let get_keys = |_: &[u8]| Ok(keys.clone());
584 let get_prev = |idx: &[u8]| {
586 if idx == index_mac.as_slice() {
587 Ok(Some(initial_value_mac.clone()))
588 } else {
589 Ok(None)
590 }
591 };
592
593 let mut patch_state = snapshot_result.state.clone();
594 let result = process_patch(
595 &patch,
596 &mut patch_state,
597 get_keys,
598 get_prev,
599 false,
600 "regular",
601 )
602 .expect("test data should be valid");
603
604 assert_eq!(result.state.version, 2);
605 assert_eq!(result.mutations.len(), 1);
606 assert_eq!(
607 result.mutations[0]
608 .action_value
609 .as_ref()
610 .and_then(|v| v.timestamp),
611 Some(2000)
612 );
613
614 let new_value_blob = overwrite_record
616 .value
617 .expect("test data should be valid")
618 .blob
619 .expect("test data should be valid");
620 let new_value_mac = new_value_blob[new_value_blob.len() - 32..].to_vec();
621
622 let expected_hash = WAPATCH_INTEGRITY.subtract_then_add(
623 &snapshot_result.state.hash,
624 &[initial_value_mac],
625 &[new_value_mac],
626 );
627
628 assert_eq!(result.state.hash.as_slice(), expected_hash.as_slice());
629 }
630
631 #[test]
635 fn test_patch_version_rollback_rejected() {
636 let master_key = [7u8; 32];
637 let keys = expand_app_state_keys(&master_key);
638 let key_id = b"test_key_id".to_vec();
639 let index_mac = vec![99; 32];
640
641 let record = create_encrypted_record(
642 wa::syncd_mutation::SyncdOperation::Set,
643 &index_mac,
644 &keys,
645 &key_id,
646 5000,
647 );
648
649 let mut state = HashState {
651 version: 5,
652 ..Default::default()
653 };
654
655 let patch = wa::SyncdPatch {
657 version: Some(wa::SyncdVersion { version: Some(3) }),
658 mutations: vec![wa::SyncdMutation {
659 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
660 record: Some(record),
661 }],
662 key_id: Some(wa::KeyId {
663 id: Some(key_id.clone()),
664 }),
665 ..Default::default()
666 };
667
668 let get_keys = |_: &[u8]| Ok(keys.clone());
669 let get_prev = |_: &[u8]| -> Result<Option<Vec<u8>>, AppStateError> { Ok(None) };
670
671 let err = process_patch(&patch, &mut state, get_keys, get_prev, false, "regular")
672 .expect_err("rollback patch should be rejected");
673
674 assert!(
675 matches!(
676 err,
677 AppStateError::PatchVersionMismatch {
678 expected: 6,
679 got: 3
680 }
681 ),
682 "expected PatchVersionMismatch {{ expected: 6, got: 3 }}, got: {err:?}"
683 );
684 }
685
686 #[test]
689 fn test_patch_version_gap_rejected() {
690 let master_key = [7u8; 32];
691 let keys = expand_app_state_keys(&master_key);
692 let key_id = b"test_key_id".to_vec();
693 let index_mac = vec![99; 32];
694
695 let record = create_encrypted_record(
696 wa::syncd_mutation::SyncdOperation::Set,
697 &index_mac,
698 &keys,
699 &key_id,
700 6000,
701 );
702
703 let mut state = HashState {
705 version: 5,
706 ..Default::default()
707 };
708
709 let patch = wa::SyncdPatch {
711 version: Some(wa::SyncdVersion { version: Some(8) }),
712 mutations: vec![wa::SyncdMutation {
713 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
714 record: Some(record),
715 }],
716 key_id: Some(wa::KeyId {
717 id: Some(key_id.clone()),
718 }),
719 ..Default::default()
720 };
721
722 let get_keys = |_: &[u8]| Ok(keys.clone());
723 let get_prev = |_: &[u8]| -> Result<Option<Vec<u8>>, AppStateError> { Ok(None) };
724
725 let err = process_patch(&patch, &mut state, get_keys, get_prev, false, "regular")
726 .expect_err("version gap should be rejected");
727
728 assert!(
729 matches!(
730 err,
731 AppStateError::PatchVersionMismatch {
732 expected: 6,
733 got: 8
734 }
735 ),
736 "expected PatchVersionMismatch {{ expected: 6, got: 8 }}, got: {err:?}"
737 );
738 }
739
740 #[test]
742 fn test_patch_version_consecutive_accepted() {
743 let master_key = [7u8; 32];
744 let keys = expand_app_state_keys(&master_key);
745 let key_id = b"test_key_id".to_vec();
746 let index_mac = vec![99; 32];
747
748 let record = create_encrypted_record(
749 wa::syncd_mutation::SyncdOperation::Set,
750 &index_mac,
751 &keys,
752 &key_id,
753 7000,
754 );
755
756 let mut state = HashState {
758 version: 5,
759 ..Default::default()
760 };
761
762 let patch = wa::SyncdPatch {
764 version: Some(wa::SyncdVersion { version: Some(6) }),
765 mutations: vec![wa::SyncdMutation {
766 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
767 record: Some(record),
768 }],
769 key_id: Some(wa::KeyId {
770 id: Some(key_id.clone()),
771 }),
772 ..Default::default()
773 };
774
775 let get_keys = |_: &[u8]| Ok(keys.clone());
776 let get_prev = |_: &[u8]| -> Result<Option<Vec<u8>>, AppStateError> { Ok(None) };
777
778 let result = process_patch(&patch, &mut state, get_keys, get_prev, false, "regular")
779 .expect("consecutive version should be accepted");
780 assert_eq!(result.state.version, 6);
781 }
782
783 #[test]
787 fn test_patch_version_check_skipped_when_no_prior_state() {
788 let master_key = [7u8; 32];
789 let keys = expand_app_state_keys(&master_key);
790 let key_id = b"test_key_id".to_vec();
791 let index_mac = vec![99; 32];
792
793 let record = create_encrypted_record(
794 wa::syncd_mutation::SyncdOperation::Set,
795 &index_mac,
796 &keys,
797 &key_id,
798 8000,
799 );
800
801 let mut state = HashState::default();
803
804 let patch = wa::SyncdPatch {
806 version: Some(wa::SyncdVersion { version: Some(42) }),
807 mutations: vec![wa::SyncdMutation {
808 operation: Some(wa::syncd_mutation::SyncdOperation::Set as i32),
809 record: Some(record),
810 }],
811 key_id: Some(wa::KeyId {
812 id: Some(key_id.clone()),
813 }),
814 ..Default::default()
815 };
816
817 let get_keys = |_: &[u8]| Ok(keys.clone());
818 let get_prev = |_: &[u8]| -> Result<Option<Vec<u8>>, AppStateError> { Ok(None) };
819
820 let result = process_patch(&patch, &mut state, get_keys, get_prev, false, "regular")
821 .expect("no-prior-state should skip version check");
822 assert_eq!(result.state.version, 42);
823 }
824}