Skip to main content

vector_ta/indicators/dispatch/
cpu_single.rs

1use super::{
2    compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3    IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4    IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7    pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8    PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13    req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15    let info = get_indicator(req.indicator_id);
16
17    if let Some(info) = info {
18        if info.id.eq_ignore_ascii_case("pattern_recognition") {
19            return compute_pattern_recognition(req, info);
20        }
21    }
22
23    let indicator_id = info.map_or(req.indicator_id, |info| info.id);
24
25    let combos = [IndicatorParamSet { params: req.params }];
26    let out = match compute_cpu_batch(IndicatorBatchRequest {
27        indicator_id,
28        output_id: req.output_id,
29        data: req.data,
30        combos: &combos,
31        kernel: req.kernel,
32    }) {
33        Err(IndicatorDispatchError::UnsupportedCapability { .. }) if info.is_none() => {
34            return Err(IndicatorDispatchError::UnknownIndicator {
35                id: req.indicator_id.to_string(),
36            });
37        }
38        other => other?,
39    };
40    map_batch_output_to_compute(indicator_id, out)
41}
42
43fn compute_pattern_recognition(
44    req: IndicatorComputeRequest<'_>,
45    info: &IndicatorInfo,
46) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
47    if !info.capabilities.supports_cpu_single {
48        return Err(IndicatorDispatchError::UnsupportedCapability {
49            indicator: info.id.to_string(),
50            capability: "cpu_single",
51        });
52    }
53
54    if let Some(param) = req.params.first() {
55        return Err(IndicatorDispatchError::InvalidParam {
56            indicator: info.id.to_string(),
57            key: param.key.to_string(),
58            reason: "pattern_recognition does not accept parameters".to_string(),
59        });
60    }
61
62    let output_id = resolve_output_id(info, req.output_id)?;
63    let input = match req.data {
64        IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
65            candles,
66            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
67        ),
68        IndicatorDataRef::Ohlc {
69            open,
70            high,
71            low,
72            close,
73        } => PatternRecognitionInput::from_slices(
74            open,
75            high,
76            low,
77            close,
78            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
79        ),
80        IndicatorDataRef::Ohlcv {
81            open,
82            high,
83            low,
84            close,
85            ..
86        } => PatternRecognitionInput::from_slices(
87            open,
88            high,
89            low,
90            close,
91            crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
92        ),
93        _ => {
94            return Err(IndicatorDispatchError::MissingRequiredInput {
95                indicator: info.id.to_string(),
96                input: IndicatorInputKind::Ohlc,
97            });
98        }
99    };
100
101    let out = pattern_recognition_with_kernel(&input, req.kernel)
102        .map_err(|e| map_pattern_error(info.id, e))?;
103
104    Ok(IndicatorComputeOutput {
105        output_id: output_id.to_string(),
106        series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
107        warmup: out.warmup,
108        rows: out.rows,
109        cols: out.cols,
110        pattern_ids: Some(
111            out.pattern_ids
112                .into_iter()
113                .map(|id| id.to_string())
114                .collect(),
115        ),
116    })
117}
118
119fn normalize_output_token(value: &str) -> String {
120    let mut normalized = String::with_capacity(value.len());
121    for ch in value.chars() {
122        if ch.is_ascii_alphanumeric() {
123            normalized.push(ch.to_ascii_lowercase());
124        }
125    }
126    if normalized == "values" {
127        "value".to_string()
128    } else {
129        normalized
130    }
131}
132
133fn output_id_matches(candidate: &str, requested: &str) -> bool {
134    candidate.eq_ignore_ascii_case(requested)
135        || normalize_output_token(candidate) == normalize_output_token(requested)
136}
137
138fn resolve_output_id<'a>(
139    info: &'a IndicatorInfo,
140    requested: Option<&str>,
141) -> Result<&'a str, IndicatorDispatchError> {
142    if info.outputs.is_empty() {
143        return Err(IndicatorDispatchError::ComputeFailed {
144            indicator: info.id.to_string(),
145            details: "indicator has no registered outputs".to_string(),
146        });
147    }
148
149    if info.outputs.len() == 1 {
150        let only = info.outputs[0].id;
151        if let Some(req) = requested {
152            if !output_id_matches(only, req) {
153                return Err(IndicatorDispatchError::UnknownOutput {
154                    indicator: info.id.to_string(),
155                    output: req.to_string(),
156                });
157            }
158        }
159        return Ok(only);
160    }
161
162    let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
163        indicator: info.id.to_string(),
164        key: "output_id".to_string(),
165        reason: "output_id is required for multi-output indicators".to_string(),
166    })?;
167
168    info.outputs
169        .iter()
170        .find(|o| output_id_matches(o.id, req))
171        .map(|o| o.id)
172        .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
173            indicator: info.id.to_string(),
174            output: req.to_string(),
175        })
176}
177
178fn map_batch_output_to_compute(
179    indicator: &str,
180    out: IndicatorBatchOutput,
181) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
182    let series = if let Some(values) = out.values_f64 {
183        IndicatorSeries::F64(values)
184    } else if let Some(values) = out.values_i32 {
185        IndicatorSeries::I32(values)
186    } else if let Some(values) = out.values_bool {
187        IndicatorSeries::Bool(values)
188    } else {
189        return Err(IndicatorDispatchError::ComputeFailed {
190            indicator: indicator.to_string(),
191            details: "dispatcher returned no output series".to_string(),
192        });
193    };
194
195    Ok(IndicatorComputeOutput {
196        output_id: out.output_id,
197        series,
198        warmup: None,
199        rows: out.rows,
200        cols: out.cols,
201        pattern_ids: None,
202    })
203}
204
205fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
206    match err {
207        PatternRecognitionError::DataLengthMismatch {
208            open,
209            high,
210            low,
211            close,
212        } => IndicatorDispatchError::DataLengthMismatch {
213            details: format!("open={} high={} low={} close={}", open, high, low, close),
214        },
215        PatternRecognitionError::OutputLengthMismatch {
216            pattern_id,
217            expected,
218            got,
219        } => IndicatorDispatchError::ComputeFailed {
220            indicator: indicator.to_string(),
221            details: format!(
222                "pattern output mismatch for {}: expected {}, got {}",
223                pattern_id, expected, got
224            ),
225        },
226        PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
227            indicator: indicator.to_string(),
228            details: e.to_string(),
229        },
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
237    use crate::indicators::pattern_recognition::list_patterns;
238    use crate::utilities::enums::Kernel;
239
240    fn sample_series(len: usize) -> Vec<f64> {
241        (0..len)
242            .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
243            .collect()
244    }
245
246    fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
247        let open = sample_series(len);
248        let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
249        let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
250        let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
251        (open, high, low, close)
252    }
253
254    #[test]
255    fn compute_cpu_pattern_recognition_returns_matrix() {
256        let (open, high, low, close) = sample_ohlc(192);
257        let req = IndicatorComputeRequest {
258            indicator_id: "pattern_recognition",
259            output_id: Some("matrix"),
260            data: IndicatorDataRef::Ohlc {
261                open: &open,
262                high: &high,
263                low: &low,
264                close: &close,
265            },
266            params: &[],
267            kernel: Kernel::Auto,
268        };
269        let out = compute_cpu(req).unwrap();
270        assert_eq!(out.output_id, "matrix");
271        assert_eq!(out.rows, list_patterns().len());
272        assert_eq!(out.cols, close.len());
273        match out.series {
274            IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
275            other => panic!("expected Bool matrix series, got {:?}", other),
276        }
277        let ids = out.pattern_ids.unwrap();
278        assert_eq!(ids.len(), out.rows);
279    }
280
281    #[test]
282    fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
283        let series = sample_series(64);
284        let req = IndicatorComputeRequest {
285            indicator_id: "pattern_recognition",
286            output_id: Some("matrix"),
287            data: IndicatorDataRef::Slice { values: &series },
288            params: &[],
289            kernel: Kernel::Auto,
290        };
291        let err = compute_cpu(req).unwrap_err();
292        match err {
293            IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
294                assert_eq!(indicator, "pattern_recognition");
295                assert_eq!(input, IndicatorInputKind::Ohlc);
296            }
297            other => panic!("expected MissingRequiredInput, got {:?}", other),
298        }
299    }
300
301    #[test]
302    fn compute_cpu_pattern_recognition_rejects_unknown_output() {
303        let (open, high, low, close) = sample_ohlc(64);
304        let req = IndicatorComputeRequest {
305            indicator_id: "pattern_recognition",
306            output_id: Some("value"),
307            data: IndicatorDataRef::Ohlc {
308                open: &open,
309                high: &high,
310                low: &low,
311                close: &close,
312            },
313            params: &[],
314            kernel: Kernel::Auto,
315        };
316        let err = compute_cpu(req).unwrap_err();
317        match err {
318            IndicatorDispatchError::UnknownOutput { indicator, output } => {
319                assert_eq!(indicator, "pattern_recognition");
320                assert_eq!(output, "value");
321            }
322            other => panic!("expected UnknownOutput, got {:?}", other),
323        }
324    }
325
326    #[test]
327    fn compute_cpu_pattern_recognition_rejects_params() {
328        let (open, high, low, close) = sample_ohlc(64);
329        let params = [ParamKV {
330            key: "period",
331            value: ParamValue::Int(14),
332        }];
333        let req = IndicatorComputeRequest {
334            indicator_id: "pattern_recognition",
335            output_id: Some("matrix"),
336            data: IndicatorDataRef::Ohlc {
337                open: &open,
338                high: &high,
339                low: &low,
340                close: &close,
341            },
342            params: &params,
343            kernel: Kernel::Auto,
344        };
345        let err = compute_cpu(req).unwrap_err();
346        match err {
347            IndicatorDispatchError::InvalidParam {
348                indicator,
349                key,
350                reason,
351            } => {
352                assert_eq!(indicator, "pattern_recognition");
353                assert_eq!(key, "period");
354                assert!(reason.contains("does not accept parameters"));
355            }
356            other => panic!("expected InvalidParam, got {:?}", other),
357        }
358    }
359
360    #[test]
361    fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
362        let (open, high, low, close) = sample_ohlc(96);
363        let combos = [IndicatorParamSet { params: &[] }];
364        let err = compute_cpu_batch(IndicatorBatchRequest {
365            indicator_id: "pattern_recognition",
366            output_id: Some("matrix"),
367            data: IndicatorDataRef::Ohlc {
368                open: &open,
369                high: &high,
370                low: &low,
371                close: &close,
372            },
373            combos: &combos,
374            kernel: Kernel::Auto,
375        })
376        .unwrap_err();
377        match err {
378            IndicatorDispatchError::UnsupportedCapability {
379                indicator,
380                capability,
381            } => {
382                assert_eq!(indicator, "pattern_recognition");
383                assert_eq!(capability, "cpu_batch");
384            }
385            other => panic!("expected UnsupportedCapability, got {:?}", other),
386        }
387    }
388
389    #[test]
390    fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
391        let series = sample_series(200);
392        let params = [ParamKV {
393            key: "period",
394            value: ParamValue::Int(14),
395        }];
396        let req = IndicatorComputeRequest {
397            indicator_id: "sma",
398            output_id: Some("value"),
399            data: IndicatorDataRef::Slice { values: &series },
400            params: &params,
401            kernel: Kernel::Auto,
402        };
403        let out = compute_cpu(req).unwrap();
404        assert_eq!(out.output_id, "value");
405        assert_eq!(out.rows, 1);
406        assert_eq!(out.cols, series.len());
407        match out.series {
408            IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
409            other => panic!("expected F64 series, got {:?}", other),
410        }
411    }
412}