Skip to main content

vector_ta/indicators/dispatch/
cpu_single.rs

1use super::{
2    compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3    IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4    IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7    pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8    PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13    req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15    let info = get_indicator(req.indicator_id);
16
17    if let Some(info) = info {
18        if info.id.eq_ignore_ascii_case("pattern_recognition") {
19            return compute_pattern_recognition(req, info);
20        }
21    }
22
23    let indicator_id = info.map_or(req.indicator_id, |info| info.id);
24
25    let combos = [IndicatorParamSet { params: req.params }];
26    let out = match compute_cpu_batch(IndicatorBatchRequest {
27        indicator_id,
28        output_id: req.output_id,
29        data: req.data,
30        combos: &combos,
31        kernel: req.kernel,
32    }) {
33        Err(IndicatorDispatchError::UnsupportedCapability { .. }) if info.is_none() => {
34            return Err(IndicatorDispatchError::UnknownIndicator {
35                id: req.indicator_id.to_string(),
36            });
37        }
38        other => other?,
39    };
40    map_batch_output_to_compute(indicator_id, out)
41}
42
43fn compute_pattern_recognition(
44    req: IndicatorComputeRequest<'_>,
45    info: &IndicatorInfo,
46) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
47    if !info.capabilities.supports_cpu_single {
48        return Err(IndicatorDispatchError::UnsupportedCapability {
49            indicator: info.id.to_string(),
50            capability: "cpu_single",
51        });
52    }
53
54    if let Some(param) = req.params.first() {
55        return Err(IndicatorDispatchError::InvalidParam {
56            indicator: info.id.to_string(),
57            key: param.key.to_string(),
58            reason: "pattern_recognition does not accept parameters".to_string(),
59        });
60    }
61
62    let output_id = resolve_output_id(info, req.output_id)?;
63    let input = match req.data {
64        IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
65            candles,
66            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
67        ),
68        IndicatorDataRef::Ohlc {
69            open,
70            high,
71            low,
72            close,
73        } => PatternRecognitionInput::from_slices(
74            open,
75            high,
76            low,
77            close,
78            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
79        ),
80        IndicatorDataRef::Ohlcv {
81            open,
82            high,
83            low,
84            close,
85            ..
86        } => PatternRecognitionInput::from_slices(
87            open,
88            high,
89            low,
90            close,
91            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
92        ),
93        _ => {
94            return Err(IndicatorDispatchError::MissingRequiredInput {
95                indicator: info.id.to_string(),
96                input: IndicatorInputKind::Ohlc,
97            });
98        }
99    };
100
101    let out = pattern_recognition_with_kernel(&input, req.kernel)
102        .map_err(|e| map_pattern_error(info.id, e))?;
103
104    Ok(IndicatorComputeOutput {
105        output_id: output_id.to_string(),
106        series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
107        warmup: out.warmup,
108        rows: out.rows,
109        cols: out.cols,
110        pattern_ids: Some(
111            out.pattern_ids
112                .into_iter()
113                .map(|id| id.to_string())
114                .collect(),
115        ),
116    })
117}
118
119fn resolve_output_id<'a>(
120    info: &'a IndicatorInfo,
121    requested: Option<&str>,
122) -> Result<&'a str, IndicatorDispatchError> {
123    if info.outputs.is_empty() {
124        return Err(IndicatorDispatchError::ComputeFailed {
125            indicator: info.id.to_string(),
126            details: "indicator has no registered outputs".to_string(),
127        });
128    }
129
130    if info.outputs.len() == 1 {
131        let only = info.outputs[0].id;
132        if let Some(req) = requested {
133            if !req.eq_ignore_ascii_case(only) {
134                return Err(IndicatorDispatchError::UnknownOutput {
135                    indicator: info.id.to_string(),
136                    output: req.to_string(),
137                });
138            }
139        }
140        return Ok(only);
141    }
142
143    let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
144        indicator: info.id.to_string(),
145        key: "output_id".to_string(),
146        reason: "output_id is required for multi-output indicators".to_string(),
147    })?;
148
149    info.outputs
150        .iter()
151        .find(|o| o.id.eq_ignore_ascii_case(req))
152        .map(|o| o.id)
153        .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
154            indicator: info.id.to_string(),
155            output: req.to_string(),
156        })
157}
158
159fn map_batch_output_to_compute(
160    indicator: &str,
161    out: IndicatorBatchOutput,
162) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
163    let series = if let Some(values) = out.values_f64 {
164        IndicatorSeries::F64(values)
165    } else if let Some(values) = out.values_i32 {
166        IndicatorSeries::I32(values)
167    } else if let Some(values) = out.values_bool {
168        IndicatorSeries::Bool(values)
169    } else {
170        return Err(IndicatorDispatchError::ComputeFailed {
171            indicator: indicator.to_string(),
172            details: "dispatcher returned no output series".to_string(),
173        });
174    };
175
176    Ok(IndicatorComputeOutput {
177        output_id: out.output_id,
178        series,
179        warmup: None,
180        rows: out.rows,
181        cols: out.cols,
182        pattern_ids: None,
183    })
184}
185
186fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
187    match err {
188        PatternRecognitionError::DataLengthMismatch {
189            open,
190            high,
191            low,
192            close,
193        } => IndicatorDispatchError::DataLengthMismatch {
194            details: format!("open={} high={} low={} close={}", open, high, low, close),
195        },
196        PatternRecognitionError::OutputLengthMismatch {
197            pattern_id,
198            expected,
199            got,
200        } => IndicatorDispatchError::ComputeFailed {
201            indicator: indicator.to_string(),
202            details: format!(
203                "pattern output mismatch for {}: expected {}, got {}",
204                pattern_id, expected, got
205            ),
206        },
207        PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
208            indicator: indicator.to_string(),
209            details: e.to_string(),
210        },
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
218    use crate::indicators::pattern_recognition::list_patterns;
219    use crate::utilities::enums::Kernel;
220
221    fn sample_series(len: usize) -> Vec<f64> {
222        (0..len)
223            .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
224            .collect()
225    }
226
227    fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
228        let open = sample_series(len);
229        let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
230        let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
231        let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
232        (open, high, low, close)
233    }
234
235    #[test]
236    fn compute_cpu_pattern_recognition_returns_matrix() {
237        let (open, high, low, close) = sample_ohlc(192);
238        let req = IndicatorComputeRequest {
239            indicator_id: "pattern_recognition",
240            output_id: Some("matrix"),
241            data: IndicatorDataRef::Ohlc {
242                open: &open,
243                high: &high,
244                low: &low,
245                close: &close,
246            },
247            params: &[],
248            kernel: Kernel::Auto,
249        };
250        let out = compute_cpu(req).unwrap();
251        assert_eq!(out.output_id, "matrix");
252        assert_eq!(out.rows, list_patterns().len());
253        assert_eq!(out.cols, close.len());
254        match out.series {
255            IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
256            other => panic!("expected Bool matrix series, got {:?}", other),
257        }
258        let ids = out.pattern_ids.unwrap();
259        assert_eq!(ids.len(), out.rows);
260    }
261
262    #[test]
263    fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
264        let series = sample_series(64);
265        let req = IndicatorComputeRequest {
266            indicator_id: "pattern_recognition",
267            output_id: Some("matrix"),
268            data: IndicatorDataRef::Slice { values: &series },
269            params: &[],
270            kernel: Kernel::Auto,
271        };
272        let err = compute_cpu(req).unwrap_err();
273        match err {
274            IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
275                assert_eq!(indicator, "pattern_recognition");
276                assert_eq!(input, IndicatorInputKind::Ohlc);
277            }
278            other => panic!("expected MissingRequiredInput, got {:?}", other),
279        }
280    }
281
282    #[test]
283    fn compute_cpu_pattern_recognition_rejects_unknown_output() {
284        let (open, high, low, close) = sample_ohlc(64);
285        let req = IndicatorComputeRequest {
286            indicator_id: "pattern_recognition",
287            output_id: Some("value"),
288            data: IndicatorDataRef::Ohlc {
289                open: &open,
290                high: &high,
291                low: &low,
292                close: &close,
293            },
294            params: &[],
295            kernel: Kernel::Auto,
296        };
297        let err = compute_cpu(req).unwrap_err();
298        match err {
299            IndicatorDispatchError::UnknownOutput { indicator, output } => {
300                assert_eq!(indicator, "pattern_recognition");
301                assert_eq!(output, "value");
302            }
303            other => panic!("expected UnknownOutput, got {:?}", other),
304        }
305    }
306
307    #[test]
308    fn compute_cpu_pattern_recognition_rejects_params() {
309        let (open, high, low, close) = sample_ohlc(64);
310        let params = [ParamKV {
311            key: "period",
312            value: ParamValue::Int(14),
313        }];
314        let req = IndicatorComputeRequest {
315            indicator_id: "pattern_recognition",
316            output_id: Some("matrix"),
317            data: IndicatorDataRef::Ohlc {
318                open: &open,
319                high: &high,
320                low: &low,
321                close: &close,
322            },
323            params: &params,
324            kernel: Kernel::Auto,
325        };
326        let err = compute_cpu(req).unwrap_err();
327        match err {
328            IndicatorDispatchError::InvalidParam {
329                indicator,
330                key,
331                reason,
332            } => {
333                assert_eq!(indicator, "pattern_recognition");
334                assert_eq!(key, "period");
335                assert!(reason.contains("does not accept parameters"));
336            }
337            other => panic!("expected InvalidParam, got {:?}", other),
338        }
339    }
340
341    #[test]
342    fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
343        let (open, high, low, close) = sample_ohlc(96);
344        let combos = [IndicatorParamSet { params: &[] }];
345        let err = compute_cpu_batch(IndicatorBatchRequest {
346            indicator_id: "pattern_recognition",
347            output_id: Some("matrix"),
348            data: IndicatorDataRef::Ohlc {
349                open: &open,
350                high: &high,
351                low: &low,
352                close: &close,
353            },
354            combos: &combos,
355            kernel: Kernel::Auto,
356        })
357        .unwrap_err();
358        match err {
359            IndicatorDispatchError::UnsupportedCapability {
360                indicator,
361                capability,
362            } => {
363                assert_eq!(indicator, "pattern_recognition");
364                assert_eq!(capability, "cpu_batch");
365            }
366            other => panic!("expected UnsupportedCapability, got {:?}", other),
367        }
368    }
369
370    #[test]
371    fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
372        let series = sample_series(200);
373        let params = [ParamKV {
374            key: "period",
375            value: ParamValue::Int(14),
376        }];
377        let req = IndicatorComputeRequest {
378            indicator_id: "sma",
379            output_id: Some("value"),
380            data: IndicatorDataRef::Slice { values: &series },
381            params: &params,
382            kernel: Kernel::Auto,
383        };
384        let out = compute_cpu(req).unwrap();
385        assert_eq!(out.output_id, "value");
386        assert_eq!(out.rows, 1);
387        assert_eq!(out.cols, series.len());
388        match out.series {
389            IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
390            other => panic!("expected F64 series, got {:?}", other),
391        }
392    }
393}