Skip to main content

vector_ta/indicators/dispatch/
cpu_single.rs

1use super::{
2    compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3    IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4    IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7    pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8    PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13    req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15    let info = get_indicator(req.indicator_id).ok_or_else(|| {
16        IndicatorDispatchError::UnknownIndicator {
17            id: req.indicator_id.to_string(),
18        }
19    })?;
20
21    if info.id.eq_ignore_ascii_case("pattern_recognition") {
22        return compute_pattern_recognition(req, info);
23    }
24
25    if !info.capabilities.supports_cpu_single {
26        return Err(IndicatorDispatchError::UnsupportedCapability {
27            indicator: info.id.to_string(),
28            capability: "cpu_single",
29        });
30    }
31
32    if !info.capabilities.supports_cpu_batch {
33        return Err(IndicatorDispatchError::UnsupportedCapability {
34            indicator: info.id.to_string(),
35            capability: "cpu_single",
36        });
37    }
38
39    let combos = [IndicatorParamSet { params: req.params }];
40    let out = compute_cpu_batch(IndicatorBatchRequest {
41        indicator_id: info.id,
42        output_id: req.output_id,
43        data: req.data,
44        combos: &combos,
45        kernel: req.kernel,
46    })?;
47    map_batch_output_to_compute(info.id, out)
48}
49
50fn compute_pattern_recognition(
51    req: IndicatorComputeRequest<'_>,
52    info: &IndicatorInfo,
53) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
54    if !info.capabilities.supports_cpu_single {
55        return Err(IndicatorDispatchError::UnsupportedCapability {
56            indicator: info.id.to_string(),
57            capability: "cpu_single",
58        });
59    }
60
61    if let Some(param) = req.params.first() {
62        return Err(IndicatorDispatchError::InvalidParam {
63            indicator: info.id.to_string(),
64            key: param.key.to_string(),
65            reason: "pattern_recognition does not accept parameters".to_string(),
66        });
67    }
68
69    let output_id = resolve_output_id(info, req.output_id)?;
70    let input = match req.data {
71        IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
72            candles,
73            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
74        ),
75        IndicatorDataRef::Ohlc {
76            open,
77            high,
78            low,
79            close,
80        } => PatternRecognitionInput::from_slices(
81            open,
82            high,
83            low,
84            close,
85            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
86        ),
87        IndicatorDataRef::Ohlcv {
88            open,
89            high,
90            low,
91            close,
92            ..
93        } => PatternRecognitionInput::from_slices(
94            open,
95            high,
96            low,
97            close,
98            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
99        ),
100        _ => {
101            return Err(IndicatorDispatchError::MissingRequiredInput {
102                indicator: info.id.to_string(),
103                input: IndicatorInputKind::Ohlc,
104            });
105        }
106    };
107
108    let out = pattern_recognition_with_kernel(&input, req.kernel)
109        .map_err(|e| map_pattern_error(info.id, e))?;
110
111    Ok(IndicatorComputeOutput {
112        output_id: output_id.to_string(),
113        series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
114        warmup: out.warmup,
115        rows: out.rows,
116        cols: out.cols,
117        pattern_ids: Some(out.pattern_ids.into_iter().map(|id| id.to_string()).collect()),
118    })
119}
120
121fn resolve_output_id<'a>(
122    info: &'a IndicatorInfo,
123    requested: Option<&str>,
124) -> Result<&'a str, IndicatorDispatchError> {
125    if info.outputs.is_empty() {
126        return Err(IndicatorDispatchError::ComputeFailed {
127            indicator: info.id.to_string(),
128            details: "indicator has no registered outputs".to_string(),
129        });
130    }
131
132    if info.outputs.len() == 1 {
133        let only = info.outputs[0].id;
134        if let Some(req) = requested {
135            if !req.eq_ignore_ascii_case(only) {
136                return Err(IndicatorDispatchError::UnknownOutput {
137                    indicator: info.id.to_string(),
138                    output: req.to_string(),
139                });
140            }
141        }
142        return Ok(only);
143    }
144
145    let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
146        indicator: info.id.to_string(),
147        key: "output_id".to_string(),
148        reason: "output_id is required for multi-output indicators".to_string(),
149    })?;
150
151    info.outputs
152        .iter()
153        .find(|o| o.id.eq_ignore_ascii_case(req))
154        .map(|o| o.id)
155        .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
156            indicator: info.id.to_string(),
157            output: req.to_string(),
158        })
159}
160
161fn map_batch_output_to_compute(
162    indicator: &str,
163    out: IndicatorBatchOutput,
164) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
165    let series = if let Some(values) = out.values_f64 {
166        IndicatorSeries::F64(values)
167    } else if let Some(values) = out.values_i32 {
168        IndicatorSeries::I32(values)
169    } else if let Some(values) = out.values_bool {
170        IndicatorSeries::Bool(values)
171    } else {
172        return Err(IndicatorDispatchError::ComputeFailed {
173            indicator: indicator.to_string(),
174            details: "dispatcher returned no output series".to_string(),
175        });
176    };
177
178    Ok(IndicatorComputeOutput {
179        output_id: out.output_id,
180        series,
181        warmup: None,
182        rows: out.rows,
183        cols: out.cols,
184        pattern_ids: None,
185    })
186}
187
188fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
189    match err {
190        PatternRecognitionError::DataLengthMismatch {
191            open,
192            high,
193            low,
194            close,
195        } => IndicatorDispatchError::DataLengthMismatch {
196            details: format!("open={} high={} low={} close={}", open, high, low, close),
197        },
198        PatternRecognitionError::OutputLengthMismatch {
199            pattern_id,
200            expected,
201            got,
202        } => IndicatorDispatchError::ComputeFailed {
203            indicator: indicator.to_string(),
204            details: format!(
205                "pattern output mismatch for {}: expected {}, got {}",
206                pattern_id, expected, got
207            ),
208        },
209        PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
210            indicator: indicator.to_string(),
211            details: e.to_string(),
212        },
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
220    use crate::indicators::pattern_recognition::list_patterns;
221    use crate::utilities::enums::Kernel;
222
223    fn sample_series(len: usize) -> Vec<f64> {
224        (0..len)
225            .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
226            .collect()
227    }
228
229    fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
230        let open = sample_series(len);
231        let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
232        let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
233        let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
234        (open, high, low, close)
235    }
236
237    #[test]
238    fn compute_cpu_pattern_recognition_returns_matrix() {
239        let (open, high, low, close) = sample_ohlc(192);
240        let req = IndicatorComputeRequest {
241            indicator_id: "pattern_recognition",
242            output_id: Some("matrix"),
243            data: IndicatorDataRef::Ohlc {
244                open: &open,
245                high: &high,
246                low: &low,
247                close: &close,
248            },
249            params: &[],
250            kernel: Kernel::Auto,
251        };
252        let out = compute_cpu(req).unwrap();
253        assert_eq!(out.output_id, "matrix");
254        assert_eq!(out.rows, list_patterns().len());
255        assert_eq!(out.cols, close.len());
256        match out.series {
257            IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
258            other => panic!("expected Bool matrix series, got {:?}", other),
259        }
260        let ids = out.pattern_ids.unwrap();
261        assert_eq!(ids.len(), out.rows);
262    }
263
264    #[test]
265    fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
266        let series = sample_series(64);
267        let req = IndicatorComputeRequest {
268            indicator_id: "pattern_recognition",
269            output_id: Some("matrix"),
270            data: IndicatorDataRef::Slice { values: &series },
271            params: &[],
272            kernel: Kernel::Auto,
273        };
274        let err = compute_cpu(req).unwrap_err();
275        match err {
276            IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
277                assert_eq!(indicator, "pattern_recognition");
278                assert_eq!(input, IndicatorInputKind::Ohlc);
279            }
280            other => panic!("expected MissingRequiredInput, got {:?}", other),
281        }
282    }
283
284    #[test]
285    fn compute_cpu_pattern_recognition_rejects_unknown_output() {
286        let (open, high, low, close) = sample_ohlc(64);
287        let req = IndicatorComputeRequest {
288            indicator_id: "pattern_recognition",
289            output_id: Some("value"),
290            data: IndicatorDataRef::Ohlc {
291                open: &open,
292                high: &high,
293                low: &low,
294                close: &close,
295            },
296            params: &[],
297            kernel: Kernel::Auto,
298        };
299        let err = compute_cpu(req).unwrap_err();
300        match err {
301            IndicatorDispatchError::UnknownOutput { indicator, output } => {
302                assert_eq!(indicator, "pattern_recognition");
303                assert_eq!(output, "value");
304            }
305            other => panic!("expected UnknownOutput, got {:?}", other),
306        }
307    }
308
309    #[test]
310    fn compute_cpu_pattern_recognition_rejects_params() {
311        let (open, high, low, close) = sample_ohlc(64);
312        let params = [ParamKV {
313            key: "period",
314            value: ParamValue::Int(14),
315        }];
316        let req = IndicatorComputeRequest {
317            indicator_id: "pattern_recognition",
318            output_id: Some("matrix"),
319            data: IndicatorDataRef::Ohlc {
320                open: &open,
321                high: &high,
322                low: &low,
323                close: &close,
324            },
325            params: &params,
326            kernel: Kernel::Auto,
327        };
328        let err = compute_cpu(req).unwrap_err();
329        match err {
330            IndicatorDispatchError::InvalidParam {
331                indicator,
332                key,
333                reason,
334            } => {
335                assert_eq!(indicator, "pattern_recognition");
336                assert_eq!(key, "period");
337                assert!(reason.contains("does not accept parameters"));
338            }
339            other => panic!("expected InvalidParam, got {:?}", other),
340        }
341    }
342
343    #[test]
344    fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
345        let (open, high, low, close) = sample_ohlc(96);
346        let combos = [IndicatorParamSet { params: &[] }];
347        let err = compute_cpu_batch(IndicatorBatchRequest {
348            indicator_id: "pattern_recognition",
349            output_id: Some("matrix"),
350            data: IndicatorDataRef::Ohlc {
351                open: &open,
352                high: &high,
353                low: &low,
354                close: &close,
355            },
356            combos: &combos,
357            kernel: Kernel::Auto,
358        })
359        .unwrap_err();
360        match err {
361            IndicatorDispatchError::UnsupportedCapability {
362                indicator,
363                capability,
364            } => {
365                assert_eq!(indicator, "pattern_recognition");
366                assert_eq!(capability, "cpu_batch");
367            }
368            other => panic!("expected UnsupportedCapability, got {:?}", other),
369        }
370    }
371
372    #[test]
373    fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
374        let series = sample_series(200);
375        let params = [ParamKV {
376            key: "period",
377            value: ParamValue::Int(14),
378        }];
379        let req = IndicatorComputeRequest {
380            indicator_id: "sma",
381            output_id: Some("value"),
382            data: IndicatorDataRef::Slice { values: &series },
383            params: &params,
384            kernel: Kernel::Auto,
385        };
386        let out = compute_cpu(req).unwrap();
387        assert_eq!(out.output_id, "value");
388        assert_eq!(out.rows, 1);
389        assert_eq!(out.cols, series.len());
390        match out.series {
391            IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
392            other => panic!("expected F64 series, got {:?}", other),
393        }
394    }
395}