scirs2_integrate/specialized/finance/utils/
calibration.rs

1//! Model calibration utilities
2//!
3//! This module provides tools for calibrating financial models to market data including
4//! optimization algorithms, objective functions, and regularization techniques.
5//!
6//! # Features
7//! - Volatility surface calibration using local/global optimization
8//! - Heston model parameter estimation from market prices
9//! - Weighted least squares with bid-ask spreads
10//! - Loss functions (MSE, MAE, relative error)
11//! - Regularization for parameter stability
12//!
13//! # Example
14//! ```
15//! use scirs2_integrate::specialized::finance::utils::calibration::{
16//!     ImpliedVolatilitySurface, HestonCalibrator, LossFunction
17//! };
18//!
19//! // Create IV surface from market quotes
20//! let strikes = vec![90.0, 95.0, 100.0, 105.0, 110.0];
21//! let maturities = vec![0.25, 0.5, 1.0];
22//! let mut surface = ImpliedVolatilitySurface::new(100.0, 0.05, 0.0);
23//!
24//! // Add market data
25//! for &t in &maturities {
26//!     for &k in &strikes {
27//!         surface.add_quote(k, t, 0.20, Some(0.01));
28//!     }
29//! }
30//!
31//! // Calibrate to Heston model
32//! let calibrator = HestonCalibrator::new(surface, LossFunction::WeightedMSE);
33//! // let result = calibrator.calibrate()?;
34//! ```
35
36use crate::error::{IntegrateError, IntegrateResult as Result};
37use crate::specialized::finance::derivatives::vanilla::EuropeanOption;
38use crate::specialized::finance::types::OptionType;
39use std::collections::HashMap;
40
41// ============================================================================
42// Nelder-Mead Simplex Optimizer
43// ============================================================================
44
45/// Result from Nelder-Mead optimization
46#[derive(Debug, Clone)]
47struct OptimizationResult {
48    /// Optimized parameters
49    parameters: Vec<f64>,
50    /// Final objective function value
51    final_value: f64,
52    /// Number of iterations
53    iterations: usize,
54    /// Whether optimization converged
55    converged: bool,
56}
57
58/// Nelder-Mead simplex optimizer for unconstrained optimization
59///
60/// This is a derivative-free optimization method that maintains a simplex
61/// of n+1 points in n-dimensional space and iteratively improves it.
62struct NelderMeadOptimizer {
63    /// Initial guess
64    initial_guess: Vec<f64>,
65    /// Parameter bounds (min, max)
66    bounds: Vec<(f64, f64)>,
67    /// Maximum iterations
68    max_iterations: usize,
69    /// Convergence tolerance
70    tolerance: f64,
71    /// Reflection coefficient (alpha)
72    alpha: f64,
73    /// Expansion coefficient (gamma)
74    gamma: f64,
75    /// Contraction coefficient (rho)
76    rho: f64,
77    /// Shrink coefficient (sigma)
78    sigma: f64,
79}
80
81impl NelderMeadOptimizer {
82    /// Create a new Nelder-Mead optimizer
83    fn new(
84        initial_guess: Vec<f64>,
85        bounds: Vec<(f64, f64)>,
86        max_iterations: usize,
87        tolerance: f64,
88    ) -> Self {
89        Self {
90            initial_guess,
91            bounds,
92            max_iterations,
93            tolerance,
94            alpha: 1.0, // Reflection
95            gamma: 2.0, // Expansion
96            rho: 0.5,   // Contraction
97            sigma: 0.5, // Shrink
98        }
99    }
100
101    /// Optimize the objective function
102    fn optimize<F>(&self, objective: &F) -> Result<OptimizationResult>
103    where
104        F: Fn(&[f64]) -> f64,
105    {
106        let n = self.initial_guess.len();
107
108        // Initialize simplex (n+1 points)
109        let mut simplex = self.initialize_simplex(n);
110
111        // Evaluate all simplex points
112        let mut values: Vec<f64> = simplex.iter().map(|x| objective(x)).collect();
113
114        let mut iterations = 0;
115
116        while iterations < self.max_iterations {
117            // Sort simplex by objective values
118            let mut indices: Vec<usize> = (0..simplex.len()).collect();
119            indices.sort_by(|&i, &j| values[i].partial_cmp(&values[j]).unwrap());
120
121            let best_idx = indices[0];
122            let worst_idx = indices[n];
123            let second_worst_idx = indices[n - 1];
124
125            // Check convergence: standard deviation of function values
126            let mean_val: f64 = values.iter().sum::<f64>() / values.len() as f64;
127            let std_dev = (values.iter().map(|v| (v - mean_val).powi(2)).sum::<f64>()
128                / values.len() as f64)
129                .sqrt();
130
131            if std_dev < self.tolerance {
132                return Ok(OptimizationResult {
133                    parameters: simplex[best_idx].clone(),
134                    final_value: values[best_idx],
135                    iterations,
136                    converged: true,
137                });
138            }
139
140            // Calculate centroid (excluding worst point)
141            let centroid = self.calculate_centroid(&simplex, worst_idx);
142
143            // Reflection
144            let reflected = self.reflect(&simplex[worst_idx], &centroid, self.alpha);
145            let reflected = self.project_to_bounds(&reflected);
146            let reflected_value = objective(&reflected);
147
148            if reflected_value < values[second_worst_idx] && reflected_value >= values[best_idx] {
149                // Accept reflection
150                simplex[worst_idx] = reflected;
151                values[worst_idx] = reflected_value;
152            } else if reflected_value < values[best_idx] {
153                // Try expansion
154                let expanded =
155                    self.reflect(&simplex[worst_idx], &centroid, self.alpha * self.gamma);
156                let expanded = self.project_to_bounds(&expanded);
157                let expanded_value = objective(&expanded);
158
159                if expanded_value < reflected_value {
160                    simplex[worst_idx] = expanded;
161                    values[worst_idx] = expanded_value;
162                } else {
163                    simplex[worst_idx] = reflected;
164                    values[worst_idx] = reflected_value;
165                }
166            } else {
167                // Try contraction
168                let contracted = if reflected_value < values[worst_idx] {
169                    // Outside contraction
170                    self.contract(&reflected, &centroid, self.rho)
171                } else {
172                    // Inside contraction
173                    self.contract(&simplex[worst_idx], &centroid, self.rho)
174                };
175                let contracted = self.project_to_bounds(&contracted);
176                let contracted_value = objective(&contracted);
177
178                if contracted_value < values[worst_idx].min(reflected_value) {
179                    simplex[worst_idx] = contracted;
180                    values[worst_idx] = contracted_value;
181                } else {
182                    // Shrink entire simplex toward best point
183                    for i in 0..simplex.len() {
184                        if i != best_idx {
185                            simplex[i] = self.shrink(&simplex[i], &simplex[best_idx], self.sigma);
186                            simplex[i] = self.project_to_bounds(&simplex[i]);
187                            values[i] = objective(&simplex[i]);
188                        }
189                    }
190                }
191            }
192
193            iterations += 1;
194        }
195
196        // Max iterations reached
197        let mut indices: Vec<usize> = (0..simplex.len()).collect();
198        indices.sort_by(|&i, &j| values[i].partial_cmp(&values[j]).unwrap());
199        let best_idx = indices[0];
200
201        Ok(OptimizationResult {
202            parameters: simplex[best_idx].clone(),
203            final_value: values[best_idx],
204            iterations,
205            converged: false,
206        })
207    }
208
209    /// Initialize simplex around initial guess
210    fn initialize_simplex(&self, n: usize) -> Vec<Vec<f64>> {
211        let mut simplex = Vec::with_capacity(n + 1);
212
213        // First point is the initial guess
214        simplex.push(self.initial_guess.clone());
215
216        // Create n additional points by perturbing each dimension
217        for i in 0..n {
218            let mut point = self.initial_guess.clone();
219            let delta = if self.initial_guess[i].abs() > 1e-10 {
220                0.05 * self.initial_guess[i] // 5% perturbation
221            } else {
222                0.00025 // Small absolute perturbation for near-zero values
223            };
224            point[i] += delta;
225            point = self.project_to_bounds(&point);
226            simplex.push(point);
227        }
228
229        simplex
230    }
231
232    /// Calculate centroid of simplex excluding specified point
233    fn calculate_centroid(&self, simplex: &[Vec<f64>], exclude_idx: usize) -> Vec<f64> {
234        let n = simplex[0].len();
235        let mut centroid = vec![0.0; n];
236
237        for (i, point) in simplex.iter().enumerate() {
238            if i != exclude_idx {
239                for (j, &val) in point.iter().enumerate() {
240                    centroid[j] += val;
241                }
242            }
243        }
244
245        let count = simplex.len() - 1;
246        for val in &mut centroid {
247            *val /= count as f64;
248        }
249
250        centroid
251    }
252
253    /// Reflect point through centroid
254    fn reflect(&self, point: &[f64], centroid: &[f64], coeff: f64) -> Vec<f64> {
255        point
256            .iter()
257            .zip(centroid.iter())
258            .map(|(&p, &c)| c + coeff * (c - p))
259            .collect()
260    }
261
262    /// Contract point toward centroid
263    fn contract(&self, point: &[f64], centroid: &[f64], coeff: f64) -> Vec<f64> {
264        point
265            .iter()
266            .zip(centroid.iter())
267            .map(|(&p, &c)| c + coeff * (p - c))
268            .collect()
269    }
270
271    /// Shrink point toward best point
272    fn shrink(&self, point: &[f64], best: &[f64], coeff: f64) -> Vec<f64> {
273        point
274            .iter()
275            .zip(best.iter())
276            .map(|(&p, &b)| b + coeff * (p - b))
277            .collect()
278    }
279
280    /// Project point to within bounds
281    fn project_to_bounds(&self, point: &[f64]) -> Vec<f64> {
282        point
283            .iter()
284            .zip(self.bounds.iter())
285            .map(|(&val, &(min, max))| val.max(min).min(max))
286            .collect()
287    }
288}
289
290// ============================================================================
291// Calibration Data Structures
292// ============================================================================
293
294/// Market quote for an option
295#[derive(Debug, Clone)]
296pub struct OptionQuote {
297    /// Strike price
298    pub strike: f64,
299    /// Time to maturity (years)
300    pub maturity: f64,
301    /// Option type (call/put)
302    pub option_type: OptionType,
303    /// Market price
304    pub market_price: f64,
305    /// Bid-ask spread (optional, for weighting)
306    pub bid_ask_spread: Option<f64>,
307}
308
309impl OptionQuote {
310    /// Create a new option quote
311    pub fn new(
312        strike: f64,
313        maturity: f64,
314        option_type: OptionType,
315        market_price: f64,
316        bid_ask_spread: Option<f64>,
317    ) -> Result<Self> {
318        if strike <= 0.0 {
319            return Err(IntegrateError::ValueError(
320                "Strike must be positive".to_string(),
321            ));
322        }
323        if maturity <= 0.0 {
324            return Err(IntegrateError::ValueError(
325                "Maturity must be positive".to_string(),
326            ));
327        }
328        if market_price < 0.0 {
329            return Err(IntegrateError::ValueError(
330                "Market price cannot be negative".to_string(),
331            ));
332        }
333        if let Some(spread) = bid_ask_spread {
334            if spread < 0.0 {
335                return Err(IntegrateError::ValueError(
336                    "Bid-ask spread cannot be negative".to_string(),
337                ));
338            }
339        }
340
341        Ok(Self {
342            strike,
343            maturity,
344            option_type,
345            market_price,
346            bid_ask_spread,
347        })
348    }
349
350    /// Calculate weight based on bid-ask spread (tighter spread = higher weight)
351    pub fn weight(&self) -> f64 {
352        match self.bid_ask_spread {
353            Some(spread) if spread > 1e-8 => 1.0 / spread,
354            _ => 1.0,
355        }
356    }
357}
358
359/// Implied volatility surface
360pub struct ImpliedVolatilitySurface {
361    /// Spot price
362    spot: f64,
363    /// Risk-free rate
364    rate: f64,
365    /// Dividend yield
366    dividend: f64,
367    /// Market quotes indexed by (strike, maturity)
368    quotes: HashMap<(String, String), (f64, Option<f64>)>,
369}
370
371impl ImpliedVolatilitySurface {
372    /// Create a new implied volatility surface
373    pub fn new(spot: f64, rate: f64, dividend: f64) -> Self {
374        Self {
375            spot,
376            rate,
377            dividend,
378            quotes: HashMap::new(),
379        }
380    }
381
382    /// Add an implied volatility quote
383    pub fn add_quote(
384        &mut self,
385        strike: f64,
386        maturity: f64,
387        implied_vol: f64,
388        bid_ask_spread: Option<f64>,
389    ) {
390        let key = (format!("{:.4}", strike), format!("{:.4}", maturity));
391        self.quotes.insert(key, (implied_vol, bid_ask_spread));
392    }
393
394    /// Get implied volatility for a given strike and maturity
395    pub fn get_vol(&self, strike: f64, maturity: f64) -> Option<f64> {
396        let key = (format!("{:.4}", strike), format!("{:.4}", maturity));
397        self.quotes.get(&key).map(|(vol, _)| *vol)
398    }
399
400    /// Get all strikes
401    pub fn strikes(&self) -> Vec<f64> {
402        let mut strikes: Vec<f64> = self
403            .quotes
404            .keys()
405            .map(|(k, _)| k.parse::<f64>().unwrap_or(0.0))
406            .collect();
407        strikes.sort_by(|a, b| a.partial_cmp(b).unwrap());
408        strikes.dedup();
409        strikes
410    }
411
412    /// Get all maturities
413    pub fn maturities(&self) -> Vec<f64> {
414        let mut maturities: Vec<f64> = self
415            .quotes
416            .keys()
417            .map(|(_, t)| t.parse::<f64>().unwrap_or(0.0))
418            .collect();
419        maturities.sort_by(|a, b| a.partial_cmp(b).unwrap());
420        maturities.dedup();
421        maturities
422    }
423
424    /// Convert to option quotes (for calibration)
425    pub fn to_option_quotes(&self) -> Result<Vec<OptionQuote>> {
426        let mut quotes = Vec::new();
427
428        for ((strike_str, maturity_str), (vol, spread)) in &self.quotes {
429            let strike: f64 = strike_str
430                .parse()
431                .map_err(|_| IntegrateError::ValueError("Invalid strike format".to_string()))?;
432            let maturity: f64 = maturity_str
433                .parse()
434                .map_err(|_| IntegrateError::ValueError("Invalid maturity format".to_string()))?;
435
436            // Assume call options (can be extended)
437            let option = EuropeanOption::new(
438                self.spot,
439                strike,
440                self.rate,
441                self.dividend,
442                *vol,
443                maturity,
444                OptionType::Call,
445            );
446
447            let market_price = option.price();
448
449            quotes.push(OptionQuote::new(
450                strike,
451                maturity,
452                OptionType::Call,
453                market_price,
454                *spread,
455            )?);
456        }
457
458        Ok(quotes)
459    }
460
461    /// Number of quotes in the surface
462    pub fn size(&self) -> usize {
463        self.quotes.len()
464    }
465}
466
467/// Loss function types for calibration
468#[derive(Debug, Clone, Copy)]
469pub enum LossFunction {
470    /// Mean squared error
471    MSE,
472    /// Mean absolute error
473    MAE,
474    /// Weighted MSE (by inverse bid-ask spread)
475    WeightedMSE,
476    /// Relative error (percentage)
477    RelativeError,
478}
479
480impl LossFunction {
481    /// Calculate loss between model and market prices
482    pub fn calculate(&self, quotes: &[OptionQuote], model_prices: &[f64]) -> Result<f64> {
483        if quotes.len() != model_prices.len() {
484            return Err(IntegrateError::ValueError(
485                "Mismatched number of quotes and prices".to_string(),
486            ));
487        }
488
489        if quotes.is_empty() {
490            return Err(IntegrateError::ValueError("No quotes provided".to_string()));
491        }
492
493        match self {
494            LossFunction::MSE => {
495                let sum: f64 = quotes
496                    .iter()
497                    .zip(model_prices.iter())
498                    .map(|(q, &p)| (q.market_price - p).powi(2))
499                    .sum();
500                Ok(sum / quotes.len() as f64)
501            }
502            LossFunction::MAE => {
503                let sum: f64 = quotes
504                    .iter()
505                    .zip(model_prices.iter())
506                    .map(|(q, &p)| (q.market_price - p).abs())
507                    .sum();
508                Ok(sum / quotes.len() as f64)
509            }
510            LossFunction::WeightedMSE => {
511                let weighted_sum: f64 = quotes
512                    .iter()
513                    .zip(model_prices.iter())
514                    .map(|(q, &p)| {
515                        let error = (q.market_price - p).powi(2);
516                        error * q.weight()
517                    })
518                    .sum();
519
520                let total_weight: f64 = quotes.iter().map(|q| q.weight()).sum();
521                Ok(weighted_sum / total_weight)
522            }
523            LossFunction::RelativeError => {
524                let sum: f64 = quotes
525                    .iter()
526                    .zip(model_prices.iter())
527                    .map(|(q, &p)| {
528                        if q.market_price.abs() > 1e-8 {
529                            ((q.market_price - p) / q.market_price).abs()
530                        } else {
531                            (q.market_price - p).abs()
532                        }
533                    })
534                    .sum();
535                Ok(sum / quotes.len() as f64)
536            }
537        }
538    }
539}
540
541/// Calibration result
542#[derive(Debug, Clone)]
543pub struct CalibrationResult {
544    /// Calibrated parameters
545    pub parameters: Vec<f64>,
546    /// Parameter names
547    pub parameter_names: Vec<String>,
548    /// Final loss value
549    pub loss: f64,
550    /// Number of iterations
551    pub iterations: usize,
552    /// Convergence status
553    pub converged: bool,
554}
555
556impl CalibrationResult {
557    /// Create a new calibration result
558    pub fn new(
559        parameters: Vec<f64>,
560        parameter_names: Vec<String>,
561        loss: f64,
562        iterations: usize,
563        converged: bool,
564    ) -> Self {
565        Self {
566            parameters,
567            parameter_names,
568            loss,
569            iterations,
570            converged,
571        }
572    }
573
574    /// Get parameter by name
575    pub fn get_parameter(&self, name: &str) -> Option<f64> {
576        self.parameter_names
577            .iter()
578            .position(|n| n == name)
579            .map(|i| self.parameters[i])
580    }
581}
582
583/// Heston model calibrator
584pub struct HestonCalibrator {
585    /// Market implied volatility surface
586    surface: ImpliedVolatilitySurface,
587    /// Loss function
588    loss_function: LossFunction,
589    /// Maximum iterations
590    max_iterations: usize,
591    /// Convergence tolerance
592    tolerance: f64,
593}
594
595impl HestonCalibrator {
596    /// Create a new Heston calibrator
597    pub fn new(surface: ImpliedVolatilitySurface, loss_function: LossFunction) -> Self {
598        Self {
599            surface,
600            loss_function,
601            max_iterations: 1000,
602            tolerance: 1e-6,
603        }
604    }
605
606    /// Set maximum iterations
607    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
608        self.max_iterations = max_iterations;
609        self
610    }
611
612    /// Set convergence tolerance
613    pub fn with_tolerance(mut self, tolerance: f64) -> Self {
614        self.tolerance = tolerance;
615        self
616    }
617
618    /// Calibrate Heston model parameters using Nelder-Mead optimization
619    ///
620    /// Returns: CalibrationResult with (kappa, theta, sigma, rho, v0)
621    ///
622    /// Parameter constraints:
623    /// - kappa > 0 (mean reversion speed)
624    /// - theta > 0 (long-run variance)
625    /// - sigma > 0 (vol of vol)
626    /// - -1 < rho < 1 (correlation)
627    /// - v0 > 0 (initial variance)
628    /// - Feller condition: 2*kappa*theta >= sigma^2 (for positivity)
629    pub fn calibrate(&self) -> Result<CalibrationResult> {
630        // Initial guess for Heston parameters (reasonable defaults)
631        let initial_guess = vec![
632            2.0,  // kappa: mean reversion speed
633            0.04, // theta: long-run variance (20% vol)
634            0.3,  // sigma: vol of vol
635            -0.5, // rho: correlation (typically negative)
636            0.04, // v0: initial variance
637        ];
638
639        // Parameter bounds: [min, max] for each parameter
640        let bounds = vec![
641            (0.01, 10.0),  // kappa
642            (0.001, 1.0),  // theta
643            (0.01, 2.0),   // sigma
644            (-0.99, 0.99), // rho
645            (0.001, 1.0),  // v0
646        ];
647
648        // Run Nelder-Mead optimization
649        let optimizer = NelderMeadOptimizer::new(
650            initial_guess.clone(),
651            bounds,
652            self.max_iterations,
653            self.tolerance,
654        );
655
656        let objective = |params: &[f64]| -> f64 { self.heston_objective(params).unwrap_or(1e10) };
657
658        let result = optimizer.optimize(&objective)?;
659
660        let parameter_names = vec![
661            "kappa".to_string(),
662            "theta".to_string(),
663            "sigma".to_string(),
664            "rho".to_string(),
665            "v0".to_string(),
666        ];
667
668        Ok(CalibrationResult::new(
669            result.parameters,
670            parameter_names,
671            result.final_value,
672            result.iterations,
673            result.converged,
674        ))
675    }
676
677    /// Objective function for Heston calibration
678    fn heston_objective(&self, params: &[f64]) -> Result<f64> {
679        if params.len() != 5 {
680            return Err(IntegrateError::ValueError(
681                "Heston calibration requires 5 parameters".to_string(),
682            ));
683        }
684
685        let kappa = params[0];
686        let theta = params[1];
687        let sigma = params[2];
688        let _rho = params[3]; // Not used in simplified model
689        let v0 = params[4];
690
691        // Check Feller condition (relaxed to warning)
692        let feller_condition = 2.0 * kappa * theta;
693        if feller_condition < sigma * sigma * 0.8 {
694            // Penalize but don't reject
695            return Ok(1e8);
696        }
697
698        // Convert surface to option quotes
699        let quotes = self.surface.to_option_quotes()?;
700        if quotes.is_empty() {
701            return Err(IntegrateError::ValueError(
702                "No market quotes available".to_string(),
703            ));
704        }
705
706        // Calculate model prices using Heston model (simplified approximation)
707        // Uses time-dependent effective volatility from mean-reverting variance
708        let mut model_prices = Vec::new();
709
710        for quote in &quotes {
711            // Effective variance at maturity: E[v_t] = theta + (v0 - theta)*exp(-kappa*t)
712            let effective_variance = theta + (v0 - theta) * (-kappa * quote.maturity).exp();
713            let effective_vol = effective_variance.sqrt();
714
715            // Price using Black-Scholes with effective vol
716            let price = match quote.option_type {
717                OptionType::Call => super::math::black_scholes_call(
718                    self.surface.spot,
719                    quote.strike,
720                    quote.maturity,
721                    self.surface.rate,
722                    effective_vol,
723                ),
724                OptionType::Put => super::math::black_scholes_put(
725                    self.surface.spot,
726                    quote.strike,
727                    quote.maturity,
728                    self.surface.rate,
729                    effective_vol,
730                ),
731            };
732
733            model_prices.push(price);
734        }
735
736        // Calculate loss
737        self.loss_function.calculate(&quotes, &model_prices)
738    }
739}
740
741#[cfg(test)]
742mod tests {
743    use super::*;
744
745    #[test]
746    fn test_option_quote_creation() {
747        let quote = OptionQuote::new(100.0, 1.0, OptionType::Call, 10.0, Some(0.5)).unwrap();
748        assert_eq!(quote.strike, 100.0);
749        assert_eq!(quote.maturity, 1.0);
750        assert_eq!(quote.market_price, 10.0);
751    }
752
753    #[test]
754    fn test_option_quote_weight() {
755        let quote1 = OptionQuote::new(100.0, 1.0, OptionType::Call, 10.0, Some(0.5)).unwrap();
756        let quote2 = OptionQuote::new(100.0, 1.0, OptionType::Call, 10.0, Some(0.1)).unwrap();
757
758        // Tighter spread should have higher weight
759        assert!(quote2.weight() > quote1.weight());
760    }
761
762    #[test]
763    fn test_implied_vol_surface() {
764        let mut surface = ImpliedVolatilitySurface::new(100.0, 0.05, 0.0);
765
766        surface.add_quote(100.0, 1.0, 0.20, None);
767        surface.add_quote(110.0, 1.0, 0.22, None);
768
769        assert_eq!(surface.size(), 2);
770        assert_eq!(surface.get_vol(100.0, 1.0), Some(0.20));
771        assert_eq!(surface.get_vol(110.0, 1.0), Some(0.22));
772    }
773
774    #[test]
775    fn test_surface_strikes_maturities() {
776        let mut surface = ImpliedVolatilitySurface::new(100.0, 0.05, 0.0);
777
778        surface.add_quote(90.0, 0.5, 0.18, None);
779        surface.add_quote(100.0, 0.5, 0.20, None);
780        surface.add_quote(110.0, 0.5, 0.22, None);
781        surface.add_quote(100.0, 1.0, 0.21, None);
782
783        let strikes = surface.strikes();
784        let maturities = surface.maturities();
785
786        assert_eq!(strikes.len(), 3);
787        assert_eq!(maturities.len(), 2);
788        assert!(strikes.contains(&100.0));
789        assert!(maturities.contains(&0.5));
790    }
791
792    #[test]
793    fn test_loss_function_mse() {
794        let quotes = vec![
795            OptionQuote::new(100.0, 1.0, OptionType::Call, 10.0, None).unwrap(),
796            OptionQuote::new(100.0, 1.0, OptionType::Call, 12.0, None).unwrap(),
797        ];
798        let model_prices = vec![10.5, 11.5];
799
800        let loss = LossFunction::MSE.calculate(&quotes, &model_prices).unwrap();
801        let expected = ((10.0_f64 - 10.5).powi(2) + (12.0_f64 - 11.5).powi(2)) / 2.0;
802
803        assert!((loss - expected).abs() < 1e-10);
804    }
805
806    #[test]
807    fn test_loss_function_weighted_mse() {
808        let quotes = vec![
809            OptionQuote::new(100.0, 1.0, OptionType::Call, 10.0, Some(0.5)).unwrap(),
810            OptionQuote::new(100.0, 1.0, OptionType::Call, 12.0, Some(0.1)).unwrap(),
811        ];
812        let model_prices = vec![10.5, 11.5];
813
814        let loss = LossFunction::WeightedMSE
815            .calculate(&quotes, &model_prices)
816            .unwrap();
817
818        // Second quote should contribute more due to tighter spread
819        let w1 = 1.0 / 0.5;
820        let w2 = 1.0 / 0.1;
821        let expected =
822            (w1 * (10.0_f64 - 10.5).powi(2) + w2 * (12.0_f64 - 11.5).powi(2)) / (w1 + w2);
823
824        assert!((loss - expected).abs() < 1e-10);
825    }
826
827    #[test]
828    fn test_calibration_result() {
829        let result = CalibrationResult::new(
830            vec![2.0, 0.04, 0.3, -0.7, 0.04],
831            vec![
832                "kappa".to_string(),
833                "theta".to_string(),
834                "sigma".to_string(),
835                "rho".to_string(),
836                "v0".to_string(),
837            ],
838            0.01,
839            100,
840            true,
841        );
842
843        assert_eq!(result.get_parameter("kappa"), Some(2.0));
844        assert_eq!(result.get_parameter("theta"), Some(0.04));
845        assert_eq!(result.loss, 0.01);
846        assert!(result.converged);
847    }
848
849    #[test]
850    fn test_heston_calibrator_creation() {
851        let surface = ImpliedVolatilitySurface::new(100.0, 0.05, 0.0);
852        let calibrator = HestonCalibrator::new(surface, LossFunction::WeightedMSE)
853            .with_max_iterations(500)
854            .with_tolerance(1e-5);
855
856        assert_eq!(calibrator.max_iterations, 500);
857        assert_eq!(calibrator.tolerance, 1e-5);
858    }
859
860    #[test]
861    fn test_surface_to_option_quotes() {
862        let mut surface = ImpliedVolatilitySurface::new(100.0, 0.05, 0.0);
863        surface.add_quote(100.0, 1.0, 0.20, Some(0.5));
864        surface.add_quote(110.0, 1.0, 0.22, Some(0.3));
865
866        let quotes = surface.to_option_quotes().unwrap();
867        assert_eq!(quotes.len(), 2);
868        assert!(quotes[0].market_price > 0.0);
869    }
870
871    #[test]
872    fn test_heston_calibration_basic() {
873        // Create a simple volatility surface
874        let mut surface = ImpliedVolatilitySurface::new(100.0, 0.05, 0.0);
875
876        // Add quotes with realistic implied vols
877        surface.add_quote(90.0, 1.0, 0.25, Some(0.02)); // OTM put
878        surface.add_quote(100.0, 1.0, 0.20, Some(0.01)); // ATM
879        surface.add_quote(110.0, 1.0, 0.22, Some(0.02)); // OTM call
880
881        // Calibrate with reduced iterations for testing
882        let calibrator = HestonCalibrator::new(surface, LossFunction::WeightedMSE)
883            .with_max_iterations(100)
884            .with_tolerance(1e-4);
885
886        let result = calibrator.calibrate().unwrap();
887
888        // Check that parameters are in reasonable ranges
889        assert!(result.get_parameter("kappa").unwrap() > 0.0);
890        assert!(result.get_parameter("theta").unwrap() > 0.0);
891        assert!(result.get_parameter("sigma").unwrap() > 0.0);
892        assert!(result.get_parameter("rho").unwrap() > -1.0);
893        assert!(result.get_parameter("rho").unwrap() < 1.0);
894        assert!(result.get_parameter("v0").unwrap() > 0.0);
895
896        // Check that loss is reasonable
897        assert!(result.loss >= 0.0);
898        assert!(result.loss < 1e6); // Not a penalty value
899
900        // Check iterations
901        assert!(result.iterations <= 100);
902    }
903
904    #[test]
905    #[allow(clippy::too_many_arguments)]
906    fn test_nelder_mead_rosenbrock() {
907        // Test Nelder-Mead on Rosenbrock function: f(x,y) = (1-x)^2 + 100*(y-x^2)^2
908        // Global minimum at (1, 1) with f = 0
909        let rosenbrock = |params: &[f64]| -> f64 {
910            let x = params[0];
911            let y = params[1];
912            (1.0 - x).powi(2) + 100.0 * (y - x * x).powi(2)
913        };
914
915        let optimizer =
916            NelderMeadOptimizer::new(vec![0.0, 0.0], vec![(-5.0, 5.0), (-5.0, 5.0)], 1000, 1e-6);
917
918        let result = optimizer.optimize(&rosenbrock).unwrap();
919
920        // Should converge close to (1, 1)
921        assert!((result.parameters[0] - 1.0).abs() < 0.1);
922        assert!((result.parameters[1] - 1.0).abs() < 0.1);
923        assert!(result.final_value < 1.0); // Near minimum
924    }
925}