Skip to main content

shadowforge_lib/domain/corpus/
mod.rs

1//! Corpus steganography: zero-modification cover selection via ANN search.
2//!
3//! Given a payload to embed and a corpus of images, this module finds
4//! the image whose natural LSB bit pattern most closely matches the
5//! payload — minimising or eliminating modifications needed.
6
7use bytes::Bytes;
8
9use crate::domain::types::CorpusEntry;
10
11/// Compute the Hamming distance between two byte slices of equal length.
12///
13/// Returns `None` if the slices differ in length.
14#[must_use]
15pub fn hamming_distance(a: &[u8], b: &[u8]) -> Option<u64> {
16    if a.len() != b.len() {
17        return None;
18    }
19    let mut dist: u64 = 0;
20    for (x, y) in a.iter().zip(b.iter()) {
21        dist = dist.strict_add(u64::from((x ^ y).count_ones()));
22    }
23    Some(dist)
24}
25
26/// Extract the LSB bit pattern from raw pixel bytes.
27///
28/// For each byte of pixel data, extracts the least significant bit and packs
29/// 8 bits into one output byte. The result length is `ceil(pixel_bytes.len() / 8)`.
30#[must_use]
31pub fn extract_lsb_pattern(pixel_bytes: &[u8]) -> Bytes {
32    let out_len = pixel_bytes.len().div_ceil(8);
33    let mut pattern = vec![0u8; out_len];
34
35    for (i, &byte) in pixel_bytes.iter().enumerate() {
36        let out_byte_idx = i / 8;
37        let bit_idx = 7 - (i % 8);
38        if byte & 1 == 1 {
39            // out_byte_idx = i/8 <= (pixel_bytes.len()-1)/8 < out_len
40            #[expect(
41                clippy::indexing_slicing,
42                reason = "out_byte_idx = i/8 < ceil(len/8) = out_len"
43            )]
44            {
45                pattern[out_byte_idx] |= 1 << bit_idx;
46            }
47        }
48    }
49
50    Bytes::from(pattern)
51}
52
53/// Expand a payload into a bit pattern of the same format as
54/// [`extract_lsb_pattern`] output (one bit per sample, packed into bytes).
55///
56/// This effectively returns the raw bytes padded to a target bit-count
57/// boundary. If `target_bits` is `None`, returns the payload bytes as-is.
58#[must_use]
59pub fn payload_to_bit_pattern(payload: &[u8], target_bits: Option<usize>) -> Bytes {
60    target_bits.map_or_else(
61        || Bytes::copy_from_slice(payload),
62        |target| {
63            let needed_bytes = target.div_ceil(8);
64            let mut result = Vec::with_capacity(needed_bytes);
65            result.extend_from_slice(payload);
66            result.resize(needed_bytes, 0);
67            Bytes::from(result)
68        },
69    )
70}
71
72/// Score a corpus entry's precomputed bit pattern against a payload pattern.
73///
74/// Returns the Hamming distance — lower is better. Returns `u64::MAX` if the
75/// patterns are incompatible in length.
76#[must_use]
77pub fn score_match(corpus_pattern: &[u8], payload_pattern: &[u8]) -> u64 {
78    // Compare only the overlapping prefix
79    let compare_len = corpus_pattern.len().min(payload_pattern.len());
80    if compare_len == 0 {
81        return u64::MAX;
82    }
83    match (
84        corpus_pattern.get(..compare_len),
85        payload_pattern.get(..compare_len),
86    ) {
87        (Some(a), Some(b)) => hamming_distance(a, b).unwrap_or(u64::MAX),
88        _ => u64::MAX,
89    }
90}
91
92/// Determine if a Hamming distance counts as a "close enough" match.
93///
94/// Threshold: fewer than 5% of total bits differ.
95#[must_use]
96pub const fn is_close_match(distance: u64, total_bits: u64) -> bool {
97    if total_bits == 0 {
98        return false;
99    }
100    // ≤ 5% of bits differ
101    distance.strict_mul(20) <= total_bits
102}
103
104/// Filter corpus entries by model ID and resolution.
105///
106/// Returns references to all entries whose `spectral_key` matches both
107/// `model_id` and `resolution`. Entries without a `spectral_key` are
108/// silently excluded.
109#[must_use]
110pub fn filter_by_model<'a>(
111    entries: &'a [CorpusEntry],
112    model_id: &str,
113    resolution: (u32, u32),
114) -> Vec<&'a CorpusEntry> {
115    entries
116        .iter()
117        .filter(|e| {
118            e.spectral_key
119                .as_ref()
120                .is_some_and(|k| k.model_id == model_id && k.resolution == resolution)
121        })
122        .collect()
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn hamming_distance_identical() {
131        let a = [0xAA, 0x55, 0xFF];
132        let b = [0xAA, 0x55, 0xFF];
133        assert_eq!(hamming_distance(&a, &b), Some(0));
134    }
135
136    #[test]
137    fn hamming_distance_one_bit() {
138        let a = [0b0000_0000];
139        let b = [0b0000_0001];
140        assert_eq!(hamming_distance(&a, &b), Some(1));
141    }
142
143    #[test]
144    fn hamming_distance_all_bits() {
145        let a = [0x00];
146        let b = [0xFF];
147        assert_eq!(hamming_distance(&a, &b), Some(8));
148    }
149
150    #[test]
151    fn hamming_distance_unequal_lengths() {
152        let a = [0x00, 0x00];
153        let b = [0xFF];
154        assert_eq!(hamming_distance(&a, &b), None);
155    }
156
157    #[test]
158    fn extract_lsb_pattern_basic() {
159        // 8 bytes: LSBs = 1,0,1,0,1,0,1,0 = 0xAA
160        let pixels = [0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00];
161        let pattern = extract_lsb_pattern(&pixels);
162        assert_eq!(pattern.as_ref(), &[0xAA]);
163    }
164
165    #[test]
166    fn extract_lsb_pattern_partial_byte() {
167        // 3 bytes: LSBs = 1,1,0 → packed into 0b1100_0000 = 0xC0
168        let pixels = [0x01, 0x03, 0x00];
169        let pattern = extract_lsb_pattern(&pixels);
170        assert_eq!(pattern.len(), 1);
171        assert_eq!(pattern.as_ref(), &[0b1100_0000]);
172    }
173
174    #[test]
175    fn payload_to_bit_pattern_no_target() {
176        let payload = b"hello";
177        let result = payload_to_bit_pattern(payload, None);
178        assert_eq!(result.as_ref(), b"hello");
179    }
180
181    #[test]
182    fn payload_to_bit_pattern_with_padding() {
183        let payload = b"\xFF";
184        let result = payload_to_bit_pattern(payload, Some(16)); // 16 bits = 2 bytes
185        assert_eq!(result.len(), 2);
186        assert_eq!(result.as_ref(), &[0xFF, 0x00]);
187    }
188
189    #[test]
190    fn score_match_identical() {
191        let a = [0xAA, 0x55];
192        let b = [0xAA, 0x55];
193        assert_eq!(score_match(&a, &b), 0);
194    }
195
196    #[test]
197    fn score_match_empty() {
198        let a: [u8; 0] = [];
199        let b: [u8; 0] = [];
200        assert_eq!(score_match(&a, &b), u64::MAX);
201    }
202
203    #[test]
204    fn is_close_match_zero_distance() {
205        assert!(is_close_match(0, 100));
206    }
207
208    #[test]
209    fn is_close_match_exactly_five_percent() {
210        // 5 differing out of 100 total = exactly 5%
211        assert!(is_close_match(5, 100));
212    }
213
214    #[test]
215    fn is_close_match_above_threshold() {
216        // 6 differing out of 100 total = 6% > 5%
217        assert!(!is_close_match(6, 100));
218    }
219
220    #[test]
221    fn is_close_match_zero_total() {
222        assert!(!is_close_match(0, 0));
223    }
224
225    // ─── filter_by_model tests ────────────────────────────────────────────────
226
227    fn make_entry(model_id: Option<&str>, resolution: Option<(u32, u32)>) -> CorpusEntry {
228        use crate::domain::types::{CoverMediaKind, SpectralKey};
229        CorpusEntry {
230            file_hash: [0u8; 32],
231            path: "test.png".to_string(),
232            cover_kind: CoverMediaKind::PngImage,
233            precomputed_bit_pattern: Bytes::new(),
234            spectral_key: model_id.zip(resolution).map(|(id, res)| SpectralKey {
235                model_id: id.to_string(),
236                resolution: res,
237            }),
238        }
239    }
240
241    #[test]
242    fn filter_by_model_returns_matching_entries() {
243        let entries = vec![
244            make_entry(Some("gemini"), Some((1024, 1024))),
245            make_entry(Some("gemini"), Some((512, 512))),
246            make_entry(Some("other"), Some((1024, 1024))),
247            make_entry(None, None),
248        ];
249        let result = filter_by_model(&entries, "gemini", (1024, 1024));
250        assert_eq!(result.len(), 1);
251    }
252
253    #[test]
254    fn filter_by_model_returns_empty_when_no_match() {
255        let entries = vec![
256            make_entry(Some("gemini"), Some((512, 512))),
257            make_entry(None, None),
258        ];
259        let result = filter_by_model(&entries, "gemini", (1024, 1024));
260        assert!(result.is_empty());
261    }
262
263    #[test]
264    fn filter_by_model_excludes_no_key_entries() {
265        let entries = vec![make_entry(None, None), make_entry(None, None)];
266        let result = filter_by_model(&entries, "gemini", (1024, 1024));
267        assert!(result.is_empty());
268    }
269}