1use super::{
2 compute_cpu_batch, IndicatorBatchOutput, IndicatorBatchRequest, IndicatorComputeOutput,
3 IndicatorComputeRequest, IndicatorDataRef, IndicatorDispatchError, IndicatorParamSet,
4 IndicatorSeries,
5};
6use crate::indicators::pattern_recognition::{
7 pattern_recognition_with_kernel, PatternRecognitionData, PatternRecognitionError,
8 PatternRecognitionInput,
9};
10use crate::indicators::registry::{get_indicator, IndicatorInfo, IndicatorInputKind};
11
12pub fn compute_cpu(
13 req: IndicatorComputeRequest<'_>,
14) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
15 let info = get_indicator(req.indicator_id);
16
17 if let Some(info) = info {
18 if info.id.eq_ignore_ascii_case("pattern_recognition") {
19 return compute_pattern_recognition(req, info);
20 }
21 }
22
23 let indicator_id = info.map_or(req.indicator_id, |info| info.id);
24
25 let combos = [IndicatorParamSet { params: req.params }];
26 let out = match compute_cpu_batch(IndicatorBatchRequest {
27 indicator_id,
28 output_id: req.output_id,
29 data: req.data,
30 combos: &combos,
31 kernel: req.kernel,
32 }) {
33 Err(IndicatorDispatchError::UnsupportedCapability { .. }) if info.is_none() => {
34 return Err(IndicatorDispatchError::UnknownIndicator {
35 id: req.indicator_id.to_string(),
36 });
37 }
38 other => other?,
39 };
40 map_batch_output_to_compute(indicator_id, out)
41}
42
43fn compute_pattern_recognition(
44 req: IndicatorComputeRequest<'_>,
45 info: &IndicatorInfo,
46) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
47 if !info.capabilities.supports_cpu_single {
48 return Err(IndicatorDispatchError::UnsupportedCapability {
49 indicator: info.id.to_string(),
50 capability: "cpu_single",
51 });
52 }
53
54 if let Some(param) = req.params.first() {
55 return Err(IndicatorDispatchError::InvalidParam {
56 indicator: info.id.to_string(),
57 key: param.key.to_string(),
58 reason: "pattern_recognition does not accept parameters".to_string(),
59 });
60 }
61
62 let output_id = resolve_output_id(info, req.output_id)?;
63 let input = match req.data {
64 IndicatorDataRef::Candles { candles, .. } => PatternRecognitionInput::from_candles(
65 candles,
66 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
67 ),
68 IndicatorDataRef::Ohlc {
69 open,
70 high,
71 low,
72 close,
73 } => PatternRecognitionInput::from_slices(
74 open,
75 high,
76 low,
77 close,
78 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
79 ),
80 IndicatorDataRef::Ohlcv {
81 open,
82 high,
83 low,
84 close,
85 ..
86 } => PatternRecognitionInput::from_slices(
87 open,
88 high,
89 low,
90 close,
91 crate::indicators::pattern_recognition::PatternRecognitionParams::default(),
92 ),
93 _ => {
94 return Err(IndicatorDispatchError::MissingRequiredInput {
95 indicator: info.id.to_string(),
96 input: IndicatorInputKind::Ohlc,
97 });
98 }
99 };
100
101 let out = pattern_recognition_with_kernel(&input, req.kernel)
102 .map_err(|e| map_pattern_error(info.id, e))?;
103
104 Ok(IndicatorComputeOutput {
105 output_id: output_id.to_string(),
106 series: IndicatorSeries::Bool(out.values_u8.into_iter().map(|v| v != 0).collect()),
107 warmup: out.warmup,
108 rows: out.rows,
109 cols: out.cols,
110 pattern_ids: Some(
111 out.pattern_ids
112 .into_iter()
113 .map(|id| id.to_string())
114 .collect(),
115 ),
116 })
117}
118
119fn normalize_output_token(value: &str) -> String {
120 let mut normalized = String::with_capacity(value.len());
121 for ch in value.chars() {
122 if ch.is_ascii_alphanumeric() {
123 normalized.push(ch.to_ascii_lowercase());
124 }
125 }
126 if normalized == "values" {
127 "value".to_string()
128 } else {
129 normalized
130 }
131}
132
133fn output_id_matches(candidate: &str, requested: &str) -> bool {
134 candidate.eq_ignore_ascii_case(requested)
135 || normalize_output_token(candidate) == normalize_output_token(requested)
136}
137
138fn resolve_output_id<'a>(
139 info: &'a IndicatorInfo,
140 requested: Option<&str>,
141) -> Result<&'a str, IndicatorDispatchError> {
142 if info.outputs.is_empty() {
143 return Err(IndicatorDispatchError::ComputeFailed {
144 indicator: info.id.to_string(),
145 details: "indicator has no registered outputs".to_string(),
146 });
147 }
148
149 if info.outputs.len() == 1 {
150 let only = info.outputs[0].id;
151 if let Some(req) = requested {
152 if !output_id_matches(only, req) {
153 return Err(IndicatorDispatchError::UnknownOutput {
154 indicator: info.id.to_string(),
155 output: req.to_string(),
156 });
157 }
158 }
159 return Ok(only);
160 }
161
162 let req = requested.ok_or_else(|| IndicatorDispatchError::InvalidParam {
163 indicator: info.id.to_string(),
164 key: "output_id".to_string(),
165 reason: "output_id is required for multi-output indicators".to_string(),
166 })?;
167
168 info.outputs
169 .iter()
170 .find(|o| output_id_matches(o.id, req))
171 .map(|o| o.id)
172 .ok_or_else(|| IndicatorDispatchError::UnknownOutput {
173 indicator: info.id.to_string(),
174 output: req.to_string(),
175 })
176}
177
178fn map_batch_output_to_compute(
179 indicator: &str,
180 out: IndicatorBatchOutput,
181) -> Result<IndicatorComputeOutput, IndicatorDispatchError> {
182 let series = if let Some(values) = out.values_f64 {
183 IndicatorSeries::F64(values)
184 } else if let Some(values) = out.values_i32 {
185 IndicatorSeries::I32(values)
186 } else if let Some(values) = out.values_bool {
187 IndicatorSeries::Bool(values)
188 } else {
189 return Err(IndicatorDispatchError::ComputeFailed {
190 indicator: indicator.to_string(),
191 details: "dispatcher returned no output series".to_string(),
192 });
193 };
194
195 Ok(IndicatorComputeOutput {
196 output_id: out.output_id,
197 series,
198 warmup: None,
199 rows: out.rows,
200 cols: out.cols,
201 pattern_ids: None,
202 })
203}
204
205fn map_pattern_error(indicator: &str, err: PatternRecognitionError) -> IndicatorDispatchError {
206 match err {
207 PatternRecognitionError::DataLengthMismatch {
208 open,
209 high,
210 low,
211 close,
212 } => IndicatorDispatchError::DataLengthMismatch {
213 details: format!("open={} high={} low={} close={}", open, high, low, close),
214 },
215 PatternRecognitionError::OutputLengthMismatch {
216 pattern_id,
217 expected,
218 got,
219 } => IndicatorDispatchError::ComputeFailed {
220 indicator: indicator.to_string(),
221 details: format!(
222 "pattern output mismatch for {}: expected {}, got {}",
223 pattern_id, expected, got
224 ),
225 },
226 PatternRecognitionError::Pattern(e) => IndicatorDispatchError::ComputeFailed {
227 indicator: indicator.to_string(),
228 details: e.to_string(),
229 },
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::indicators::dispatch::{compute_cpu_batch, ParamKV, ParamValue};
237 use crate::indicators::pattern_recognition::list_patterns;
238 use crate::utilities::enums::Kernel;
239
240 fn sample_series(len: usize) -> Vec<f64> {
241 (0..len)
242 .map(|i| 100.0 + ((i as f64) * 0.01).sin() + ((i as f64) * 0.0005).cos())
243 .collect()
244 }
245
246 fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
247 let open = sample_series(len);
248 let high: Vec<f64> = open.iter().map(|v| v + 1.0).collect();
249 let low: Vec<f64> = open.iter().map(|v| v - 1.0).collect();
250 let close: Vec<f64> = open.iter().map(|v| v + 0.25).collect();
251 (open, high, low, close)
252 }
253
254 #[test]
255 fn compute_cpu_pattern_recognition_returns_matrix() {
256 let (open, high, low, close) = sample_ohlc(192);
257 let req = IndicatorComputeRequest {
258 indicator_id: "pattern_recognition",
259 output_id: Some("matrix"),
260 data: IndicatorDataRef::Ohlc {
261 open: &open,
262 high: &high,
263 low: &low,
264 close: &close,
265 },
266 params: &[],
267 kernel: Kernel::Auto,
268 };
269 let out = compute_cpu(req).unwrap();
270 assert_eq!(out.output_id, "matrix");
271 assert_eq!(out.rows, list_patterns().len());
272 assert_eq!(out.cols, close.len());
273 match out.series {
274 IndicatorSeries::Bool(v) => assert_eq!(v.len(), out.rows * out.cols),
275 other => panic!("expected Bool matrix series, got {:?}", other),
276 }
277 let ids = out.pattern_ids.unwrap();
278 assert_eq!(ids.len(), out.rows);
279 }
280
281 #[test]
282 fn compute_cpu_pattern_recognition_rejects_missing_input_shape() {
283 let series = sample_series(64);
284 let req = IndicatorComputeRequest {
285 indicator_id: "pattern_recognition",
286 output_id: Some("matrix"),
287 data: IndicatorDataRef::Slice { values: &series },
288 params: &[],
289 kernel: Kernel::Auto,
290 };
291 let err = compute_cpu(req).unwrap_err();
292 match err {
293 IndicatorDispatchError::MissingRequiredInput { indicator, input } => {
294 assert_eq!(indicator, "pattern_recognition");
295 assert_eq!(input, IndicatorInputKind::Ohlc);
296 }
297 other => panic!("expected MissingRequiredInput, got {:?}", other),
298 }
299 }
300
301 #[test]
302 fn compute_cpu_pattern_recognition_rejects_unknown_output() {
303 let (open, high, low, close) = sample_ohlc(64);
304 let req = IndicatorComputeRequest {
305 indicator_id: "pattern_recognition",
306 output_id: Some("value"),
307 data: IndicatorDataRef::Ohlc {
308 open: &open,
309 high: &high,
310 low: &low,
311 close: &close,
312 },
313 params: &[],
314 kernel: Kernel::Auto,
315 };
316 let err = compute_cpu(req).unwrap_err();
317 match err {
318 IndicatorDispatchError::UnknownOutput { indicator, output } => {
319 assert_eq!(indicator, "pattern_recognition");
320 assert_eq!(output, "value");
321 }
322 other => panic!("expected UnknownOutput, got {:?}", other),
323 }
324 }
325
326 #[test]
327 fn compute_cpu_pattern_recognition_rejects_params() {
328 let (open, high, low, close) = sample_ohlc(64);
329 let params = [ParamKV {
330 key: "period",
331 value: ParamValue::Int(14),
332 }];
333 let req = IndicatorComputeRequest {
334 indicator_id: "pattern_recognition",
335 output_id: Some("matrix"),
336 data: IndicatorDataRef::Ohlc {
337 open: &open,
338 high: &high,
339 low: &low,
340 close: &close,
341 },
342 params: ¶ms,
343 kernel: Kernel::Auto,
344 };
345 let err = compute_cpu(req).unwrap_err();
346 match err {
347 IndicatorDispatchError::InvalidParam {
348 indicator,
349 key,
350 reason,
351 } => {
352 assert_eq!(indicator, "pattern_recognition");
353 assert_eq!(key, "period");
354 assert!(reason.contains("does not accept parameters"));
355 }
356 other => panic!("expected InvalidParam, got {:?}", other),
357 }
358 }
359
360 #[test]
361 fn pattern_recognition_batch_mode_is_explicitly_unsupported() {
362 let (open, high, low, close) = sample_ohlc(96);
363 let combos = [IndicatorParamSet { params: &[] }];
364 let err = compute_cpu_batch(IndicatorBatchRequest {
365 indicator_id: "pattern_recognition",
366 output_id: Some("matrix"),
367 data: IndicatorDataRef::Ohlc {
368 open: &open,
369 high: &high,
370 low: &low,
371 close: &close,
372 },
373 combos: &combos,
374 kernel: Kernel::Auto,
375 })
376 .unwrap_err();
377 match err {
378 IndicatorDispatchError::UnsupportedCapability {
379 indicator,
380 capability,
381 } => {
382 assert_eq!(indicator, "pattern_recognition");
383 assert_eq!(capability, "cpu_batch");
384 }
385 other => panic!("expected UnsupportedCapability, got {:?}", other),
386 }
387 }
388
389 #[test]
390 fn compute_cpu_for_sma_delegates_to_batch_dispatch() {
391 let series = sample_series(200);
392 let params = [ParamKV {
393 key: "period",
394 value: ParamValue::Int(14),
395 }];
396 let req = IndicatorComputeRequest {
397 indicator_id: "sma",
398 output_id: Some("value"),
399 data: IndicatorDataRef::Slice { values: &series },
400 params: ¶ms,
401 kernel: Kernel::Auto,
402 };
403 let out = compute_cpu(req).unwrap();
404 assert_eq!(out.output_id, "value");
405 assert_eq!(out.rows, 1);
406 assert_eq!(out.cols, series.len());
407 match out.series {
408 IndicatorSeries::F64(v) => assert_eq!(v.len(), series.len()),
409 other => panic!("expected F64 series, got {:?}", other),
410 }
411 }
412}