Skip to main content

vector_ta/indicators/dispatch/
compiled.rs

1use super::{
2    compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorDataRef,
3    IndicatorDispatchError, IndicatorParamSet, ParamKV, ParamValue,
4};
5use crate::indicators::dx::{dx_batch_with_kernel, DxBatchRange};
6use crate::indicators::mfi::{mfi_batch_with_kernel, MfiBatchRange};
7use crate::indicators::moving_averages::sma::{sma_batch_with_kernel, SmaBatchRange};
8use crate::indicators::registry::get_indicator;
9use crate::utilities::data_loader::source_type;
10use crate::utilities::enums::Kernel;
11
12#[cfg(feature = "cuda")]
13use super::{
14    compute_cuda, CudaOutputTarget, IndicatorCudaDataRef, IndicatorCudaOutput, IndicatorCudaRequest,
15};
16
17#[derive(Debug, Clone, PartialEq)]
18enum OwnedParamValue {
19    Int(i64),
20    Float(f64),
21    Bool(bool),
22    EnumString(String),
23}
24
25#[derive(Debug, Clone, PartialEq)]
26struct OwnedParamKV {
27    key: String,
28    value: OwnedParamValue,
29}
30
31#[derive(Debug, Clone, PartialEq)]
32enum CpuCompiledPlan {
33    Generic,
34    SmaPeriod { period: usize },
35    MfiPeriod { period: usize },
36    DxPeriod { period: usize },
37}
38
39#[derive(Debug, Clone, PartialEq)]
40pub struct CompiledIndicatorCall {
41    indicator_id: String,
42    output_id: Option<String>,
43    params: Vec<OwnedParamKV>,
44    cpu_plan: CpuCompiledPlan,
45    prefer_cuda: bool,
46}
47
48impl CompiledIndicatorCall {
49    pub fn indicator_id(&self) -> &str {
50        &self.indicator_id
51    }
52
53    pub fn output_id(&self) -> Option<&str> {
54        self.output_id.as_deref()
55    }
56
57    pub fn prefer_cuda(&self) -> bool {
58        self.prefer_cuda
59    }
60
61    fn as_param_kv(&self) -> Vec<ParamKV<'_>> {
62        let mut out = Vec::with_capacity(self.params.len());
63        for p in &self.params {
64            let value = match &p.value {
65                OwnedParamValue::Int(v) => ParamValue::Int(*v),
66                OwnedParamValue::Float(v) => ParamValue::Float(*v),
67                OwnedParamValue::Bool(v) => ParamValue::Bool(*v),
68                OwnedParamValue::EnumString(v) => ParamValue::EnumString(v.as_str()),
69            };
70            out.push(ParamKV {
71                key: p.key.as_str(),
72                value,
73            });
74        }
75        out
76    }
77}
78
79fn parse_usize_param_value(
80    indicator: &str,
81    key: &str,
82    value: ParamValue<'_>,
83) -> Result<usize, IndicatorDispatchError> {
84    match value {
85        ParamValue::Int(v) => {
86            if v < 0 {
87                return Err(IndicatorDispatchError::InvalidParam {
88                    indicator: indicator.to_string(),
89                    key: key.to_string(),
90                    reason: "expected integer >= 0".to_string(),
91                });
92            }
93            Ok(v as usize)
94        }
95        ParamValue::Float(v) => {
96            if !v.is_finite() {
97                return Err(IndicatorDispatchError::InvalidParam {
98                    indicator: indicator.to_string(),
99                    key: key.to_string(),
100                    reason: "expected finite number".to_string(),
101                });
102            }
103            if v < 0.0 {
104                return Err(IndicatorDispatchError::InvalidParam {
105                    indicator: indicator.to_string(),
106                    key: key.to_string(),
107                    reason: "expected number >= 0".to_string(),
108                });
109            }
110            let rounded = v.round();
111            if (v - rounded).abs() > 1e-9 {
112                return Err(IndicatorDispatchError::InvalidParam {
113                    indicator: indicator.to_string(),
114                    key: key.to_string(),
115                    reason: "expected whole number".to_string(),
116                });
117            }
118            Ok(rounded as usize)
119        }
120        _ => Err(IndicatorDispatchError::InvalidParam {
121            indicator: indicator.to_string(),
122            key: key.to_string(),
123            reason: "expected Int or Float".to_string(),
124        }),
125    }
126}
127
128fn compile_period_only_plan(
129    indicator: &str,
130    selected_output: Option<&str>,
131    params: &[ParamKV<'_>],
132) -> Result<CpuCompiledPlan, IndicatorDispatchError> {
133    let supports_fast_route = indicator.eq_ignore_ascii_case("sma")
134        || indicator.eq_ignore_ascii_case("mfi")
135        || indicator.eq_ignore_ascii_case("dx");
136    if !supports_fast_route {
137        return Ok(CpuCompiledPlan::Generic);
138    }
139
140    let is_value = selected_output
141        .map(|out| out.eq_ignore_ascii_case("value"))
142        .unwrap_or(true);
143    if !is_value {
144        return Ok(CpuCompiledPlan::Generic);
145    }
146
147    let mut period: Option<usize> = None;
148    for p in params {
149        if p.key.eq_ignore_ascii_case("period") {
150            period = Some(parse_usize_param_value(indicator, "period", p.value)?);
151        } else {
152            return Ok(CpuCompiledPlan::Generic);
153        }
154    }
155    let period = period.unwrap_or(14);
156
157    if indicator.eq_ignore_ascii_case("sma") {
158        return Ok(CpuCompiledPlan::SmaPeriod { period });
159    }
160    if indicator.eq_ignore_ascii_case("mfi") {
161        return Ok(CpuCompiledPlan::MfiPeriod { period });
162    }
163    if indicator.eq_ignore_ascii_case("dx") {
164        return Ok(CpuCompiledPlan::DxPeriod { period });
165    }
166    Ok(CpuCompiledPlan::Generic)
167}
168
169pub fn compile_call(
170    indicator_id: &str,
171    output_id: Option<&str>,
172    params: &[ParamKV<'_>],
173    prefer_cuda: bool,
174) -> Result<CompiledIndicatorCall, IndicatorDispatchError> {
175    let info =
176        get_indicator(indicator_id).ok_or_else(|| IndicatorDispatchError::UnknownIndicator {
177            id: indicator_id.to_string(),
178        })?;
179
180    if info.outputs.len() > 1 && output_id.is_none() {
181        return Err(IndicatorDispatchError::InvalidParam {
182            indicator: info.id.to_string(),
183            key: "output_id".to_string(),
184            reason: "output_id is required for multi-output indicators".to_string(),
185        });
186    }
187
188    if let Some(out_id) = output_id {
189        let exists = info
190            .outputs
191            .iter()
192            .any(|out| out.id.eq_ignore_ascii_case(out_id));
193        if !exists {
194            return Err(IndicatorDispatchError::UnknownOutput {
195                indicator: info.id.to_string(),
196                output: out_id.to_string(),
197            });
198        }
199    }
200
201    if prefer_cuda && !info.capabilities.supports_cuda_batch {
202        return Err(IndicatorDispatchError::UnsupportedCapability {
203            indicator: info.id.to_string(),
204            capability: "cuda_batch",
205        });
206    }
207
208    let mut owned_params = Vec::with_capacity(params.len());
209    for param in params {
210        let value = match param.value {
211            ParamValue::Int(v) => OwnedParamValue::Int(v),
212            ParamValue::Float(v) => OwnedParamValue::Float(v),
213            ParamValue::Bool(v) => OwnedParamValue::Bool(v),
214            ParamValue::EnumString(v) => OwnedParamValue::EnumString(v.to_string()),
215        };
216        owned_params.push(OwnedParamKV {
217            key: param.key.to_string(),
218            value,
219        });
220    }
221
222    let selected_output = output_id.or_else(|| {
223        if info.outputs.len() == 1 {
224            Some(info.outputs[0].id)
225        } else {
226            None
227        }
228    });
229    let cpu_plan = compile_period_only_plan(info.id, selected_output, params)?;
230
231    Ok(CompiledIndicatorCall {
232        indicator_id: info.id.to_string(),
233        output_id: output_id.map(str::to_string),
234        params: owned_params,
235        cpu_plan,
236        prefer_cuda,
237    })
238}
239
240pub fn run_compiled_cpu(
241    call: &CompiledIndicatorCall,
242    data: IndicatorDataRef<'_>,
243    kernel: Kernel,
244) -> Result<IndicatorBatchOutput, IndicatorDispatchError> {
245    match call.cpu_plan {
246        CpuCompiledPlan::SmaPeriod { period } => {
247            let series = match data {
248                IndicatorDataRef::Slice { values } => values,
249                IndicatorDataRef::Candles { candles, source } => {
250                    source_type(candles, source.unwrap_or("close"))
251                }
252                IndicatorDataRef::Ohlc { close, .. } => close,
253                IndicatorDataRef::Ohlcv { close, .. } => close,
254                IndicatorDataRef::CloseVolume { close, .. } => close,
255                IndicatorDataRef::HighLow { .. } => {
256                    return Err(IndicatorDispatchError::MissingRequiredInput {
257                        indicator: "sma".to_string(),
258                        input: crate::indicators::registry::IndicatorInputKind::Slice,
259                    });
260                }
261            };
262            let out = sma_batch_with_kernel(
263                series,
264                &SmaBatchRange {
265                    period: (period, period, 0),
266                },
267                to_batch_kernel(kernel),
268            )
269            .map_err(|e| IndicatorDispatchError::ComputeFailed {
270                indicator: "sma".to_string(),
271                details: e.to_string(),
272            })?;
273            return Ok(f64_output(
274                call.output_id.as_deref().unwrap_or("value"),
275                out.rows,
276                out.cols,
277                out.values,
278            ));
279        }
280        CpuCompiledPlan::MfiPeriod { period } => {
281            let mut derived_typical_price: Option<Vec<f64>> = None;
282            let (typical_price, volume): (&[f64], &[f64]) = match data {
283                IndicatorDataRef::Candles { candles, source } => (
284                    source_type(candles, source.unwrap_or("hlc3")),
285                    candles.volume.as_slice(),
286                ),
287                IndicatorDataRef::Ohlcv {
288                    open,
289                    high,
290                    low,
291                    close,
292                    volume,
293                } => {
294                    ensure_same_len_5(
295                        "mfi",
296                        open.len(),
297                        high.len(),
298                        low.len(),
299                        close.len(),
300                        volume.len(),
301                    )?;
302                    derived_typical_price = Some(
303                        high.iter()
304                            .zip(low)
305                            .zip(close)
306                            .map(|((h, l), c)| (h + l + c) / 3.0)
307                            .collect(),
308                    );
309                    (derived_typical_price.as_deref().unwrap_or(close), volume)
310                }
311                IndicatorDataRef::CloseVolume { close, volume } => {
312                    ensure_same_len_2("mfi", close.len(), volume.len())?;
313                    (close, volume)
314                }
315                _ => {
316                    return Err(IndicatorDispatchError::MissingRequiredInput {
317                        indicator: "mfi".to_string(),
318                        input: crate::indicators::registry::IndicatorInputKind::CloseVolume,
319                    });
320                }
321            };
322            let out = mfi_batch_with_kernel(
323                typical_price,
324                volume,
325                &MfiBatchRange {
326                    period: (period, period, 0),
327                },
328                to_batch_kernel(kernel),
329            )
330            .map_err(|e| IndicatorDispatchError::ComputeFailed {
331                indicator: "mfi".to_string(),
332                details: e.to_string(),
333            })?;
334            return Ok(f64_output(
335                call.output_id.as_deref().unwrap_or("value"),
336                out.rows,
337                out.cols,
338                out.values,
339            ));
340        }
341        CpuCompiledPlan::DxPeriod { period } => {
342            let (high, low, close): (&[f64], &[f64], &[f64]) = match data {
343                IndicatorDataRef::Candles { candles, .. } => (
344                    candles.high.as_slice(),
345                    candles.low.as_slice(),
346                    candles.close.as_slice(),
347                ),
348                IndicatorDataRef::Ohlc {
349                    open,
350                    high,
351                    low,
352                    close,
353                } => {
354                    ensure_same_len_4("dx", open.len(), high.len(), low.len(), close.len())?;
355                    (high, low, close)
356                }
357                IndicatorDataRef::Ohlcv {
358                    open,
359                    high,
360                    low,
361                    close,
362                    volume,
363                } => {
364                    ensure_same_len_5(
365                        "dx",
366                        open.len(),
367                        high.len(),
368                        low.len(),
369                        close.len(),
370                        volume.len(),
371                    )?;
372                    (high, low, close)
373                }
374                _ => {
375                    return Err(IndicatorDispatchError::MissingRequiredInput {
376                        indicator: "dx".to_string(),
377                        input: crate::indicators::registry::IndicatorInputKind::Ohlc,
378                    });
379                }
380            };
381            let out = dx_batch_with_kernel(
382                high,
383                low,
384                close,
385                &DxBatchRange {
386                    period: (period, period, 0),
387                },
388                to_batch_kernel(kernel),
389            )
390            .map_err(|e| IndicatorDispatchError::ComputeFailed {
391                indicator: "dx".to_string(),
392                details: e.to_string(),
393            })?;
394            return Ok(f64_output(
395                call.output_id.as_deref().unwrap_or("value"),
396                out.rows,
397                out.cols,
398                out.values,
399            ));
400        }
401        CpuCompiledPlan::Generic => {}
402    }
403
404    let params = call.as_param_kv();
405    let combos = [IndicatorParamSet {
406        params: params.as_slice(),
407    }];
408    compute_cpu_batch(IndicatorBatchRequest {
409        indicator_id: call.indicator_id.as_str(),
410        output_id: call.output_id.as_deref(),
411        data,
412        combos: &combos,
413        kernel,
414    })
415}
416
417fn to_batch_kernel(kernel: Kernel) -> Kernel {
418    match kernel {
419        Kernel::Auto => Kernel::Auto,
420        Kernel::Scalar => Kernel::ScalarBatch,
421        Kernel::Avx2 => Kernel::Avx2Batch,
422        Kernel::Avx512 => Kernel::Avx512Batch,
423        other => other,
424    }
425}
426
427fn ensure_same_len_2(indicator: &str, a: usize, b: usize) -> Result<(), IndicatorDispatchError> {
428    if a == b {
429        return Ok(());
430    }
431    Err(IndicatorDispatchError::DataLengthMismatch {
432        details: format!("{indicator}: expected equal lengths, got {a} and {b}"),
433    })
434}
435
436fn ensure_same_len_4(
437    indicator: &str,
438    a: usize,
439    b: usize,
440    c: usize,
441    d: usize,
442) -> Result<(), IndicatorDispatchError> {
443    if a == b && b == c && c == d {
444        return Ok(());
445    }
446    Err(IndicatorDispatchError::DataLengthMismatch {
447        details: format!("{indicator}: expected equal lengths, got {a}, {b}, {c}, {d}"),
448    })
449}
450
451fn ensure_same_len_5(
452    indicator: &str,
453    a: usize,
454    b: usize,
455    c: usize,
456    d: usize,
457    e: usize,
458) -> Result<(), IndicatorDispatchError> {
459    if a == b && b == c && c == d && d == e {
460        return Ok(());
461    }
462    Err(IndicatorDispatchError::DataLengthMismatch {
463        details: format!("{indicator}: expected equal lengths, got {a}, {b}, {c}, {d}, {e}"),
464    })
465}
466
467fn f64_output(output_id: &str, rows: usize, cols: usize, values: Vec<f64>) -> IndicatorBatchOutput {
468    IndicatorBatchOutput {
469        output_id: output_id.to_string(),
470        rows,
471        cols,
472        values_f64: Some(values),
473        values_i32: None,
474        values_bool: None,
475    }
476}
477
478#[cfg(feature = "cuda")]
479pub fn run_compiled_cuda(
480    call: &CompiledIndicatorCall,
481    data: IndicatorCudaDataRef<'_>,
482    kernel: Kernel,
483    target: CudaOutputTarget,
484) -> Result<IndicatorCudaOutput, IndicatorDispatchError> {
485    let params = call.as_param_kv();
486    compute_cuda(IndicatorCudaRequest {
487        indicator_id: call.indicator_id.as_str(),
488        output_id: call.output_id.as_deref(),
489        data,
490        params: params.as_slice(),
491        kernel,
492        target,
493    })
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use crate::indicators::dispatch::{compute_cpu_batch, IndicatorBatchRequest};
500
501    fn sample_series() -> Vec<f64> {
502        (1..=128).map(|v| v as f64).collect()
503    }
504
505    fn sample_ohlc() -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
506        let open: Vec<f64> = (0..128).map(|i| 100.0 + (i as f64 * 0.1)).collect();
507        let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
508        let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
509        let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
510        (open, high, low, close)
511    }
512
513    #[test]
514    fn compile_rejects_unknown_indicator() {
515        let err = compile_call("does_not_exist", Some("value"), &[], false).unwrap_err();
516        match err {
517            IndicatorDispatchError::UnknownIndicator { id } => assert_eq!(id, "does_not_exist"),
518            other => panic!("expected UnknownIndicator, got {other:?}"),
519        }
520    }
521
522    #[test]
523    fn compile_validates_output_id() {
524        let err = compile_call("sma", Some("hist"), &[], false).unwrap_err();
525        match err {
526            IndicatorDispatchError::UnknownOutput { indicator, output } => {
527                assert_eq!(indicator, "sma");
528                assert_eq!(output, "hist");
529            }
530            other => panic!("expected UnknownOutput, got {other:?}"),
531        }
532    }
533
534    #[test]
535    fn run_compiled_cpu_matches_direct_dispatch() {
536        let data = sample_series();
537        let params = [ParamKV {
538            key: "period",
539            value: ParamValue::Int(14),
540        }];
541        let call = compile_call("sma", Some("value"), &params, false).unwrap();
542        let compiled = run_compiled_cpu(
543            &call,
544            IndicatorDataRef::Slice { values: &data },
545            Kernel::Auto,
546        )
547        .unwrap();
548
549        let combos = [IndicatorParamSet { params: &params }];
550        let direct = compute_cpu_batch(IndicatorBatchRequest {
551            indicator_id: "sma",
552            output_id: Some("value"),
553            data: IndicatorDataRef::Slice { values: &data },
554            combos: &combos,
555            kernel: Kernel::Auto,
556        })
557        .unwrap();
558        assert_eq!(compiled.output_id, direct.output_id);
559        assert_eq!(compiled.rows, direct.rows);
560        assert_eq!(compiled.cols, direct.cols);
561        let compiled_values = compiled.values_f64.unwrap();
562        let direct_values = direct.values_f64.unwrap();
563        assert_eq!(compiled_values.len(), direct_values.len());
564        for i in 0..compiled_values.len() {
565            let a = compiled_values[i];
566            let b = direct_values[i];
567            if a.is_nan() && b.is_nan() {
568                continue;
569            }
570            assert!((a - b).abs() <= 1e-12, "mismatch at index {i}: {a} vs {b}");
571        }
572    }
573
574    #[test]
575    fn compile_pre_resolves_sma_period_plan() {
576        let params = [ParamKV {
577            key: "period",
578            value: ParamValue::Int(9),
579        }];
580        let call = compile_call("sma", Some("value"), &params, false).unwrap();
581        assert!(matches!(
582            call.cpu_plan,
583            CpuCompiledPlan::SmaPeriod { period: 9 }
584        ));
585    }
586
587    #[test]
588    fn compile_falls_back_to_generic_when_params_are_not_period_only() {
589        let params = [
590            ParamKV {
591                key: "period",
592                value: ParamValue::Int(9),
593            },
594            ParamKV {
595                key: "unused",
596                value: ParamValue::Float(1.0),
597            },
598        ];
599        let call = compile_call("mfi", Some("value"), &params, false).unwrap();
600        assert!(matches!(call.cpu_plan, CpuCompiledPlan::Generic));
601    }
602
603    #[test]
604    fn run_compiled_mfi_fast_plan_matches_dispatch() {
605        let (open, high, low, close) = sample_ohlc();
606        let volume: Vec<f64> = (0..close.len()).map(|i| 1000.0 + (i as f64)).collect();
607        let params = [ParamKV {
608            key: "period",
609            value: ParamValue::Int(14),
610        }];
611        let call = compile_call("mfi", Some("value"), &params, false).unwrap();
612        let compiled = run_compiled_cpu(
613            &call,
614            IndicatorDataRef::Ohlcv {
615                open: &open,
616                high: &high,
617                low: &low,
618                close: &close,
619                volume: &volume,
620            },
621            Kernel::Auto,
622        )
623        .unwrap();
624        let combos = [IndicatorParamSet { params: &params }];
625        let direct = compute_cpu_batch(IndicatorBatchRequest {
626            indicator_id: "mfi",
627            output_id: Some("value"),
628            data: IndicatorDataRef::Ohlcv {
629                open: &open,
630                high: &high,
631                low: &low,
632                close: &close,
633                volume: &volume,
634            },
635            combos: &combos,
636            kernel: Kernel::Auto,
637        })
638        .unwrap();
639        assert_eq!(compiled.rows, direct.rows);
640        assert_eq!(compiled.cols, direct.cols);
641        let a = compiled.values_f64.unwrap();
642        let b = direct.values_f64.unwrap();
643        assert_eq!(a.len(), b.len());
644        for i in 0..a.len() {
645            let x = a[i];
646            let y = b[i];
647            if x.is_nan() && y.is_nan() {
648                continue;
649            }
650            assert!((x - y).abs() <= 1e-12, "mismatch at index {i}: {x} vs {y}");
651        }
652    }
653
654    #[test]
655    fn run_compiled_dx_fast_plan_matches_dispatch() {
656        let (open, high, low, close) = sample_ohlc();
657        let params = [ParamKV {
658            key: "period",
659            value: ParamValue::Int(14),
660        }];
661        let call = compile_call("dx", Some("value"), &params, false).unwrap();
662        let compiled = run_compiled_cpu(
663            &call,
664            IndicatorDataRef::Ohlc {
665                open: &open,
666                high: &high,
667                low: &low,
668                close: &close,
669            },
670            Kernel::Auto,
671        )
672        .unwrap();
673        let combos = [IndicatorParamSet { params: &params }];
674        let direct = compute_cpu_batch(IndicatorBatchRequest {
675            indicator_id: "dx",
676            output_id: Some("value"),
677            data: IndicatorDataRef::Ohlc {
678                open: &open,
679                high: &high,
680                low: &low,
681                close: &close,
682            },
683            combos: &combos,
684            kernel: Kernel::Auto,
685        })
686        .unwrap();
687        assert_eq!(compiled.rows, direct.rows);
688        assert_eq!(compiled.cols, direct.cols);
689        let a = compiled.values_f64.unwrap();
690        let b = direct.values_f64.unwrap();
691        assert_eq!(a.len(), b.len());
692        for i in 0..a.len() {
693            let x = a[i];
694            let y = b[i];
695            if x.is_nan() && y.is_nan() {
696                continue;
697            }
698            assert!((x - y).abs() <= 1e-12, "mismatch at index {i}: {x} vs {y}");
699        }
700    }
701
702    #[cfg(feature = "cuda")]
703    #[test]
704    fn compile_prefer_cuda_rejects_non_cuda_indicator() {
705        let err = compile_call("adx", Some("value"), &[], true).unwrap_err();
706        match err {
707            IndicatorDispatchError::UnsupportedCapability {
708                indicator,
709                capability,
710            } => {
711                assert_eq!(indicator, "adx");
712                assert_eq!(capability, "cuda_batch");
713            }
714            other => panic!("expected UnsupportedCapability, got {other:?}"),
715        }
716    }
717}