surface_lib/models/svi/
svi_calibrator.rs

1// src/models/svi/svi_calibrator.rs
2
3//! SVI model calibrator implementation
4//!
5//! This module implements the calibrator for the SVI (Stochastic Volatility Inspired) model.
6//! The calibrator follows the same pattern as other models in the codebase, implementing
7//! the ModelCalibrator trait and providing methods for parameter optimization.
8
9use crate::calibration::config::OptimizationConfig;
10use crate::calibration::types::{MarketDataRow, ModelCalibrator, PricingResult};
11use crate::model_params::{ModelParams, SviModelParams};
12use crate::models::svi::svi_model::{SVIParams, SVISlice};
13use crate::models::utils::{log_moneyness, price_option, OptionPricingResult};
14use anyhow::{anyhow, Result};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18/// Structure to hold parameter bounds for the SVI model calibration
19#[derive(Debug, Clone, Deserialize, Serialize)]
20pub struct SVIParamBounds {
21    /// Vertical shift parameter bounds (controls ATM variance level)
22    pub a: (f64, f64),
23    /// Slope factor bounds (controls overall variance level)
24    pub b: (f64, f64),
25    /// Asymmetry parameter bounds (skew, must be in (-1, 1))
26    pub rho: (f64, f64),
27    /// Horizontal shift parameter bounds (ATM location)
28    pub m: (f64, f64),
29    /// Curvature parameter bounds (controls smile curvature, must be > 0)
30    pub sigma: (f64, f64),
31}
32
33impl Default for SVIParamBounds {
34    fn default() -> Self {
35        Self {
36            a: (-0.5, 0.5),
37            b: (0.01, 2.0),
38            rho: (-0.99, -0.01), // Restrict to negative for BTC skew
39            m: (-1.0, 1.0),
40            sigma: (0.01, 2.0),
41        }
42    }
43}
44
45impl From<&[(f64, f64)]> for SVIParamBounds {
46    fn from(bounds: &[(f64, f64)]) -> Self {
47        if bounds.len() != 5 {
48            return Self::default();
49        }
50        Self {
51            a: bounds[0],
52            b: bounds[1],
53            rho: bounds[2],
54            m: bounds[3],
55            sigma: bounds[4],
56        }
57    }
58}
59
60/// Calibrator for the SVI model with 5 parameters per expiry:
61/// [a, b, rho, m, sigma]
62#[derive(Debug, Clone)]
63pub struct SVIModelCalibrator {
64    /// Store only the single expiration (timestamp, years_to_exp)
65    expiration: (i64, f64),
66    /// Parameters for a single slice (length 5)
67    param_bounds: Vec<(f64, f64)>,
68
69    /// Model-specific parameters (e.g. ATM boost)
70    params: SviModelParams,
71
72    /// Optional previous solution for temporal regularization
73    prev_solution: Option<Vec<f64>>,
74    temporal_reg_lambda: f64,
75}
76
77impl SVIModelCalibrator {
78    /// Constructor from market data and configuration parameters.
79    pub fn new(
80        data: &[MarketDataRow],
81        param_bounds_opt: Option<SVIParamBounds>,
82        model_params: Option<Box<dyn ModelParams>>, // new optional parameters
83    ) -> Result<Self> {
84        // Group data by expiration to ensure single expiry requirement
85        let mut grouped = HashMap::<i64, Vec<f64>>::new();
86        for r in data {
87            grouped
88                .entry(r.expiration)
89                .or_default()
90                .push(r.years_to_exp);
91        }
92
93        // Ensure exactly one expiration is present
94        if grouped.len() != 1 {
95            return Err(anyhow!(
96                "SVIModelCalibrator requires data for exactly one expiration, but found {}. Expirations: {:?}", 
97                grouped.len(), grouped.keys().collect::<Vec<_>>()
98            ));
99        }
100
101        // Get the single expiration timestamp and calculate average time
102        let (single_exp_ts, times) = grouped.into_iter().next().unwrap();
103        let single_avg_t = times.iter().copied().sum::<f64>() / times.len() as f64;
104        let expiration = (single_exp_ts, single_avg_t);
105
106        let bounds = param_bounds_opt.unwrap_or_default();
107
108        /*
109        // Auto-adjust bounds based on time to expiry (adaptive bounds placeholder)
110        let bounds = param_bounds.unwrap_or_else(|| {
111            let days = single_avg_t * 365.0;
112            if days < 1.0 {
113                // For intraday options (< 1 day), allow very tight parameters
114                SVIParamBounds {
115                    a: (-0.1, 0.1),
116                    b: (0.001, 0.5),
117                    rho: (-0.99, -0.01),
118                    m: (-0.5, 0.5),
119                    sigma: (0.001, 0.1),
120                }
121            } else if days < 3.0 {
122                // Very short-term (1-3 days): keep m near ATM to avoid extreme shifts
123                SVIParamBounds {
124                    a: (-0.5, 1.0),
125                    b: (0.01, 5.0),
126                    rho: (-0.999, -0.01),
127                    m: (-0.3, 0.3),
128                    sigma: (0.01, 2.0),
129                }
130            } else if days < 7.0 {
131                // Short-term (< 1 week) – moderate m range
132                SVIParamBounds {
133                    a: (-0.5, 1.0),
134                    b: (0.01, 5.0),
135                    rho: (-0.999, -0.01),
136                    m: (-1.0, 1.0),
137                    sigma: (0.01, 2.0),
138                }
139            } else if days < 30.0 {
140                // For medium-term options (< 1 month)
141                SVIParamBounds {
142                    a: (-0.5, 0.8),
143                    b: (0.01, 3.0),
144                    rho: (-0.99, -0.01),
145                    m: (-1.5, 1.5),
146                    sigma: (0.03, 1.0),
147                }
148            } else {
149                // For longer-term options, use default bounds
150                SVIParamBounds::default()
151            }
152        });
153        */
154
155        // Fill parameter bounds vector from the struct (5 parameters: a, b, rho, m, sigma)
156        let param_bounds = vec![bounds.a, bounds.b, bounds.rho, bounds.m, bounds.sigma];
157
158        // Note: relaxed_bounds removed, using param_bounds directly
159
160        // Resolve model-specific parameters (default if not supplied or type mismatch)
161        let params = if let Some(mp) = model_params {
162            mp.as_any()
163                .downcast_ref::<SviModelParams>()
164                .cloned()
165                .unwrap_or_default()
166        } else {
167            SviModelParams::default()
168        };
169
170        Ok(Self {
171            expiration,
172            param_bounds,
173            params,
174            prev_solution: None,
175            temporal_reg_lambda: 0.0,
176        })
177    }
178
179    pub fn set_prev_solution(&mut self, prev_sol: Vec<f64>) {
180        if prev_sol.len() == self.param_count() {
181            self.prev_solution = Some(prev_sol);
182        }
183    }
184
185    pub fn set_temporal_reg_lambda(&mut self, lambda: f64) {
186        self.temporal_reg_lambda = lambda.max(0.0);
187    }
188}
189
190impl ModelCalibrator for SVIModelCalibrator {
191    fn model_name(&self) -> &str {
192        "svi"
193    }
194
195    fn param_count(&self) -> usize {
196        self.param_bounds.len() // Should be 5
197    }
198
199    fn param_bounds(&self) -> &[(f64, f64)] {
200        &self.param_bounds
201    }
202
203    /// Evaluate objective function using vega-weighted RMSE on total variance with
204    /// an additional exponential ATM weighting.
205    /// x is the parameter vector [a, b, rho, m, sigma].
206    fn evaluate_objective(&self, x: &[f64], data: &[MarketDataRow]) -> f64 {
207        assert_eq!(
208            x.len(),
209            5,
210            "Input parameter vector length must be 5 for SVI model"
211        );
212
213        let (exp_ts, t) = self.expiration;
214
215        // 1. Build the SVI slice from the candidate parameters ----------------------------
216        let params = match SVIParams::new(t, x[0], x[1], x[2], x[3], x[4]) {
217            Ok(p) => p,
218            Err(_) => return 1.0e12, // Reject invalid parameter sets outright
219        };
220        let slice = SVISlice::new(params);
221
222        // 2. Weighted error computation ----------------------------------------------------
223        let mut weighted_error_sum = 0.0;
224        let mut weight_sum = 0.0;
225        let mut valid_points = 0u32;
226
227        for row in data {
228            if row.expiration != exp_ts {
229                continue; // Keep only this slice's points
230            }
231
232            let k = log_moneyness(row.strike_price, row.underlying_price);
233            let model_iv = slice.implied_vol(k);
234            let market_iv_dec = row.market_iv; // already in decimal form
235
236            // Skip points with non-positive IVs
237            if model_iv <= 0.0 || market_iv_dec <= 0.0 {
238                continue;
239            }
240
241            // Total variance (w = σ² · t) difference – preferred over raw IV diff for
242            // short-dated options where IV is highly non-linear in the parameters.
243            let model_w = model_iv * model_iv * t;
244            let market_w = market_iv_dec * market_iv_dec * t;
245            let diff = model_w - market_w;
246            let squared_error = diff * diff;
247
248            // --- Weighting scheme --------------------------------------------------------
249            // 1. Vega weighting (optional)
250            let vega_weight = if self.params.use_vega_weighting {
251                if row.vega > 0.0 {
252                    row.vega
253                } else {
254                    1.0
255                }
256            } else {
257                1.0
258            };
259            // 2. ATM emphasis – exponential decay as |k| grows.
260            let atm_weight = (-self.params.atm_boost_factor * k.abs()).exp();
261            let weight = vega_weight * atm_weight;
262
263            weighted_error_sum += weight * squared_error;
264            weight_sum += weight;
265            valid_points += 1;
266        }
267
268        if valid_points == 0 || weight_sum <= 1e-12 {
269            return 1.0e12; // Fail-safe if no usable points
270        }
271
272        // Weighted root-mean-squared error on total variance
273        let mut obj = (weighted_error_sum / weight_sum).sqrt();
274
275        // -----------------------------------------------------------------------------------
276        // Optional temporal regularisation on raw parameters
277        // -----------------------------------------------------------------------------------
278        if let (Some(prev), lambda) = (&self.prev_solution, self.temporal_reg_lambda) {
279            if lambda > 0.0 && prev.len() == x.len() {
280                let penalty: f64 = x
281                    .iter()
282                    .zip(prev.iter())
283                    .map(|(v, p)| (v - p).powi(2))
284                    .sum::<f64>()
285                    * lambda;
286                obj += penalty;
287            }
288        }
289        obj
290    }
291
292    // Note: create_param_map removed as param_map is no longer returned from calibration API
293
294    fn price_options(
295        &self,
296        market_data: &[MarketDataRow],
297        best_params: &[f64],
298        config: &OptimizationConfig,
299    ) -> Vec<PricingResult> {
300        assert_eq!(best_params.len(), 5, "Expected 5 parameters for SVI model");
301        let (exp_ts, t) = self.expiration;
302
303        // Extract parameters
304        let a = best_params[0];
305        let b = best_params[1];
306        let rho = best_params[2];
307        let m = best_params[3];
308        let sigma = best_params[4];
309
310        let final_params = match SVIParams::new(t, a, b, rho, m, sigma) {
311            Ok(params) => params,
312            Err(e) => {
313                eprintln!(
314                    "Error creating final SVIParams for pricing: {}. Using fallback parameters.",
315                    e
316                );
317                SVIParams::new(0.1, 0.04, 0.2, -0.3, 0.0, 0.2).unwrap() // Fallback
318            }
319        };
320        let final_slice = SVISlice::new(final_params);
321
322        let r = config.fixed_params.r;
323        let q = config.fixed_params.q;
324        let mut results = Vec::with_capacity(market_data.len());
325
326        for row in market_data {
327            // Filter data for the single expiration this calibrator handles
328            if row.expiration == exp_ts {
329                let t_row = row.years_to_exp;
330                let underlying = row.underlying_price;
331                let strike = row.strike_price;
332
333                // Price the option using SVI model
334                let pricing_result = if underlying > 1e-8 {
335                    price_option(
336                        &row.option_type,
337                        strike,
338                        underlying,
339                        r,
340                        q,
341                        t_row,
342                        &final_slice,
343                    )
344                } else {
345                    Ok(OptionPricingResult {
346                        price: 0.0,
347                        model_iv: 0.0,
348                    })
349                };
350
351                let (model_price, model_iv) = match pricing_result {
352                    Ok(pr) => (pr.price, pr.model_iv),
353                    Err(e) => {
354                        eprintln!(
355                            "Error pricing option (exp={}, strike={}): {}",
356                            exp_ts, strike, e
357                        );
358                        (0.0, 0.0)
359                    }
360                };
361
362                results.push(PricingResult {
363                    option_type: row.option_type.clone(),
364                    strike_price: row.strike_price,
365                    underlying_price: row.underlying_price,
366                    years_to_exp: row.years_to_exp,
367                    model_price,
368                    model_iv,
369                });
370            }
371        }
372
373        results.sort_by(|a, b| a.strike_price.partial_cmp(&b.strike_price).unwrap());
374        results
375    }
376
377    fn param_names(&self) -> Vec<&str> {
378        vec!["a", "b", "rho", "m", "sigma"]
379    }
380
381    fn as_any(&self) -> &dyn std::any::Any {
382        self
383    }
384
385    fn set_prev_solution(&mut self, prev_solution: Vec<f64>) {
386        self.set_prev_solution(prev_solution);
387    }
388
389    fn set_temporal_reg_lambda(&mut self, lambda: f64) {
390        self.set_temporal_reg_lambda(lambda);
391    }
392
393    // Note: relaxed methods removed as they were redundant
394
395    fn expand_bounds_if_needed(
396        &mut self,
397        params: &[f64],
398        proximity_threshold: f64,
399        expansion_factor: f64,
400    ) -> bool {
401        let mut adjusted = false;
402        for (bounds, param) in self.param_bounds.iter_mut().zip(params.iter()) {
403            let range = bounds.1 - bounds.0;
404            let lower_thresh = bounds.0 + range * proximity_threshold;
405            let upper_thresh = bounds.1 - range * proximity_threshold;
406            if *param <= lower_thresh {
407                let expansion = range * expansion_factor;
408                bounds.0 -= expansion;
409                adjusted = true;
410            }
411            if *param >= upper_thresh {
412                let expansion = range * expansion_factor;
413                bounds.1 += expansion;
414                adjusted = true;
415            }
416            // Note: bounds already updated in-place above
417        }
418        adjusted
419    }
420}