1use super::{
2 compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3 IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4 IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7 pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8 PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13 req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15 let info = get_indicator(req.indicator_id).ok_or_else(|| {
16 IndicatorDispatchError::UnknownIndicator {
17 id: req.indicator_id.to_string(),
18 }
19 })?;
20
21 if info.id.eq_ignore_ascii_case("pattern_recognition") {
22 return compute_pattern_recognition(req, info);
23 }
24
25 if !info.capabilities.supports_cpu_single {
26 return Err(IndicatorDispatchError::UnsupportedCapability {
27 indicator: info.id.to_string(),
28 capability: "cpu_single",
29 });
30 }
31
32 if !info.capabilities.supports_cpu_batch {
33 return Err(IndicatorDispatchError::UnsupportedCapability {
34 indicator: info.id.to_string(),
35 capability: "cpu_single",
36 });
37 }
38
39 let combos = [IndicatorParamSet { params: req.params }];
40 let out = compute_cpu_batch(IndicatorBatchRequest {
41 indicator_id: info.id,
42 output_id: req.output_id,
43 data: req.data,
44 combos: &combos,
45 kernel: req.kernel,
46 })?;
47 map_batch_output_to_compute(info.id, out)
48}
49
50fn compute_pattern_recognition(
51 req: IndicatorComputeRequest<'_>,
52 info: &IndicatorInfo,
53) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
54 if !info.capabilities.supports_cpu_single {
55 return Err(IndicatorDispatchError::UnsupportedCapability {
56 indicator: info.id.to_string(),
57 capability: "cpu_single",
58 });
59 }
60
61 if let Some(param) = req.params.first() {
62 return Err(IndicatorDispatchError::InvalidParam {
63 indicator: info.id.to_string(),
64 key: param.key.to_string(),
65 reason: "pattern_recognition does not accept parameters".to_string(),
66 });
67 }
68
69 let output_id = resolve_output_id(info, req.output_id)?;
70 let input = match req.data {
71 IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
72 candles,
73 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
74 ),
75 IndicatorDataRef::Ohlc {
76 open,
77 high,
78 low,
79 close,
80 } => PatternRecognitionInput::from_slices(
81 open,
82 high,
83 low,
84 close,
85 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
86 ),
87 IndicatorDataRef::Ohlcv {
88 open,
89 high,
90 low,
91 close,
92 ..
93 } => PatternRecognitionInput::from_slices(
94 open,
95 high,
96 low,
97 close,
98 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
99 ),
100 _ => {
101 return Err(IndicatorDispatchError::MissingRequiredInput {
102 indicator: info.id.to_string(),
103 input: IndicatorInputKind::Ohlc,
104 });
105 }
106 };
107
108 let out = pattern_recognition_with_kernel(&input, req.kernel)
109 .map_err(|e| map_pattern_error(info.id, e))?;
110
111 Ok(IndicatorComputeOutput {
112 output_id: output_id.to_string(),
113 series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
114 warmup: out.warmup,
115 rows: out.rows,
116 cols: out.cols,
117 pattern_ids: Some(
118 out.pattern_ids
119 .into_iter()
120 .map(|id| id.to_string())
121 .collect(),
122 ),
123 })
124}
125
126fn resolve_output_id<'a>(
127 info: &'a IndicatorInfo,
128 requested: Option<&str>,
129) -> Result<&'a str, IndicatorDispatchError> {
130 if info.outputs.is_empty() {
131 return Err(IndicatorDispatchError::ComputeFailed {
132 indicator: info.id.to_string(),
133 details: "indicator has no registered outputs".to_string(),
134 });
135 }
136
137 if info.outputs.len() == 1 {
138 let only = info.outputs[0].id;
139 if let Some(req) = requested {
140 if !req.eq_ignore_ascii_case(only) {
141 return Err(IndicatorDispatchError::UnknownOutput {
142 indicator: info.id.to_string(),
143 output: req.to_string(),
144 });
145 }
146 }
147 return Ok(only);
148 }
149
150 let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
151 indicator: info.id.to_string(),
152 key: "output_id".to_string(),
153 reason: "output_id is required for multi-output indicators".to_string(),
154 })?;
155
156 info.outputs
157 .iter()
158 .find(|o| o.id.eq_ignore_ascii_case(req))
159 .map(|o| o.id)
160 .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
161 indicator: info.id.to_string(),
162 output: req.to_string(),
163 })
164}
165
166fn map_batch_output_to_compute(
167 indicator: &str,
168 out: IndicatorBatchOutput,
169) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
170 let series = if let Some(values) = out.values_f64 {
171 IndicatorSeries::F64(values)
172 } else if let Some(values) = out.values_i32 {
173 IndicatorSeries::I32(values)
174 } else if let Some(values) = out.values_bool {
175 IndicatorSeries::Bool(values)
176 } else {
177 return Err(IndicatorDispatchError::ComputeFailed {
178 indicator: indicator.to_string(),
179 details: "dispatcher returned no output series".to_string(),
180 });
181 };
182
183 Ok(IndicatorComputeOutput {
184 output_id: out.output_id,
185 series,
186 warmup: None,
187 rows: out.rows,
188 cols: out.cols,
189 pattern_ids: None,
190 })
191}
192
193fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
194 match err {
195 PatternRecognitionError::DataLengthMismatch {
196 open,
197 high,
198 low,
199 close,
200 } => IndicatorDispatchError::DataLengthMismatch {
201 details: format!("open={} high={} low={} close={}", open, high, low, close),
202 },
203 PatternRecognitionError::OutputLengthMismatch {
204 pattern_id,
205 expected,
206 got,
207 } => IndicatorDispatchError::ComputeFailed {
208 indicator: indicator.to_string(),
209 details: format!(
210 "pattern output mismatch for {}: expected {}, got {}",
211 pattern_id, expected, got
212 ),
213 },
214 PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
215 indicator: indicator.to_string(),
216 details: e.to_string(),
217 },
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
225 use crate::indicators::pattern_recognition::list_patterns;
226 use crate::utilities::enums::Kernel;
227
228 fn sample_series(len: usize) -> Vec<f64> {
229 (0..len)
230 .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
231 .collect()
232 }
233
234 fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
235 let open = sample_series(len);
236 let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
237 let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
238 let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
239 (open, high, low, close)
240 }
241
242 #[test]
243 fn compute_cpu_pattern_recognition_returns_matrix() {
244 let (open, high, low, close) = sample_ohlc(192);
245 let req = IndicatorComputeRequest {
246 indicator_id: "pattern_recognition",
247 output_id: Some("matrix"),
248 data: IndicatorDataRef::Ohlc {
249 open: &open,
250 high: &high,
251 low: &low,
252 close: &close,
253 },
254 params: &[],
255 kernel: Kernel::Auto,
256 };
257 let out = compute_cpu(req).unwrap();
258 assert_eq!(out.output_id, "matrix");
259 assert_eq!(out.rows, list_patterns().len());
260 assert_eq!(out.cols, close.len());
261 match out.series {
262 IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
263 other => panic!("expected Bool matrix series, got {:?}", other),
264 }
265 let ids = out.pattern_ids.unwrap();
266 assert_eq!(ids.len(), out.rows);
267 }
268
269 #[test]
270 fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
271 let series = sample_series(64);
272 let req = IndicatorComputeRequest {
273 indicator_id: "pattern_recognition",
274 output_id: Some("matrix"),
275 data: IndicatorDataRef::Slice { values: &series },
276 params: &[],
277 kernel: Kernel::Auto,
278 };
279 let err = compute_cpu(req).unwrap_err();
280 match err {
281 IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
282 assert_eq!(indicator, "pattern_recognition");
283 assert_eq!(input, IndicatorInputKind::Ohlc);
284 }
285 other => panic!("expected MissingRequiredInput, got {:?}", other),
286 }
287 }
288
289 #[test]
290 fn compute_cpu_pattern_recognition_rejects_unknown_output() {
291 let (open, high, low, close) = sample_ohlc(64);
292 let req = IndicatorComputeRequest {
293 indicator_id: "pattern_recognition",
294 output_id: Some("value"),
295 data: IndicatorDataRef::Ohlc {
296 open: &open,
297 high: &high,
298 low: &low,
299 close: &close,
300 },
301 params: &[],
302 kernel: Kernel::Auto,
303 };
304 let err = compute_cpu(req).unwrap_err();
305 match err {
306 IndicatorDispatchError::UnknownOutput { indicator, output } => {
307 assert_eq!(indicator, "pattern_recognition");
308 assert_eq!(output, "value");
309 }
310 other => panic!("expected UnknownOutput, got {:?}", other),
311 }
312 }
313
314 #[test]
315 fn compute_cpu_pattern_recognition_rejects_params() {
316 let (open, high, low, close) = sample_ohlc(64);
317 let params = [ParamKV {
318 key: "period",
319 value: ParamValue::Int(14),
320 }];
321 let req = IndicatorComputeRequest {
322 indicator_id: "pattern_recognition",
323 output_id: Some("matrix"),
324 data: IndicatorDataRef::Ohlc {
325 open: &open,
326 high: &high,
327 low: &low,
328 close: &close,
329 },
330 params: ¶ms,
331 kernel: Kernel::Auto,
332 };
333 let err = compute_cpu(req).unwrap_err();
334 match err {
335 IndicatorDispatchError::InvalidParam {
336 indicator,
337 key,
338 reason,
339 } => {
340 assert_eq!(indicator, "pattern_recognition");
341 assert_eq!(key, "period");
342 assert!(reason.contains("does not accept parameters"));
343 }
344 other => panic!("expected InvalidParam, got {:?}", other),
345 }
346 }
347
348 #[test]
349 fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
350 let (open, high, low, close) = sample_ohlc(96);
351 let combos = [IndicatorParamSet { params: &[] }];
352 let err = compute_cpu_batch(IndicatorBatchRequest {
353 indicator_id: "pattern_recognition",
354 output_id: Some("matrix"),
355 data: IndicatorDataRef::Ohlc {
356 open: &open,
357 high: &high,
358 low: &low,
359 close: &close,
360 },
361 combos: &combos,
362 kernel: Kernel::Auto,
363 })
364 .unwrap_err();
365 match err {
366 IndicatorDispatchError::UnsupportedCapability {
367 indicator,
368 capability,
369 } => {
370 assert_eq!(indicator, "pattern_recognition");
371 assert_eq!(capability, "cpu_batch");
372 }
373 other => panic!("expected UnsupportedCapability, got {:?}", other),
374 }
375 }
376
377 #[test]
378 fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
379 let series = sample_series(200);
380 let params = [ParamKV {
381 key: "period",
382 value: ParamValue::Int(14),
383 }];
384 let req = IndicatorComputeRequest {
385 indicator_id: "sma",
386 output_id: Some("value"),
387 data: IndicatorDataRef::Slice { values: &series },
388 params: ¶ms,
389 kernel: Kernel::Auto,
390 };
391 let out = compute_cpu(req).unwrap();
392 assert_eq!(out.output_id, "value");
393 assert_eq!(out.rows, 1);
394 assert_eq!(out.cols, series.len());
395 match out.series {
396 IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
397 other => panic!("expected F64 series, got {:?}", other),
398 }
399 }
400}