1use super::{
2 compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3 IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4 IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7 pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8 PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13 req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15 let info = get_indicator(req.indicator_id).ok_or_else(|| {
16 IndicatorDispatchError::UnknownIndicator {
17 id: req.indicator_id.to_string(),
18 }
19 })?;
20
21 if info.id.eq_ignore_ascii_case("pattern_recognition") {
22 return compute_pattern_recognition(req, info);
23 }
24
25 if !info.capabilities.supports_cpu_single {
26 return Err(IndicatorDispatchError::UnsupportedCapability {
27 indicator: info.id.to_string(),
28 capability: "cpu_single",
29 });
30 }
31
32 if !info.capabilities.supports_cpu_batch {
33 return Err(IndicatorDispatchError::UnsupportedCapability {
34 indicator: info.id.to_string(),
35 capability: "cpu_single",
36 });
37 }
38
39 let combos = [IndicatorParamSet { params: req.params }];
40 let out = compute_cpu_batch(IndicatorBatchRequest {
41 indicator_id: info.id,
42 output_id: req.output_id,
43 data: req.data,
44 combos: &combos,
45 kernel: req.kernel,
46 })?;
47 map_batch_output_to_compute(info.id, out)
48}
49
50fn compute_pattern_recognition(
51 req: IndicatorComputeRequest<'_>,
52 info: &IndicatorInfo,
53) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
54 if !info.capabilities.supports_cpu_single {
55 return Err(IndicatorDispatchError::UnsupportedCapability {
56 indicator: info.id.to_string(),
57 capability: "cpu_single",
58 });
59 }
60
61 if let Some(param) = req.params.first() {
62 return Err(IndicatorDispatchError::InvalidParam {
63 indicator: info.id.to_string(),
64 key: param.key.to_string(),
65 reason: "pattern_recognition does not accept parameters".to_string(),
66 });
67 }
68
69 let output_id = resolve_output_id(info, req.output_id)?;
70 let input = match req.data {
71 IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
72 candles,
73 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
74 ),
75 IndicatorDataRef::Ohlc {
76 open,
77 high,
78 low,
79 close,
80 } => PatternRecognitionInput::from_slices(
81 open,
82 high,
83 low,
84 close,
85 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
86 ),
87 IndicatorDataRef::Ohlcv {
88 open,
89 high,
90 low,
91 close,
92 ..
93 } => PatternRecognitionInput::from_slices(
94 open,
95 high,
96 low,
97 close,
98 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
99 ),
100 _ => {
101 return Err(IndicatorDispatchError::MissingRequiredInput {
102 indicator: info.id.to_string(),
103 input: IndicatorInputKind::Ohlc,
104 });
105 }
106 };
107
108 let out = pattern_recognition_with_kernel(&input, req.kernel)
109 .map_err(|e| map_pattern_error(info.id, e))?;
110
111 Ok(IndicatorComputeOutput {
112 output_id: output_id.to_string(),
113 series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
114 warmup: out.warmup,
115 rows: out.rows,
116 cols: out.cols,
117 pattern_ids: Some(out.pattern_ids.into_iter().map(|id| id.to_string()).collect()),
118 })
119}
120
121fn resolve_output_id<'a>(
122 info: &'a IndicatorInfo,
123 requested: Option<&str>,
124) -> Result<&'a str, IndicatorDispatchError> {
125 if info.outputs.is_empty() {
126 return Err(IndicatorDispatchError::ComputeFailed {
127 indicator: info.id.to_string(),
128 details: "indicator has no registered outputs".to_string(),
129 });
130 }
131
132 if info.outputs.len() == 1 {
133 let only = info.outputs[0].id;
134 if let Some(req) = requested {
135 if !req.eq_ignore_ascii_case(only) {
136 return Err(IndicatorDispatchError::UnknownOutput {
137 indicator: info.id.to_string(),
138 output: req.to_string(),
139 });
140 }
141 }
142 return Ok(only);
143 }
144
145 let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
146 indicator: info.id.to_string(),
147 key: "output_id".to_string(),
148 reason: "output_id is required for multi-output indicators".to_string(),
149 })?;
150
151 info.outputs
152 .iter()
153 .find(|o| o.id.eq_ignore_ascii_case(req))
154 .map(|o| o.id)
155 .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
156 indicator: info.id.to_string(),
157 output: req.to_string(),
158 })
159}
160
161fn map_batch_output_to_compute(
162 indicator: &str,
163 out: IndicatorBatchOutput,
164) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
165 let series = if let Some(values) = out.values_f64 {
166 IndicatorSeries::F64(values)
167 } else if let Some(values) = out.values_i32 {
168 IndicatorSeries::I32(values)
169 } else if let Some(values) = out.values_bool {
170 IndicatorSeries::Bool(values)
171 } else {
172 return Err(IndicatorDispatchError::ComputeFailed {
173 indicator: indicator.to_string(),
174 details: "dispatcher returned no output series".to_string(),
175 });
176 };
177
178 Ok(IndicatorComputeOutput {
179 output_id: out.output_id,
180 series,
181 warmup: None,
182 rows: out.rows,
183 cols: out.cols,
184 pattern_ids: None,
185 })
186}
187
188fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
189 match err {
190 PatternRecognitionError::DataLengthMismatch {
191 open,
192 high,
193 low,
194 close,
195 } => IndicatorDispatchError::DataLengthMismatch {
196 details: format!("open={} high={} low={} close={}", open, high, low, close),
197 },
198 PatternRecognitionError::OutputLengthMismatch {
199 pattern_id,
200 expected,
201 got,
202 } => IndicatorDispatchError::ComputeFailed {
203 indicator: indicator.to_string(),
204 details: format!(
205 "pattern output mismatch for {}: expected {}, got {}",
206 pattern_id, expected, got
207 ),
208 },
209 PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
210 indicator: indicator.to_string(),
211 details: e.to_string(),
212 },
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
220 use crate::indicators::pattern_recognition::list_patterns;
221 use crate::utilities::enums::Kernel;
222
223 fn sample_series(len: usize) -> Vec<f64> {
224 (0..len)
225 .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
226 .collect()
227 }
228
229 fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
230 let open = sample_series(len);
231 let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
232 let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
233 let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
234 (open, high, low, close)
235 }
236
237 #[test]
238 fn compute_cpu_pattern_recognition_returns_matrix() {
239 let (open, high, low, close) = sample_ohlc(192);
240 let req = IndicatorComputeRequest {
241 indicator_id: "pattern_recognition",
242 output_id: Some("matrix"),
243 data: IndicatorDataRef::Ohlc {
244 open: &open,
245 high: &high,
246 low: &low,
247 close: &close,
248 },
249 params: &[],
250 kernel: Kernel::Auto,
251 };
252 let out = compute_cpu(req).unwrap();
253 assert_eq!(out.output_id, "matrix");
254 assert_eq!(out.rows, list_patterns().len());
255 assert_eq!(out.cols, close.len());
256 match out.series {
257 IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
258 other => panic!("expected Bool matrix series, got {:?}", other),
259 }
260 let ids = out.pattern_ids.unwrap();
261 assert_eq!(ids.len(), out.rows);
262 }
263
264 #[test]
265 fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
266 let series = sample_series(64);
267 let req = IndicatorComputeRequest {
268 indicator_id: "pattern_recognition",
269 output_id: Some("matrix"),
270 data: IndicatorDataRef::Slice { values: &series },
271 params: &[],
272 kernel: Kernel::Auto,
273 };
274 let err = compute_cpu(req).unwrap_err();
275 match err {
276 IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
277 assert_eq!(indicator, "pattern_recognition");
278 assert_eq!(input, IndicatorInputKind::Ohlc);
279 }
280 other => panic!("expected MissingRequiredInput, got {:?}", other),
281 }
282 }
283
284 #[test]
285 fn compute_cpu_pattern_recognition_rejects_unknown_output() {
286 let (open, high, low, close) = sample_ohlc(64);
287 let req = IndicatorComputeRequest {
288 indicator_id: "pattern_recognition",
289 output_id: Some("value"),
290 data: IndicatorDataRef::Ohlc {
291 open: &open,
292 high: &high,
293 low: &low,
294 close: &close,
295 },
296 params: &[],
297 kernel: Kernel::Auto,
298 };
299 let err = compute_cpu(req).unwrap_err();
300 match err {
301 IndicatorDispatchError::UnknownOutput { indicator, output } => {
302 assert_eq!(indicator, "pattern_recognition");
303 assert_eq!(output, "value");
304 }
305 other => panic!("expected UnknownOutput, got {:?}", other),
306 }
307 }
308
309 #[test]
310 fn compute_cpu_pattern_recognition_rejects_params() {
311 let (open, high, low, close) = sample_ohlc(64);
312 let params = [ParamKV {
313 key: "period",
314 value: ParamValue::Int(14),
315 }];
316 let req = IndicatorComputeRequest {
317 indicator_id: "pattern_recognition",
318 output_id: Some("matrix"),
319 data: IndicatorDataRef::Ohlc {
320 open: &open,
321 high: &high,
322 low: &low,
323 close: &close,
324 },
325 params: ¶ms,
326 kernel: Kernel::Auto,
327 };
328 let err = compute_cpu(req).unwrap_err();
329 match err {
330 IndicatorDispatchError::InvalidParam {
331 indicator,
332 key,
333 reason,
334 } => {
335 assert_eq!(indicator, "pattern_recognition");
336 assert_eq!(key, "period");
337 assert!(reason.contains("does not accept parameters"));
338 }
339 other => panic!("expected InvalidParam, got {:?}", other),
340 }
341 }
342
343 #[test]
344 fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
345 let (open, high, low, close) = sample_ohlc(96);
346 let combos = [IndicatorParamSet { params: &[] }];
347 let err = compute_cpu_batch(IndicatorBatchRequest {
348 indicator_id: "pattern_recognition",
349 output_id: Some("matrix"),
350 data: IndicatorDataRef::Ohlc {
351 open: &open,
352 high: &high,
353 low: &low,
354 close: &close,
355 },
356 combos: &combos,
357 kernel: Kernel::Auto,
358 })
359 .unwrap_err();
360 match err {
361 IndicatorDispatchError::UnsupportedCapability {
362 indicator,
363 capability,
364 } => {
365 assert_eq!(indicator, "pattern_recognition");
366 assert_eq!(capability, "cpu_batch");
367 }
368 other => panic!("expected UnsupportedCapability, got {:?}", other),
369 }
370 }
371
372 #[test]
373 fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
374 let series = sample_series(200);
375 let params = [ParamKV {
376 key: "period",
377 value: ParamValue::Int(14),
378 }];
379 let req = IndicatorComputeRequest {
380 indicator_id: "sma",
381 output_id: Some("value"),
382 data: IndicatorDataRef::Slice { values: &series },
383 params: ¶ms,
384 kernel: Kernel::Auto,
385 };
386 let out = compute_cpu(req).unwrap();
387 assert_eq!(out.output_id, "value");
388 assert_eq!(out.rows, 1);
389 assert_eq!(out.cols, series.len());
390 match out.series {
391 IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
392 other => panic!("expected F64 series, got {:?}", other),
393 }
394 }
395}