1use super::{
2 compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorDataRef,
3 IndicatorDispatchError, IndicatorParamSet, ParamKV, ParamValue,
4};
5use crate::indicators::dx::{dx_batch_with_kernel, DxBatchRange};
6use crate::indicators::mfi::{mfi_batch_with_kernel, MfiBatchRange};
7use crate::indicators::moving_averages::sma::{sma_batch_with_kernel, SmaBatchRange};
8use crate::indicators::registry::get_indicator;
9use crate::utilities::data_loader::source_type;
10use crate::utilities::enums::Kernel;
11
12#[cfg(feature = "cuda")]
13use super::{
14 compute_cuda, CudaOutputTarget, IndicatorCudaDataRef, IndicatorCudaOutput, IndicatorCudaRequest,
15};
16
17#[derive(Debug, Clone, PartialEq)]
18enum OwnedParamValue {
19 Int(i64),
20 Float(f64),
21 Bool(bool),
22 EnumString(String),
23}
24
25#[derive(Debug, Clone, PartialEq)]
26struct OwnedParamKV {
27 key: String,
28 value: OwnedParamValue,
29}
30
31#[derive(Debug, Clone, PartialEq)]
32enum CpuCompiledPlan {
33 Generic,
34 SmaPeriod { period: usize },
35 MfiPeriod { period: usize },
36 DxPeriod { period: usize },
37}
38
39#[derive(Debug, Clone, PartialEq)]
40pub struct CompiledIndicatorCall {
41 indicator_id: String,
42 output_id: Option<String>,
43 params: Vec<OwnedParamKV>,
44 cpu_plan: CpuCompiledPlan,
45 prefer_cuda: bool,
46}
47
48impl CompiledIndicatorCall {
49 pub fn indicator_id(&self) -> &str {
50 &self.indicator_id
51 }
52
53 pub fn output_id(&self) -> Option<&str> {
54 self.output_id.as_deref()
55 }
56
57 pub fn prefer_cuda(&self) -> bool {
58 self.prefer_cuda
59 }
60
61 fn as_param_kv(&self) -> Vec<ParamKV<'_>> {
62 let mut out = Vec::with_capacity(self.params.len());
63 for p in &self.params {
64 let value = match &p.value {
65 OwnedParamValue::Int(v) => ParamValue::Int(*v),
66 OwnedParamValue::Float(v) => ParamValue::Float(*v),
67 OwnedParamValue::Bool(v) => ParamValue::Bool(*v),
68 OwnedParamValue::EnumString(v) => ParamValue::EnumString(v.as_str()),
69 };
70 out.push(ParamKV {
71 key: p.key.as_str(),
72 value,
73 });
74 }
75 out
76 }
77}
78
79fn parse_usize_param_value(
80 indicator: &str,
81 key: &str,
82 value: ParamValue<'_>,
83) -> Result<usize, IndicatorDispatchError> {
84 match value {
85 ParamValue::Int(v) => {
86 if v < 0 {
87 return Err(IndicatorDispatchError::InvalidParam {
88 indicator: indicator.to_string(),
89 key: key.to_string(),
90 reason: "expected integer >= 0".to_string(),
91 });
92 }
93 Ok(v as usize)
94 }
95 ParamValue::Float(v) => {
96 if !v.is_finite() {
97 return Err(IndicatorDispatchError::InvalidParam {
98 indicator: indicator.to_string(),
99 key: key.to_string(),
100 reason: "expected finite number".to_string(),
101 });
102 }
103 if v < 0.0 {
104 return Err(IndicatorDispatchError::InvalidParam {
105 indicator: indicator.to_string(),
106 key: key.to_string(),
107 reason: "expected number >= 0".to_string(),
108 });
109 }
110 let rounded = v.round();
111 if (v - rounded).abs() > 1e-9 {
112 return Err(IndicatorDispatchError::InvalidParam {
113 indicator: indicator.to_string(),
114 key: key.to_string(),
115 reason: "expected whole number".to_string(),
116 });
117 }
118 Ok(rounded as usize)
119 }
120 _ => Err(IndicatorDispatchError::InvalidParam {
121 indicator: indicator.to_string(),
122 key: key.to_string(),
123 reason: "expected Int or Float".to_string(),
124 }),
125 }
126}
127
128fn compile_period_only_plan(
129 indicator: &str,
130 selected_output: Option<&str>,
131 params: &[ParamKV<'_>],
132) -> Result<CpuCompiledPlan, IndicatorDispatchError> {
133 let supports_fast_route = indicator.eq_ignore_ascii_case("sma")
134 || indicator.eq_ignore_ascii_case("mfi")
135 || indicator.eq_ignore_ascii_case("dx");
136 if !supports_fast_route {
137 return Ok(CpuCompiledPlan::Generic);
138 }
139
140 let is_value = selected_output
141 .map(|out| out.eq_ignore_ascii_case("value"))
142 .unwrap_or(true);
143 if !is_value {
144 return Ok(CpuCompiledPlan::Generic);
145 }
146
147 let mut period: Option<usize> = None;
148 for p in params {
149 if p.key.eq_ignore_ascii_case("period") {
150 period = Some(parse_usize_param_value(indicator, "period", p.value)?);
151 } else {
152 return Ok(CpuCompiledPlan::Generic);
153 }
154 }
155 let period = period.unwrap_or(14);
156
157 if indicator.eq_ignore_ascii_case("sma") {
158 return Ok(CpuCompiledPlan::SmaPeriod { period });
159 }
160 if indicator.eq_ignore_ascii_case("mfi") {
161 return Ok(CpuCompiledPlan::MfiPeriod { period });
162 }
163 if indicator.eq_ignore_ascii_case("dx") {
164 return Ok(CpuCompiledPlan::DxPeriod { period });
165 }
166 Ok(CpuCompiledPlan::Generic)
167}
168
169pub fn compile_call(
170 indicator_id: &str,
171 output_id: Option<&str>,
172 params: &[ParamKV<'_>],
173 prefer_cuda: bool,
174) -> Result<CompiledIndicatorCall, IndicatorDispatchError> {
175 let info =
176 get_indicator(indicator_id).ok_or_else(|| IndicatorDispatchError::UnknownIndicator {
177 id: indicator_id.to_string(),
178 })?;
179
180 if info.outputs.len() > 1 && output_id.is_none() {
181 return Err(IndicatorDispatchError::InvalidParam {
182 indicator: info.id.to_string(),
183 key: "output_id".to_string(),
184 reason: "output_id is required for multi-output indicators".to_string(),
185 });
186 }
187
188 if let Some(out_id) = output_id {
189 let exists = info
190 .outputs
191 .iter()
192 .any(|out| out.id.eq_ignore_ascii_case(out_id));
193 if !exists {
194 return Err(IndicatorDispatchError::UnknownOutput {
195 indicator: info.id.to_string(),
196 output: out_id.to_string(),
197 });
198 }
199 }
200
201 if prefer_cuda && !info.capabilities.supports_cuda_batch {
202 return Err(IndicatorDispatchError::UnsupportedCapability {
203 indicator: info.id.to_string(),
204 capability: "cuda_batch",
205 });
206 }
207
208 let mut owned_params = Vec::with_capacity(params.len());
209 for param in params {
210 let value = match param.value {
211 ParamValue::Int(v) => OwnedParamValue::Int(v),
212 ParamValue::Float(v) => OwnedParamValue::Float(v),
213 ParamValue::Bool(v) => OwnedParamValue::Bool(v),
214 ParamValue::EnumString(v) => OwnedParamValue::EnumString(v.to_string()),
215 };
216 owned_params.push(OwnedParamKV {
217 key: param.key.to_string(),
218 value,
219 });
220 }
221
222 let selected_output = output_id.or_else(|| {
223 if info.outputs.len() == 1 {
224 Some(info.outputs[0].id)
225 } else {
226 None
227 }
228 });
229 let cpu_plan = compile_period_only_plan(info.id, selected_output, params)?;
230
231 Ok(CompiledIndicatorCall {
232 indicator_id: info.id.to_string(),
233 output_id: output_id.map(str::to_string),
234 params: owned_params,
235 cpu_plan,
236 prefer_cuda,
237 })
238}
239
240pub fn run_compiled_cpu(
241 call: &CompiledIndicatorCall,
242 data: IndicatorDataRef<'_>,
243 kernel: Kernel,
244) -> Result<IndicatorBatchOutput, IndicatorDispatchError> {
245 match call.cpu_plan {
246 CpuCompiledPlan::SmaPeriod { period } => {
247 let series = match data {
248 IndicatorDataRef::Slice { values } => values,
249 IndicatorDataRef::Candles { candles, source } => {
250 source_type(candles, source.unwrap_or("close"))
251 }
252 IndicatorDataRef::Ohlc { close, .. } => close,
253 IndicatorDataRef::Ohlcv { close, .. } => close,
254 IndicatorDataRef::CloseVolume { close, .. } => close,
255 IndicatorDataRef::HighLow { .. } => {
256 return Err(IndicatorDispatchError::MissingRequiredInput {
257 indicator: "sma".to_string(),
258 input: crate::indicators::registry::IndicatorInputKind::Slice,
259 });
260 }
261 };
262 let out = sma_batch_with_kernel(
263 series,
264 &SmaBatchRange {
265 period: (period, period, 0),
266 },
267 to_batch_kernel(kernel),
268 )
269 .map_err(|e| IndicatorDispatchError::ComputeFailed {
270 indicator: "sma".to_string(),
271 details: e.to_string(),
272 })?;
273 return Ok(f64_output(
274 call.output_id.as_deref().unwrap_or("value"),
275 out.rows,
276 out.cols,
277 out.values,
278 ));
279 }
280 CpuCompiledPlan::MfiPeriod { period } => {
281 let mut derived_typical_price: Option<Vec<f64>> = None;
282 let (typical_price, volume): (&[f64], &[f64]) = match data {
283 IndicatorDataRef::Candles { candles, source } => (
284 source_type(candles, source.unwrap_or("hlc3")),
285 candles.volume.as_slice(),
286 ),
287 IndicatorDataRef::Ohlcv {
288 open,
289 high,
290 low,
291 close,
292 volume,
293 } => {
294 ensure_same_len_5(
295 "mfi",
296 open.len(),
297 high.len(),
298 low.len(),
299 close.len(),
300 volume.len(),
301 )?;
302 derived_typical_price = Some(
303 high.iter()
304 .zip(low)
305 .zip(close)
306 .map(|((h, l), c)| (h + l + c) / 3.0)
307 .collect(),
308 );
309 (derived_typical_price.as_deref().unwrap_or(close), volume)
310 }
311 IndicatorDataRef::CloseVolume { close, volume } => {
312 ensure_same_len_2("mfi", close.len(), volume.len())?;
313 (close, volume)
314 }
315 _ => {
316 return Err(IndicatorDispatchError::MissingRequiredInput {
317 indicator: "mfi".to_string(),
318 input: crate::indicators::registry::IndicatorInputKind::CloseVolume,
319 });
320 }
321 };
322 let out = mfi_batch_with_kernel(
323 typical_price,
324 volume,
325 &MfiBatchRange {
326 period: (period, period, 0),
327 },
328 to_batch_kernel(kernel),
329 )
330 .map_err(|e| IndicatorDispatchError::ComputeFailed {
331 indicator: "mfi".to_string(),
332 details: e.to_string(),
333 })?;
334 return Ok(f64_output(
335 call.output_id.as_deref().unwrap_or("value"),
336 out.rows,
337 out.cols,
338 out.values,
339 ));
340 }
341 CpuCompiledPlan::DxPeriod { period } => {
342 let (high, low, close): (&[f64], &[f64], &[f64]) = match data {
343 IndicatorDataRef::Candles { candles, .. } => (
344 candles.high.as_slice(),
345 candles.low.as_slice(),
346 candles.close.as_slice(),
347 ),
348 IndicatorDataRef::Ohlc {
349 open,
350 high,
351 low,
352 close,
353 } => {
354 ensure_same_len_4("dx", open.len(), high.len(), low.len(), close.len())?;
355 (high, low, close)
356 }
357 IndicatorDataRef::Ohlcv {
358 open,
359 high,
360 low,
361 close,
362 volume,
363 } => {
364 ensure_same_len_5(
365 "dx",
366 open.len(),
367 high.len(),
368 low.len(),
369 close.len(),
370 volume.len(),
371 )?;
372 (high, low, close)
373 }
374 _ => {
375 return Err(IndicatorDispatchError::MissingRequiredInput {
376 indicator: "dx".to_string(),
377 input: crate::indicators::registry::IndicatorInputKind::Ohlc,
378 });
379 }
380 };
381 let out = dx_batch_with_kernel(
382 high,
383 low,
384 close,
385 &DxBatchRange {
386 period: (period, period, 0),
387 },
388 to_batch_kernel(kernel),
389 )
390 .map_err(|e| IndicatorDispatchError::ComputeFailed {
391 indicator: "dx".to_string(),
392 details: e.to_string(),
393 })?;
394 return Ok(f64_output(
395 call.output_id.as_deref().unwrap_or("value"),
396 out.rows,
397 out.cols,
398 out.values,
399 ));
400 }
401 CpuCompiledPlan::Generic => {}
402 }
403
404 let params = call.as_param_kv();
405 let combos = [IndicatorParamSet {
406 params: params.as_slice(),
407 }];
408 compute_cpu_batch(IndicatorBatchRequest {
409 indicator_id: call.indicator_id.as_str(),
410 output_id: call.output_id.as_deref(),
411 data,
412 combos: &combos,
413 kernel,
414 })
415}
416
417fn to_batch_kernel(kernel: Kernel) -> Kernel {
418 match kernel {
419 Kernel::Auto => Kernel::Auto,
420 Kernel::Scalar => Kernel::ScalarBatch,
421 Kernel::Avx2 => Kernel::Avx2Batch,
422 Kernel::Avx512 => Kernel::Avx512Batch,
423 other => other,
424 }
425}
426
427fn ensure_same_len_2(indicator: &str, a: usize, b: usize) -> Result<(), IndicatorDispatchError> {
428 if a == b {
429 return Ok(());
430 }
431 Err(IndicatorDispatchError::DataLengthMismatch {
432 details: format!("{indicator}: expected equal lengths, got {a} and {b}"),
433 })
434}
435
436fn ensure_same_len_4(
437 indicator: &str,
438 a: usize,
439 b: usize,
440 c: usize,
441 d: usize,
442) -> Result<(), IndicatorDispatchError> {
443 if a == b && b == c && c == d {
444 return Ok(());
445 }
446 Err(IndicatorDispatchError::DataLengthMismatch {
447 details: format!("{indicator}: expected equal lengths, got {a}, {b}, {c}, {d}"),
448 })
449}
450
451fn ensure_same_len_5(
452 indicator: &str,
453 a: usize,
454 b: usize,
455 c: usize,
456 d: usize,
457 e: usize,
458) -> Result<(), IndicatorDispatchError> {
459 if a == b && b == c && c == d && d == e {
460 return Ok(());
461 }
462 Err(IndicatorDispatchError::DataLengthMismatch {
463 details: format!("{indicator}: expected equal lengths, got {a}, {b}, {c}, {d}, {e}"),
464 })
465}
466
467fn f64_output(output_id: &str, rows: usize, cols: usize, values: Vec<f64>) -> IndicatorBatchOutput {
468 IndicatorBatchOutput {
469 output_id: output_id.to_string(),
470 rows,
471 cols,
472 values_f64: Some(values),
473 values_i32: None,
474 values_bool: None,
475 }
476}
477
478#[cfg(feature = "cuda")]
479pub fn run_compiled_cuda(
480 call: &CompiledIndicatorCall,
481 data: IndicatorCudaDataRef<'_>,
482 kernel: Kernel,
483 target: CudaOutputTarget,
484) -> Result<IndicatorCudaOutput, IndicatorDispatchError> {
485 let params = call.as_param_kv();
486 compute_cuda(IndicatorCudaRequest {
487 indicator_id: call.indicator_id.as_str(),
488 output_id: call.output_id.as_deref(),
489 data,
490 params: params.as_slice(),
491 kernel,
492 target,
493 })
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use crate::indicators::dispatch::{compute_cpu_batch, IndicatorBatchRequest};
500
501 fn sample_series() -> Vec<f64> {
502 (1..=128).map(|v| v as f64).collect()
503 }
504
505 fn sample_ohlc() -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
506 let open: Vec<f64> = (0..128).map(|i| 100.0 + (i as f64 * 0.1)).collect();
507 let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
508 let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
509 let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
510 (open, high, low, close)
511 }
512
513 #[test]
514 fn compile_rejects_unknown_indicator() {
515 let err = compile_call("does_not_exist", Some("value"), &[], false).unwrap_err();
516 match err {
517 IndicatorDispatchError::UnknownIndicator { id } => assert_eq!(id, "does_not_exist"),
518 other => panic!("expected UnknownIndicator, got {other:?}"),
519 }
520 }
521
522 #[test]
523 fn compile_validates_output_id() {
524 let err = compile_call("sma", Some("hist"), &[], false).unwrap_err();
525 match err {
526 IndicatorDispatchError::UnknownOutput { indicator, output } => {
527 assert_eq!(indicator, "sma");
528 assert_eq!(output, "hist");
529 }
530 other => panic!("expected UnknownOutput, got {other:?}"),
531 }
532 }
533
534 #[test]
535 fn run_compiled_cpu_matches_direct_dispatch() {
536 let data = sample_series();
537 let params = [ParamKV {
538 key: "period",
539 value: ParamValue::Int(14),
540 }];
541 let call = compile_call("sma", Some("value"), ¶ms, false).unwrap();
542 let compiled = run_compiled_cpu(
543 &call,
544 IndicatorDataRef::Slice { values: &data },
545 Kernel::Auto,
546 )
547 .unwrap();
548
549 let combos = [IndicatorParamSet { params: ¶ms }];
550 let direct = compute_cpu_batch(IndicatorBatchRequest {
551 indicator_id: "sma",
552 output_id: Some("value"),
553 data: IndicatorDataRef::Slice { values: &data },
554 combos: &combos,
555 kernel: Kernel::Auto,
556 })
557 .unwrap();
558 assert_eq!(compiled.output_id, direct.output_id);
559 assert_eq!(compiled.rows, direct.rows);
560 assert_eq!(compiled.cols, direct.cols);
561 let compiled_values = compiled.values_f64.unwrap();
562 let direct_values = direct.values_f64.unwrap();
563 assert_eq!(compiled_values.len(), direct_values.len());
564 for i in 0..compiled_values.len() {
565 let a = compiled_values[i];
566 let b = direct_values[i];
567 if a.is_nan() && b.is_nan() {
568 continue;
569 }
570 assert!((a - b).abs() <= 1e-12, "mismatch at index {i}: {a} vs {b}");
571 }
572 }
573
574 #[test]
575 fn compile_pre_resolves_sma_period_plan() {
576 let params = [ParamKV {
577 key: "period",
578 value: ParamValue::Int(9),
579 }];
580 let call = compile_call("sma", Some("value"), ¶ms, false).unwrap();
581 assert!(matches!(
582 call.cpu_plan,
583 CpuCompiledPlan::SmaPeriod { period: 9 }
584 ));
585 }
586
587 #[test]
588 fn compile_falls_back_to_generic_when_params_are_not_period_only() {
589 let params = [
590 ParamKV {
591 key: "period",
592 value: ParamValue::Int(9),
593 },
594 ParamKV {
595 key: "unused",
596 value: ParamValue::Float(1.0),
597 },
598 ];
599 let call = compile_call("mfi", Some("value"), ¶ms, false).unwrap();
600 assert!(matches!(call.cpu_plan, CpuCompiledPlan::Generic));
601 }
602
603 #[test]
604 fn run_compiled_mfi_fast_plan_matches_dispatch() {
605 let (open, high, low, close) = sample_ohlc();
606 let volume: Vec<f64> = (0..close.len()).map(|i| 1000.0 + (i as f64)).collect();
607 let params = [ParamKV {
608 key: "period",
609 value: ParamValue::Int(14),
610 }];
611 let call = compile_call("mfi", Some("value"), ¶ms, false).unwrap();
612 let compiled = run_compiled_cpu(
613 &call,
614 IndicatorDataRef::Ohlcv {
615 open: &open,
616 high: &high,
617 low: &low,
618 close: &close,
619 volume: &volume,
620 },
621 Kernel::Auto,
622 )
623 .unwrap();
624 let combos = [IndicatorParamSet { params: ¶ms }];
625 let direct = compute_cpu_batch(IndicatorBatchRequest {
626 indicator_id: "mfi",
627 output_id: Some("value"),
628 data: IndicatorDataRef::Ohlcv {
629 open: &open,
630 high: &high,
631 low: &low,
632 close: &close,
633 volume: &volume,
634 },
635 combos: &combos,
636 kernel: Kernel::Auto,
637 })
638 .unwrap();
639 assert_eq!(compiled.rows, direct.rows);
640 assert_eq!(compiled.cols, direct.cols);
641 let a = compiled.values_f64.unwrap();
642 let b = direct.values_f64.unwrap();
643 assert_eq!(a.len(), b.len());
644 for i in 0..a.len() {
645 let x = a[i];
646 let y = b[i];
647 if x.is_nan() && y.is_nan() {
648 continue;
649 }
650 assert!((x - y).abs() <= 1e-12, "mismatch at index {i}: {x} vs {y}");
651 }
652 }
653
654 #[test]
655 fn run_compiled_dx_fast_plan_matches_dispatch() {
656 let (open, high, low, close) = sample_ohlc();
657 let params = [ParamKV {
658 key: "period",
659 value: ParamValue::Int(14),
660 }];
661 let call = compile_call("dx", Some("value"), ¶ms, false).unwrap();
662 let compiled = run_compiled_cpu(
663 &call,
664 IndicatorDataRef::Ohlc {
665 open: &open,
666 high: &high,
667 low: &low,
668 close: &close,
669 },
670 Kernel::Auto,
671 )
672 .unwrap();
673 let combos = [IndicatorParamSet { params: ¶ms }];
674 let direct = compute_cpu_batch(IndicatorBatchRequest {
675 indicator_id: "dx",
676 output_id: Some("value"),
677 data: IndicatorDataRef::Ohlc {
678 open: &open,
679 high: &high,
680 low: &low,
681 close: &close,
682 },
683 combos: &combos,
684 kernel: Kernel::Auto,
685 })
686 .unwrap();
687 assert_eq!(compiled.rows, direct.rows);
688 assert_eq!(compiled.cols, direct.cols);
689 let a = compiled.values_f64.unwrap();
690 let b = direct.values_f64.unwrap();
691 assert_eq!(a.len(), b.len());
692 for i in 0..a.len() {
693 let x = a[i];
694 let y = b[i];
695 if x.is_nan() && y.is_nan() {
696 continue;
697 }
698 assert!((x - y).abs() <= 1e-12, "mismatch at index {i}: {x} vs {y}");
699 }
700 }
701
702 #[cfg(feature = "cuda")]
703 #[test]
704 fn compile_prefer_cuda_rejects_non_cuda_indicator() {
705 let err = compile_call("adx", Some("value"), &[], true).unwrap_err();
706 match err {
707 IndicatorDispatchError::UnsupportedCapability {
708 indicator,
709 capability,
710 } => {
711 assert_eq!(indicator, "adx");
712 assert_eq!(capability, "cuda_batch");
713 }
714 other => panic!("expected UnsupportedCapability, got {other:?}"),
715 }
716 }
717}