Skip to main content

shadowforge_lib/domain/adaptive/
mod.rs

1//! Adversarial embedding optimisation, camera model profile matching,
2//! compression-survivable embedding.
3//!
4//! Pure domain logic — no I/O, no file system, no async runtime.
5
6use rand::RngExt as _;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10use crate::domain::analysis::pair_delta_chi_square_score;
11use crate::domain::ports::CoverProfile;
12
13// ─── BinMask ─────────────────────────────────────────────────────────────────
14
15/// Occupancy mask for 2-D FFT bins, built from a [`CoverProfile`].
16///
17/// A `true` entry at `(row, col)` means that bin is occupied by a known
18/// external signal (e.g. an AI generator's watermark carrier) and **must
19/// not** be used for payload embedding.
20pub struct BinMask {
21    width: u32,
22    height: u32,
23    /// Row-major flat array; index = `row * width + col`.
24    occupied: Vec<bool>,
25}
26
27impl BinMask {
28    /// Build a bin-occupancy mask from `profile` at the given resolution.
29    ///
30    /// - `CoverProfile::Camera` → all-zeros (no protected bins).
31    /// - `CoverProfile::AiGenerator` → marks strong carrier bins
32    ///   (`coherence >= 0.90`).
33    #[must_use]
34    pub fn build(profile: &CoverProfile, width: u32, height: u32) -> Self {
35        let len = (width as usize).strict_mul(height as usize);
36        let mut occupied = vec![false; len];
37
38        if let CoverProfile::AiGenerator(p) = profile
39            && let Some(bins) = p.carrier_bins_for(width, height)
40        {
41            for bin in bins.iter().filter(|b| b.is_strong()) {
42                let (row, col) = bin.freq;
43                if row < height && col < width {
44                    let idx = (row as usize)
45                        .strict_mul(width as usize)
46                        .strict_add(col as usize);
47                    #[expect(
48                        clippy::indexing_slicing,
49                        reason = "idx < len is guaranteed by the row/col range check above"
50                    )]
51                    {
52                        occupied[idx] = true;
53                    }
54                }
55            }
56        }
57
58        Self {
59            width,
60            height,
61            occupied,
62        }
63    }
64
65    /// Return `true` if the bin at `(row, col)` is marked as occupied.
66    #[must_use]
67    pub fn is_occupied(&self, row: u32, col: u32) -> bool {
68        if row >= self.height || col >= self.width {
69            return false;
70        }
71        let idx = (row as usize)
72            .strict_mul(self.width as usize)
73            .strict_add(col as usize);
74        self.occupied.get(idx).copied().unwrap_or(false)
75    }
76
77    /// Return the number of occupied bins.
78    #[must_use]
79    pub fn occupied_count(&self) -> usize {
80        self.occupied.iter().filter(|&&b| b).count()
81    }
82
83    /// Total number of bins in the mask.
84    #[must_use]
85    pub const fn total_bins(&self) -> usize {
86        self.occupied.len()
87    }
88}
89
90// ─── Cost function ───────────────────────────────────────────────────────────
91
92/// Per-bit-position distortion cost for the adaptive permutation search.
93///
94/// Returns `f64::INFINITY` for positions that map to occupied FFT bins.
95/// Returns a value in `1.0..=2.0` for safe bins — higher near
96/// moderate-coherence bins as a soft margin.
97#[must_use]
98pub fn cost_at(bit_position: usize, total_positions: usize, mask: &BinMask) -> f64 {
99    if total_positions == 0 {
100        return f64::INFINITY;
101    }
102
103    let Ok(width) = usize::try_from(mask.width.max(1)) else {
104        return f64::INFINITY;
105    };
106    let col_usize = bit_position % width;
107    let row_usize = bit_position / width;
108    let Ok(col) = u32::try_from(col_usize) else {
109        return f64::INFINITY;
110    };
111    let Ok(row) = u32::try_from(row_usize) else {
112        return f64::INFINITY;
113    };
114
115    if mask.is_occupied(row, col) {
116        return f64::INFINITY;
117    }
118
119    // Soft margin: positions near the boundary of occupied regions get a
120    // slightly higher cost (up to 2.0).  We use the fractional position
121    // within the image as a proxy for proximity to occupied areas.
122    let bit_position_f = u32::try_from(bit_position)
123        .ok()
124        .map_or_else(|| f64::from(u32::MAX), f64::from);
125    let total_positions_f = u32::try_from(total_positions)
126        .ok()
127        .map_or_else(|| f64::from(u32::MAX), f64::from);
128    let fraction = bit_position_f / total_positions_f;
129    1.0 + fraction.min(1.0)
130}
131
132// ─── Permutation ─────────────────────────────────────────────────────────────
133
134/// A bit-position permutation derived from a cost-weighted PRNG walk.
135///
136/// `map[original_position] = new_position`.
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Permutation {
139    map: Vec<usize>,
140}
141
142impl Permutation {
143    /// Identity permutation — no reordering.
144    #[must_use]
145    pub fn identity(len: usize) -> Self {
146        Self {
147            map: (0..len).collect(),
148        }
149    }
150
151    /// Apply the permutation to `data` in-place (cycle-following algorithm).
152    pub fn apply(&self, data: &mut [u8]) {
153        let n = data.len().min(self.map.len());
154        let source = match data.get(..n) {
155            Some(slice) => slice.to_vec(),
156            None => return,
157        };
158        let mut dest = source.clone();
159
160        for (original_position, &new_position) in self.map.iter().take(n).enumerate() {
161            if new_position >= n {
162                continue;
163            }
164            if let (Some(dst), Some(&src)) =
165                (dest.get_mut(new_position), source.get(original_position))
166            {
167                *dst = src;
168            }
169        }
170
171        if let Some(target) = data.get_mut(..n) {
172            target.copy_from_slice(&dest);
173        }
174    }
175
176    /// Return the inverse permutation  such that `inv.apply(p.apply(x)) == x`.
177    #[must_use]
178    pub fn inverse(&self) -> Self {
179        let mut inv = vec![0usize; self.map.len()];
180        for (orig, &new_pos) in self.map.iter().enumerate() {
181            if new_pos < inv.len() {
182                #[expect(
183                    clippy::indexing_slicing,
184                    reason = "new_pos is within bounds by the range-check above"
185                )]
186                {
187                    inv[new_pos] = orig;
188                }
189            }
190        }
191        Self { map: inv }
192    }
193
194    /// Raw permutation map (`original_position -> new_position`).
195    #[must_use]
196    pub fn as_slice(&self) -> &[usize] {
197        &self.map
198    }
199}
200
201// ─── SearchConfig ────────────────────────────────────────────────────────────
202
203/// Permutation search configuration.
204#[derive(Debug, Clone)]
205pub struct SearchConfig {
206    /// Maximum number of candidate permutations to evaluate.
207    pub max_iterations: u32,
208    /// Target chi-square score (dB) — search stops early when reached.
209    pub target_db: f64,
210}
211
212impl Default for SearchConfig {
213    fn default() -> Self {
214        Self {
215            max_iterations: 100,
216            target_db: -12.0,
217        }
218    }
219}
220
221// ─── permutation_search ──────────────────────────────────────────────────────
222
223/// Find the lowest-detectability permutation within `config.max_iterations`.
224///
225/// `seed` must be derived from the crypto key — never use a fresh random seed
226/// (the receiver needs to reconstruct the same permutation for extraction).
227///
228/// Uses a random-restart hill-climb: each iteration proposes a swap of two
229/// positions and accepts it if it lowers the chi-square score.
230#[must_use]
231pub fn permutation_search(
232    stego_bytes: &[u8],
233    mask: &BinMask,
234    config: &SearchConfig,
235    seed: u64,
236) -> Permutation {
237    if stego_bytes.is_empty() || config.max_iterations == 0 {
238        return Permutation::identity(stego_bytes.len());
239    }
240
241    let n = stego_bytes.len();
242    let mut rng = ChaCha8Rng::seed_from_u64(seed);
243    let mut best_perm = Permutation::identity(n);
244    // Use pair-delta chi-square (order-sensitive) so that swapping bytes
245    // actually changes the score and the hill-climb can make progress.
246    let mut best_score = pair_delta_chi_square_score(stego_bytes);
247
248    // Collect safe (non-occupied) positions to limit candidate swaps.
249    let safe_positions: Vec<usize> = (0..n)
250        .filter(|&pos| cost_at(pos, n, mask).is_finite())
251        .collect();
252
253    if safe_positions.len() < 2 {
254        return best_perm;
255    }
256
257    let mut current_map = best_perm.map.clone();
258    let mut current_data = stego_bytes.to_vec();
259
260    for _ in 0..config.max_iterations {
261        // Pick two distinct safe positions.
262        let idx_a = rng.random_range(0..safe_positions.len());
263        let mut idx_b = rng.random_range(0..safe_positions.len());
264        while idx_b == idx_a {
265            idx_b = rng.random_range(0..safe_positions.len());
266        }
267        let (Some(&pos_a), Some(&pos_b)) = (safe_positions.get(idx_a), safe_positions.get(idx_b))
268        else {
269            continue;
270        };
271
272        // Tentatively swap in both the map and the data.
273        current_map.swap(pos_a, pos_b);
274        current_data.swap(pos_a, pos_b);
275
276        let score = pair_delta_chi_square_score(&current_data);
277        if score < best_score {
278            best_score = score;
279            best_perm = Permutation {
280                map: current_map.clone(),
281            };
282            if best_score <= config.target_db {
283                break;
284            }
285        } else {
286            // Revert.
287            current_map.swap(pos_a, pos_b);
288            current_data.swap(pos_a, pos_b);
289        }
290    }
291
292    best_perm
293}
294
295// ─── Tests ───────────────────────────────────────────────────────────────────
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::domain::ports::{AiGenProfile, CarrierBin, CoverProfile};
301    use std::collections::HashMap;
302
303    fn gemini_1024_profile() -> CoverProfile {
304        let bins = vec![
305            CarrierBin::new((9, 9), 0.0, 1.0),
306            CarrierBin::new((5, 5), 0.0, 1.0),
307            CarrierBin::new((10, 11), 0.0, 1.0),
308            CarrierBin::new((13, 6), 0.0, 0.82), // below 0.90 — NOT strong
309        ];
310        let mut carrier_map = HashMap::new();
311        carrier_map.insert("1024x1024".to_string(), bins);
312        CoverProfile::AiGenerator(AiGenProfile {
313            model_id: "gemini".to_string(),
314            channel_weights: [0.85, 1.0, 0.70],
315            carrier_map,
316        })
317    }
318
319    #[test]
320    fn camera_profile_yields_all_zeros_mask() {
321        use crate::domain::ports::CameraProfile;
322        let profile = CoverProfile::Camera(CameraProfile {
323            quantisation_table: [0u16; 64],
324            noise_floor_db: -80.0,
325            model_id: "canon".to_string(),
326        });
327        let mask = BinMask::build(&profile, 64, 64);
328        assert_eq!(mask.occupied_count(), 0);
329    }
330
331    #[test]
332    fn ai_gen_profile_marks_strong_carrier_bins() {
333        let profile = gemini_1024_profile();
334        let mask = BinMask::build(&profile, 1024, 1024);
335        // (9,9), (5,5), (10,11) are strong; (13,6) has coherence 0.82 — not marked
336        assert!(mask.is_occupied(9, 9));
337        assert!(mask.is_occupied(5, 5));
338        assert!(mask.is_occupied(10, 11));
339        assert!(!mask.is_occupied(13, 6)); // below 0.90
340        assert!(!mask.is_occupied(100, 100));
341        assert_eq!(mask.occupied_count(), 3);
342    }
343
344    #[test]
345    fn cost_at_returns_infinity_for_occupied_bin() {
346        let profile = gemini_1024_profile();
347        let mask = BinMask::build(&profile, 1024, 1024);
348        // Position for (row=9, col=9): index = 9 * 1024 + 9 = 9225
349        let occupied_position = 9usize * 1024 + 9;
350        let cost = cost_at(occupied_position, 1024 * 1024, &mask);
351        assert!(cost.is_infinite(), "expected infinity for occupied bin");
352    }
353
354    #[test]
355    fn cost_at_returns_finite_for_safe_bin() {
356        let profile = gemini_1024_profile();
357        let mask = BinMask::build(&profile, 1024, 1024);
358        let safe_position = 500usize;
359        let cost = cost_at(safe_position, 1024 * 1024, &mask);
360        assert!(cost.is_finite());
361        assert!(cost >= 1.0);
362        assert!(cost <= 2.0);
363    }
364
365    #[test]
366    fn permutation_zero_iterations_returns_identity() {
367        let data = vec![1u8, 2, 3, 4, 5, 6];
368        let mask = BinMask::build(
369            &CoverProfile::Camera(crate::domain::ports::CameraProfile {
370                quantisation_table: [0u16; 64],
371                noise_floor_db: -80.0,
372                model_id: "test".to_string(),
373            }),
374            6,
375            1,
376        );
377        let config = SearchConfig {
378            max_iterations: 0,
379            target_db: -12.0,
380        };
381        let perm = permutation_search(&data, &mask, &config, 42);
382        assert_eq!(perm, Permutation::identity(6));
383    }
384
385    #[test]
386    fn permutation_is_deterministic_same_seed() {
387        let data: Vec<u8> = (0u8..64).collect();
388        let mask = BinMask::build(
389            &CoverProfile::Camera(crate::domain::ports::CameraProfile {
390                quantisation_table: [0u16; 64],
391                noise_floor_db: -80.0,
392                model_id: "test".to_string(),
393            }),
394            8,
395            8,
396        );
397        let config = SearchConfig::default();
398        let p1 = permutation_search(&data, &mask, &config, 12345);
399        let p2 = permutation_search(&data, &mask, &config, 12345);
400        assert_eq!(p1, p2);
401    }
402
403    #[test]
404    fn permutation_inverse_round_trips() {
405        let data: Vec<u8> = vec![10, 20, 30, 40, 50];
406        let mask = BinMask::build(
407            &CoverProfile::Camera(crate::domain::ports::CameraProfile {
408                quantisation_table: [0u16; 64],
409                noise_floor_db: -80.0,
410                model_id: "test".to_string(),
411            }),
412            5,
413            1,
414        );
415        let config = SearchConfig::default();
416        let perm = permutation_search(&data, &mask, &config, 99);
417        let original = data.clone();
418        let mut modified = data;
419        perm.apply(&mut modified);
420        perm.inverse().apply(&mut modified);
421        assert_eq!(modified, original);
422    }
423
424    #[test]
425    fn permutation_identity_apply_is_noop() {
426        let original = vec![1u8, 2, 3, 4];
427        let mut data = original.clone();
428        let perm = Permutation::identity(4);
429        perm.apply(&mut data);
430        assert_eq!(data, original);
431    }
432
433    #[test]
434    fn permutation_search_may_improve_score() {
435        // Build stego data with a known non-uniform pattern.
436        // The permutation search may find a better ordering.
437        let mut data: Vec<u8> = (0u8..=255u8).collect();
438        data.extend_from_slice(&[0u8; 256]); // add 256 zeros → heavily skewed histogram
439        let mask = BinMask::build(
440            &CoverProfile::Camera(crate::domain::ports::CameraProfile {
441                quantisation_table: [0u16; 64],
442                noise_floor_db: -80.0,
443                model_id: "test".to_string(),
444            }),
445            16,
446            32,
447        );
448        let config = SearchConfig {
449            max_iterations: 100,
450            target_db: -12.0,
451        };
452        let perm = permutation_search(&data, &mask, &config, 777);
453        // The permutation must be the right size.
454        assert_eq!(perm.as_slice().len(), data.len());
455        // After applying and inverting we must recover the original.
456        let original = data.clone();
457        let mut applied = data;
458        perm.apply(&mut applied);
459        perm.inverse().apply(&mut applied);
460        assert_eq!(applied, original);
461    }
462}