1use std::collections::HashMap;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4use std::path::{Path, PathBuf};
5
6use anyhow::{anyhow, Context, Result};
7use chrono::Utc;
8use rusqlite::{params, Connection};
9
10use crate::event::{MarketRegime, MarketRegimeSignal};
11use crate::model::candle::Candle;
12use crate::model::signal::Signal;
13use crate::predictor::{
14 build_predictor_models, default_predictor_specs, PredictorBaseConfig, PredictorModel,
15};
16use crate::runtime::alpha_portfolio::{decide_portfolio_action_from_alpha, RegimeDecisionConfig};
17use crate::runtime::predictor_eval::{
18 observe_predictor_eval_volatility, predictor_eval_scale, PredictorEvalVolState,
19};
20use crate::runtime::regime::{RegimeDetector, RegimeDetectorConfig};
21
22const DEFAULT_BACKTEST_TRADE_NOTIONAL_USDT: f64 = 100.0;
23const DEFAULT_BACKTEST_FEE_RATE: f64 = 0.001;
24const DEFAULT_BACKTEST_SLIPPAGE_BPS: f64 = 2.0;
25const DEFAULT_BACKTEST_MIN_SIGNAL_STRENGTH: f64 = 0.0;
26const DEFAULT_TRAIN_WINDOW: usize = 420;
27const DEFAULT_TEST_WINDOW: usize = 120;
28const DEFAULT_EMBARGO_WINDOW: usize = 8;
29const DEFAULT_MAX_FOLDS: usize = 5;
30const PORTFOLIO_REBALANCE_MIN_DELTA: f64 = 0.05;
31const PORTFOLIO_MIN_ENTRY_RATIO: f64 = 0.15;
32
33#[derive(Debug, Clone)]
34pub struct BacktestConfig {
35 pub symbol: String,
36 pub bars_csv: PathBuf,
37 pub strategy_db_path: PathBuf,
38 pub order_db_path: PathBuf,
39 pub order_amount_usdt: f64,
40 pub fee_rate: f64,
41 pub slippage_bps: f64,
42 pub train_window: usize,
43 pub test_window: usize,
44 pub embargo_window: usize,
45 pub max_folds: usize,
46 pub min_signal_abs: f64,
47 pub regime_gate_enabled: bool,
48 pub predictor_ewma_alpha_mean: f64,
49 pub predictor_ewma_alpha_var: f64,
50 pub predictor_min_sigma: f64,
51 pub predictor_mu: f64,
52 pub predictor_sigma: f64,
53 pub run_seed: u64,
54}
55
56impl Default for BacktestConfig {
57 fn default() -> Self {
58 Self {
59 symbol: "BTCUSDT".to_string(),
60 bars_csv: PathBuf::from("data/demo_market_backlog.csv"),
61 strategy_db_path: PathBuf::from("data/backtest_strategy.sqlite"),
62 order_db_path: PathBuf::from("data/backtest_orders.sqlite"),
63 order_amount_usdt: DEFAULT_BACKTEST_TRADE_NOTIONAL_USDT,
64 fee_rate: DEFAULT_BACKTEST_FEE_RATE,
65 slippage_bps: DEFAULT_BACKTEST_SLIPPAGE_BPS,
66 train_window: DEFAULT_TRAIN_WINDOW,
67 test_window: DEFAULT_TEST_WINDOW,
68 embargo_window: DEFAULT_EMBARGO_WINDOW,
69 max_folds: DEFAULT_MAX_FOLDS,
70 min_signal_abs: DEFAULT_BACKTEST_MIN_SIGNAL_STRENGTH,
71 regime_gate_enabled: false,
72 predictor_ewma_alpha_mean: 0.18,
73 predictor_ewma_alpha_var: 0.18,
74 predictor_min_sigma: 0.001,
75 predictor_mu: 0.0,
76 predictor_sigma: 0.01,
77 run_seed: 0,
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct WalkWindow {
84 pub fold: usize,
85 pub train_start: usize,
86 pub train_end: usize,
87 pub test_start: usize,
88 pub test_end: usize,
89}
90
91#[derive(Debug, Clone)]
92pub struct BacktestMetrics {
93 pub realized_pnl_usdt: f64,
94 pub total_fees_usdt: f64,
95 pub trade_count: u64,
96 pub win_count: u64,
97 pub lose_count: u64,
98 pub max_drawdown: f64,
99 pub sharpe_like: f64,
100 pub end_equity_usdt: f64,
101}
102
103#[derive(Debug, Clone)]
104pub struct BacktestFoldResult {
105 pub fold: usize,
106 pub train_bars: usize,
107 pub test_bars: usize,
108 pub metrics: BacktestMetrics,
109 pub train_start_timestamp_ms: u64,
110 pub train_end_timestamp_ms: u64,
111 pub start_timestamp_ms: u64,
112 pub end_timestamp_ms: u64,
113}
114
115#[derive(Debug, Clone)]
116pub struct BacktestResult {
117 pub run_id: String,
118 pub symbol: String,
119 pub total_bars: usize,
120 pub folds: Vec<BacktestFoldResult>,
121 pub metrics: BacktestMetrics,
122 pub run_started_ms: u64,
123 pub run_finished_ms: u64,
124}
125
126#[derive(Debug, Clone)]
127pub struct BacktestOrderLedgerRow {
128 pub run_id: String,
129 pub fold: usize,
130 pub order_index: u64,
131 pub source: String,
132 pub bar_idx: usize,
133 pub timestamp_ms: u64,
134 pub side: String,
135 pub target_ratio: f64,
136 pub current_ratio: f64,
137 pub qty: f64,
138 pub price: f64,
139 pub fee_usdt: f64,
140 pub pnl_realized_usdt: f64,
141 pub reason: String,
142}
143
144#[derive(Debug, Clone)]
145pub struct CandleFeed {
146 pub symbol: String,
147 pub interval_ms: u64,
148 pub bars: Vec<Candle>,
149}
150
151#[derive(Debug)]
152struct BacktestPosition {
153 qty: f64,
154 entry_price: f64,
155 cost_quote: f64,
156 realized_pnl: f64,
157 unrealized_pnl: f64,
158}
159
160impl BacktestPosition {
161 fn new() -> Self {
162 Self {
163 qty: 0.0,
164 entry_price: 0.0,
165 cost_quote: 0.0,
166 realized_pnl: 0.0,
167 unrealized_pnl: 0.0,
168 }
169 }
170
171 fn is_flat(&self) -> bool {
172 self.qty <= f64::EPSILON
173 }
174
175 fn notional(&self, price: f64) -> f64 {
176 self.qty * price
177 }
178
179 fn apply_fill(&mut self, side: Signal, qty: f64, price: f64, fee: f64) -> f64 {
180 match side {
181 Signal::Buy => {
182 if qty <= f64::EPSILON {
183 return 0.0;
184 }
185
186 self.qty += qty;
187 self.cost_quote += qty * price + fee;
188 self.entry_price = self.cost_quote / self.qty.max(f64::EPSILON);
189 0.0
190 }
191 Signal::Sell => {
192 if self.qty <= f64::EPSILON || qty <= f64::EPSILON {
193 return 0.0;
194 }
195
196 let close_qty = qty.min(self.qty);
197 let avg_cost = self.cost_quote / self.qty.max(f64::EPSILON);
198 let pnl = close_qty * price - close_qty * avg_cost - fee;
199 self.realized_pnl += pnl;
200 self.qty -= close_qty;
201 self.cost_quote -= close_qty * avg_cost;
202
203 if self.qty <= f64::EPSILON {
204 self.qty = 0.0;
205 self.entry_price = 0.0;
206 self.cost_quote = 0.0;
207 } else {
208 self.entry_price = self.cost_quote / self.qty.max(f64::EPSILON);
209 }
210 pnl
211 }
212 Signal::Hold => 0.0,
213 }
214 }
215
216 fn update_unrealized(&mut self, price: f64) {
217 if self.is_flat() {
218 self.unrealized_pnl = 0.0;
219 } else {
220 self.unrealized_pnl = (price - self.entry_price) * self.qty;
221 }
222 }
223
224 fn total_equity(&mut self, close: f64) -> f64 {
225 self.update_unrealized(close);
226 self.realized_pnl + self.unrealized_pnl
227 }
228}
229
230pub fn infer_interval_ms(bars: &[Candle]) -> u64 {
231 if bars.len() < 2 {
232 return 60_000;
233 }
234 let mut diffs: Vec<u64> = bars
235 .windows(2)
236 .filter_map(|w| w[1].open_time.checked_sub(w[0].open_time))
237 .filter(|d| *d > 0)
238 .collect();
239 if diffs.is_empty() {
240 return 60_000;
241 }
242 diffs.sort_unstable();
243 diffs[diffs.len() / 2]
244}
245
246pub fn parse_candle_csv(symbol: &str, path: &Path) -> Result<CandleFeed> {
247 let file = File::open(path).with_context(|| format!("open candle csv: {}", path.display()))?;
248 let reader = BufReader::new(file);
249
250 let mut bars = Vec::new();
251 let mut idx_ts = 0usize;
252 let mut idx_open = 1usize;
253 let mut idx_high = 2usize;
254 let mut idx_low = 3usize;
255 let mut idx_close = 4usize;
256 let mut line_no = 0usize;
257 let mut has_header = false;
258
259 let resolve_idx = |header: &[String], names: &[&str]| -> Option<usize> {
260 names
261 .iter()
262 .find_map(|name| header.iter().position(|h| h == name))
263 };
264
265 for raw in reader.lines() {
266 let line = raw?;
267 let trimmed = line.trim();
268 if trimmed.is_empty() {
269 line_no += 1;
270 continue;
271 }
272
273 let fields: Vec<&str> = trimmed.split(',').map(|v| v.trim()).collect();
274
275 if !has_header {
276 let is_header = fields.first().is_some_and(|f| f.parse::<f64>().is_err());
277
278 if is_header {
279 let h = fields
280 .iter()
281 .map(|v| v.to_ascii_lowercase())
282 .collect::<Vec<_>>();
283 idx_ts = resolve_idx(&h, &["open_time", "timestamp_ms", "time"])
284 .or_else(|| resolve_idx(&h, &["ts"]))
285 .unwrap_or(0);
286 idx_open = resolve_idx(&h, &["open"]).unwrap_or(1);
287 idx_high = resolve_idx(&h, &["high"]).unwrap_or(2);
288 idx_low = resolve_idx(&h, &["low"]).unwrap_or(3);
289 idx_close = resolve_idx(&h, &["close"]).unwrap_or(4);
290 has_header = true;
291 line_no += 1;
292 continue;
293 }
294 has_header = true;
295 }
296
297 if fields.len() <= idx_close {
298 return Err(anyhow!("invalid row #{}: too few fields", line_no + 1));
299 }
300
301 let open_time = parse_u64(fields[idx_ts], line_no)?;
302 let open = parse_f64(fields[idx_open], "open", line_no)?;
303 let high = parse_f64(fields[idx_high], "high", line_no)?;
304 let low = parse_f64(fields[idx_low], "low", line_no)?;
305 let close = parse_f64(fields[idx_close], "close", line_no)?;
306 if open <= 0.0 || high <= 0.0 || low <= 0.0 || close <= 0.0 {
307 return Err(anyhow!("invalid row #{}: non-positive price", line_no + 1));
308 }
309
310 bars.push(Candle {
311 open,
312 high,
313 low,
314 close,
315 open_time,
316 close_time: 0,
317 });
318 line_no += 1;
319 }
320
321 if bars.len() < 2 {
322 return Err(anyhow!("need at least 2 valid rows, got {}", bars.len()));
323 }
324
325 let interval_ms = infer_interval_ms(&bars);
326 for i in 0..bars.len().saturating_sub(1) {
327 bars[i].close_time = bars[i + 1].open_time;
328 }
329 let last_close_time = bars
330 .last()
331 .map(|c| c.open_time + interval_ms)
332 .unwrap_or_default();
333 if let Some(last) = bars.last_mut() {
334 last.close_time = last_close_time;
335 }
336
337 Ok(CandleFeed {
338 symbol: symbol.to_string(),
339 interval_ms,
340 bars,
341 })
342}
343
344fn parse_u64(value: &str, line_no: usize) -> Result<u64> {
345 value
346 .parse::<u64>()
347 .or_else(|_| value.parse::<f64>().map(|v| v.max(0.0) as u64))
348 .with_context(|| format!("line {}: parse timestamp", line_no + 1))
349}
350
351fn parse_f64(value: &str, field: &str, line_no: usize) -> Result<f64> {
352 value
353 .parse::<f64>()
354 .with_context(|| format!("line {}: parse {} {}", line_no + 1, field, value))
355}
356
357pub fn build_walk_forward_windows(total_bars: usize, cfg: &BacktestConfig) -> Vec<WalkWindow> {
358 if total_bars < cfg.train_window + cfg.test_window + cfg.embargo_window
359 || cfg.train_window == 0
360 || cfg.test_window == 0
361 {
362 return Vec::new();
363 }
364
365 let mut folds = Vec::new();
366 let mut train_start = 0usize;
367
368 for fold_idx in 0..cfg.max_folds {
369 let train_end = train_start + cfg.train_window;
370 let test_start = train_end + cfg.embargo_window;
371 let test_end = test_start + cfg.test_window;
372
373 if test_end > total_bars {
374 break;
375 }
376
377 folds.push(WalkWindow {
378 fold: fold_idx,
379 train_start,
380 train_end,
381 test_start,
382 test_end,
383 });
384
385 train_start = test_end;
386 }
387
388 folds
389}
390
391fn default_regime_signal(now_ms: u64) -> MarketRegimeSignal {
392 MarketRegimeSignal {
393 regime: MarketRegime::TrendUp,
394 confidence: 1.0,
395 ema_fast: 0.0,
396 ema_slow: 0.0,
397 vol_ratio: 0.0,
398 slope: 0.0,
399 updated_at_ms: now_ms,
400 }
401}
402
403fn build_predictor_specs(cfg: &BacktestConfig) -> Vec<(String, crate::predictor::PredictorConfig)> {
404 let base_cfg = PredictorBaseConfig {
405 alpha_mean: cfg.predictor_ewma_alpha_mean,
406 alpha_var: cfg.predictor_ewma_alpha_var,
407 min_sigma: cfg.predictor_min_sigma,
408 };
409 default_predictor_specs(base_cfg)
410}
411
412pub fn run_walk_forward_backtest(
413 cfg: &BacktestConfig,
414 feed: &CandleFeed,
415) -> Result<BacktestResult> {
416 let required = cfg.train_window + cfg.test_window + cfg.embargo_window;
417 if feed.bars.len() < required {
418 return Err(anyhow!(
419 "insufficient bars: got {}, need at least {}",
420 feed.bars.len(),
421 required
422 ));
423 }
424
425 let run_started_ms = Utc::now().timestamp_millis() as u64;
426 let run_id = format!("bt-{}-{}", cfg.symbol, run_started_ms);
427 let windows = build_walk_forward_windows(feed.bars.len(), cfg);
428 if windows.is_empty() {
429 return Err(anyhow!("no walk-forward folds available"));
430 }
431
432 ensure_strategy_db(&cfg.strategy_db_path)?;
433 ensure_order_db(&cfg.order_db_path)?;
434
435 let mut order_rows: Vec<BacktestOrderLedgerRow> = Vec::new();
436 let mut fold_results: Vec<BacktestFoldResult> = Vec::new();
437 let mut run_fold_metrics = Vec::new();
438 let mut next_order_index = 0u64;
439
440 for window in windows.iter() {
441 let mut models: HashMap<String, PredictorModel> =
442 build_predictor_models(&build_predictor_specs(cfg));
443 let mut vol_state = PredictorEvalVolState::default();
444 let mut regime_detector = RegimeDetector::new(RegimeDetectorConfig::default());
445 let mut position = BacktestPosition::new();
446
447 let mut fold_fees = 0.0;
448 let mut win_count = 0u64;
449 let mut lose_count = 0u64;
450 let mut trade_count = 0u64;
451 let mut fold_orders = Vec::new();
452 let mut fold_equity_curve = Vec::new();
453 let mut prev_equity = cfg.order_amount_usdt;
454
455 for idx in window.train_start..window.train_end {
456 let close = feed.bars[idx].close;
457 if close > f64::EPSILON {
458 for m in models.values_mut() {
459 m.observe_price(&cfg.symbol, close);
460 }
461 observe_predictor_eval_volatility(
462 &mut vol_state,
463 close,
464 cfg.predictor_ewma_alpha_var,
465 );
466 }
467 }
468
469 fold_equity_curve.push(prev_equity);
470 for idx in window.test_start..window.test_end {
471 let bar = &feed.bars[idx];
472 let close = bar.close;
473 let now_ms = bar.close_time;
474
475 if close <= f64::EPSILON {
476 fold_equity_curve.push(prev_equity);
477 continue;
478 }
479
480 let regime = if cfg.regime_gate_enabled {
481 regime_detector.update(close, now_ms)
482 } else {
483 default_regime_signal(now_ms)
484 };
485
486 for m in models.values_mut() {
487 m.observe_price(&cfg.symbol, close);
488 }
489 observe_predictor_eval_volatility(&mut vol_state, close, cfg.predictor_ewma_alpha_var);
490 let norm_scale = predictor_eval_scale(&vol_state, cfg.predictor_min_sigma)
491 .max(cfg.predictor_min_sigma);
492
493 let mut selected = ("".to_string(), 0.0f64);
494 for (name, model) in models.iter() {
495 let pred = model.estimate_base(
496 &cfg.symbol,
497 cfg.predictor_mu,
498 cfg.predictor_sigma.max(cfg.predictor_min_sigma),
499 );
500 if pred.sigma <= 0.0 {
501 continue;
502 }
503 let normalized_alpha = pred.mu / norm_scale;
504 if normalized_alpha.abs() > selected.1.abs() {
505 selected = (name.clone(), normalized_alpha);
506 }
507 }
508
509 if selected.0.is_empty() || selected.1.abs() < cfg.min_signal_abs {
510 fold_equity_curve.push(prev_equity);
511 continue;
512 }
513
514 let alpha_mu = selected.1;
515 let current_ratio = (position.notional(close)
516 / cfg.order_amount_usdt.max(f64::EPSILON))
517 .clamp(0.0, 1.0);
518 let decision = decide_portfolio_action_from_alpha(
519 &cfg.symbol,
520 now_ms,
521 position.is_flat(),
522 alpha_mu,
523 cfg.order_amount_usdt,
524 regime,
525 RegimeDecisionConfig {
526 enabled: cfg.regime_gate_enabled,
527 confidence_min: 0.0,
528 entry_multiplier_trend_up: 1.0,
529 entry_multiplier_range: 1.0,
530 entry_multiplier_trend_down: 1.0,
531 entry_multiplier_unknown: 1.0,
532 hold_multiplier_trend_up: 1.0,
533 hold_multiplier_range: 1.0,
534 hold_multiplier_trend_down: 1.0,
535 hold_multiplier_unknown: 1.0,
536 },
537 );
538
539 if decision.target_position_ratio < PORTFOLIO_MIN_ENTRY_RATIO {
540 fold_equity_curve.push(prev_equity);
541 continue;
542 }
543
544 let intent = decision.to_intent("bt", cfg.order_amount_usdt, current_ratio);
545 let signal = intent.effective_signal(PORTFOLIO_REBALANCE_MIN_DELTA);
546 if signal == Signal::Hold {
547 fold_equity_curve.push(prev_equity);
548 continue;
549 }
550
551 let target_qty =
552 (cfg.order_amount_usdt * decision.target_position_ratio) / close.max(f64::EPSILON);
553 let current_qty = position.qty;
554 let delta_qty = target_qty - current_qty;
555 let fill_qty = delta_qty.abs();
556 if fill_qty <= f64::EPSILON {
557 fold_equity_curve.push(prev_equity);
558 continue;
559 }
560
561 let slippage = cfg.slippage_bps / 10_000.0;
562 let fill_price = match signal {
563 Signal::Buy => close * (1.0 + slippage),
564 Signal::Sell => close * (1.0 - slippage),
565 Signal::Hold => close,
566 };
567 let notional = fill_qty * fill_price;
568 let fee = notional * cfg.fee_rate;
569 let pnl_realized = position.apply_fill(signal, fill_qty, fill_price, fee);
570
571 fold_fees += fee;
572 if signal == Signal::Buy {
573 trade_count += 1;
574 } else if signal == Signal::Sell {
575 trade_count += 1;
576 if pnl_realized >= 0.0 {
577 win_count += 1;
578 } else {
579 lose_count += 1;
580 }
581 }
582
583 fold_orders.push(BacktestOrderLedgerRow {
584 run_id: run_id.clone(),
585 fold: window.fold,
586 order_index: next_order_index,
587 source: "bt".to_string(),
588 bar_idx: idx,
589 timestamp_ms: now_ms,
590 side: match signal {
591 Signal::Buy => "BUY".to_string(),
592 Signal::Sell => "SELL".to_string(),
593 Signal::Hold => "HOLD".to_string(),
594 },
595 target_ratio: decision.target_position_ratio,
596 current_ratio,
597 qty: fill_qty,
598 price: fill_price,
599 fee_usdt: fee,
600 pnl_realized_usdt: pnl_realized,
601 reason: decision.reason.to_string(),
602 });
603 next_order_index = next_order_index.saturating_add(1);
604
605 position.update_unrealized(close);
606 let equity = cfg.order_amount_usdt + position.total_equity(close);
607 fold_equity_curve.push(equity);
608 prev_equity = equity;
609 }
610
611 let returns: Vec<f64> = fold_equity_curve
612 .windows(2)
613 .filter_map(|w| {
614 if w[0].abs() <= f64::EPSILON {
615 None
616 } else {
617 Some((w[1] - w[0]) / w[0])
618 }
619 })
620 .collect();
621
622 let fold_metrics = BacktestMetrics {
623 realized_pnl_usdt: position.realized_pnl,
624 total_fees_usdt: fold_fees,
625 trade_count,
626 win_count,
627 lose_count,
628 max_drawdown: max_drawdown(&fold_equity_curve),
629 sharpe_like: sharpe_like(&returns),
630 end_equity_usdt: fold_equity_curve
631 .last()
632 .copied()
633 .unwrap_or(cfg.order_amount_usdt),
634 };
635
636 fold_results.push(BacktestFoldResult {
637 fold: window.fold,
638 train_bars: window.train_end - window.train_start,
639 test_bars: window.test_end - window.test_start,
640 metrics: fold_metrics.clone(),
641 train_start_timestamp_ms: feed.bars[window.train_start].open_time,
642 train_end_timestamp_ms: feed.bars[window.train_end - 1].open_time,
643 start_timestamp_ms: feed.bars[window.test_start].open_time,
644 end_timestamp_ms: feed.bars[window.test_end - 1].open_time,
645 });
646 run_fold_metrics.push(fold_metrics);
647 order_rows.extend(fold_orders);
648 }
649
650 persist_run_meta(
651 &cfg.strategy_db_path,
652 &run_id,
653 cfg,
654 &feed.symbol,
655 run_started_ms,
656 Utc::now().timestamp_millis() as u64,
657 &fold_results,
658 )?;
659 persist_fold_results(&cfg.strategy_db_path, &run_id, &fold_results)?;
660 persist_order_ledger(&cfg.order_db_path, &order_rows)?;
661
662 let total = summarize_metrics(&run_fold_metrics);
663 Ok(BacktestResult {
664 run_id,
665 symbol: feed.symbol.clone(),
666 total_bars: feed.bars.len(),
667 folds: fold_results,
668 metrics: total,
669 run_started_ms,
670 run_finished_ms: Utc::now().timestamp_millis() as u64,
671 })
672}
673
674pub fn parse_backtest_args(args: &[String]) -> Result<BacktestConfig> {
675 let mut cfg = BacktestConfig::default();
676 let mut i = 0usize;
677 let mut bars_set = false;
678
679 while i < args.len() {
680 match args[i].as_str() {
681 "--symbol" if i + 1 < args.len() => {
682 cfg.symbol = args[i + 1].to_string();
683 i += 2;
684 }
685 "--bars" if i + 1 < args.len() => {
686 cfg.bars_csv = PathBuf::from(&args[i + 1]);
687 bars_set = true;
688 i += 2;
689 }
690 "--strategy-db" if i + 1 < args.len() => {
691 cfg.strategy_db_path = PathBuf::from(&args[i + 1]);
692 i += 2;
693 }
694 "--order-db" if i + 1 < args.len() => {
695 cfg.order_db_path = PathBuf::from(&args[i + 1]);
696 i += 2;
697 }
698 "--order-usdt" if i + 1 < args.len() => {
699 cfg.order_amount_usdt = args[i + 1].parse()?;
700 i += 2;
701 }
702 "--fee-rate" if i + 1 < args.len() => {
703 cfg.fee_rate = args[i + 1].parse()?;
704 i += 2;
705 }
706 "--slippage-bps" if i + 1 < args.len() => {
707 cfg.slippage_bps = args[i + 1].parse()?;
708 i += 2;
709 }
710 "--train-bars" if i + 1 < args.len() => {
711 cfg.train_window = args[i + 1].parse()?;
712 i += 2;
713 }
714 "--test-bars" if i + 1 < args.len() => {
715 cfg.test_window = args[i + 1].parse()?;
716 i += 2;
717 }
718 "--embargo-bars" if i + 1 < args.len() => {
719 cfg.embargo_window = args[i + 1].parse()?;
720 i += 2;
721 }
722 "--max-folds" if i + 1 < args.len() => {
723 cfg.max_folds = args[i + 1].parse()?;
724 i += 2;
725 }
726 "--min-signal" if i + 1 < args.len() => {
727 cfg.min_signal_abs = args[i + 1].parse()?;
728 i += 2;
729 }
730 "--regime-gate" => {
731 cfg.regime_gate_enabled = true;
732 i += 1;
733 }
734 "--help" => {
735 return Err(anyhow!(print_backtest_usage()));
736 }
737 _ => {
738 return Err(anyhow!("unknown arg '{}'", args[i]));
739 }
740 }
741 }
742
743 if !cfg.bars_csv.exists() {
744 if !bars_set {
745 return Err(anyhow!("--bars is required"));
746 }
747 return Err(anyhow!("bars file not found: {}", cfg.bars_csv.display()));
748 }
749
750 if cfg.train_window == 0 || cfg.test_window == 0 {
751 return Err(anyhow!("train/test windows must be > 0"));
752 }
753 if cfg.max_folds == 0 {
754 return Err(anyhow!("max-folds must be > 0"));
755 }
756
757 Ok(cfg)
758}
759
760pub fn print_backtest_usage() -> String {
761 let help = [
762 "USAGE: backtest [--symbol SYMBOL] --bars FILE [--strategy-db PATH] [--order-db PATH]",
763 " [--order-usdt AMOUNT] [--fee-rate RATE] [--slippage-bps BPS]",
764 " [--train-bars N] [--test-bars N] [--embargo-bars N]",
765 " [--max-folds N] [--min-signal N] [--regime-gate]",
766 ];
767 help.join("\n")
768}
769
770fn summarize_metrics(folds: &[BacktestMetrics]) -> BacktestMetrics {
771 let mut realized = 0.0;
772 let mut fees = 0.0;
773 let mut trade_count = 0u64;
774 let mut win_count = 0u64;
775 let mut lose_count = 0u64;
776 let mut max_dd = 0.0;
777 let mut sharpe_sum = 0.0;
778 let mut sharpe_count = 0u64;
779 let mut end_equity = 0.0;
780
781 for fold in folds {
782 realized += fold.realized_pnl_usdt;
783 fees += fold.total_fees_usdt;
784 trade_count += fold.trade_count;
785 win_count += fold.win_count;
786 lose_count += fold.lose_count;
787 if fold.max_drawdown > max_dd {
788 max_dd = fold.max_drawdown;
789 }
790 if fold.sharpe_like.is_finite() {
791 sharpe_sum += fold.sharpe_like;
792 sharpe_count += 1;
793 }
794 if folds.last().is_some_and(|last| std::ptr::eq(last, fold)) {
795 end_equity = fold.end_equity_usdt;
796 }
797 }
798
799 BacktestMetrics {
800 realized_pnl_usdt: realized,
801 total_fees_usdt: fees,
802 trade_count,
803 win_count,
804 lose_count,
805 max_drawdown: max_dd,
806 sharpe_like: if sharpe_count == 0 {
807 0.0
808 } else {
809 sharpe_sum / sharpe_count as f64
810 },
811 end_equity_usdt: end_equity,
812 }
813}
814
815fn max_drawdown(equity: &[f64]) -> f64 {
816 let mut peak = f64::NEG_INFINITY;
817 let mut max_dd = 0.0;
818
819 for &value in equity {
820 if value > peak {
821 peak = value;
822 }
823 if peak > 0.0 && value < peak {
824 let dd = (peak - value) / peak;
825 if dd > max_dd {
826 max_dd = dd;
827 }
828 }
829 }
830 max_dd
831}
832
833fn sharpe_like(returns: &[f64]) -> f64 {
834 if returns.len() < 2 {
835 return 0.0;
836 }
837 let mean = returns.iter().sum::<f64>() / returns.len() as f64;
838 let var = returns.iter().map(|r| (*r - mean).powi(2)).sum::<f64>() / returns.len() as f64;
839 let sd = var.sqrt();
840 if sd <= f64::EPSILON {
841 0.0
842 } else {
843 mean / sd
844 }
845}
846
847fn ensure_strategy_db(path: &Path) -> Result<()> {
848 if let Some(parent) = path.parent() {
849 std::fs::create_dir_all(parent)?;
850 }
851
852 let conn = Connection::open(path)?;
853 conn.execute_batch(
854 r#"
855 CREATE TABLE IF NOT EXISTS backtest_runs (
856 run_id TEXT PRIMARY KEY,
857 symbol TEXT NOT NULL,
858 started_at_ms INTEGER NOT NULL,
859 finished_at_ms INTEGER NOT NULL,
860 config_json TEXT NOT NULL,
861 total_bars INTEGER NOT NULL,
862 folds INTEGER NOT NULL,
863 created_at_ms INTEGER NOT NULL
864 );
865
866 CREATE TABLE IF NOT EXISTS backtest_fold_results (
867 run_id TEXT NOT NULL,
868 fold_idx INTEGER NOT NULL,
869 train_bars INTEGER NOT NULL,
870 test_bars INTEGER NOT NULL,
871 realized_pnl_usdt REAL NOT NULL,
872 total_fees_usdt REAL NOT NULL,
873 trade_count INTEGER NOT NULL,
874 win_count INTEGER NOT NULL,
875 lose_count INTEGER NOT NULL,
876 max_drawdown REAL NOT NULL,
877 sharpe_like REAL NOT NULL,
878 end_equity_usdt REAL NOT NULL,
879 train_start_ms INTEGER NOT NULL,
880 train_end_ms INTEGER NOT NULL,
881 test_start_ms INTEGER NOT NULL,
882 test_end_ms INTEGER NOT NULL,
883 PRIMARY KEY(run_id, fold_idx)
884 );
885 "#,
886 )?;
887 Ok(())
888}
889
890fn ensure_order_db(path: &Path) -> Result<()> {
891 if let Some(parent) = path.parent() {
892 std::fs::create_dir_all(parent)?;
893 }
894
895 let conn = Connection::open(path)?;
896 conn.execute_batch(
897 r#"
898 CREATE TABLE IF NOT EXISTS backtest_orders (
899 run_id TEXT NOT NULL,
900 fold INTEGER NOT NULL,
901 order_index INTEGER NOT NULL,
902 source TEXT NOT NULL,
903 bar_idx INTEGER NOT NULL,
904 timestamp_ms INTEGER NOT NULL,
905 side TEXT NOT NULL,
906 target_ratio REAL NOT NULL,
907 current_ratio REAL NOT NULL,
908 qty REAL NOT NULL,
909 price REAL NOT NULL,
910 fee_usdt REAL NOT NULL,
911 pnl_realized_usdt REAL NOT NULL,
912 reason TEXT NOT NULL,
913 created_at_ms INTEGER NOT NULL,
914 PRIMARY KEY(run_id, fold, order_index)
915 );
916 "#,
917 )?;
918 Ok(())
919}
920
921fn persist_run_meta(
922 path: &Path,
923 run_id: &str,
924 cfg: &BacktestConfig,
925 symbol: &str,
926 started_ms: u64,
927 finished_ms: u64,
928 folds: &[BacktestFoldResult],
929) -> Result<()> {
930 let conn = Connection::open(path)?;
931 let cfg_json = serde_json::json!({
932 "symbol": cfg.symbol,
933 "bars_csv": cfg.bars_csv,
934 "order_amount_usdt": cfg.order_amount_usdt,
935 "fee_rate": cfg.fee_rate,
936 "slippage_bps": cfg.slippage_bps,
937 "train_window": cfg.train_window,
938 "test_window": cfg.test_window,
939 "embargo_window": cfg.embargo_window,
940 "max_folds": cfg.max_folds,
941 "min_signal_abs": cfg.min_signal_abs,
942 "regime_gate_enabled": cfg.regime_gate_enabled,
943 });
944
945 conn.execute(
946 "INSERT OR REPLACE INTO backtest_runs (
947 run_id, symbol, started_at_ms, finished_at_ms, config_json, total_bars, folds, created_at_ms
948 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
949 params![
950 run_id,
951 symbol,
952 started_ms as i64,
953 finished_ms as i64,
954 cfg_json.to_string(),
955 folds.iter().map(|f| f.train_bars + f.test_bars).sum::<usize>() as i64,
956 folds.len() as i64,
957 Utc::now().timestamp_millis() as i64,
958 ],
959 )?;
960 Ok(())
961}
962
963fn persist_fold_results(path: &Path, run_id: &str, folds: &[BacktestFoldResult]) -> Result<()> {
964 let mut conn = Connection::open(path)?;
965 let tx = conn.transaction()?;
966 for fold in folds {
967 tx.execute(
968 "INSERT INTO backtest_fold_results (
969 run_id, fold_idx, train_bars, test_bars, realized_pnl_usdt, total_fees_usdt,
970 trade_count, win_count, lose_count, max_drawdown, sharpe_like, end_equity_usdt,
971 train_start_ms, train_end_ms, test_start_ms, test_end_ms
972 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16)",
973 params![
974 run_id,
975 fold.fold as i64,
976 fold.train_bars as i64,
977 fold.test_bars as i64,
978 fold.metrics.realized_pnl_usdt,
979 fold.metrics.total_fees_usdt,
980 fold.metrics.trade_count as i64,
981 fold.metrics.win_count as i64,
982 fold.metrics.lose_count as i64,
983 fold.metrics.max_drawdown,
984 fold.metrics.sharpe_like,
985 fold.metrics.end_equity_usdt,
986 fold.train_start_timestamp_ms as i64,
987 fold.train_end_timestamp_ms as i64,
988 fold.start_timestamp_ms as i64,
989 fold.end_timestamp_ms as i64,
990 ],
991 )?;
992 }
993 tx.commit()?;
994 Ok(())
995}
996
997fn persist_order_ledger(path: &Path, rows: &[BacktestOrderLedgerRow]) -> Result<()> {
998 let mut conn = Connection::open(path)?;
999 let tx = conn.transaction()?;
1000 for row in rows {
1001 tx.execute(
1002 "INSERT OR REPLACE INTO backtest_orders (
1003 run_id, fold, order_index, source, bar_idx, timestamp_ms, side,
1004 target_ratio, current_ratio, qty, price, fee_usdt, pnl_realized_usdt, reason, created_at_ms
1005 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)",
1006 params![
1007 row.run_id,
1008 row.fold as i64,
1009 row.order_index as i64,
1010 row.source,
1011 row.bar_idx as i64,
1012 row.timestamp_ms as i64,
1013 row.side,
1014 row.target_ratio,
1015 row.current_ratio,
1016 row.qty,
1017 row.price,
1018 row.fee_usdt,
1019 row.pnl_realized_usdt,
1020 row.reason,
1021 Utc::now().timestamp_millis() as i64,
1022 ],
1023 )?;
1024 }
1025 tx.commit()?;
1026 Ok(())
1027}