Skip to main content

shadowforge_lib/adapters/
adaptive.rs

1//! Adaptive embedding adapters: cover-profile matching, STC-inspired
2//! optimisation, and platform compression simulation.
3//!
4//! I/O is allowed here; domain logic lives in `domain/adaptive`.
5
6use std::io::Cursor;
7use std::sync::LazyLock;
8
9use bytes::Bytes;
10use rustfft::FftPlanner;
11use rustfft::num_complex::Complex;
12use serde::Deserialize;
13
14use crate::domain::adaptive::{BinMask, SearchConfig, permutation_search};
15use crate::domain::errors::AdaptiveError;
16use crate::domain::ports::{
17    AdaptiveOptimiser, AiGenProfile, CameraProfile, CarrierBin, CompressionSimulator, CoverProfile,
18    CoverProfileMatcher,
19};
20use crate::domain::types::{
21    AiWatermarkAssessment, Capacity, CoverMedia, CoverMediaKind, PlatformProfile, StegoTechnique,
22};
23
24// ─── Built-in AI codebook ────────────────────────────────────────────────────
25
26/// Deserialisation wrapper for `ai_profiles.json`.
27#[derive(Deserialize)]
28struct ProfileCodebook {
29    profiles: Vec<AiGenProfile>,
30}
31
32// ─── CoverProfileMatcherImpl ─────────────────────────────────────────────────
33
34/// Concrete cover-profile matcher.
35///
36/// Loaded from a JSON codebook at construction time.  The built-in codebook
37/// includes the Gemini watermark profile.
38pub struct CoverProfileMatcherImpl {
39    ai_profiles: Vec<AiGenProfile>,
40    camera_profiles: Vec<CameraProfile>,
41}
42
43struct AiProfileMatch<'a> {
44    profile: &'a AiGenProfile,
45    matched_strong_bins: usize,
46    total_strong_bins: usize,
47    /// 1.0 for exact resolution + native scale; stacked penalty otherwise.
48    confidence_multiplier: f64,
49    /// Mean `|cos(obs_phase − expected_phase)|` over matched strong G-channel bins.
50    phase_consistency: f64,
51    /// Weighted fraction of matched bins where R and B channels also agree.
52    cross_validation_score: f64,
53}
54
55impl CoverProfileMatcherImpl {
56    /// Parse an AI profile codebook from a JSON string.
57    ///
58    /// # Errors
59    /// Returns [`AdaptiveError::ProfileMatchFailed`] if the JSON is malformed.
60    pub fn from_codebook(json: &str) -> Result<Self, AdaptiveError> {
61        let book: ProfileCodebook =
62            serde_json::from_str(json).map_err(|e| AdaptiveError::ProfileMatchFailed {
63                reason: format!("invalid codebook JSON: {e}"),
64            })?;
65        Ok(Self {
66            ai_profiles: book.profiles,
67            camera_profiles: Vec::new(),
68        })
69    }
70
71    /// Build using the built-in `ai_profiles.json` codebook bundled at
72    /// compile time.
73    ///
74    /// # Panics
75    ///
76    /// Never panics in production — the embedded JSON is validated at
77    /// compile-time test level.
78    #[must_use]
79    pub fn with_built_in() -> Self {
80        static BUILT_IN: LazyLock<Vec<AiGenProfile>> = LazyLock::new(|| {
81            let raw = include_str!("ai_profiles.json");
82            match serde_json::from_str::<ProfileCodebook>(raw) {
83                Ok(book) => book.profiles,
84                Err(e) => {
85                    tracing::error!(
86                        "built-in AI profile codebook is malformed — \
87                         adaptive matching is disabled: {e}"
88                    );
89                    Vec::new()
90                }
91            }
92        });
93        Self {
94            ai_profiles: BUILT_IN.clone(),
95            camera_profiles: Vec::new(),
96        }
97    }
98
99    const fn ai_detection_supported(kind: CoverMediaKind) -> bool {
100        matches!(
101            kind,
102            CoverMediaKind::PngImage
103                | CoverMediaKind::BmpImage
104                | CoverMediaKind::JpegImage
105                | CoverMediaKind::GifImage
106        )
107    }
108
109    fn detection_threshold(total_strong_bins: usize) -> usize {
110        total_strong_bins.saturating_sub(1).max(1)
111    }
112
113    #[expect(
114        clippy::similar_names,
115        reason = "R/G/B channel triples are intentionally symmetric; names like fft_r_half/fft_b_half reflect the domain"
116    )]
117    #[expect(
118        clippy::too_many_lines,
119        reason = "multi-channel pyramid build + multi-scale phase loop; splitting would obscure the data flow more than it helps"
120    )]
121    fn best_ai_profile_match(&self, cover: &CoverMedia) -> Option<AiProfileMatch<'_>> {
122        let width = cover
123            .metadata
124            .get("width")
125            .and_then(|v| v.parse::<u32>().ok())
126            .unwrap_or(0);
127        let height = cover
128            .metadata
129            .get("height")
130            .and_then(|v| v.parse::<u32>().ok())
131            .unwrap_or(0);
132
133        if width == 0 || height == 0 {
134            return None;
135        }
136
137        // Build multi-channel image pyramid (native → ½ → ¼).
138        // G is the primary carrier; R and B are used for cross-validation.
139        let pixels_g_nat = extract_channel_f32(&cover.data, width, height, 1);
140        if pixels_g_nat.len() < 4 {
141            return None;
142        }
143        let pixels_r_nat = extract_channel_f32(&cover.data, width, height, 0);
144        let pixels_b_nat = extract_channel_f32(&cover.data, width, height, 2);
145
146        let w = width as usize;
147        let h = height as usize;
148        let hw = width.saturating_div(2) as usize;
149        let hh = height.saturating_div(2) as usize;
150        let pixels_g_half = downsample_2x(&pixels_g_nat, w, h);
151        let pixels_r_half = downsample_2x(&pixels_r_nat, w, h);
152        let pixels_b_half = downsample_2x(&pixels_b_nat, w, h);
153        let pixels_g_qtr = downsample_2x(&pixels_g_half, hw, hh);
154        let pixels_r_qtr = downsample_2x(&pixels_r_half, hw, hh);
155        let pixels_b_qtr = downsample_2x(&pixels_b_half, hw, hh);
156
157        let fft_g_nat = fft_1d(
158            &pixels_g_nat,
159            pixels_g_nat.len().next_power_of_two().min(MAX_FFT_LEN),
160        );
161        let fft_r_nat = compute_fft_or_empty(&pixels_r_nat);
162        let fft_b_nat = compute_fft_or_empty(&pixels_b_nat);
163        let fft_g_half = compute_fft_or_empty(&pixels_g_half);
164        let fft_r_half = compute_fft_or_empty(&pixels_r_half);
165        let fft_b_half = compute_fft_or_empty(&pixels_b_half);
166        let fft_g_qtr = compute_fft_or_empty(&pixels_g_qtr);
167        let fft_r_qtr = compute_fft_or_empty(&pixels_r_qtr);
168        let fft_b_qtr = compute_fft_or_empty(&pixels_b_qtr);
169
170        self.ai_profiles
171            .iter()
172            .filter_map(|profile| {
173                // Exact resolution match first; fall back to nearest if absent.
174                let (bins, confidence_multiplier) =
175                    if let Some(exact) = profile.carrier_bins_for(width, height) {
176                        (exact, 1.0_f64)
177                    } else {
178                        nearest_resolution_bins(profile, width, height)?
179                    };
180
181                let total_strong_bins = bins.iter().filter(|b| b.is_strong()).count();
182                if total_strong_bins == 0 {
183                    return None;
184                }
185
186                // Run phase-coherence at native, ½, and ¼ scales; take the best.
187                let half_w = width.saturating_div(2);
188                let quarter_w = width.saturating_div(4);
189                let (best_detail, scale_penalty) = [
190                    Some((&fft_g_nat, &fft_r_nat, &fft_b_nat, width, 0u32, 1.0_f64)),
191                    if fft_g_half.is_empty() {
192                        None
193                    } else {
194                        Some((&fft_g_half, &fft_r_half, &fft_b_half, half_w, 1, 0.85))
195                    },
196                    if fft_g_qtr.is_empty() {
197                        None
198                    } else {
199                        Some((&fft_g_qtr, &fft_r_qtr, &fft_b_qtr, quarter_w, 2, 0.75))
200                    },
201                ]
202                .into_iter()
203                .flatten()
204                .map(|(fg, fr, fb, sw, shift, penalty)| {
205                    let detail = phase_match_detail_at_scale(
206                        fg,
207                        fr,
208                        fb,
209                        sw,
210                        bins,
211                        shift,
212                        profile.channel_weights,
213                    );
214                    (detail, penalty)
215                })
216                .max_by_key(|(detail, _)| detail.matched_strong)?;
217
218                Some(AiProfileMatch {
219                    profile,
220                    matched_strong_bins: best_detail.matched_strong,
221                    total_strong_bins: best_detail.total_strong,
222                    confidence_multiplier: confidence_multiplier * scale_penalty,
223                    phase_consistency: best_detail.phase_consistency,
224                    cross_validation_score: best_detail.cross_validation,
225                })
226            })
227            .max_by(|left, right| {
228                // Prefer candidates that clear the detection threshold first,
229                // then rank by match ratio (cross-multiply to stay in integers),
230                // then by raw matched count as a final tiebreaker.
231                let left_detected =
232                    left.matched_strong_bins >= Self::detection_threshold(left.total_strong_bins);
233                let right_detected =
234                    right.matched_strong_bins >= Self::detection_threshold(right.total_strong_bins);
235                left_detected
236                    .cmp(&right_detected)
237                    .then_with(|| {
238                        let lscore =
239                            (left.matched_strong_bins as u128) * (right.total_strong_bins as u128);
240                        let rscore =
241                            (right.matched_strong_bins as u128) * (left.total_strong_bins as u128);
242                        lscore.cmp(&rscore)
243                    })
244                    .then_with(|| left.matched_strong_bins.cmp(&right.matched_strong_bins))
245            })
246    }
247
248    /// Assess whether a raster cover still matches a known AI watermark profile.
249    #[must_use]
250    pub fn assess_ai_watermark(&self, cover: &CoverMedia) -> Option<AiWatermarkAssessment> {
251        if !Self::ai_detection_supported(cover.kind) {
252            return None;
253        }
254
255        let Some(best_match) = self.best_ai_profile_match(cover) else {
256            return Some(AiWatermarkAssessment {
257                detected: false,
258                model_id: None,
259                confidence: 0.0,
260                matched_strong_bins: 0,
261                total_strong_bins: 0,
262            });
263        };
264
265        let confidence = best_match.phase_consistency
266            * best_match.cross_validation_score
267            * best_match.confidence_multiplier;
268        let detected = best_match.matched_strong_bins
269            >= Self::detection_threshold(best_match.total_strong_bins);
270
271        Some(AiWatermarkAssessment {
272            detected,
273            model_id: detected.then(|| best_match.profile.model_id.clone()),
274            confidence,
275            matched_strong_bins: best_match.matched_strong_bins,
276            total_strong_bins: best_match.total_strong_bins,
277        })
278    }
279}
280
281impl CoverProfileMatcher for CoverProfileMatcherImpl {
282    fn profile_for(&self, cover: &CoverMedia) -> Option<CoverProfile> {
283        if let Some(best_match) = self.best_ai_profile_match(cover)
284            && best_match.matched_strong_bins
285                >= Self::detection_threshold(best_match.total_strong_bins)
286        {
287            return Some(CoverProfile::AiGenerator(best_match.profile.clone()));
288        }
289
290        // Fallback: camera profile if any are loaded.
291        self.camera_profiles
292            .first()
293            .cloned()
294            .map(CoverProfile::Camera)
295    }
296
297    fn apply_profile(
298        &self,
299        cover: CoverMedia,
300        _profile: &CoverProfile,
301    ) -> Result<CoverMedia, AdaptiveError> {
302        // For AI profiles: no modification — the optimiser avoids carrier bins.
303        // For camera profiles: a real implementation would adjust quant tables
304        // via the JPEG encoder; stub returns cover unchanged.
305        Ok(cover)
306    }
307}
308
309// ─── AdaptiveOptimiserImpl ───────────────────────────────────────────────────
310
311/// Concrete adversarial optimiser.
312///
313/// Uses `permutation_search` from `domain/adaptive` to find a byte-reordering
314/// that minimises chi-square detectability.
315pub struct AdaptiveOptimiserImpl {
316    matcher: CoverProfileMatcherImpl,
317    config: SearchConfig,
318}
319
320impl AdaptiveOptimiserImpl {
321    /// Create with an explicit codebook and search configuration.
322    ///
323    /// # Errors
324    /// Returns [`AdaptiveError::ProfileMatchFailed`] if the codebook is invalid.
325    pub fn from_codebook(codebook_json: &str, config: SearchConfig) -> Result<Self, AdaptiveError> {
326        Ok(Self {
327            matcher: CoverProfileMatcherImpl::from_codebook(codebook_json)?,
328            config,
329        })
330    }
331
332    /// Create with the built-in Gemini codebook and default search config.
333    #[must_use]
334    pub fn with_built_in() -> Self {
335        Self {
336            matcher: CoverProfileMatcherImpl::with_built_in(),
337            config: SearchConfig::default(),
338        }
339    }
340}
341
342impl AdaptiveOptimiser for AdaptiveOptimiserImpl {
343    fn optimise(
344        &self,
345        mut stego: CoverMedia,
346        _original: &CoverMedia,
347        target_db: f64,
348    ) -> Result<CoverMedia, AdaptiveError> {
349        let width = stego
350            .metadata
351            .get("width")
352            .and_then(|v| v.parse::<u32>().ok())
353            .unwrap_or(1);
354        let height = stego
355            .metadata
356            .get("height")
357            .and_then(|v| v.parse::<u32>().ok())
358            .unwrap_or(1);
359
360        let profile = self.matcher.profile_for(&stego);
361        let fallback_profile = CoverProfile::Camera(CameraProfile {
362            quantisation_table: [0u16; 64],
363            noise_floor_db: -80.0,
364            model_id: "fallback".to_string(),
365        });
366        let mask = BinMask::build(profile.as_ref().unwrap_or(&fallback_profile), width, height);
367
368        let config = SearchConfig {
369            max_iterations: self.config.max_iterations,
370            target_db,
371        };
372
373        // Derive deterministic seed from first 8 bytes of the stego data.
374        let seed = stego.data.get(..8).map_or_else(
375            || 0,
376            |bytes| {
377                let mut seed_bytes = [0u8; 8];
378                seed_bytes.copy_from_slice(bytes);
379                u64::from_le_bytes(seed_bytes)
380            },
381        );
382
383        let mut data = stego.data.to_vec();
384        let perm = permutation_search(&data, &mask, &config, seed);
385        perm.apply(&mut data);
386        stego.data = Bytes::from(data);
387        Ok(stego)
388    }
389}
390
391// ─── CompressionSimulatorImpl ────────────────────────────────────────────────
392
393/// Platform-specific quality and chroma settings.
394#[derive(Debug, Clone, Copy)]
395struct PlatformSettings {
396    jpeg_quality: u8,
397}
398
399impl PlatformSettings {
400    const fn for_platform(platform: &PlatformProfile) -> Self {
401        match platform {
402            PlatformProfile::Instagram => Self { jpeg_quality: 82 },
403            PlatformProfile::Twitter => Self { jpeg_quality: 75 },
404            PlatformProfile::WhatsApp | PlatformProfile::Imgur => Self { jpeg_quality: 85 },
405            PlatformProfile::Telegram => Self { jpeg_quality: 95 },
406            PlatformProfile::Custom { quality, .. } => Self {
407                jpeg_quality: *quality,
408            },
409        }
410    }
411}
412
413/// Compression simulator using the `image` crate for in-memory JPEG encode/
414/// decode.  No temporary files are used — bytes flow through a `Cursor`.
415pub struct CompressionSimulatorImpl;
416
417impl CompressionSimulator for CompressionSimulatorImpl {
418    fn simulate(
419        &self,
420        cover: CoverMedia,
421        platform: &PlatformProfile,
422    ) -> Result<CoverMedia, AdaptiveError> {
423        let settings = PlatformSettings::for_platform(platform);
424        let quality = settings.jpeg_quality;
425
426        let width = cover
427            .metadata
428            .get("width")
429            .and_then(|v| v.parse::<u32>().ok());
430        let height = cover
431            .metadata
432            .get("height")
433            .and_then(|v| v.parse::<u32>().ok());
434
435        // We can only JPEG-compress actual image data.  If we lack dimensions
436        // or the cover isn't an image type, return unchanged.
437        let (Some(w), Some(h)) = (width, height) else {
438            return Ok(cover);
439        };
440
441        if !matches!(
442            cover.kind,
443            CoverMediaKind::PngImage
444                | CoverMediaKind::JpegImage
445                | CoverMediaKind::BmpImage
446                | CoverMediaKind::GifImage
447        ) {
448            return Ok(cover);
449        }
450
451        // Treat raw data as RGBA8 pixels.
452        let pixels = cover.data.to_vec();
453        let expected_len = (w as usize).saturating_mul(h as usize).saturating_mul(3);
454        if pixels.len() < expected_len {
455            return Ok(cover);
456        }
457
458        // Encode as JPEG into memory.
459        let mut encoded: Vec<u8> = Vec::new();
460        {
461            let mut cursor = Cursor::new(&mut encoded);
462            let mut jpeg_encoder =
463                image::codecs::jpeg::JpegEncoder::new_with_quality(&mut cursor, quality);
464            image::ImageBuffer::<image::Rgb<u8>, _>::from_raw(
465                w,
466                h,
467                pixels.get(..expected_len).unwrap_or(&[]),
468            )
469            .ok_or_else(|| AdaptiveError::CompressionSimFailed {
470                reason: "invalid pixel dimensions".to_string(),
471            })
472            .and_then(|buf| {
473                jpeg_encoder
474                    .encode(buf.as_raw(), w, h, image::ExtendedColorType::Rgb8)
475                    .map_err(|e| AdaptiveError::CompressionSimFailed {
476                        reason: format!("JPEG encode failed: {e}"),
477                    })
478            })?;
479        }
480
481        // Decode back.
482        let decoded = image::load_from_memory_with_format(&encoded, image::ImageFormat::Jpeg)
483            .map_err(|e| AdaptiveError::CompressionSimFailed {
484                reason: format!("JPEG decode failed: {e}"),
485            })?;
486        let rgb = decoded.to_rgb8();
487        let mut out_meta = cover.metadata;
488        out_meta.insert("width".to_string(), w.to_string());
489        out_meta.insert("height".to_string(), h.to_string());
490
491        Ok(CoverMedia {
492            kind: CoverMediaKind::JpegImage,
493            data: Bytes::from(rgb.into_raw()),
494            metadata: out_meta,
495        })
496    }
497
498    fn survivable_capacity(
499        &self,
500        cover: &CoverMedia,
501        platform: &PlatformProfile,
502    ) -> Result<Capacity, AdaptiveError> {
503        let total_bytes = cover.data.len() as u64;
504        let basis_points: u64 = match platform {
505            PlatformProfile::Instagram | PlatformProfile::Custom { .. } => 4000,
506            PlatformProfile::Twitter => 3000,
507            PlatformProfile::WhatsApp | PlatformProfile::Imgur => 4500,
508            PlatformProfile::Telegram => 7000,
509        };
510        let survivable = total_bytes.saturating_mul(basis_points).div_euclid(10_000);
511        Ok(Capacity {
512            bytes: survivable,
513            technique: StegoTechnique::LsbImage,
514        })
515    }
516}
517
518// ─── Dependency factory ───────────────────────────────────────────────────────
519
520/// Construct the three built-in adaptive adapters that together form the
521/// default profile-hardening dependency set.
522///
523/// Call sites should keep the returned values alive for as long as an
524/// [`crate::application::services::AdaptiveProfileDeps`] that borrows them
525/// is in scope.
526///
527/// Placing this factory in the adapter layer keeps the interface layer free
528/// from knowing *which* concrete types implement the adaptive port traits.
529#[must_use]
530pub fn build_adaptive_profile_deps() -> (
531    CoverProfileMatcherImpl,
532    AdaptiveOptimiserImpl,
533    CompressionSimulatorImpl,
534) {
535    (
536        CoverProfileMatcherImpl::with_built_in(),
537        AdaptiveOptimiserImpl::with_built_in(),
538        CompressionSimulatorImpl,
539    )
540}
541
542// ─── Helpers ─────────────────────────────────────────────────────────────────
543
544/// Extract a single colour channel from flat RGBA8 or RGB8 pixel data.
545///
546/// `channel` is a 0-based index: 0=R, 1=G, 2=B (3=A for RGBA8).
547/// Falls back to treating each byte as a grey value when the buffer layout
548/// cannot be determined.
549fn extract_channel_f32(data: &Bytes, width: u32, height: u32, channel: usize) -> Vec<f32> {
550    let npix = (width as usize).saturating_mul(height as usize);
551    if data.len() >= npix.saturating_mul(4) {
552        // RGBA8
553        data.chunks_exact(4)
554            .map(|ch| f32::from(ch.get(channel).copied().unwrap_or(0)))
555            .collect()
556    } else if data.len() >= npix.saturating_mul(3) {
557        // RGB8 — clamp channel index to 0..=2
558        let idx = channel.min(2);
559        data.chunks_exact(3)
560            .map(|ch| f32::from(ch.get(idx).copied().unwrap_or(0)))
561            .collect()
562    } else {
563        data.iter().map(|&b| f32::from(b)).collect()
564    }
565}
566
567/// Compute the 1-D FFT of `pixels`, zero-padded to the next power of two.
568/// Returns an empty `Vec` when `pixels` is empty.
569/// Normalise a raw phase difference into `[−π, π]`.
570///
571/// Handles the branch-cut discontinuity at `±π` so that phases near `π` and
572/// near `−π` compare correctly.
573#[inline]
574fn wrap_phase(raw: f64) -> f64 {
575    (raw + std::f64::consts::PI).rem_euclid(std::f64::consts::TAU) - std::f64::consts::PI
576}
577
578fn compute_fft_or_empty(pixels: &[f32]) -> Vec<Complex<f32>> {
579    if pixels.is_empty() {
580        Vec::new()
581    } else {
582        fft_1d(pixels, pixels.len().next_power_of_two().min(MAX_FFT_LEN))
583    }
584}
585
586fn fft_1d(samples: &[f32], fft_len: usize) -> Vec<Complex<f32>> {
587    let mut input: Vec<Complex<f32>> = samples.iter().map(|&x| Complex::new(x, 0.0)).collect();
588    input.resize(fft_len, Complex::new(0.0, 0.0));
589    let mut planner = FftPlanner::<f32>::new();
590    let fft = planner.plan_fft_forward(fft_len);
591    fft.process(&mut input);
592    input
593}
594
595/// Parse a `"WxH"` resolution key into `(width, height)`.
596fn parse_resolution_key(key: &str) -> Option<(u32, u32)> {
597    let (w, h) = key.split_once('x')?;
598    Some((w.parse().ok()?, h.parse().ok()?))
599}
600
601/// Confidence multiplier applied when the exact resolution is absent from the
602/// codebook and the detector falls back to the nearest available resolution.
603const FALLBACK_CONFIDENCE_MULTIPLIER: f64 = 0.65;
604
605/// Maximum FFT length used for watermark detection.
606///
607/// Caps the 1-D FFT at ~1 M samples to bound CPU/memory cost for large images.
608/// Phase-coherence is driven by the carrier-bin indices (typically ≪ 1 M), so
609/// truncating at this length does not affect detection quality in practice.
610const MAX_FFT_LEN: usize = 1 << 20;
611
612/// Return the bin slice from `profile` whose resolution is closest to
613/// `(width, height)` by pixel count and aspect ratio, together with
614/// [`FALLBACK_CONFIDENCE_MULTIPLIER`] to signal the approximation.
615///
616/// Proximity score: `|actual_pixels − candidate_pixels| / actual_pixels + |Δ AR|`.
617///
618/// Returns `None` only when the profile has no carrier bins at all.
619fn nearest_resolution_bins(
620    profile: &AiGenProfile,
621    width: u32,
622    height: u32,
623) -> Option<(&[CarrierBin], f64)> {
624    let actual_pixels = u64::from(width).saturating_mul(u64::from(height));
625    let actual_ar = f64::from(width) / f64::from(height.max(1));
626
627    profile
628        .carrier_map
629        .iter()
630        .filter_map(|(key, bins)| {
631            let (cw, ch) = parse_resolution_key(key)?;
632            let candidate_pixels = u64::from(cw).saturating_mul(u64::from(ch));
633            #[expect(
634                clippy::cast_precision_loss,
635                reason = "pixel counts fit in f64 mantissa for realistic image dimensions"
636            )]
637            let pixel_diff = {
638                let diff = actual_pixels.abs_diff(candidate_pixels) as f64;
639                let denom = actual_pixels.max(1) as f64;
640                diff / denom
641            };
642            let ar_diff = (actual_ar - f64::from(cw) / f64::from(ch.max(1))).abs();
643            Some((pixel_diff + ar_diff, bins.as_slice()))
644        })
645        .min_by(|(score_a, _), (score_b, _)| {
646            score_a
647                .partial_cmp(score_b)
648                .unwrap_or(std::cmp::Ordering::Equal)
649        })
650        .map(|(_, bins)| (bins, FALLBACK_CONFIDENCE_MULTIPLIER))
651}
652
653/// Box-filter downsample by ×½ on a flat single-channel `f32` buffer.
654///
655/// Each output sample is the mean of the corresponding 2×2 input block.
656/// Returns an empty `Vec` when `width < 2`, `height < 2`, or the buffer is
657/// shorter than `width × height`.
658fn downsample_2x(pixels: &[f32], width: usize, height: usize) -> Vec<f32> {
659    let out_w = width / 2;
660    let out_h = height / 2;
661    if out_w == 0 || out_h == 0 || pixels.len() < width.saturating_mul(height) {
662        return Vec::new();
663    }
664    (0..out_h)
665        .flat_map(move |oy| {
666            (0..out_w).map(move |ox| {
667                let r0 = oy.saturating_mul(2).saturating_mul(width);
668                let r1 = oy.saturating_mul(2).saturating_add(1).saturating_mul(width);
669                let tl = pixels.get(r0 + ox * 2).copied().unwrap_or(0.0);
670                let tr = pixels.get(r0 + ox * 2 + 1).copied().unwrap_or(0.0);
671                let bl = pixels.get(r1 + ox * 2).copied().unwrap_or(0.0);
672                let br = pixels.get(r1 + ox * 2 + 1).copied().unwrap_or(0.0);
673                (tl + tr + bl + br) / 4.0
674            })
675        })
676        .collect()
677}
678
679/// Detailed result of a phase-coherence check at one pyramid scale.
680struct ScaleMatchDetail {
681    matched_strong: usize,
682    total_strong: usize,
683    /// Mean `|cos(obs_phase − expected_phase)|` over matched strong G-channel bins.
684    phase_consistency: f64,
685    /// Weighted fraction of matched bins where R and B also show coherence.
686    cross_validation: f64,
687}
688
689/// Full phase-coherence measurement for `bins` at one pyramid scale.
690///
691/// G channel is the primary carrier; R and B are cross-validated using
692/// `channel_weights[0]` (R) and `channel_weights[2]` (B) from the profile's `channel_weights`.
693/// `scale_shift` = `log₂(native_width / scaled_width)`: 0 = native, 1 = ½, 2 = ¼.
694fn phase_match_detail_at_scale(
695    freq_g: &[Complex<f32>],
696    freq_r: &[Complex<f32>],
697    freq_b: &[Complex<f32>],
698    scaled_width: u32,
699    bins: &[CarrierBin],
700    scale_shift: u32,
701    channel_weights: [f64; 3],
702) -> ScaleMatchDetail {
703    let r_weight = channel_weights.first().copied().unwrap_or(0.0);
704    let b_weight = channel_weights.get(2).copied().unwrap_or(0.0);
705    let divisor = 1u32.wrapping_shl(scale_shift);
706    let total_strong = bins.iter().filter(|b| b.is_strong()).count();
707    let mut matched_strong = 0usize;
708    let mut phase_cos_sum = 0.0_f64;
709    let mut r_cross = 0usize;
710    let mut b_cross = 0usize;
711
712    for bin in bins.iter().filter(|b| b.is_strong()) {
713        let row = bin.freq.0.saturating_div(divisor);
714        let col = bin.freq.1.saturating_div(divisor);
715        let idx = (row as usize)
716            .saturating_mul(scaled_width as usize)
717            .saturating_add(col as usize);
718
719        let Some(carrier_g) = freq_g.get(idx) else {
720            continue;
721        };
722        let phase_diff = wrap_phase(f64::from(carrier_g.arg()) - bin.phase);
723        if phase_diff.abs() >= std::f64::consts::PI / 8.0 {
724            continue;
725        }
726        matched_strong += 1;
727        phase_cos_sum += phase_diff.cos().abs();
728        if freq_r.get(idx).is_some_and(|c| {
729            wrap_phase(f64::from(c.arg()) - bin.phase).abs() < std::f64::consts::PI / 8.0
730        }) {
731            r_cross += 1;
732        }
733        if freq_b.get(idx).is_some_and(|c| {
734            wrap_phase(f64::from(c.arg()) - bin.phase).abs() < std::f64::consts::PI / 8.0
735        }) {
736            b_cross += 1;
737        }
738    }
739
740    #[expect(
741        clippy::cast_precision_loss,
742        reason = "bin counts are small; f64 precision loss is negligible"
743    )]
744    let phase_consistency = if matched_strong == 0 {
745        0.0
746    } else {
747        phase_cos_sum / matched_strong as f64
748    };
749
750    let weight_sum = r_weight + b_weight;
751    #[expect(
752        clippy::cast_precision_loss,
753        reason = "bin counts are small; f64 precision loss is negligible"
754    )]
755    let cross_validation = if weight_sum < 1e-10 || matched_strong == 0 {
756        1.0 // No channel-weight data; do not penalise.
757    } else {
758        let ms = matched_strong as f64;
759        r_weight.mul_add(r_cross as f64 / ms, b_weight * (b_cross as f64 / ms)) / weight_sum
760    };
761
762    ScaleMatchDetail {
763        matched_strong,
764        total_strong,
765        phase_consistency,
766        cross_validation,
767    }
768}
769
770// ─── Tests ───────────────────────────────────────────────────────────────────
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775    use std::collections::HashMap;
776
777    fn make_cover(kind: CoverMediaKind, w: u32, h: u32) -> CoverMedia {
778        let n = (w as usize).saturating_mul(h as usize).saturating_mul(4);
779        let mut meta = HashMap::new();
780        meta.insert("width".to_string(), w.to_string());
781        meta.insert("height".to_string(), h.to_string());
782        CoverMedia {
783            kind,
784            data: Bytes::from(vec![128u8; n]),
785            metadata: meta,
786        }
787    }
788
789    #[test]
790    fn built_in_codebook_parses_without_error() {
791        let matcher = CoverProfileMatcherImpl::with_built_in();
792        assert!(!matcher.ai_profiles.is_empty());
793        let first = matcher.ai_profiles.first();
794        assert!(first.is_some());
795        assert_eq!(first.map(|p| p.model_id.as_str()), Some("gemini"));
796    }
797
798    #[test]
799    fn from_codebook_returns_error_on_bad_json() {
800        let result = CoverProfileMatcherImpl::from_codebook("not json");
801        assert!(result.is_err());
802    }
803
804    #[test]
805    fn from_codebook_accepts_valid_json() {
806        let json = r#"{"profiles":[{"model_id":"test","channel_weights":[1.0,1.0,1.0],"carrier_map":{}}]}"#;
807        let result = CoverProfileMatcherImpl::from_codebook(json);
808        assert!(result.is_ok());
809    }
810
811    #[test]
812    fn profile_for_returns_none_for_zero_dimensions() {
813        let matcher = CoverProfileMatcherImpl::with_built_in();
814        let cover = CoverMedia {
815            kind: CoverMediaKind::PngImage,
816            data: Bytes::from(vec![0u8; 16]),
817            metadata: HashMap::new(), // no width/height
818        };
819        assert!(matcher.profile_for(&cover).is_none());
820    }
821
822    #[test]
823    fn assess_ai_watermark_detects_matching_profile() -> Result<(), Box<dyn std::error::Error>> {
824        let matcher = CoverProfileMatcherImpl::from_codebook(
825            r#"{"profiles":[{"model_id":"test-ai","channel_weights":[1.0,1.0,1.0],"carrier_map":{"8x8":[{"freq":[0,0],"phase":0.0,"coherence":1.0}]}}]}"#,
826        )?;
827        let cover = make_cover(CoverMediaKind::PngImage, 8, 8);
828
829        let assessment = matcher.assess_ai_watermark(&cover);
830        assert!(
831            assessment.is_some(),
832            "expected ai watermark assessment for matching cover"
833        );
834        let Some(assessment) = assessment else {
835            return Ok(());
836        };
837        assert!(assessment.detected);
838        assert_eq!(assessment.model_id.as_deref(), Some("test-ai"));
839        assert_eq!(assessment.matched_strong_bins, 1);
840        assert_eq!(assessment.total_strong_bins, 1);
841        Ok(())
842    }
843
844    #[test]
845    fn phase_match_detail_perfect_phase_gives_consistency_one() {
846        // Uniform input → DC component has arg ≈ 0; bin expects 0.0.
847        // Both phase_consistency and cross_validation must be ≈ 1.0.
848        let samples: Vec<f32> = vec![100.0; 64];
849        let fft_len = samples.len().next_power_of_two();
850        let freq = fft_1d(&samples, fft_len);
851        let bins = vec![CarrierBin::new((0, 0), 0.0, 1.0)];
852        let detail = phase_match_detail_at_scale(&freq, &freq, &freq, 8, &bins, 0, [1.0, 1.0, 1.0]);
853        assert_eq!(detail.matched_strong, 1);
854        assert!(
855            (detail.phase_consistency - 1.0).abs() < 1e-5,
856            "expected phase_consistency ≈ 1.0, got {}",
857            detail.phase_consistency
858        );
859        assert!(
860            (detail.cross_validation - 1.0).abs() < 1e-5,
861            "expected cross_validation ≈ 1.0, got {}",
862            detail.cross_validation
863        );
864    }
865
866    #[test]
867    fn confidence_is_in_unit_interval() -> Result<(), Box<dyn std::error::Error>> {
868        let matcher = CoverProfileMatcherImpl::from_codebook(
869            r#"{"profiles":[{"model_id":"test-ai","channel_weights":[1.0,1.0,1.0],"carrier_map":{"8x8":[{"freq":[0,0],"phase":0.0,"coherence":1.0}]}}]}"#,
870        )?;
871        let cover = make_cover(CoverMediaKind::PngImage, 8, 8);
872        let assessment = matcher.assess_ai_watermark(&cover);
873        assert!(
874            assessment.is_some(),
875            "expected assessment for matching cover"
876        );
877        let Some(assessment) = assessment else {
878            return Ok(());
879        };
880        assert!(
881            (0.0..=1.0).contains(&assessment.confidence),
882            "confidence must be in [0.0, 1.0], got {}",
883            assessment.confidence
884        );
885        Ok(())
886    }
887
888    #[test]
889    fn downsample_2x_produces_box_filtered_output() {
890        // 4×4 single-channel input; expect 2×2 box-filtered output.
891        let input = vec![
892            1.0_f32, 2.0, 3.0, 4.0, // row 0
893            5.0, 6.0, 7.0, 8.0, // row 1
894            9.0, 10.0, 11.0, 12.0, // row 2
895            13.0, 14.0, 15.0, 16.0, // row 3
896        ];
897        let output = downsample_2x(&input, 4, 4);
898        assert_eq!(output.len(), 4, "expected 2×2 output");
899        // Top-left block (rows 0-1, cols 0-1): (1+2+5+6)/4 = 3.5
900        let tl = output.first().copied().unwrap_or(0.0);
901        assert!((tl - 3.5).abs() < 1e-5, "top-left expected 3.5, got {tl}");
902        // Top-right block (rows 0-1, cols 2-3): (3+4+7+8)/4 = 5.5
903        let tr = output.get(1).copied().unwrap_or(0.0);
904        assert!((tr - 5.5).abs() < 1e-5, "top-right expected 5.5, got {tr}");
905    }
906
907    #[test]
908    fn resolution_fallback_uses_nearest_profile() -> Result<(), Box<dyn std::error::Error>> {
909        // Profile registers only "16x16"; querying at 8×8 must still produce a
910        // result via the nearest-resolution fallback (with reduced confidence).
911        let matcher = CoverProfileMatcherImpl::from_codebook(
912            r#"{"profiles":[{"model_id":"fallback-test","channel_weights":[1.0,1.0,1.0],"carrier_map":{"16x16":[{"freq":[0,0],"phase":0.0,"coherence":1.0}]}}]}"#,
913        )?;
914        let cover = make_cover(CoverMediaKind::PngImage, 8, 8);
915
916        let assessment = matcher.assess_ai_watermark(&cover);
917        assert!(
918            assessment.is_some(),
919            "expected Some from fallback resolution match"
920        );
921        let Some(assessment) = assessment else {
922            return Ok(());
923        };
924        // Fallback multiplier (0.65) means confidence < 1.0 even for a perfect bin match.
925        assert!(
926            assessment.confidence < 1.0,
927            "fallback should reduce confidence below 1.0, got {}",
928            assessment.confidence
929        );
930        assert_eq!(assessment.total_strong_bins, 1);
931        Ok(())
932    }
933
934    #[test]
935    fn apply_profile_returns_cover_unchanged() {
936        let matcher = CoverProfileMatcherImpl::with_built_in();
937        let cover = make_cover(CoverMediaKind::PngImage, 8, 8);
938        let profile = CoverProfile::Camera(CameraProfile {
939            quantisation_table: [0u16; 64],
940            noise_floor_db: -80.0,
941            model_id: "test".to_string(),
942        });
943        let result = matcher.apply_profile(cover.clone(), &profile);
944        assert!(result.is_ok());
945        let Some(result) = result.ok() else {
946            return;
947        };
948        assert_eq!(result.data, cover.data);
949    }
950
951    #[test]
952    fn adaptive_optimiser_built_in_runs_without_error() {
953        let optimiser = AdaptiveOptimiserImpl::with_built_in();
954        let cover = make_cover(CoverMediaKind::PngImage, 8, 8);
955        let stego = make_cover(CoverMediaKind::PngImage, 8, 8);
956        let result = optimiser.optimise(stego, &cover, -12.0);
957        assert!(result.is_ok());
958    }
959
960    #[test]
961    fn adaptive_optimiser_preserves_data_length() {
962        let optimiser = AdaptiveOptimiserImpl::with_built_in();
963        let cover = make_cover(CoverMediaKind::PngImage, 4, 4);
964        let stego = make_cover(CoverMediaKind::PngImage, 4, 4);
965        let original_len = stego.data.len();
966        let result = optimiser.optimise(stego, &cover, -12.0);
967        assert!(result.is_ok());
968        let Some(result) = result.ok() else {
969            return;
970        };
971        assert_eq!(result.data.len(), original_len);
972    }
973
974    #[test]
975    fn compression_simulator_survivable_capacity() {
976        let sim = CompressionSimulatorImpl;
977        let cover = make_cover(CoverMediaKind::PngImage, 32, 32);
978        let cap = sim.survivable_capacity(&cover, &PlatformProfile::Instagram);
979        assert!(cap.is_ok());
980        let Some(cap) = cap.ok() else {
981            return;
982        };
983        assert!(cap.bytes > 0);
984        assert!(cap.bytes < cover.data.len() as u64);
985    }
986
987    #[test]
988    fn compression_simulator_non_image_returns_unchanged() {
989        let sim = CompressionSimulatorImpl;
990        let cover = CoverMedia {
991            kind: CoverMediaKind::WavAudio,
992            data: Bytes::from(vec![0u8; 1024]),
993            metadata: {
994                let mut m = HashMap::new();
995                m.insert("width".to_string(), "32".to_string());
996                m.insert("height".to_string(), "32".to_string());
997                m
998            },
999        };
1000        let result = sim.simulate(cover.clone(), &PlatformProfile::Twitter);
1001        assert!(result.is_ok());
1002        let Some(result) = result.ok() else {
1003            return;
1004        };
1005        assert_eq!(result.data, cover.data);
1006    }
1007
1008    #[test]
1009    fn platform_settings_telegram_highest_quality() {
1010        let t = PlatformSettings::for_platform(&PlatformProfile::Telegram);
1011        let i = PlatformSettings::for_platform(&PlatformProfile::Twitter);
1012        assert!(t.jpeg_quality > i.jpeg_quality);
1013
1014        let sim = CompressionSimulatorImpl;
1015        let cover = make_cover(CoverMediaKind::PngImage, 32, 32);
1016        let t_cap = sim.survivable_capacity(&cover, &PlatformProfile::Telegram);
1017        let i_cap = sim.survivable_capacity(&cover, &PlatformProfile::Twitter);
1018        assert!(t_cap.is_ok());
1019        assert!(i_cap.is_ok());
1020        let Some(t_cap) = t_cap.ok() else {
1021            return;
1022        };
1023        let Some(i_cap) = i_cap.ok() else {
1024            return;
1025        };
1026        assert!(t_cap.bytes > i_cap.bytes);
1027    }
1028
1029    #[test]
1030    fn build_adaptive_profile_deps_returns_functional_impls() {
1031        let (matcher, _optimiser, _compressor) = build_adaptive_profile_deps();
1032        // The matcher must have loaded the built-in codebook.
1033        assert!(!matcher.ai_profiles.is_empty());
1034    }
1035}