Skip to main content

sandbox_quant/backtest/
mod.rs

1use std::collections::HashMap;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4use std::path::{Path, PathBuf};
5
6use anyhow::{anyhow, Context, Result};
7use chrono::Utc;
8use rusqlite::{params, Connection};
9
10use crate::event::{MarketRegime, MarketRegimeSignal};
11use crate::model::candle::Candle;
12use crate::model::signal::Signal;
13use crate::predictor::{
14    build_predictor_models, default_predictor_specs, PredictorBaseConfig, PredictorModel,
15};
16use crate::runtime::alpha_portfolio::{decide_portfolio_action_from_alpha, RegimeDecisionConfig};
17use crate::runtime::predictor_eval::{
18    observe_predictor_eval_volatility, predictor_eval_scale, PredictorEvalVolState,
19};
20use crate::runtime::regime::{RegimeDetector, RegimeDetectorConfig};
21
22const DEFAULT_BACKTEST_TRADE_NOTIONAL_USDT: f64 = 100.0;
23const DEFAULT_BACKTEST_FEE_RATE: f64 = 0.001;
24const DEFAULT_BACKTEST_SLIPPAGE_BPS: f64 = 2.0;
25const DEFAULT_BACKTEST_MIN_SIGNAL_STRENGTH: f64 = 0.0;
26const DEFAULT_TRAIN_WINDOW: usize = 420;
27const DEFAULT_TEST_WINDOW: usize = 120;
28const DEFAULT_EMBARGO_WINDOW: usize = 8;
29const DEFAULT_MAX_FOLDS: usize = 5;
30const PORTFOLIO_REBALANCE_MIN_DELTA: f64 = 0.05;
31const PORTFOLIO_MIN_ENTRY_RATIO: f64 = 0.15;
32
33#[derive(Debug, Clone)]
34pub struct BacktestConfig {
35    pub symbol: String,
36    pub bars_csv: PathBuf,
37    pub strategy_db_path: PathBuf,
38    pub order_db_path: PathBuf,
39    pub order_amount_usdt: f64,
40    pub fee_rate: f64,
41    pub slippage_bps: f64,
42    pub train_window: usize,
43    pub test_window: usize,
44    pub embargo_window: usize,
45    pub max_folds: usize,
46    pub min_signal_abs: f64,
47    pub regime_gate_enabled: bool,
48    pub predictor_ewma_alpha_mean: f64,
49    pub predictor_ewma_alpha_var: f64,
50    pub predictor_min_sigma: f64,
51    pub predictor_mu: f64,
52    pub predictor_sigma: f64,
53    pub run_seed: u64,
54}
55
56impl Default for BacktestConfig {
57    fn default() -> Self {
58        Self {
59            symbol: "BTCUSDT".to_string(),
60            bars_csv: PathBuf::from("data/demo_market_backlog.csv"),
61            strategy_db_path: PathBuf::from("data/backtest_strategy.sqlite"),
62            order_db_path: PathBuf::from("data/backtest_orders.sqlite"),
63            order_amount_usdt: DEFAULT_BACKTEST_TRADE_NOTIONAL_USDT,
64            fee_rate: DEFAULT_BACKTEST_FEE_RATE,
65            slippage_bps: DEFAULT_BACKTEST_SLIPPAGE_BPS,
66            train_window: DEFAULT_TRAIN_WINDOW,
67            test_window: DEFAULT_TEST_WINDOW,
68            embargo_window: DEFAULT_EMBARGO_WINDOW,
69            max_folds: DEFAULT_MAX_FOLDS,
70            min_signal_abs: DEFAULT_BACKTEST_MIN_SIGNAL_STRENGTH,
71            regime_gate_enabled: false,
72            predictor_ewma_alpha_mean: 0.18,
73            predictor_ewma_alpha_var: 0.18,
74            predictor_min_sigma: 0.001,
75            predictor_mu: 0.0,
76            predictor_sigma: 0.01,
77            run_seed: 0,
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct WalkWindow {
84    pub fold: usize,
85    pub train_start: usize,
86    pub train_end: usize,
87    pub test_start: usize,
88    pub test_end: usize,
89}
90
91#[derive(Debug, Clone)]
92pub struct BacktestMetrics {
93    pub realized_pnl_usdt: f64,
94    pub total_fees_usdt: f64,
95    pub trade_count: u64,
96    pub win_count: u64,
97    pub lose_count: u64,
98    pub max_drawdown: f64,
99    pub sharpe_like: f64,
100    pub end_equity_usdt: f64,
101}
102
103#[derive(Debug, Clone)]
104pub struct BacktestFoldResult {
105    pub fold: usize,
106    pub train_bars: usize,
107    pub test_bars: usize,
108    pub metrics: BacktestMetrics,
109    pub train_start_timestamp_ms: u64,
110    pub train_end_timestamp_ms: u64,
111    pub start_timestamp_ms: u64,
112    pub end_timestamp_ms: u64,
113}
114
115#[derive(Debug, Clone)]
116pub struct BacktestResult {
117    pub run_id: String,
118    pub symbol: String,
119    pub total_bars: usize,
120    pub folds: Vec<BacktestFoldResult>,
121    pub metrics: BacktestMetrics,
122    pub run_started_ms: u64,
123    pub run_finished_ms: u64,
124}
125
126#[derive(Debug, Clone)]
127pub struct BacktestOrderLedgerRow {
128    pub run_id: String,
129    pub fold: usize,
130    pub order_index: u64,
131    pub source: String,
132    pub bar_idx: usize,
133    pub timestamp_ms: u64,
134    pub side: String,
135    pub target_ratio: f64,
136    pub current_ratio: f64,
137    pub qty: f64,
138    pub price: f64,
139    pub fee_usdt: f64,
140    pub pnl_realized_usdt: f64,
141    pub reason: String,
142}
143
144#[derive(Debug, Clone)]
145pub struct CandleFeed {
146    pub symbol: String,
147    pub interval_ms: u64,
148    pub bars: Vec<Candle>,
149}
150
151#[derive(Debug)]
152struct BacktestPosition {
153    qty: f64,
154    entry_price: f64,
155    cost_quote: f64,
156    realized_pnl: f64,
157    unrealized_pnl: f64,
158}
159
160impl BacktestPosition {
161    fn new() -> Self {
162        Self {
163            qty: 0.0,
164            entry_price: 0.0,
165            cost_quote: 0.0,
166            realized_pnl: 0.0,
167            unrealized_pnl: 0.0,
168        }
169    }
170
171    fn is_flat(&self) -> bool {
172        self.qty <= f64::EPSILON
173    }
174
175    fn notional(&self, price: f64) -> f64 {
176        self.qty * price
177    }
178
179    fn apply_fill(&mut self, side: Signal, qty: f64, price: f64, fee: f64) -> f64 {
180        match side {
181            Signal::Buy => {
182                if qty <= f64::EPSILON {
183                    return 0.0;
184                }
185
186                self.qty += qty;
187                self.cost_quote += qty * price + fee;
188                self.entry_price = self.cost_quote / self.qty.max(f64::EPSILON);
189                0.0
190            }
191            Signal::Sell => {
192                if self.qty <= f64::EPSILON || qty <= f64::EPSILON {
193                    return 0.0;
194                }
195
196                let close_qty = qty.min(self.qty);
197                let avg_cost = self.cost_quote / self.qty.max(f64::EPSILON);
198                let pnl = close_qty * price - close_qty * avg_cost - fee;
199                self.realized_pnl += pnl;
200                self.qty -= close_qty;
201                self.cost_quote -= close_qty * avg_cost;
202
203                if self.qty <= f64::EPSILON {
204                    self.qty = 0.0;
205                    self.entry_price = 0.0;
206                    self.cost_quote = 0.0;
207                } else {
208                    self.entry_price = self.cost_quote / self.qty.max(f64::EPSILON);
209                }
210                pnl
211            }
212            Signal::Hold => 0.0,
213        }
214    }
215
216    fn update_unrealized(&mut self, price: f64) {
217        if self.is_flat() {
218            self.unrealized_pnl = 0.0;
219        } else {
220            self.unrealized_pnl = (price - self.entry_price) * self.qty;
221        }
222    }
223
224    fn total_equity(&mut self, close: f64) -> f64 {
225        self.update_unrealized(close);
226        self.realized_pnl + self.unrealized_pnl
227    }
228}
229
230pub fn infer_interval_ms(bars: &[Candle]) -> u64 {
231    if bars.len() < 2 {
232        return 60_000;
233    }
234    let mut diffs: Vec<u64> = bars
235        .windows(2)
236        .filter_map(|w| w[1].open_time.checked_sub(w[0].open_time))
237        .filter(|d| *d > 0)
238        .collect();
239    if diffs.is_empty() {
240        return 60_000;
241    }
242    diffs.sort_unstable();
243    diffs[diffs.len() / 2]
244}
245
246pub fn parse_candle_csv(symbol: &str, path: &Path) -> Result<CandleFeed> {
247    let file = File::open(path).with_context(|| format!("open candle csv: {}", path.display()))?;
248    let reader = BufReader::new(file);
249
250    let mut bars = Vec::new();
251    let mut idx_ts = 0usize;
252    let mut idx_open = 1usize;
253    let mut idx_high = 2usize;
254    let mut idx_low = 3usize;
255    let mut idx_close = 4usize;
256    let mut line_no = 0usize;
257    let mut has_header = false;
258
259    let resolve_idx = |header: &[String], names: &[&str]| -> Option<usize> {
260        names
261            .iter()
262            .find_map(|name| header.iter().position(|h| h == name))
263    };
264
265    for raw in reader.lines() {
266        let line = raw?;
267        let trimmed = line.trim();
268        if trimmed.is_empty() {
269            line_no += 1;
270            continue;
271        }
272
273        let fields: Vec<&str> = trimmed.split(',').map(|v| v.trim()).collect();
274
275        if !has_header {
276            let is_header = fields.first().is_some_and(|f| f.parse::<f64>().is_err());
277
278            if is_header {
279                let h = fields
280                    .iter()
281                    .map(|v| v.to_ascii_lowercase())
282                    .collect::<Vec<_>>();
283                idx_ts = resolve_idx(&h, &["open_time", "timestamp_ms", "time"])
284                    .or_else(|| resolve_idx(&h, &["ts"]))
285                    .unwrap_or(0);
286                idx_open = resolve_idx(&h, &["open"]).unwrap_or(1);
287                idx_high = resolve_idx(&h, &["high"]).unwrap_or(2);
288                idx_low = resolve_idx(&h, &["low"]).unwrap_or(3);
289                idx_close = resolve_idx(&h, &["close"]).unwrap_or(4);
290                has_header = true;
291                line_no += 1;
292                continue;
293            }
294            has_header = true;
295        }
296
297        if fields.len() <= idx_close {
298            return Err(anyhow!("invalid row #{}: too few fields", line_no + 1));
299        }
300
301        let open_time = parse_u64(fields[idx_ts], line_no)?;
302        let open = parse_f64(fields[idx_open], "open", line_no)?;
303        let high = parse_f64(fields[idx_high], "high", line_no)?;
304        let low = parse_f64(fields[idx_low], "low", line_no)?;
305        let close = parse_f64(fields[idx_close], "close", line_no)?;
306        if open <= 0.0 || high <= 0.0 || low <= 0.0 || close <= 0.0 {
307            return Err(anyhow!("invalid row #{}: non-positive price", line_no + 1));
308        }
309
310        bars.push(Candle {
311            open,
312            high,
313            low,
314            close,
315            open_time,
316            close_time: 0,
317        });
318        line_no += 1;
319    }
320
321    if bars.len() < 2 {
322        return Err(anyhow!("need at least 2 valid rows, got {}", bars.len()));
323    }
324
325    let interval_ms = infer_interval_ms(&bars);
326    for i in 0..bars.len().saturating_sub(1) {
327        bars[i].close_time = bars[i + 1].open_time;
328    }
329    let last_close_time = bars
330        .last()
331        .map(|c| c.open_time + interval_ms)
332        .unwrap_or_default();
333    if let Some(last) = bars.last_mut() {
334        last.close_time = last_close_time;
335    }
336
337    Ok(CandleFeed {
338        symbol: symbol.to_string(),
339        interval_ms,
340        bars,
341    })
342}
343
344fn parse_u64(value: &str, line_no: usize) -> Result<u64> {
345    value
346        .parse::<u64>()
347        .or_else(|_| value.parse::<f64>().map(|v| v.max(0.0) as u64))
348        .with_context(|| format!("line {}: parse timestamp", line_no + 1))
349}
350
351fn parse_f64(value: &str, field: &str, line_no: usize) -> Result<f64> {
352    value
353        .parse::<f64>()
354        .with_context(|| format!("line {}: parse {} {}", line_no + 1, field, value))
355}
356
357pub fn build_walk_forward_windows(total_bars: usize, cfg: &BacktestConfig) -> Vec<WalkWindow> {
358    if total_bars < cfg.train_window + cfg.test_window + cfg.embargo_window
359        || cfg.train_window == 0
360        || cfg.test_window == 0
361    {
362        return Vec::new();
363    }
364
365    let mut folds = Vec::new();
366    let mut train_start = 0usize;
367
368    for fold_idx in 0..cfg.max_folds {
369        let train_end = train_start + cfg.train_window;
370        let test_start = train_end + cfg.embargo_window;
371        let test_end = test_start + cfg.test_window;
372
373        if test_end > total_bars {
374            break;
375        }
376
377        folds.push(WalkWindow {
378            fold: fold_idx,
379            train_start,
380            train_end,
381            test_start,
382            test_end,
383        });
384
385        train_start = test_end;
386    }
387
388    folds
389}
390
391fn default_regime_signal(now_ms: u64) -> MarketRegimeSignal {
392    MarketRegimeSignal {
393        regime: MarketRegime::TrendUp,
394        confidence: 1.0,
395        ema_fast: 0.0,
396        ema_slow: 0.0,
397        vol_ratio: 0.0,
398        slope: 0.0,
399        updated_at_ms: now_ms,
400    }
401}
402
403fn build_predictor_specs(cfg: &BacktestConfig) -> Vec<(String, crate::predictor::PredictorConfig)> {
404    let base_cfg = PredictorBaseConfig {
405        alpha_mean: cfg.predictor_ewma_alpha_mean,
406        alpha_var: cfg.predictor_ewma_alpha_var,
407        min_sigma: cfg.predictor_min_sigma,
408    };
409    default_predictor_specs(base_cfg)
410}
411
412pub fn run_walk_forward_backtest(
413    cfg: &BacktestConfig,
414    feed: &CandleFeed,
415) -> Result<BacktestResult> {
416    let required = cfg.train_window + cfg.test_window + cfg.embargo_window;
417    if feed.bars.len() < required {
418        return Err(anyhow!(
419            "insufficient bars: got {}, need at least {}",
420            feed.bars.len(),
421            required
422        ));
423    }
424
425    let run_started_ms = Utc::now().timestamp_millis() as u64;
426    let run_id = format!("bt-{}-{}", cfg.symbol, run_started_ms);
427    let windows = build_walk_forward_windows(feed.bars.len(), cfg);
428    if windows.is_empty() {
429        return Err(anyhow!("no walk-forward folds available"));
430    }
431
432    ensure_strategy_db(&cfg.strategy_db_path)?;
433    ensure_order_db(&cfg.order_db_path)?;
434
435    let mut order_rows: Vec<BacktestOrderLedgerRow> = Vec::new();
436    let mut fold_results: Vec<BacktestFoldResult> = Vec::new();
437    let mut run_fold_metrics = Vec::new();
438    let mut next_order_index = 0u64;
439
440    for window in windows.iter() {
441        let mut models: HashMap<String, PredictorModel> =
442            build_predictor_models(&build_predictor_specs(cfg));
443        let mut vol_state = PredictorEvalVolState::default();
444        let mut regime_detector = RegimeDetector::new(RegimeDetectorConfig::default());
445        let mut position = BacktestPosition::new();
446
447        let mut fold_fees = 0.0;
448        let mut win_count = 0u64;
449        let mut lose_count = 0u64;
450        let mut trade_count = 0u64;
451        let mut fold_orders = Vec::new();
452        let mut fold_equity_curve = Vec::new();
453        let mut prev_equity = cfg.order_amount_usdt;
454
455        for idx in window.train_start..window.train_end {
456            let close = feed.bars[idx].close;
457            if close > f64::EPSILON {
458                for m in models.values_mut() {
459                    m.observe_price(&cfg.symbol, close);
460                }
461                observe_predictor_eval_volatility(
462                    &mut vol_state,
463                    close,
464                    cfg.predictor_ewma_alpha_var,
465                );
466            }
467        }
468
469        fold_equity_curve.push(prev_equity);
470        for idx in window.test_start..window.test_end {
471            let bar = &feed.bars[idx];
472            let close = bar.close;
473            let now_ms = bar.close_time;
474
475            if close <= f64::EPSILON {
476                fold_equity_curve.push(prev_equity);
477                continue;
478            }
479
480            let regime = if cfg.regime_gate_enabled {
481                regime_detector.update(close, now_ms)
482            } else {
483                default_regime_signal(now_ms)
484            };
485
486            for m in models.values_mut() {
487                m.observe_price(&cfg.symbol, close);
488            }
489            observe_predictor_eval_volatility(&mut vol_state, close, cfg.predictor_ewma_alpha_var);
490            let norm_scale = predictor_eval_scale(&vol_state, cfg.predictor_min_sigma)
491                .max(cfg.predictor_min_sigma);
492
493            let mut selected = ("".to_string(), 0.0f64);
494            for (name, model) in models.iter() {
495                let pred = model.estimate_base(
496                    &cfg.symbol,
497                    cfg.predictor_mu,
498                    cfg.predictor_sigma.max(cfg.predictor_min_sigma),
499                );
500                if pred.sigma <= 0.0 {
501                    continue;
502                }
503                let normalized_alpha = pred.mu / norm_scale;
504                if normalized_alpha.abs() > selected.1.abs() {
505                    selected = (name.clone(), normalized_alpha);
506                }
507            }
508
509            if selected.0.is_empty() || selected.1.abs() < cfg.min_signal_abs {
510                fold_equity_curve.push(prev_equity);
511                continue;
512            }
513
514            let alpha_mu = selected.1;
515            let current_ratio = (position.notional(close)
516                / cfg.order_amount_usdt.max(f64::EPSILON))
517            .clamp(0.0, 1.0);
518            let decision = decide_portfolio_action_from_alpha(
519                &cfg.symbol,
520                now_ms,
521                position.is_flat(),
522                alpha_mu,
523                cfg.order_amount_usdt,
524                regime,
525                RegimeDecisionConfig {
526                    enabled: cfg.regime_gate_enabled,
527                    confidence_min: 0.0,
528                    entry_multiplier_trend_up: 1.0,
529                    entry_multiplier_range: 1.0,
530                    entry_multiplier_trend_down: 1.0,
531                    entry_multiplier_unknown: 1.0,
532                    hold_multiplier_trend_up: 1.0,
533                    hold_multiplier_range: 1.0,
534                    hold_multiplier_trend_down: 1.0,
535                    hold_multiplier_unknown: 1.0,
536                },
537            );
538
539            if decision.target_position_ratio < PORTFOLIO_MIN_ENTRY_RATIO {
540                fold_equity_curve.push(prev_equity);
541                continue;
542            }
543
544            let intent = decision.to_intent("bt", cfg.order_amount_usdt, current_ratio);
545            let signal = intent.effective_signal(PORTFOLIO_REBALANCE_MIN_DELTA);
546            if signal == Signal::Hold {
547                fold_equity_curve.push(prev_equity);
548                continue;
549            }
550
551            let target_qty =
552                (cfg.order_amount_usdt * decision.target_position_ratio) / close.max(f64::EPSILON);
553            let current_qty = position.qty;
554            let delta_qty = target_qty - current_qty;
555            let fill_qty = delta_qty.abs();
556            if fill_qty <= f64::EPSILON {
557                fold_equity_curve.push(prev_equity);
558                continue;
559            }
560
561            let slippage = cfg.slippage_bps / 10_000.0;
562            let fill_price = match signal {
563                Signal::Buy => close * (1.0 + slippage),
564                Signal::Sell => close * (1.0 - slippage),
565                Signal::Hold => close,
566            };
567            let notional = fill_qty * fill_price;
568            let fee = notional * cfg.fee_rate;
569            let pnl_realized = position.apply_fill(signal, fill_qty, fill_price, fee);
570
571            fold_fees += fee;
572            if signal == Signal::Buy {
573                trade_count += 1;
574            } else if signal == Signal::Sell {
575                trade_count += 1;
576                if pnl_realized >= 0.0 {
577                    win_count += 1;
578                } else {
579                    lose_count += 1;
580                }
581            }
582
583            fold_orders.push(BacktestOrderLedgerRow {
584                run_id: run_id.clone(),
585                fold: window.fold,
586                order_index: next_order_index,
587                source: "bt".to_string(),
588                bar_idx: idx,
589                timestamp_ms: now_ms,
590                side: match signal {
591                    Signal::Buy => "BUY".to_string(),
592                    Signal::Sell => "SELL".to_string(),
593                    Signal::Hold => "HOLD".to_string(),
594                },
595                target_ratio: decision.target_position_ratio,
596                current_ratio,
597                qty: fill_qty,
598                price: fill_price,
599                fee_usdt: fee,
600                pnl_realized_usdt: pnl_realized,
601                reason: decision.reason.to_string(),
602            });
603            next_order_index = next_order_index.saturating_add(1);
604
605            position.update_unrealized(close);
606            let equity = cfg.order_amount_usdt + position.total_equity(close);
607            fold_equity_curve.push(equity);
608            prev_equity = equity;
609        }
610
611        let returns: Vec<f64> = fold_equity_curve
612            .windows(2)
613            .filter_map(|w| {
614                if w[0].abs() <= f64::EPSILON {
615                    None
616                } else {
617                    Some((w[1] - w[0]) / w[0])
618                }
619            })
620            .collect();
621
622        let fold_metrics = BacktestMetrics {
623            realized_pnl_usdt: position.realized_pnl,
624            total_fees_usdt: fold_fees,
625            trade_count,
626            win_count,
627            lose_count,
628            max_drawdown: max_drawdown(&fold_equity_curve),
629            sharpe_like: sharpe_like(&returns),
630            end_equity_usdt: fold_equity_curve
631                .last()
632                .copied()
633                .unwrap_or(cfg.order_amount_usdt),
634        };
635
636        fold_results.push(BacktestFoldResult {
637            fold: window.fold,
638            train_bars: window.train_end - window.train_start,
639            test_bars: window.test_end - window.test_start,
640            metrics: fold_metrics.clone(),
641            train_start_timestamp_ms: feed.bars[window.train_start].open_time,
642            train_end_timestamp_ms: feed.bars[window.train_end - 1].open_time,
643            start_timestamp_ms: feed.bars[window.test_start].open_time,
644            end_timestamp_ms: feed.bars[window.test_end - 1].open_time,
645        });
646        run_fold_metrics.push(fold_metrics);
647        order_rows.extend(fold_orders);
648    }
649
650    persist_run_meta(
651        &cfg.strategy_db_path,
652        &run_id,
653        cfg,
654        &feed.symbol,
655        run_started_ms,
656        Utc::now().timestamp_millis() as u64,
657        &fold_results,
658    )?;
659    persist_fold_results(&cfg.strategy_db_path, &run_id, &fold_results)?;
660    persist_order_ledger(&cfg.order_db_path, &order_rows)?;
661
662    let total = summarize_metrics(&run_fold_metrics);
663    Ok(BacktestResult {
664        run_id,
665        symbol: feed.symbol.clone(),
666        total_bars: feed.bars.len(),
667        folds: fold_results,
668        metrics: total,
669        run_started_ms,
670        run_finished_ms: Utc::now().timestamp_millis() as u64,
671    })
672}
673
674pub fn parse_backtest_args(args: &[String]) -> Result<BacktestConfig> {
675    let mut cfg = BacktestConfig::default();
676    let mut i = 0usize;
677    let mut bars_set = false;
678
679    while i < args.len() {
680        match args[i].as_str() {
681            "--symbol" if i + 1 < args.len() => {
682                cfg.symbol = args[i + 1].to_string();
683                i += 2;
684            }
685            "--bars" if i + 1 < args.len() => {
686                cfg.bars_csv = PathBuf::from(&args[i + 1]);
687                bars_set = true;
688                i += 2;
689            }
690            "--strategy-db" if i + 1 < args.len() => {
691                cfg.strategy_db_path = PathBuf::from(&args[i + 1]);
692                i += 2;
693            }
694            "--order-db" if i + 1 < args.len() => {
695                cfg.order_db_path = PathBuf::from(&args[i + 1]);
696                i += 2;
697            }
698            "--order-usdt" if i + 1 < args.len() => {
699                cfg.order_amount_usdt = args[i + 1].parse()?;
700                i += 2;
701            }
702            "--fee-rate" if i + 1 < args.len() => {
703                cfg.fee_rate = args[i + 1].parse()?;
704                i += 2;
705            }
706            "--slippage-bps" if i + 1 < args.len() => {
707                cfg.slippage_bps = args[i + 1].parse()?;
708                i += 2;
709            }
710            "--train-bars" if i + 1 < args.len() => {
711                cfg.train_window = args[i + 1].parse()?;
712                i += 2;
713            }
714            "--test-bars" if i + 1 < args.len() => {
715                cfg.test_window = args[i + 1].parse()?;
716                i += 2;
717            }
718            "--embargo-bars" if i + 1 < args.len() => {
719                cfg.embargo_window = args[i + 1].parse()?;
720                i += 2;
721            }
722            "--max-folds" if i + 1 < args.len() => {
723                cfg.max_folds = args[i + 1].parse()?;
724                i += 2;
725            }
726            "--min-signal" if i + 1 < args.len() => {
727                cfg.min_signal_abs = args[i + 1].parse()?;
728                i += 2;
729            }
730            "--regime-gate" => {
731                cfg.regime_gate_enabled = true;
732                i += 1;
733            }
734            "--help" => {
735                return Err(anyhow!(print_backtest_usage()));
736            }
737            _ => {
738                return Err(anyhow!("unknown arg '{}'", args[i]));
739            }
740        }
741    }
742
743    if !cfg.bars_csv.exists() {
744        if !bars_set {
745            return Err(anyhow!("--bars is required"));
746        }
747        return Err(anyhow!("bars file not found: {}", cfg.bars_csv.display()));
748    }
749
750    if cfg.train_window == 0 || cfg.test_window == 0 {
751        return Err(anyhow!("train/test windows must be > 0"));
752    }
753    if cfg.max_folds == 0 {
754        return Err(anyhow!("max-folds must be > 0"));
755    }
756
757    Ok(cfg)
758}
759
760pub fn print_backtest_usage() -> String {
761    let help = [
762        "USAGE: backtest [--symbol SYMBOL] --bars FILE [--strategy-db PATH] [--order-db PATH]",
763        "               [--order-usdt AMOUNT] [--fee-rate RATE] [--slippage-bps BPS]",
764        "               [--train-bars N] [--test-bars N] [--embargo-bars N]",
765        "               [--max-folds N] [--min-signal N] [--regime-gate]",
766    ];
767    help.join("\n")
768}
769
770fn summarize_metrics(folds: &[BacktestMetrics]) -> BacktestMetrics {
771    let mut realized = 0.0;
772    let mut fees = 0.0;
773    let mut trade_count = 0u64;
774    let mut win_count = 0u64;
775    let mut lose_count = 0u64;
776    let mut max_dd = 0.0;
777    let mut sharpe_sum = 0.0;
778    let mut sharpe_count = 0u64;
779    let mut end_equity = 0.0;
780
781    for fold in folds {
782        realized += fold.realized_pnl_usdt;
783        fees += fold.total_fees_usdt;
784        trade_count += fold.trade_count;
785        win_count += fold.win_count;
786        lose_count += fold.lose_count;
787        if fold.max_drawdown > max_dd {
788            max_dd = fold.max_drawdown;
789        }
790        if fold.sharpe_like.is_finite() {
791            sharpe_sum += fold.sharpe_like;
792            sharpe_count += 1;
793        }
794        if folds.last().is_some_and(|last| std::ptr::eq(last, fold)) {
795            end_equity = fold.end_equity_usdt;
796        }
797    }
798
799    BacktestMetrics {
800        realized_pnl_usdt: realized,
801        total_fees_usdt: fees,
802        trade_count,
803        win_count,
804        lose_count,
805        max_drawdown: max_dd,
806        sharpe_like: if sharpe_count == 0 {
807            0.0
808        } else {
809            sharpe_sum / sharpe_count as f64
810        },
811        end_equity_usdt: end_equity,
812    }
813}
814
815fn max_drawdown(equity: &[f64]) -> f64 {
816    let mut peak = f64::NEG_INFINITY;
817    let mut max_dd = 0.0;
818
819    for &value in equity {
820        if value > peak {
821            peak = value;
822        }
823        if peak > 0.0 && value < peak {
824            let dd = (peak - value) / peak;
825            if dd > max_dd {
826                max_dd = dd;
827            }
828        }
829    }
830    max_dd
831}
832
833fn sharpe_like(returns: &[f64]) -> f64 {
834    if returns.len() < 2 {
835        return 0.0;
836    }
837    let mean = returns.iter().sum::<f64>() / returns.len() as f64;
838    let var = returns.iter().map(|r| (*r - mean).powi(2)).sum::<f64>() / returns.len() as f64;
839    let sd = var.sqrt();
840    if sd <= f64::EPSILON {
841        0.0
842    } else {
843        mean / sd
844    }
845}
846
847fn ensure_strategy_db(path: &Path) -> Result<()> {
848    if let Some(parent) = path.parent() {
849        std::fs::create_dir_all(parent)?;
850    }
851
852    let conn = Connection::open(path)?;
853    conn.execute_batch(
854        r#"
855        CREATE TABLE IF NOT EXISTS backtest_runs (
856            run_id TEXT PRIMARY KEY,
857            symbol TEXT NOT NULL,
858            started_at_ms INTEGER NOT NULL,
859            finished_at_ms INTEGER NOT NULL,
860            config_json TEXT NOT NULL,
861            total_bars INTEGER NOT NULL,
862            folds INTEGER NOT NULL,
863            created_at_ms INTEGER NOT NULL
864        );
865
866        CREATE TABLE IF NOT EXISTS backtest_fold_results (
867            run_id TEXT NOT NULL,
868            fold_idx INTEGER NOT NULL,
869            train_bars INTEGER NOT NULL,
870            test_bars INTEGER NOT NULL,
871            realized_pnl_usdt REAL NOT NULL,
872            total_fees_usdt REAL NOT NULL,
873            trade_count INTEGER NOT NULL,
874            win_count INTEGER NOT NULL,
875            lose_count INTEGER NOT NULL,
876            max_drawdown REAL NOT NULL,
877            sharpe_like REAL NOT NULL,
878            end_equity_usdt REAL NOT NULL,
879            train_start_ms INTEGER NOT NULL,
880            train_end_ms INTEGER NOT NULL,
881            test_start_ms INTEGER NOT NULL,
882            test_end_ms INTEGER NOT NULL,
883            PRIMARY KEY(run_id, fold_idx)
884        );
885        "#,
886    )?;
887    Ok(())
888}
889
890fn ensure_order_db(path: &Path) -> Result<()> {
891    if let Some(parent) = path.parent() {
892        std::fs::create_dir_all(parent)?;
893    }
894
895    let conn = Connection::open(path)?;
896    conn.execute_batch(
897        r#"
898        CREATE TABLE IF NOT EXISTS backtest_orders (
899            run_id TEXT NOT NULL,
900            fold INTEGER NOT NULL,
901            order_index INTEGER NOT NULL,
902            source TEXT NOT NULL,
903            bar_idx INTEGER NOT NULL,
904            timestamp_ms INTEGER NOT NULL,
905            side TEXT NOT NULL,
906            target_ratio REAL NOT NULL,
907            current_ratio REAL NOT NULL,
908            qty REAL NOT NULL,
909            price REAL NOT NULL,
910            fee_usdt REAL NOT NULL,
911            pnl_realized_usdt REAL NOT NULL,
912            reason TEXT NOT NULL,
913            created_at_ms INTEGER NOT NULL,
914            PRIMARY KEY(run_id, fold, order_index)
915        );
916        "#,
917    )?;
918    Ok(())
919}
920
921fn persist_run_meta(
922    path: &Path,
923    run_id: &str,
924    cfg: &BacktestConfig,
925    symbol: &str,
926    started_ms: u64,
927    finished_ms: u64,
928    folds: &[BacktestFoldResult],
929) -> Result<()> {
930    let conn = Connection::open(path)?;
931    let cfg_json = serde_json::json!({
932        "symbol": cfg.symbol,
933        "bars_csv": cfg.bars_csv,
934        "order_amount_usdt": cfg.order_amount_usdt,
935        "fee_rate": cfg.fee_rate,
936        "slippage_bps": cfg.slippage_bps,
937        "train_window": cfg.train_window,
938        "test_window": cfg.test_window,
939        "embargo_window": cfg.embargo_window,
940        "max_folds": cfg.max_folds,
941        "min_signal_abs": cfg.min_signal_abs,
942        "regime_gate_enabled": cfg.regime_gate_enabled,
943    });
944
945    conn.execute(
946        "INSERT OR REPLACE INTO backtest_runs (
947            run_id, symbol, started_at_ms, finished_at_ms, config_json, total_bars, folds, created_at_ms
948        ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
949        params![
950            run_id,
951            symbol,
952            started_ms as i64,
953            finished_ms as i64,
954            cfg_json.to_string(),
955            folds.iter().map(|f| f.train_bars + f.test_bars).sum::<usize>() as i64,
956            folds.len() as i64,
957            Utc::now().timestamp_millis() as i64,
958        ],
959    )?;
960    Ok(())
961}
962
963fn persist_fold_results(path: &Path, run_id: &str, folds: &[BacktestFoldResult]) -> Result<()> {
964    let mut conn = Connection::open(path)?;
965    let tx = conn.transaction()?;
966    for fold in folds {
967        tx.execute(
968            "INSERT INTO backtest_fold_results (
969                run_id, fold_idx, train_bars, test_bars, realized_pnl_usdt, total_fees_usdt,
970                trade_count, win_count, lose_count, max_drawdown, sharpe_like, end_equity_usdt,
971                train_start_ms, train_end_ms, test_start_ms, test_end_ms
972            ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16)",
973            params![
974                run_id,
975                fold.fold as i64,
976                fold.train_bars as i64,
977                fold.test_bars as i64,
978                fold.metrics.realized_pnl_usdt,
979                fold.metrics.total_fees_usdt,
980                fold.metrics.trade_count as i64,
981                fold.metrics.win_count as i64,
982                fold.metrics.lose_count as i64,
983                fold.metrics.max_drawdown,
984                fold.metrics.sharpe_like,
985                fold.metrics.end_equity_usdt,
986                fold.train_start_timestamp_ms as i64,
987                fold.train_end_timestamp_ms as i64,
988                fold.start_timestamp_ms as i64,
989                fold.end_timestamp_ms as i64,
990            ],
991        )?;
992    }
993    tx.commit()?;
994    Ok(())
995}
996
997fn persist_order_ledger(path: &Path, rows: &[BacktestOrderLedgerRow]) -> Result<()> {
998    let mut conn = Connection::open(path)?;
999    let tx = conn.transaction()?;
1000    for row in rows {
1001        tx.execute(
1002            "INSERT OR REPLACE INTO backtest_orders (
1003                run_id, fold, order_index, source, bar_idx, timestamp_ms, side,
1004                target_ratio, current_ratio, qty, price, fee_usdt, pnl_realized_usdt, reason, created_at_ms
1005            ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)",
1006            params![
1007                row.run_id,
1008                row.fold as i64,
1009                row.order_index as i64,
1010                row.source,
1011                row.bar_idx as i64,
1012                row.timestamp_ms as i64,
1013                row.side,
1014                row.target_ratio,
1015                row.current_ratio,
1016                row.qty,
1017                row.price,
1018                row.fee_usdt,
1019                row.pnl_realized_usdt,
1020                row.reason,
1021                Utc::now().timestamp_millis() as i64,
1022            ],
1023        )?;
1024    }
1025    tx.commit()?;
1026    Ok(())
1027}