1use super::{
2 compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3 IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4 IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7 pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8 PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13 req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15 let info = get_indicator(req.indicator_id);
16
17 if let Some(info) = info {
18 if info.id.eq_ignore_ascii_case("pattern_recognition") {
19 return compute_pattern_recognition(req, info);
20 }
21 }
22
23 let indicator_id = info.map_or(req.indicator_id, |info| info.id);
24
25 let combos = [IndicatorParamSet { params: req.params }];
26 let out = match compute_cpu_batch(IndicatorBatchRequest {
27 indicator_id,
28 output_id: req.output_id,
29 data: req.data,
30 combos: &combos,
31 kernel: req.kernel,
32 }) {
33 Err(IndicatorDispatchError::UnsupportedCapability { .. }) if info.is_none() => {
34 return Err(IndicatorDispatchError::UnknownIndicator {
35 id: req.indicator_id.to_string(),
36 });
37 }
38 other => other?,
39 };
40 map_batch_output_to_compute(indicator_id, out)
41}
42
43fn compute_pattern_recognition(
44 req: IndicatorComputeRequest<'_>,
45 info: &IndicatorInfo,
46) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
47 if !info.capabilities.supports_cpu_single {
48 return Err(IndicatorDispatchError::UnsupportedCapability {
49 indicator: info.id.to_string(),
50 capability: "cpu_single",
51 });
52 }
53
54 if let Some(param) = req.params.first() {
55 return Err(IndicatorDispatchError::InvalidParam {
56 indicator: info.id.to_string(),
57 key: param.key.to_string(),
58 reason: "pattern_recognition does not accept parameters".to_string(),
59 });
60 }
61
62 let output_id = resolve_output_id(info, req.output_id)?;
63 let input = match req.data {
64 IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
65 candles,
66 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
67 ),
68 IndicatorDataRef::Ohlc {
69 open,
70 high,
71 low,
72 close,
73 } => PatternRecognitionInput::from_slices(
74 open,
75 high,
76 low,
77 close,
78 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
79 ),
80 IndicatorDataRef::Ohlcv {
81 open,
82 high,
83 low,
84 close,
85 ..
86 } => PatternRecognitionInput::from_slices(
87 open,
88 high,
89 low,
90 close,
91 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
92 ),
93 _ => {
94 return Err(IndicatorDispatchError::MissingRequiredInput {
95 indicator: info.id.to_string(),
96 input: IndicatorInputKind::Ohlc,
97 });
98 }
99 };
100
101 let out = pattern_recognition_with_kernel(&input, req.kernel)
102 .map_err(|e| map_pattern_error(info.id, e))?;
103
104 Ok(IndicatorComputeOutput {
105 output_id: output_id.to_string(),
106 series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
107 warmup: out.warmup,
108 rows: out.rows,
109 cols: out.cols,
110 pattern_ids: Some(
111 out.pattern_ids
112 .into_iter()
113 .map(|id| id.to_string())
114 .collect(),
115 ),
116 })
117}
118
119fn resolve_output_id<'a>(
120 info: &'a IndicatorInfo,
121 requested: Option<&str>,
122) -> Result<&'a str, IndicatorDispatchError> {
123 if info.outputs.is_empty() {
124 return Err(IndicatorDispatchError::ComputeFailed {
125 indicator: info.id.to_string(),
126 details: "indicator has no registered outputs".to_string(),
127 });
128 }
129
130 if info.outputs.len() == 1 {
131 let only = info.outputs[0].id;
132 if let Some(req) = requested {
133 if !req.eq_ignore_ascii_case(only) {
134 return Err(IndicatorDispatchError::UnknownOutput {
135 indicator: info.id.to_string(),
136 output: req.to_string(),
137 });
138 }
139 }
140 return Ok(only);
141 }
142
143 let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
144 indicator: info.id.to_string(),
145 key: "output_id".to_string(),
146 reason: "output_id is required for multi-output indicators".to_string(),
147 })?;
148
149 info.outputs
150 .iter()
151 .find(|o| o.id.eq_ignore_ascii_case(req))
152 .map(|o| o.id)
153 .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
154 indicator: info.id.to_string(),
155 output: req.to_string(),
156 })
157}
158
159fn map_batch_output_to_compute(
160 indicator: &str,
161 out: IndicatorBatchOutput,
162) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
163 let series = if let Some(values) = out.values_f64 {
164 IndicatorSeries::F64(values)
165 } else if let Some(values) = out.values_i32 {
166 IndicatorSeries::I32(values)
167 } else if let Some(values) = out.values_bool {
168 IndicatorSeries::Bool(values)
169 } else {
170 return Err(IndicatorDispatchError::ComputeFailed {
171 indicator: indicator.to_string(),
172 details: "dispatcher returned no output series".to_string(),
173 });
174 };
175
176 Ok(IndicatorComputeOutput {
177 output_id: out.output_id,
178 series,
179 warmup: None,
180 rows: out.rows,
181 cols: out.cols,
182 pattern_ids: None,
183 })
184}
185
186fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
187 match err {
188 PatternRecognitionError::DataLengthMismatch {
189 open,
190 high,
191 low,
192 close,
193 } => IndicatorDispatchError::DataLengthMismatch {
194 details: format!("open={} high={} low={} close={}", open, high, low, close),
195 },
196 PatternRecognitionError::OutputLengthMismatch {
197 pattern_id,
198 expected,
199 got,
200 } => IndicatorDispatchError::ComputeFailed {
201 indicator: indicator.to_string(),
202 details: format!(
203 "pattern output mismatch for {}: expected {}, got {}",
204 pattern_id, expected, got
205 ),
206 },
207 PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
208 indicator: indicator.to_string(),
209 details: e.to_string(),
210 },
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
218 use crate::indicators::pattern_recognition::list_patterns;
219 use crate::utilities::enums::Kernel;
220
221 fn sample_series(len: usize) -> Vec<f64> {
222 (0..len)
223 .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
224 .collect()
225 }
226
227 fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
228 let open = sample_series(len);
229 let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
230 let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
231 let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
232 (open, high, low, close)
233 }
234
235 #[test]
236 fn compute_cpu_pattern_recognition_returns_matrix() {
237 let (open, high, low, close) = sample_ohlc(192);
238 let req = IndicatorComputeRequest {
239 indicator_id: "pattern_recognition",
240 output_id: Some("matrix"),
241 data: IndicatorDataRef::Ohlc {
242 open: &open,
243 high: &high,
244 low: &low,
245 close: &close,
246 },
247 params: &[],
248 kernel: Kernel::Auto,
249 };
250 let out = compute_cpu(req).unwrap();
251 assert_eq!(out.output_id, "matrix");
252 assert_eq!(out.rows, list_patterns().len());
253 assert_eq!(out.cols, close.len());
254 match out.series {
255 IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
256 other => panic!("expected Bool matrix series, got {:?}", other),
257 }
258 let ids = out.pattern_ids.unwrap();
259 assert_eq!(ids.len(), out.rows);
260 }
261
262 #[test]
263 fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
264 let series = sample_series(64);
265 let req = IndicatorComputeRequest {
266 indicator_id: "pattern_recognition",
267 output_id: Some("matrix"),
268 data: IndicatorDataRef::Slice { values: &series },
269 params: &[],
270 kernel: Kernel::Auto,
271 };
272 let err = compute_cpu(req).unwrap_err();
273 match err {
274 IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
275 assert_eq!(indicator, "pattern_recognition");
276 assert_eq!(input, IndicatorInputKind::Ohlc);
277 }
278 other => panic!("expected MissingRequiredInput, got {:?}", other),
279 }
280 }
281
282 #[test]
283 fn compute_cpu_pattern_recognition_rejects_unknown_output() {
284 let (open, high, low, close) = sample_ohlc(64);
285 let req = IndicatorComputeRequest {
286 indicator_id: "pattern_recognition",
287 output_id: Some("value"),
288 data: IndicatorDataRef::Ohlc {
289 open: &open,
290 high: &high,
291 low: &low,
292 close: &close,
293 },
294 params: &[],
295 kernel: Kernel::Auto,
296 };
297 let err = compute_cpu(req).unwrap_err();
298 match err {
299 IndicatorDispatchError::UnknownOutput { indicator, output } => {
300 assert_eq!(indicator, "pattern_recognition");
301 assert_eq!(output, "value");
302 }
303 other => panic!("expected UnknownOutput, got {:?}", other),
304 }
305 }
306
307 #[test]
308 fn compute_cpu_pattern_recognition_rejects_params() {
309 let (open, high, low, close) = sample_ohlc(64);
310 let params = [ParamKV {
311 key: "period",
312 value: ParamValue::Int(14),
313 }];
314 let req = IndicatorComputeRequest {
315 indicator_id: "pattern_recognition",
316 output_id: Some("matrix"),
317 data: IndicatorDataRef::Ohlc {
318 open: &open,
319 high: &high,
320 low: &low,
321 close: &close,
322 },
323 params: ¶ms,
324 kernel: Kernel::Auto,
325 };
326 let err = compute_cpu(req).unwrap_err();
327 match err {
328 IndicatorDispatchError::InvalidParam {
329 indicator,
330 key,
331 reason,
332 } => {
333 assert_eq!(indicator, "pattern_recognition");
334 assert_eq!(key, "period");
335 assert!(reason.contains("does not accept parameters"));
336 }
337 other => panic!("expected InvalidParam, got {:?}", other),
338 }
339 }
340
341 #[test]
342 fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
343 let (open, high, low, close) = sample_ohlc(96);
344 let combos = [IndicatorParamSet { params: &[] }];
345 let err = compute_cpu_batch(IndicatorBatchRequest {
346 indicator_id: "pattern_recognition",
347 output_id: Some("matrix"),
348 data: IndicatorDataRef::Ohlc {
349 open: &open,
350 high: &high,
351 low: &low,
352 close: &close,
353 },
354 combos: &combos,
355 kernel: Kernel::Auto,
356 })
357 .unwrap_err();
358 match err {
359 IndicatorDispatchError::UnsupportedCapability {
360 indicator,
361 capability,
362 } => {
363 assert_eq!(indicator, "pattern_recognition");
364 assert_eq!(capability, "cpu_batch");
365 }
366 other => panic!("expected UnsupportedCapability, got {:?}", other),
367 }
368 }
369
370 #[test]
371 fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
372 let series = sample_series(200);
373 let params = [ParamKV {
374 key: "period",
375 value: ParamValue::Int(14),
376 }];
377 let req = IndicatorComputeRequest {
378 indicator_id: "sma",
379 output_id: Some("value"),
380 data: IndicatorDataRef::Slice { values: &series },
381 params: ¶ms,
382 kernel: Kernel::Auto,
383 };
384 let out = compute_cpu(req).unwrap();
385 assert_eq!(out.output_id, "value");
386 assert_eq!(out.rows, 1);
387 assert_eq!(out.cols, series.len());
388 match out.series {
389 IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
390 other => panic!("expected F64 series, got {:?}", other),
391 }
392 }
393}