Skip to main content

vector_ta/indicators/dispatch/
cpu_single.rs

1use super::{
2    compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3    IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4    IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7    pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8    PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13    req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15    let info = get_indicator(req.indicator_id).ok_or_else(|| {
16        IndicatorDispatchError::UnknownIndicator {
17            id: req.indicator_id.to_string(),
18        }
19    })?;
20
21    if info.id.eq_ignore_ascii_case("pattern_recognition") {
22        return compute_pattern_recognition(req, info);
23    }
24
25    if !info.capabilities.supports_cpu_single {
26        return Err(IndicatorDispatchError::UnsupportedCapability {
27            indicator: info.id.to_string(),
28            capability: "cpu_single",
29        });
30    }
31
32    if !info.capabilities.supports_cpu_batch {
33        return Err(IndicatorDispatchError::UnsupportedCapability {
34            indicator: info.id.to_string(),
35            capability: "cpu_single",
36        });
37    }
38
39    let combos = [IndicatorParamSet { params: req.params }];
40    let out = compute_cpu_batch(IndicatorBatchRequest {
41        indicator_id: info.id,
42        output_id: req.output_id,
43        data: req.data,
44        combos: &combos,
45        kernel: req.kernel,
46    })?;
47    map_batch_output_to_compute(info.id, out)
48}
49
50fn compute_pattern_recognition(
51    req: IndicatorComputeRequest<'_>,
52    info: &IndicatorInfo,
53) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
54    if !info.capabilities.supports_cpu_single {
55        return Err(IndicatorDispatchError::UnsupportedCapability {
56            indicator: info.id.to_string(),
57            capability: "cpu_single",
58        });
59    }
60
61    if let Some(param) = req.params.first() {
62        return Err(IndicatorDispatchError::InvalidParam {
63            indicator: info.id.to_string(),
64            key: param.key.to_string(),
65            reason: "pattern_recognition does not accept parameters".to_string(),
66        });
67    }
68
69    let output_id = resolve_output_id(info, req.output_id)?;
70    let input = match req.data {
71        IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
72            candles,
73            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
74        ),
75        IndicatorDataRef::Ohlc {
76            open,
77            high,
78            low,
79            close,
80        } => PatternRecognitionInput::from_slices(
81            open,
82            high,
83            low,
84            close,
85            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
86        ),
87        IndicatorDataRef::Ohlcv {
88            open,
89            high,
90            low,
91            close,
92            ..
93        } => PatternRecognitionInput::from_slices(
94            open,
95            high,
96            low,
97            close,
98            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
99        ),
100        _ => {
101            return Err(IndicatorDispatchError::MissingRequiredInput {
102                indicator: info.id.to_string(),
103                input: IndicatorInputKind::Ohlc,
104            });
105        }
106    };
107
108    let out = pattern_recognition_with_kernel(&input, req.kernel)
109        .map_err(|e| map_pattern_error(info.id, e))?;
110
111    Ok(IndicatorComputeOutput {
112        output_id: output_id.to_string(),
113        series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
114        warmup: out.warmup,
115        rows: out.rows,
116        cols: out.cols,
117        pattern_ids: Some(
118            out.pattern_ids
119                .into_iter()
120                .map(|id| id.to_string())
121                .collect(),
122        ),
123    })
124}
125
126fn resolve_output_id<'a>(
127    info: &'a IndicatorInfo,
128    requested: Option<&str>,
129) -> Result<&'a str, IndicatorDispatchError> {
130    if info.outputs.is_empty() {
131        return Err(IndicatorDispatchError::ComputeFailed {
132            indicator: info.id.to_string(),
133            details: "indicator has no registered outputs".to_string(),
134        });
135    }
136
137    if info.outputs.len() == 1 {
138        let only = info.outputs[0].id;
139        if let Some(req) = requested {
140            if !req.eq_ignore_ascii_case(only) {
141                return Err(IndicatorDispatchError::UnknownOutput {
142                    indicator: info.id.to_string(),
143                    output: req.to_string(),
144                });
145            }
146        }
147        return Ok(only);
148    }
149
150    let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
151        indicator: info.id.to_string(),
152        key: "output_id".to_string(),
153        reason: "output_id is required for multi-output indicators".to_string(),
154    })?;
155
156    info.outputs
157        .iter()
158        .find(|o| o.id.eq_ignore_ascii_case(req))
159        .map(|o| o.id)
160        .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
161            indicator: info.id.to_string(),
162            output: req.to_string(),
163        })
164}
165
166fn map_batch_output_to_compute(
167    indicator: &str,
168    out: IndicatorBatchOutput,
169) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
170    let series = if let Some(values) = out.values_f64 {
171        IndicatorSeries::F64(values)
172    } else if let Some(values) = out.values_i32 {
173        IndicatorSeries::I32(values)
174    } else if let Some(values) = out.values_bool {
175        IndicatorSeries::Bool(values)
176    } else {
177        return Err(IndicatorDispatchError::ComputeFailed {
178            indicator: indicator.to_string(),
179            details: "dispatcher returned no output series".to_string(),
180        });
181    };
182
183    Ok(IndicatorComputeOutput {
184        output_id: out.output_id,
185        series,
186        warmup: None,
187        rows: out.rows,
188        cols: out.cols,
189        pattern_ids: None,
190    })
191}
192
193fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
194    match err {
195        PatternRecognitionError::DataLengthMismatch {
196            open,
197            high,
198            low,
199            close,
200        } => IndicatorDispatchError::DataLengthMismatch {
201            details: format!("open={} high={} low={} close={}", open, high, low, close),
202        },
203        PatternRecognitionError::OutputLengthMismatch {
204            pattern_id,
205            expected,
206            got,
207        } => IndicatorDispatchError::ComputeFailed {
208            indicator: indicator.to_string(),
209            details: format!(
210                "pattern output mismatch for {}: expected {}, got {}",
211                pattern_id, expected, got
212            ),
213        },
214        PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
215            indicator: indicator.to_string(),
216            details: e.to_string(),
217        },
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
225    use crate::indicators::pattern_recognition::list_patterns;
226    use crate::utilities::enums::Kernel;
227
228    fn sample_series(len: usize) -> Vec<f64> {
229        (0..len)
230            .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
231            .collect()
232    }
233
234    fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
235        let open = sample_series(len);
236        let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
237        let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
238        let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
239        (open, high, low, close)
240    }
241
242    #[test]
243    fn compute_cpu_pattern_recognition_returns_matrix() {
244        let (open, high, low, close) = sample_ohlc(192);
245        let req = IndicatorComputeRequest {
246            indicator_id: "pattern_recognition",
247            output_id: Some("matrix"),
248            data: IndicatorDataRef::Ohlc {
249                open: &open,
250                high: &high,
251                low: &low,
252                close: &close,
253            },
254            params: &[],
255            kernel: Kernel::Auto,
256        };
257        let out = compute_cpu(req).unwrap();
258        assert_eq!(out.output_id, "matrix");
259        assert_eq!(out.rows, list_patterns().len());
260        assert_eq!(out.cols, close.len());
261        match out.series {
262            IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
263            other => panic!("expected Bool matrix series, got {:?}", other),
264        }
265        let ids = out.pattern_ids.unwrap();
266        assert_eq!(ids.len(), out.rows);
267    }
268
269    #[test]
270    fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
271        let series = sample_series(64);
272        let req = IndicatorComputeRequest {
273            indicator_id: "pattern_recognition",
274            output_id: Some("matrix"),
275            data: IndicatorDataRef::Slice { values: &series },
276            params: &[],
277            kernel: Kernel::Auto,
278        };
279        let err = compute_cpu(req).unwrap_err();
280        match err {
281            IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
282                assert_eq!(indicator, "pattern_recognition");
283                assert_eq!(input, IndicatorInputKind::Ohlc);
284            }
285            other => panic!("expected MissingRequiredInput, got {:?}", other),
286        }
287    }
288
289    #[test]
290    fn compute_cpu_pattern_recognition_rejects_unknown_output() {
291        let (open, high, low, close) = sample_ohlc(64);
292        let req = IndicatorComputeRequest {
293            indicator_id: "pattern_recognition",
294            output_id: Some("value"),
295            data: IndicatorDataRef::Ohlc {
296                open: &open,
297                high: &high,
298                low: &low,
299                close: &close,
300            },
301            params: &[],
302            kernel: Kernel::Auto,
303        };
304        let err = compute_cpu(req).unwrap_err();
305        match err {
306            IndicatorDispatchError::UnknownOutput { indicator, output } => {
307                assert_eq!(indicator, "pattern_recognition");
308                assert_eq!(output, "value");
309            }
310            other => panic!("expected UnknownOutput, got {:?}", other),
311        }
312    }
313
314    #[test]
315    fn compute_cpu_pattern_recognition_rejects_params() {
316        let (open, high, low, close) = sample_ohlc(64);
317        let params = [ParamKV {
318            key: "period",
319            value: ParamValue::Int(14),
320        }];
321        let req = IndicatorComputeRequest {
322            indicator_id: "pattern_recognition",
323            output_id: Some("matrix"),
324            data: IndicatorDataRef::Ohlc {
325                open: &open,
326                high: &high,
327                low: &low,
328                close: &close,
329            },
330            params: &params,
331            kernel: Kernel::Auto,
332        };
333        let err = compute_cpu(req).unwrap_err();
334        match err {
335            IndicatorDispatchError::InvalidParam {
336                indicator,
337                key,
338                reason,
339            } => {
340                assert_eq!(indicator, "pattern_recognition");
341                assert_eq!(key, "period");
342                assert!(reason.contains("does not accept parameters"));
343            }
344            other => panic!("expected InvalidParam, got {:?}", other),
345        }
346    }
347
348    #[test]
349    fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
350        let (open, high, low, close) = sample_ohlc(96);
351        let combos = [IndicatorParamSet { params: &[] }];
352        let err = compute_cpu_batch(IndicatorBatchRequest {
353            indicator_id: "pattern_recognition",
354            output_id: Some("matrix"),
355            data: IndicatorDataRef::Ohlc {
356                open: &open,
357                high: &high,
358                low: &low,
359                close: &close,
360            },
361            combos: &combos,
362            kernel: Kernel::Auto,
363        })
364        .unwrap_err();
365        match err {
366            IndicatorDispatchError::UnsupportedCapability {
367                indicator,
368                capability,
369            } => {
370                assert_eq!(indicator, "pattern_recognition");
371                assert_eq!(capability, "cpu_batch");
372            }
373            other => panic!("expected UnsupportedCapability, got {:?}", other),
374        }
375    }
376
377    #[test]
378    fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
379        let series = sample_series(200);
380        let params = [ParamKV {
381            key: "period",
382            value: ParamValue::Int(14),
383        }];
384        let req = IndicatorComputeRequest {
385            indicator_id: "sma",
386            output_id: Some("value"),
387            data: IndicatorDataRef::Slice { values: &series },
388            params: &params,
389            kernel: Kernel::Auto,
390        };
391        let out = compute_cpu(req).unwrap();
392        assert_eq!(out.output_id, "value");
393        assert_eq!(out.rows, 1);
394        assert_eq!(out.cols, series.len());
395        match out.series {
396            IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
397            other => panic!("expected F64 series, got {:?}", other),
398        }
399    }
400}