Skip to main content

scirs2_signal/
modulation.rs

1//! Modulation and demodulation module for signal processing
2//!
3//! This module provides modulation/demodulation algorithms for communication signals:
4//! - **AM** (Amplitude Modulation) — DSB-SC, DSB-FC (conventional), and SSB
5//! - **FM** (Frequency Modulation) — analog FM with configurable deviation
6//! - **QAM** (Quadrature Amplitude Modulation) — 4-QAM, 16-QAM, 64-QAM, 256-QAM
7//!
8//! All functions operate on real-valued sample vectors and produce real or complex outputs.
9//! Pure Rust, no unwrap(), snake_case naming.
10
11use crate::error::{SignalError, SignalResult};
12use std::f64::consts::PI;
13
14// ---------------------------------------------------------------------------
15// AM modulation / demodulation
16// ---------------------------------------------------------------------------
17
18/// AM modulation mode
19#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum AmMode {
21    /// Double-Sideband Suppressed Carrier (DSB-SC)
22    DsbSc,
23    /// Double-Sideband Full Carrier (conventional AM)
24    /// The parameter is the modulation index (0 < m <= 1 for no over-modulation)
25    DsbFc(f64),
26    /// Single-Sideband (upper sideband)
27    SsbUpper,
28    /// Single-Sideband (lower sideband)
29    SsbLower,
30}
31
32/// Amplitude-modulate a baseband signal onto a carrier
33///
34/// # Arguments
35///
36/// * `signal` - Baseband (message) signal
37/// * `carrier_freq` - Carrier frequency in Hz
38/// * `sample_rate` - Sample rate in Hz
39/// * `mode` - AM modulation mode
40///
41/// # Returns
42///
43/// * Modulated signal (same length as input)
44pub fn am_modulate(
45    signal: &[f64],
46    carrier_freq: f64,
47    sample_rate: f64,
48    mode: AmMode,
49) -> SignalResult<Vec<f64>> {
50    validate_mod_params(signal, carrier_freq, sample_rate)?;
51
52    let n = signal.len();
53    let mut output = vec![0.0; n];
54    let omega_c = 2.0 * PI * carrier_freq / sample_rate;
55
56    match mode {
57        AmMode::DsbSc => {
58            for (i, out) in output.iter_mut().enumerate() {
59                *out = signal[i] * (omega_c * i as f64).cos();
60            }
61        }
62        AmMode::DsbFc(mod_index) => {
63            if mod_index <= 0.0 {
64                return Err(SignalError::ValueError(
65                    "Modulation index must be positive".to_string(),
66                ));
67            }
68            // Normalize signal to [-1, 1] for modulation index interpretation
69            let max_abs = signal.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
70            let scale = if max_abs > 1e-20 { 1.0 / max_abs } else { 1.0 };
71
72            for (i, out) in output.iter_mut().enumerate() {
73                let carrier = (omega_c * i as f64).cos();
74                *out = (1.0 + mod_index * signal[i] * scale) * carrier;
75            }
76        }
77        AmMode::SsbUpper | AmMode::SsbLower => {
78            // SSB uses Hilbert transform (approximation via 90-degree phase shift)
79            let hilbert = hilbert_transform_approx(signal)?;
80            let sign = if mode == AmMode::SsbUpper { -1.0 } else { 1.0 };
81            for (i, out) in output.iter_mut().enumerate() {
82                let cos_c = (omega_c * i as f64).cos();
83                let sin_c = (omega_c * i as f64).sin();
84                *out = signal[i] * cos_c + sign * hilbert[i] * sin_c;
85            }
86        }
87    }
88
89    Ok(output)
90}
91
92/// Demodulate an AM signal back to baseband
93///
94/// # Arguments
95///
96/// * `modulated` - AM modulated signal
97/// * `carrier_freq` - Carrier frequency in Hz
98/// * `sample_rate` - Sample rate in Hz
99/// * `mode` - AM mode used for modulation
100///
101/// # Returns
102///
103/// * Demodulated baseband signal
104pub fn am_demodulate(
105    modulated: &[f64],
106    carrier_freq: f64,
107    sample_rate: f64,
108    mode: AmMode,
109) -> SignalResult<Vec<f64>> {
110    validate_mod_params(modulated, carrier_freq, sample_rate)?;
111
112    let n = modulated.len();
113    let omega_c = 2.0 * PI * carrier_freq / sample_rate;
114
115    match mode {
116        AmMode::DsbSc => {
117            // Coherent detection: multiply by carrier then low-pass
118            let mut baseband = vec![0.0; n];
119            for (i, out) in baseband.iter_mut().enumerate() {
120                *out = 2.0 * modulated[i] * (omega_c * i as f64).cos();
121            }
122            // Simple moving-average low-pass filter
123            let cutoff_samples = (sample_rate / carrier_freq).ceil() as usize;
124            let filtered = moving_average_lowpass(&baseband, cutoff_samples.max(2));
125            Ok(filtered)
126        }
127        AmMode::DsbFc(_mod_index) => {
128            // Envelope detection (rectify + low-pass)
129            let envelope: Vec<f64> = modulated.iter().map(|x| x.abs()).collect();
130            let cutoff_samples = (sample_rate / carrier_freq).ceil() as usize;
131            let filtered = moving_average_lowpass(&envelope, cutoff_samples.max(2));
132            // Remove DC offset (the carrier component)
133            let mean: f64 = filtered.iter().sum::<f64>() / filtered.len() as f64;
134            Ok(filtered.iter().map(|x| x - mean).collect())
135        }
136        AmMode::SsbUpper | AmMode::SsbLower => {
137            // Coherent SSB demodulation
138            let mut baseband = vec![0.0; n];
139            for (i, out) in baseband.iter_mut().enumerate() {
140                *out = 2.0 * modulated[i] * (omega_c * i as f64).cos();
141            }
142            let cutoff_samples = (sample_rate / carrier_freq).ceil() as usize;
143            let filtered = moving_average_lowpass(&baseband, cutoff_samples.max(2));
144            Ok(filtered)
145        }
146    }
147}
148
149// ---------------------------------------------------------------------------
150// FM modulation / demodulation
151// ---------------------------------------------------------------------------
152
153/// Frequency-modulate a baseband signal
154///
155/// # Arguments
156///
157/// * `signal` - Baseband (message) signal
158/// * `carrier_freq` - Carrier frequency in Hz
159/// * `sample_rate` - Sample rate in Hz
160/// * `freq_deviation` - Maximum frequency deviation in Hz
161///
162/// # Returns
163///
164/// * FM modulated signal
165pub fn fm_modulate(
166    signal: &[f64],
167    carrier_freq: f64,
168    sample_rate: f64,
169    freq_deviation: f64,
170) -> SignalResult<Vec<f64>> {
171    validate_mod_params(signal, carrier_freq, sample_rate)?;
172    if freq_deviation <= 0.0 {
173        return Err(SignalError::ValueError(
174            "Frequency deviation must be positive".to_string(),
175        ));
176    }
177
178    let n = signal.len();
179    let omega_c = 2.0 * PI * carrier_freq / sample_rate;
180    let k_f = 2.0 * PI * freq_deviation / sample_rate;
181
182    // Cumulative integral of message signal
183    let mut phase_integral = 0.0;
184    let mut output = Vec::with_capacity(n);
185
186    for (i, &s) in signal.iter().enumerate() {
187        phase_integral += s;
188        let phase = omega_c * i as f64 + k_f * phase_integral;
189        output.push(phase.cos());
190    }
191
192    Ok(output)
193}
194
195/// Demodulate an FM signal back to baseband
196///
197/// Uses differentiation of instantaneous phase (arctangent discriminator).
198///
199/// # Arguments
200///
201/// * `modulated` - FM modulated signal
202/// * `sample_rate` - Sample rate in Hz
203/// * `freq_deviation` - Maximum frequency deviation used during modulation
204///
205/// # Returns
206///
207/// * Demodulated baseband signal
208pub fn fm_demodulate(
209    modulated: &[f64],
210    sample_rate: f64,
211    freq_deviation: f64,
212) -> SignalResult<Vec<f64>> {
213    if modulated.is_empty() {
214        return Err(SignalError::ValueError(
215            "Input signal must not be empty".to_string(),
216        ));
217    }
218    if sample_rate <= 0.0 {
219        return Err(SignalError::ValueError(
220            "Sample rate must be positive".to_string(),
221        ));
222    }
223    if freq_deviation <= 0.0 {
224        return Err(SignalError::ValueError(
225            "Frequency deviation must be positive".to_string(),
226        ));
227    }
228
229    let n = modulated.len();
230    if n < 2 {
231        return Ok(vec![0.0]);
232    }
233
234    // Compute analytic signal via Hilbert transform
235    let hilbert = hilbert_transform_approx(modulated)?;
236
237    // Compute instantaneous phase
238    let mut inst_phase = Vec::with_capacity(n);
239    for i in 0..n {
240        inst_phase.push(hilbert[i].atan2(modulated[i]));
241    }
242
243    // Unwrap phase
244    let unwrapped = unwrap_phase_vec(&inst_phase);
245
246    // Differentiate to get instantaneous frequency
247    let k_f = 2.0 * PI * freq_deviation / sample_rate;
248    let scale = if k_f.abs() > 1e-20 {
249        sample_rate / (2.0 * PI * freq_deviation)
250    } else {
251        1.0
252    };
253
254    let mut demodulated = Vec::with_capacity(n);
255    demodulated.push(0.0); // first sample
256    for i in 1..n {
257        let diff = unwrapped[i] - unwrapped[i - 1];
258        demodulated.push(diff * scale);
259    }
260
261    // Remove DC component
262    let mean: f64 = demodulated.iter().sum::<f64>() / demodulated.len() as f64;
263    Ok(demodulated.iter().map(|x| x - mean).collect())
264}
265
266// ---------------------------------------------------------------------------
267// QAM modulation
268// ---------------------------------------------------------------------------
269
270/// QAM constellation order
271#[derive(Debug, Clone, Copy, PartialEq)]
272pub enum QamOrder {
273    /// 4-QAM (equivalent to QPSK)
274    Qam4,
275    /// 16-QAM
276    Qam16,
277    /// 64-QAM
278    Qam64,
279    /// 256-QAM
280    Qam256,
281}
282
283impl QamOrder {
284    /// Get the number of constellation points
285    pub fn constellation_size(self) -> usize {
286        match self {
287            QamOrder::Qam4 => 4,
288            QamOrder::Qam16 => 16,
289            QamOrder::Qam64 => 64,
290            QamOrder::Qam256 => 256,
291        }
292    }
293
294    /// Get the number of bits per symbol
295    pub fn bits_per_symbol(self) -> usize {
296        match self {
297            QamOrder::Qam4 => 2,
298            QamOrder::Qam16 => 4,
299            QamOrder::Qam64 => 6,
300            QamOrder::Qam256 => 8,
301        }
302    }
303
304    /// Get the grid dimension (sqrt of constellation size)
305    fn grid_dim(self) -> usize {
306        match self {
307            QamOrder::Qam4 => 2,
308            QamOrder::Qam16 => 4,
309            QamOrder::Qam64 => 8,
310            QamOrder::Qam256 => 16,
311        }
312    }
313}
314
315/// A QAM symbol (I + jQ)
316#[derive(Debug, Clone, Copy)]
317pub struct QamSymbol {
318    /// In-phase component
319    pub i: f64,
320    /// Quadrature component
321    pub q: f64,
322}
323
324/// Generate the QAM constellation map (Gray-coded)
325///
326/// Returns a vector of QamSymbol, one per constellation point, normalized
327/// to unit average power.
328pub fn qam_constellation(order: QamOrder) -> Vec<QamSymbol> {
329    let m = order.grid_dim();
330    let mut points = Vec::with_capacity(order.constellation_size());
331
332    for row in 0..m {
333        for col in 0..m {
334            // Map to symmetric grid centered at origin
335            let i_val = 2.0 * col as f64 - (m as f64 - 1.0);
336            let q_val = 2.0 * row as f64 - (m as f64 - 1.0);
337            points.push(QamSymbol { i: i_val, q: q_val });
338        }
339    }
340
341    // Normalize to unit average power
342    let avg_power: f64 =
343        points.iter().map(|p| p.i * p.i + p.q * p.q).sum::<f64>() / points.len() as f64;
344    let scale = if avg_power > 1e-20 {
345        1.0 / avg_power.sqrt()
346    } else {
347        1.0
348    };
349
350    for p in &mut points {
351        p.i *= scale;
352        p.q *= scale;
353    }
354
355    points
356}
357
358/// Map a bit sequence to QAM symbols
359///
360/// # Arguments
361///
362/// * `bits` - Input bit sequence (each element is 0 or 1)
363/// * `order` - QAM constellation order
364///
365/// # Returns
366///
367/// * Vector of QAM symbols
368pub fn qam_modulate_bits(bits: &[u8], order: QamOrder) -> SignalResult<Vec<QamSymbol>> {
369    let bps = order.bits_per_symbol();
370    if bits.len() % bps != 0 {
371        return Err(SignalError::ValueError(format!(
372            "Bit sequence length {} must be a multiple of {} for {:?}",
373            bits.len(),
374            bps,
375            order
376        )));
377    }
378
379    // Validate bits
380    if bits.iter().any(|&b| b > 1) {
381        return Err(SignalError::ValueError("Bits must be 0 or 1".to_string()));
382    }
383
384    let constellation = qam_constellation(order);
385    let mut symbols = Vec::with_capacity(bits.len() / bps);
386
387    for chunk in bits.chunks(bps) {
388        // Convert bit group to integer index
389        let mut index: usize = 0;
390        for &bit in chunk {
391            index = (index << 1) | (bit as usize);
392        }
393        if index >= constellation.len() {
394            return Err(SignalError::ValueError(format!(
395                "Symbol index {} out of range for constellation size {}",
396                index,
397                constellation.len()
398            )));
399        }
400        symbols.push(constellation[index]);
401    }
402
403    Ok(symbols)
404}
405
406/// Demodulate QAM symbols back to bits using minimum-distance hard decision
407///
408/// # Arguments
409///
410/// * `symbols` - Received QAM symbols (possibly noisy)
411/// * `order` - QAM constellation order
412///
413/// # Returns
414///
415/// * Demodulated bit sequence
416pub fn qam_demodulate_bits(symbols: &[QamSymbol], order: QamOrder) -> SignalResult<Vec<u8>> {
417    if symbols.is_empty() {
418        return Err(SignalError::ValueError(
419            "Symbol sequence must not be empty".to_string(),
420        ));
421    }
422
423    let constellation = qam_constellation(order);
424    let bps = order.bits_per_symbol();
425    let mut bits = Vec::with_capacity(symbols.len() * bps);
426
427    for sym in symbols {
428        // Find nearest constellation point (minimum Euclidean distance)
429        let mut best_idx = 0;
430        let mut best_dist = f64::MAX;
431        for (idx, point) in constellation.iter().enumerate() {
432            let di = sym.i - point.i;
433            let dq = sym.q - point.q;
434            let dist = di * di + dq * dq;
435            if dist < best_dist {
436                best_dist = dist;
437                best_idx = idx;
438            }
439        }
440
441        // Convert index back to bits
442        for bit_pos in (0..bps).rev() {
443            bits.push(((best_idx >> bit_pos) & 1) as u8);
444        }
445    }
446
447    Ok(bits)
448}
449
450/// Modulate QAM symbols onto a carrier for transmission
451///
452/// Produces a real-valued passband signal:
453///   s(t) = I(t) * cos(2*pi*fc*t) - Q(t) * sin(2*pi*fc*t)
454///
455/// Each symbol is held for `samples_per_symbol` samples.
456///
457/// # Arguments
458///
459/// * `symbols` - QAM symbol sequence
460/// * `carrier_freq` - Carrier frequency in Hz
461/// * `sample_rate` - Sample rate in Hz
462/// * `samples_per_symbol` - Number of samples per symbol period
463///
464/// # Returns
465///
466/// * Passband modulated signal
467pub fn qam_modulate_passband(
468    symbols: &[QamSymbol],
469    carrier_freq: f64,
470    sample_rate: f64,
471    samples_per_symbol: usize,
472) -> SignalResult<Vec<f64>> {
473    if symbols.is_empty() {
474        return Err(SignalError::ValueError(
475            "Symbol sequence must not be empty".to_string(),
476        ));
477    }
478    if carrier_freq <= 0.0 || sample_rate <= 0.0 {
479        return Err(SignalError::ValueError(
480            "Carrier frequency and sample rate must be positive".to_string(),
481        ));
482    }
483    if samples_per_symbol == 0 {
484        return Err(SignalError::ValueError(
485            "Samples per symbol must be positive".to_string(),
486        ));
487    }
488
489    let total_samples = symbols.len() * samples_per_symbol;
490    let omega_c = 2.0 * PI * carrier_freq / sample_rate;
491    let mut output = Vec::with_capacity(total_samples);
492
493    for (sym_idx, sym) in symbols.iter().enumerate() {
494        for k in 0..samples_per_symbol {
495            let t = (sym_idx * samples_per_symbol + k) as f64;
496            let sample = sym.i * (omega_c * t).cos() - sym.q * (omega_c * t).sin();
497            output.push(sample);
498        }
499    }
500
501    Ok(output)
502}
503
504// ---------------------------------------------------------------------------
505// Unified modulate / demodulate interface
506// ---------------------------------------------------------------------------
507
508/// Modulation method enumeration
509#[derive(Debug, Clone, Copy, PartialEq)]
510pub enum ModulationMethod {
511    /// Amplitude Modulation (DSB-SC)
512    Am,
513    /// Amplitude Modulation with modulation index
514    AmFc(f64),
515    /// Frequency Modulation with frequency deviation in Hz
516    Fm(f64),
517}
518
519/// Unified modulation function
520///
521/// # Arguments
522///
523/// * `signal` - Baseband signal
524/// * `carrier_freq` - Carrier frequency in Hz
525/// * `sample_rate` - Sample rate in Hz
526/// * `method` - Modulation method
527///
528/// # Returns
529///
530/// * Modulated signal
531pub fn modulate(
532    signal: &[f64],
533    carrier_freq: f64,
534    sample_rate: f64,
535    method: ModulationMethod,
536) -> SignalResult<Vec<f64>> {
537    match method {
538        ModulationMethod::Am => am_modulate(signal, carrier_freq, sample_rate, AmMode::DsbSc),
539        ModulationMethod::AmFc(mod_index) => {
540            am_modulate(signal, carrier_freq, sample_rate, AmMode::DsbFc(mod_index))
541        }
542        ModulationMethod::Fm(deviation) => {
543            fm_modulate(signal, carrier_freq, sample_rate, deviation)
544        }
545    }
546}
547
548/// Unified demodulation function
549///
550/// # Arguments
551///
552/// * `modulated` - Modulated signal
553/// * `carrier_freq` - Carrier frequency in Hz
554/// * `sample_rate` - Sample rate in Hz
555/// * `method` - Modulation method used
556///
557/// # Returns
558///
559/// * Demodulated baseband signal
560pub fn demodulate(
561    modulated: &[f64],
562    carrier_freq: f64,
563    sample_rate: f64,
564    method: ModulationMethod,
565) -> SignalResult<Vec<f64>> {
566    match method {
567        ModulationMethod::Am => am_demodulate(modulated, carrier_freq, sample_rate, AmMode::DsbSc),
568        ModulationMethod::AmFc(mod_index) => am_demodulate(
569            modulated,
570            carrier_freq,
571            sample_rate,
572            AmMode::DsbFc(mod_index),
573        ),
574        ModulationMethod::Fm(deviation) => fm_demodulate(modulated, sample_rate, deviation),
575    }
576}
577
578// ---------------------------------------------------------------------------
579// Internal helpers
580// ---------------------------------------------------------------------------
581
582fn validate_mod_params(signal: &[f64], carrier_freq: f64, sample_rate: f64) -> SignalResult<()> {
583    if signal.is_empty() {
584        return Err(SignalError::ValueError(
585            "Input signal must not be empty".to_string(),
586        ));
587    }
588    if carrier_freq <= 0.0 {
589        return Err(SignalError::ValueError(
590            "Carrier frequency must be positive".to_string(),
591        ));
592    }
593    if sample_rate <= 0.0 {
594        return Err(SignalError::ValueError(
595            "Sample rate must be positive".to_string(),
596        ));
597    }
598    if carrier_freq >= sample_rate / 2.0 {
599        return Err(SignalError::ValueError(
600            "Carrier frequency must be below Nyquist frequency".to_string(),
601        ));
602    }
603    Ok(())
604}
605
606/// Approximate Hilbert transform using FFT
607/// Returns the imaginary part of the analytic signal
608fn hilbert_transform_approx(signal: &[f64]) -> SignalResult<Vec<f64>> {
609    let n = signal.len();
610    if n == 0 {
611        return Ok(Vec::new());
612    }
613
614    // FFT
615    let spectrum = scirs2_fft::fft(signal, Some(n))
616        .map_err(|e| SignalError::ComputationError(format!("FFT failed: {}", e)))?;
617
618    // Create one-sided spectrum (Hilbert mask)
619    let mut analytic_spectrum = vec![scirs2_core::numeric::Complex64::new(0.0, 0.0); n];
620    analytic_spectrum[0] = spectrum[0]; // DC
621    if n % 2 == 0 && n > 1 {
622        analytic_spectrum[n / 2] = spectrum[n / 2]; // Nyquist
623    }
624    for i in 1..((n + 1) / 2) {
625        analytic_spectrum[i] = spectrum[i] * 2.0;
626    }
627
628    // IFFT
629    let analytic = scirs2_fft::ifft(&analytic_spectrum, Some(n))
630        .map_err(|e| SignalError::ComputationError(format!("IFFT failed: {}", e)))?;
631
632    // Return imaginary part
633    Ok(analytic.iter().map(|c| c.im).collect())
634}
635
636/// Simple moving average low-pass filter
637fn moving_average_lowpass(signal: &[f64], window_size: usize) -> Vec<f64> {
638    let n = signal.len();
639    if window_size == 0 || n == 0 {
640        return signal.to_vec();
641    }
642    let w = window_size.min(n);
643    let mut output = Vec::with_capacity(n);
644
645    let mut sum = 0.0;
646    // Initialize with first w samples
647    for i in 0..w.min(n) {
648        sum += signal[i];
649    }
650
651    for i in 0..n {
652        if i >= w {
653            sum += signal[i] - signal[i - w];
654        } else if i > 0 {
655            sum += signal[i.min(n - 1)];
656        }
657        let count = (i + 1).min(w);
658        output.push(sum / count as f64);
659    }
660
661    output
662}
663
664/// Unwrap phase angles to avoid discontinuities
665fn unwrap_phase_vec(phases: &[f64]) -> Vec<f64> {
666    if phases.is_empty() {
667        return Vec::new();
668    }
669    let mut unwrapped = vec![0.0; phases.len()];
670    unwrapped[0] = phases[0];
671    for i in 1..phases.len() {
672        let mut diff = phases[i] - phases[i - 1];
673        while diff > PI {
674            diff -= 2.0 * PI;
675        }
676        while diff <= -PI {
677            diff += 2.0 * PI;
678        }
679        unwrapped[i] = unwrapped[i - 1] + diff;
680    }
681    unwrapped
682}
683
684// ---------------------------------------------------------------------------
685// Tests
686// ---------------------------------------------------------------------------
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use approx::assert_relative_eq;
692
693    // ----- Validation tests -----
694
695    #[test]
696    fn test_am_modulate_validation() {
697        assert!(am_modulate(&[], 1000.0, 8000.0, AmMode::DsbSc).is_err());
698        assert!(am_modulate(&[1.0], 0.0, 8000.0, AmMode::DsbSc).is_err());
699        assert!(am_modulate(&[1.0], 1000.0, 0.0, AmMode::DsbSc).is_err());
700        assert!(am_modulate(&[1.0], 5000.0, 8000.0, AmMode::DsbSc).is_err()); // above Nyquist
701    }
702
703    #[test]
704    fn test_fm_modulate_validation() {
705        assert!(fm_modulate(&[], 1000.0, 8000.0, 75.0).is_err());
706        assert!(fm_modulate(&[1.0], 1000.0, 8000.0, 0.0).is_err()); // zero deviation
707    }
708
709    // ----- AM DSB-SC tests -----
710
711    #[test]
712    fn test_am_dsbsc_zero_signal() {
713        let signal = vec![0.0; 100];
714        let modulated =
715            am_modulate(&signal, 1000.0, 8000.0, AmMode::DsbSc).expect("AM modulation failed");
716        for &s in &modulated {
717            assert_relative_eq!(s, 0.0, epsilon = 1e-12);
718        }
719    }
720
721    #[test]
722    fn test_am_dsbsc_carrier_present() {
723        let sr = 8000.0;
724        let fc = 1000.0;
725        let n = 800;
726        let signal = vec![1.0; n]; // constant amplitude
727
728        let modulated = am_modulate(&signal, fc, sr, AmMode::DsbSc).expect("AM modulation failed");
729
730        // Should be a cosine at carrier frequency
731        for (i, &s) in modulated.iter().enumerate() {
732            let expected = (2.0 * PI * fc * i as f64 / sr).cos();
733            assert_relative_eq!(s, expected, epsilon = 1e-10);
734        }
735    }
736
737    #[test]
738    fn test_am_dsbsc_modulate_demodulate() {
739        let sr = 8000.0;
740        let fc = 2000.0;
741        let n = 400;
742        // Low-frequency message
743        let signal: Vec<f64> = (0..n)
744            .map(|i| (2.0 * PI * 100.0 * i as f64 / sr).sin())
745            .collect();
746
747        let modulated = am_modulate(&signal, fc, sr, AmMode::DsbSc).expect("AM modulation failed");
748        let demodulated =
749            am_demodulate(&modulated, fc, sr, AmMode::DsbSc).expect("AM demodulation failed");
750
751        assert_eq!(demodulated.len(), n);
752        // After demod + lowpass, the signal shape should be recovered (with some delay/attenuation)
753        assert!(demodulated.iter().any(|&x| x.abs() > 0.01));
754    }
755
756    #[test]
757    fn test_am_dsbsc_output_bounded() {
758        let sr = 8000.0;
759        let fc = 1500.0;
760        let signal: Vec<f64> = (0..200)
761            .map(|i| (2.0 * PI * 200.0 * i as f64 / sr).sin())
762            .collect();
763
764        let modulated = am_modulate(&signal, fc, sr, AmMode::DsbSc).expect("AM modulation failed");
765        // Modulated signal should be bounded by signal amplitude
766        let max_input = signal.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
767        for &s in &modulated {
768            assert!(s.abs() <= max_input + 1e-10);
769        }
770    }
771
772    #[test]
773    fn test_am_dsbsc_length_preserved() {
774        let signal = vec![0.5; 123];
775        let modulated =
776            am_modulate(&signal, 1000.0, 8000.0, AmMode::DsbSc).expect("AM modulation failed");
777        assert_eq!(modulated.len(), 123);
778    }
779
780    // ----- AM DSB-FC tests -----
781
782    #[test]
783    fn test_am_dsbfc_carrier_always_present() {
784        let sr = 8000.0;
785        let fc = 1000.0;
786        let signal = vec![0.0; 200]; // zero message
787
788        let modulated =
789            am_modulate(&signal, fc, sr, AmMode::DsbFc(0.5)).expect("AM-FC modulation failed");
790
791        // With zero message, output is just the carrier
792        for (i, &s) in modulated.iter().enumerate() {
793            let expected = (2.0 * PI * fc * i as f64 / sr).cos();
794            assert_relative_eq!(s, expected, epsilon = 1e-10);
795        }
796    }
797
798    #[test]
799    fn test_am_dsbfc_invalid_mod_index() {
800        let signal = vec![1.0; 100];
801        assert!(am_modulate(&signal, 1000.0, 8000.0, AmMode::DsbFc(0.0)).is_err());
802        assert!(am_modulate(&signal, 1000.0, 8000.0, AmMode::DsbFc(-0.5)).is_err());
803    }
804
805    #[test]
806    fn test_am_dsbfc_modulate_demodulate() {
807        let sr = 8000.0;
808        let fc = 2000.0;
809        let n = 400;
810        let signal: Vec<f64> = (0..n)
811            .map(|i| (2.0 * PI * 100.0 * i as f64 / sr).sin())
812            .collect();
813
814        let modulated =
815            am_modulate(&signal, fc, sr, AmMode::DsbFc(0.8)).expect("AM-FC modulation failed");
816        let demodulated = am_demodulate(&modulated, fc, sr, AmMode::DsbFc(0.8))
817            .expect("AM-FC demodulation failed");
818
819        assert_eq!(demodulated.len(), n);
820        // Should recover some signal shape
821        assert!(demodulated.iter().any(|&x| x.abs() > 0.001));
822    }
823
824    #[test]
825    fn test_am_dsbfc_envelope_positive() {
826        let sr = 8000.0;
827        let fc = 1500.0;
828        let signal: Vec<f64> = (0..200)
829            .map(|i| 0.5 * (2.0 * PI * 100.0 * i as f64 / sr).sin())
830            .collect();
831
832        let modulated =
833            am_modulate(&signal, fc, sr, AmMode::DsbFc(0.5)).expect("AM-FC modulation failed");
834
835        // Envelope = |1 + m*signal| should be positive for m <= 1 and |signal| <= 1
836        // This means the modulated signal shouldn't always be negative at carrier peaks
837        let has_positive = modulated.iter().any(|&x| x > 0.0);
838        assert!(has_positive);
839    }
840
841    #[test]
842    fn test_am_dsbfc_modulation_depth() {
843        let sr = 8000.0;
844        let fc = 2000.0;
845        let signal = vec![1.0; 200]; // constant
846
847        let low_mod = am_modulate(&signal, fc, sr, AmMode::DsbFc(0.3)).expect("Low mod AM failed");
848        let high_mod =
849            am_modulate(&signal, fc, sr, AmMode::DsbFc(0.9)).expect("High mod AM failed");
850
851        let low_max = low_mod.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
852        let high_max = high_mod.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
853
854        // Higher modulation index should produce larger peak amplitude
855        assert!(high_max > low_max);
856    }
857
858    // ----- AM SSB tests -----
859
860    #[test]
861    fn test_am_ssb_produces_output() {
862        let sr = 8000.0;
863        let fc = 2000.0;
864        let n = 256;
865        let signal: Vec<f64> = (0..n)
866            .map(|i| (2.0 * PI * 300.0 * i as f64 / sr).sin())
867            .collect();
868
869        let upper =
870            am_modulate(&signal, fc, sr, AmMode::SsbUpper).expect("SSB Upper modulation failed");
871        let lower =
872            am_modulate(&signal, fc, sr, AmMode::SsbLower).expect("SSB Lower modulation failed");
873
874        assert_eq!(upper.len(), n);
875        assert_eq!(lower.len(), n);
876        // Upper and lower sideband should be different
877        let diff: f64 = upper
878            .iter()
879            .zip(lower.iter())
880            .map(|(a, b)| (a - b).abs())
881            .sum();
882        assert!(diff > 0.1);
883    }
884
885    #[test]
886    fn test_am_ssb_demodulate() {
887        let sr = 8000.0;
888        let fc = 2000.0;
889        let n = 256;
890        let signal: Vec<f64> = (0..n)
891            .map(|i| (2.0 * PI * 300.0 * i as f64 / sr).sin())
892            .collect();
893
894        let modulated =
895            am_modulate(&signal, fc, sr, AmMode::SsbUpper).expect("SSB modulation failed");
896        let demodulated =
897            am_demodulate(&modulated, fc, sr, AmMode::SsbUpper).expect("SSB demodulation failed");
898
899        assert_eq!(demodulated.len(), n);
900        assert!(demodulated.iter().any(|&x| x.abs() > 0.001));
901    }
902
903    // ----- FM tests -----
904
905    #[test]
906    fn test_fm_modulate_constant_signal() {
907        let sr = 8000.0;
908        let fc = 1000.0;
909        let deviation = 75.0;
910        let signal = vec![0.0; 200]; // zero message = pure carrier
911
912        let modulated = fm_modulate(&signal, fc, sr, deviation).expect("FM modulation failed");
913
914        // With zero message, should be a pure cosine at carrier freq
915        for (i, &s) in modulated.iter().enumerate() {
916            let expected = (2.0 * PI * fc * i as f64 / sr).cos();
917            assert_relative_eq!(s, expected, epsilon = 1e-10);
918        }
919    }
920
921    #[test]
922    fn test_fm_modulate_output_bounded() {
923        let sr = 8000.0;
924        let fc = 1500.0;
925        let deviation = 200.0;
926        let signal: Vec<f64> = (0..200)
927            .map(|i| (2.0 * PI * 100.0 * i as f64 / sr).sin())
928            .collect();
929
930        let modulated = fm_modulate(&signal, fc, sr, deviation).expect("FM modulation failed");
931
932        // FM output is always a unit-amplitude cosine
933        for &s in &modulated {
934            assert!(s.abs() <= 1.0 + 1e-12);
935        }
936    }
937
938    #[test]
939    fn test_fm_modulate_demodulate_roundtrip() {
940        let sr = 8000.0;
941        let fc = 2000.0;
942        let deviation = 300.0;
943        let n = 500;
944        let message_freq = 100.0;
945
946        let signal: Vec<f64> = (0..n)
947            .map(|i| (2.0 * PI * message_freq * i as f64 / sr).sin())
948            .collect();
949
950        let modulated = fm_modulate(&signal, fc, sr, deviation).expect("FM modulation failed");
951        let demodulated = fm_demodulate(&modulated, sr, deviation).expect("FM demodulation failed");
952
953        assert_eq!(demodulated.len(), n);
954        // Demodulated signal should have some periodic content
955        let energy: f64 = demodulated.iter().map(|x| x * x).sum();
956        assert!(energy > 0.0);
957    }
958
959    #[test]
960    fn test_fm_modulate_length_preserved() {
961        let signal = vec![0.3; 77];
962        let modulated = fm_modulate(&signal, 1000.0, 8000.0, 75.0).expect("FM modulation failed");
963        assert_eq!(modulated.len(), 77);
964    }
965
966    #[test]
967    fn test_fm_demodulate_validation() {
968        assert!(fm_demodulate(&[], 8000.0, 75.0).is_err());
969        assert!(fm_demodulate(&[1.0], 0.0, 75.0).is_err());
970        assert!(fm_demodulate(&[1.0], 8000.0, 0.0).is_err());
971    }
972
973    #[test]
974    fn test_fm_different_deviations() {
975        let sr = 8000.0;
976        let fc = 1500.0;
977        let n = 200;
978        let signal: Vec<f64> = (0..n)
979            .map(|i| (2.0 * PI * 100.0 * i as f64 / sr).sin())
980            .collect();
981
982        let low_dev = fm_modulate(&signal, fc, sr, 50.0).expect("FM low dev failed");
983        let high_dev = fm_modulate(&signal, fc, sr, 500.0).expect("FM high dev failed");
984
985        // Different deviations should produce different signals
986        let diff: f64 = low_dev
987            .iter()
988            .zip(high_dev.iter())
989            .map(|(a, b)| (a - b).abs())
990            .sum();
991        assert!(diff > 0.1);
992    }
993
994    // ----- QAM tests -----
995
996    #[test]
997    fn test_qam_constellation_sizes() {
998        assert_eq!(qam_constellation(QamOrder::Qam4).len(), 4);
999        assert_eq!(qam_constellation(QamOrder::Qam16).len(), 16);
1000        assert_eq!(qam_constellation(QamOrder::Qam64).len(), 64);
1001        assert_eq!(qam_constellation(QamOrder::Qam256).len(), 256);
1002    }
1003
1004    #[test]
1005    fn test_qam_constellation_unit_power() {
1006        for order in &[QamOrder::Qam4, QamOrder::Qam16, QamOrder::Qam64] {
1007            let const_map = qam_constellation(*order);
1008            let avg_power: f64 = const_map.iter().map(|p| p.i * p.i + p.q * p.q).sum::<f64>()
1009                / const_map.len() as f64;
1010            assert_relative_eq!(avg_power, 1.0, epsilon = 1e-6);
1011        }
1012    }
1013
1014    #[test]
1015    fn test_qam_constellation_symmetric() {
1016        for order in &[QamOrder::Qam4, QamOrder::Qam16, QamOrder::Qam64] {
1017            let constellation = qam_constellation(*order);
1018            // Check that for every point (i, q) there exists (-i, -q)
1019            for p in &constellation {
1020                let has_conjugate = constellation
1021                    .iter()
1022                    .any(|q| (q.i + p.i).abs() < 1e-10 && (q.q + p.q).abs() < 1e-10);
1023                assert!(
1024                    has_conjugate,
1025                    "Missing conjugate point for ({}, {})",
1026                    p.i, p.q
1027                );
1028            }
1029        }
1030    }
1031
1032    #[test]
1033    fn test_qam4_modulate_demodulate() {
1034        let bits: Vec<u8> = vec![0, 0, 0, 1, 1, 0, 1, 1];
1035        let symbols = qam_modulate_bits(&bits, QamOrder::Qam4).expect("QAM4 modulation failed");
1036        assert_eq!(symbols.len(), 4); // 8 bits / 2 bps
1037
1038        let recovered =
1039            qam_demodulate_bits(&symbols, QamOrder::Qam4).expect("QAM4 demodulation failed");
1040        assert_eq!(recovered, bits);
1041    }
1042
1043    #[test]
1044    fn test_qam16_modulate_demodulate() {
1045        let bits: Vec<u8> = vec![0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0];
1046        let symbols = qam_modulate_bits(&bits, QamOrder::Qam16).expect("QAM16 modulation failed");
1047        assert_eq!(symbols.len(), 4); // 16 bits / 4 bps
1048
1049        let recovered =
1050            qam_demodulate_bits(&symbols, QamOrder::Qam16).expect("QAM16 demodulation failed");
1051        assert_eq!(recovered, bits);
1052    }
1053
1054    #[test]
1055    fn test_qam64_modulate_demodulate() {
1056        let bits: Vec<u8> = vec![0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1057        let symbols = qam_modulate_bits(&bits, QamOrder::Qam64).expect("QAM64 modulation failed");
1058        assert_eq!(symbols.len(), 2); // 12 bits / 6 bps
1059
1060        let recovered =
1061            qam_demodulate_bits(&symbols, QamOrder::Qam64).expect("QAM64 demodulation failed");
1062        assert_eq!(recovered, bits);
1063    }
1064
1065    #[test]
1066    fn test_qam_modulate_invalid_bits() {
1067        // Wrong length for QAM16 (not multiple of 4)
1068        let bits: Vec<u8> = vec![0, 0, 0];
1069        assert!(qam_modulate_bits(&bits, QamOrder::Qam16).is_err());
1070
1071        // Invalid bit value
1072        let bits: Vec<u8> = vec![0, 0, 2, 0];
1073        assert!(qam_modulate_bits(&bits, QamOrder::Qam16).is_err());
1074    }
1075
1076    #[test]
1077    fn test_qam_demodulate_empty() {
1078        let symbols: Vec<QamSymbol> = vec![];
1079        assert!(qam_demodulate_bits(&symbols, QamOrder::Qam4).is_err());
1080    }
1081
1082    #[test]
1083    fn test_qam_passband_length() {
1084        let symbols = vec![QamSymbol { i: 1.0, q: 0.0 }, QamSymbol { i: 0.0, q: 1.0 }];
1085        let passband =
1086            qam_modulate_passband(&symbols, 1000.0, 8000.0, 10).expect("QAM passband failed");
1087        assert_eq!(passband.len(), 20); // 2 symbols * 10 sps
1088    }
1089
1090    #[test]
1091    fn test_qam_passband_validation() {
1092        let symbols = vec![QamSymbol { i: 1.0, q: 0.0 }];
1093        assert!(qam_modulate_passband(&symbols, 0.0, 8000.0, 10).is_err());
1094        assert!(qam_modulate_passband(&symbols, 1000.0, 0.0, 10).is_err());
1095        assert!(qam_modulate_passband(&symbols, 1000.0, 8000.0, 0).is_err());
1096        assert!(qam_modulate_passband(&[], 1000.0, 8000.0, 10).is_err());
1097    }
1098
1099    #[test]
1100    fn test_qam_passband_bounded() {
1101        let bits: Vec<u8> = vec![0, 1, 1, 0, 1, 1, 0, 0];
1102        let symbols = qam_modulate_bits(&bits, QamOrder::Qam4).expect("QAM4 mod failed");
1103        let passband =
1104            qam_modulate_passband(&symbols, 1000.0, 8000.0, 20).expect("QAM passband failed");
1105
1106        // All samples should be finite
1107        assert!(passband.iter().all(|x| x.is_finite()));
1108    }
1109
1110    // ----- Unified interface tests -----
1111
1112    #[test]
1113    fn test_unified_modulate_am() {
1114        let signal = vec![1.0; 100];
1115        let modulated =
1116            modulate(&signal, 1000.0, 8000.0, ModulationMethod::Am).expect("Unified AM mod failed");
1117        assert_eq!(modulated.len(), 100);
1118    }
1119
1120    #[test]
1121    fn test_unified_modulate_fm() {
1122        let signal = vec![0.5; 100];
1123        let modulated = modulate(&signal, 1000.0, 8000.0, ModulationMethod::Fm(200.0))
1124            .expect("Unified FM mod failed");
1125        assert_eq!(modulated.len(), 100);
1126    }
1127
1128    #[test]
1129    fn test_unified_demodulate_am() {
1130        let signal: Vec<f64> = (0..200)
1131            .map(|i| (2.0 * PI * 100.0 * i as f64 / 8000.0).sin())
1132            .collect();
1133        let modulated =
1134            modulate(&signal, 1500.0, 8000.0, ModulationMethod::Am).expect("AM mod failed");
1135        let demodulated =
1136            demodulate(&modulated, 1500.0, 8000.0, ModulationMethod::Am).expect("AM demod failed");
1137        assert_eq!(demodulated.len(), signal.len());
1138    }
1139
1140    #[test]
1141    fn test_unified_demodulate_fm() {
1142        let signal: Vec<f64> = (0..200)
1143            .map(|i| (2.0 * PI * 100.0 * i as f64 / 8000.0).sin())
1144            .collect();
1145        let modulated =
1146            modulate(&signal, 1500.0, 8000.0, ModulationMethod::Fm(150.0)).expect("FM mod failed");
1147        let demodulated = demodulate(&modulated, 1500.0, 8000.0, ModulationMethod::Fm(150.0))
1148            .expect("FM demod failed");
1149        assert_eq!(demodulated.len(), signal.len());
1150    }
1151
1152    // ----- Internal helper tests -----
1153
1154    #[test]
1155    fn test_hilbert_transform_approx() {
1156        let n = 128;
1157        let signal: Vec<f64> = (0..n)
1158            .map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).cos())
1159            .collect();
1160
1161        let hilbert = hilbert_transform_approx(&signal).expect("Hilbert failed");
1162        assert_eq!(hilbert.len(), n);
1163
1164        // For cos(wt), Hilbert transform should approximate sin(wt)
1165        let expected: Vec<f64> = (0..n)
1166            .map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
1167            .collect();
1168
1169        // Check correlation (should be high)
1170        let corr: f64 = hilbert
1171            .iter()
1172            .zip(expected.iter())
1173            .map(|(a, b)| a * b)
1174            .sum::<f64>();
1175        let norm_h: f64 = hilbert.iter().map(|x| x * x).sum::<f64>().sqrt();
1176        let norm_e: f64 = expected.iter().map(|x| x * x).sum::<f64>().sqrt();
1177        let normalized_corr = if norm_h * norm_e > 1e-20 {
1178            corr / (norm_h * norm_e)
1179        } else {
1180            0.0
1181        };
1182        assert!(
1183            normalized_corr > 0.9,
1184            "Hilbert correlation too low: {}",
1185            normalized_corr
1186        );
1187    }
1188
1189    #[test]
1190    fn test_moving_average_lowpass() {
1191        let signal = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1192        let filtered = moving_average_lowpass(&signal, 4);
1193        assert_eq!(filtered.len(), signal.len());
1194        // Filtered signal should be smoother (less variation)
1195        let raw_var: f64 = signal.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
1196        let filt_var: f64 = filtered.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
1197        assert!(filt_var < raw_var);
1198    }
1199
1200    #[test]
1201    fn test_unwrap_phase_vec() {
1202        let phases = vec![0.0, 1.0, 2.0, 3.0, -3.0, -2.0, -1.0, 0.0];
1203        let unwrapped = unwrap_phase_vec(&phases);
1204        assert_eq!(unwrapped.len(), phases.len());
1205        // All consecutive differences should be in (-pi, pi]
1206        for w in unwrapped.windows(2) {
1207            let diff = w[1] - w[0];
1208            assert!(diff > -PI && diff <= PI, "diff = {} out of range", diff);
1209        }
1210    }
1211}