1use arrow_ipc::reader::StreamReader;
8use arrow_schema::{DataType, Schema};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::io::Cursor;
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ColumnInfo {
17 pub name: String,
18 pub data_type: String,
19}
20
21#[derive(Debug, Clone, PartialEq)]
23enum ChartType {
24 Candlestick,
25 Line,
26 Bar,
27 Scatter,
28 TableOnly,
29}
30
31pub fn extract_columns(ipc_bytes: &[u8]) -> Vec<ColumnInfo> {
33 let schema = match read_schema(ipc_bytes) {
34 Some(s) => s,
35 None => return vec![],
36 };
37
38 schema
39 .fields()
40 .iter()
41 .map(|f| ColumnInfo {
42 name: f.name().clone(),
43 data_type: format_arrow_type(f.data_type()),
44 })
45 .collect()
46}
47
48pub fn detect_chart(ipc_bytes: &[u8]) -> Option<Value> {
50 if ipc_bytes.is_empty() {
51 return None;
52 }
53
54 let (schema, data) = read_schema_and_data(ipc_bytes)?;
55 let chart_type = detect_chart_type(&schema);
56
57 if chart_type == ChartType::TableOnly {
58 return None;
59 }
60
61 Some(build_echart_option(&chart_type, &schema, &data))
62}
63
64fn read_schema(ipc_bytes: &[u8]) -> Option<Arc<Schema>> {
66 let cursor = Cursor::new(ipc_bytes);
67 let reader = StreamReader::try_new(cursor, None).ok()?;
68 Some(reader.schema().clone())
69}
70
71fn read_schema_and_data(ipc_bytes: &[u8]) -> Option<(Arc<Schema>, Vec<Vec<Value>>)> {
73 let cursor = Cursor::new(ipc_bytes);
74 let reader = StreamReader::try_new(cursor, None).ok()?;
75 let schema = reader.schema().clone();
76 let num_cols = schema.fields().len();
77
78 let mut columns: Vec<Vec<Value>> = vec![vec![]; num_cols];
80
81 for batch_result in reader {
82 let batch = batch_result.ok()?;
83 for col_idx in 0..num_cols {
84 let array = batch.column(col_idx);
85 for row_idx in 0..batch.num_rows() {
86 let val = arrow_value_to_json(array, row_idx);
87 columns[col_idx].push(val);
88 }
89 }
90 }
91
92 Some((schema, columns))
93}
94
95fn detect_chart_type(schema: &Schema) -> ChartType {
97 let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
98
99 let has_ohlc = ["open", "high", "low", "close"]
101 .iter()
102 .all(|name| field_names.iter().any(|f| f.eq_ignore_ascii_case(name)));
103
104 if has_ohlc {
105 return ChartType::Candlestick;
106 }
107
108 let mut has_timestamp = false;
110 let mut numeric_count = 0;
111 let mut string_count = 0;
112
113 for field in schema.fields() {
114 match field.data_type() {
115 DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64 => {
116 has_timestamp = true;
117 }
118 DataType::Float16
119 | DataType::Float32
120 | DataType::Float64
121 | DataType::Int8
122 | DataType::Int16
123 | DataType::Int32
124 | DataType::Int64
125 | DataType::UInt8
126 | DataType::UInt16
127 | DataType::UInt32
128 | DataType::UInt64 => {
129 numeric_count += 1;
130 }
131 DataType::Utf8 | DataType::LargeUtf8 => {
132 string_count += 1;
133 }
134 _ => {}
135 }
136 }
137
138 if has_timestamp && numeric_count >= 1 {
140 return ChartType::Line;
141 }
142
143 if string_count >= 1 && numeric_count >= 1 {
145 return ChartType::Bar;
146 }
147
148 if numeric_count >= 2 {
150 return ChartType::Scatter;
151 }
152
153 ChartType::TableOnly
154}
155
156fn build_echart_option(chart_type: &ChartType, schema: &Schema, columns: &[Vec<Value>]) -> Value {
158 match chart_type {
159 ChartType::Candlestick => build_candlestick(schema, columns),
160 ChartType::Line => build_line(schema, columns),
161 ChartType::Bar => build_bar(schema, columns),
162 ChartType::Scatter => build_scatter(schema, columns),
163 ChartType::TableOnly => json!(null),
164 }
165}
166
167fn build_candlestick(schema: &Schema, columns: &[Vec<Value>]) -> Value {
168 let find_col = |name: &str| -> Option<usize> {
169 schema
170 .fields()
171 .iter()
172 .position(|f| f.name().eq_ignore_ascii_case(name))
173 };
174
175 let open_idx = find_col("open").unwrap_or(0);
176 let close_idx = find_col("close").unwrap_or(1);
177 let low_idx = find_col("low").unwrap_or(2);
178 let high_idx = find_col("high").unwrap_or(3);
179
180 let x_idx = schema
182 .fields()
183 .iter()
184 .position(|f| {
185 matches!(
186 f.data_type(),
187 DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64
188 )
189 })
190 .or_else(|| find_col("timestamp"))
191 .or_else(|| find_col("date"));
192
193 let row_count = columns.first().map(|c| c.len()).unwrap_or(0);
194
195 let x_data: Vec<Value> = if let Some(xi) = x_idx {
196 columns[xi].clone()
197 } else {
198 (0..row_count).map(|i| json!(i)).collect()
199 };
200
201 let ohlc_data: Vec<Value> = (0..row_count)
203 .map(|i| {
204 json!([
205 columns[open_idx].get(i).unwrap_or(&json!(0)),
206 columns[close_idx].get(i).unwrap_or(&json!(0)),
207 columns[low_idx].get(i).unwrap_or(&json!(0)),
208 columns[high_idx].get(i).unwrap_or(&json!(0)),
209 ])
210 })
211 .collect();
212
213 json!({
214 "xAxis": {
215 "type": "category",
216 "data": x_data,
217 "axisLine": { "lineStyle": { "color": "#8392A5" } }
218 },
219 "yAxis": {
220 "scale": true,
221 "splitArea": { "show": true }
222 },
223 "series": [{
224 "type": "candlestick",
225 "data": ohlc_data,
226 "itemStyle": {
227 "color": "#26a69a",
228 "color0": "#ef5350",
229 "borderColor": "#26a69a",
230 "borderColor0": "#ef5350"
231 }
232 }],
233 "tooltip": { "trigger": "axis", "axisPointer": { "type": "cross" } },
234 "dataZoom": [
235 { "type": "inside", "start": 0, "end": 100 },
236 { "type": "slider", "start": 0, "end": 100 }
237 ],
238 "grid": { "left": "10%", "right": "10%", "bottom": "15%" }
239 })
240}
241
242fn build_line(schema: &Schema, columns: &[Vec<Value>]) -> Value {
243 let x_idx = schema
245 .fields()
246 .iter()
247 .position(|f| {
248 matches!(
249 f.data_type(),
250 DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64
251 )
252 })
253 .unwrap_or(0);
254
255 let row_count = columns.first().map(|c| c.len()).unwrap_or(0);
256 let x_data: Vec<Value> = columns.get(x_idx).cloned().unwrap_or_default();
257
258 let mut series = Vec::new();
260 for (i, field) in schema.fields().iter().enumerate() {
261 if i == x_idx {
262 continue;
263 }
264 if is_numeric_type(field.data_type()) {
265 let data: Vec<Value> = columns.get(i).cloned().unwrap_or_default();
266 series.push(json!({
267 "name": field.name(),
268 "type": "line",
269 "data": data,
270 "sampling": "lttb",
271 "smooth": false,
272 "symbol": if row_count > 100 { "none" } else { "circle" },
273 }));
274 }
275 }
276
277 json!({
278 "xAxis": {
279 "type": "category",
280 "data": x_data,
281 "axisLine": { "lineStyle": { "color": "#8392A5" } }
282 },
283 "yAxis": { "type": "value", "scale": true },
284 "series": series,
285 "tooltip": { "trigger": "axis" },
286 "legend": { "show": series.len() > 1 },
287 "dataZoom": [
288 { "type": "inside", "start": 0, "end": 100 },
289 { "type": "slider", "start": 0, "end": 100 }
290 ],
291 "grid": { "left": "10%", "right": "10%", "bottom": "15%" }
292 })
293}
294
295fn build_bar(schema: &Schema, columns: &[Vec<Value>]) -> Value {
296 let cat_idx = schema
298 .fields()
299 .iter()
300 .position(|f| matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8))
301 .unwrap_or(0);
302
303 let categories: Vec<Value> = columns.get(cat_idx).cloned().unwrap_or_default();
304
305 let mut series = Vec::new();
306 for (i, field) in schema.fields().iter().enumerate() {
307 if i == cat_idx {
308 continue;
309 }
310 if is_numeric_type(field.data_type()) {
311 let data: Vec<Value> = columns.get(i).cloned().unwrap_or_default();
312 series.push(json!({
313 "name": field.name(),
314 "type": "bar",
315 "data": data,
316 }));
317 }
318 }
319
320 json!({
321 "xAxis": { "type": "category", "data": categories },
322 "yAxis": { "type": "value" },
323 "series": series,
324 "tooltip": { "trigger": "axis" },
325 "legend": { "show": series.len() > 1 },
326 "grid": { "left": "10%", "right": "10%", "bottom": "10%" }
327 })
328}
329
330fn build_scatter(schema: &Schema, columns: &[Vec<Value>]) -> Value {
331 let numeric_indices: Vec<usize> = schema
333 .fields()
334 .iter()
335 .enumerate()
336 .filter(|(_, f)| is_numeric_type(f.data_type()))
337 .map(|(i, _)| i)
338 .collect();
339
340 let x_idx = numeric_indices.first().copied().unwrap_or(0);
341 let y_idx = numeric_indices.get(1).copied().unwrap_or(1);
342
343 let row_count = columns.first().map(|c| c.len()).unwrap_or(0);
344 let scatter_data: Vec<Value> = (0..row_count)
345 .map(|i| {
346 json!([
347 columns
348 .get(x_idx)
349 .and_then(|c| c.get(i))
350 .unwrap_or(&json!(0)),
351 columns
352 .get(y_idx)
353 .and_then(|c| c.get(i))
354 .unwrap_or(&json!(0)),
355 ])
356 })
357 .collect();
358
359 let x_name = schema
360 .fields()
361 .get(x_idx)
362 .map(|f| f.name().as_str())
363 .unwrap_or("x");
364 let y_name = schema
365 .fields()
366 .get(y_idx)
367 .map(|f| f.name().as_str())
368 .unwrap_or("y");
369
370 json!({
371 "xAxis": { "type": "value", "name": x_name, "scale": true },
372 "yAxis": { "type": "value", "name": y_name, "scale": true },
373 "series": [{
374 "type": "scatter",
375 "data": scatter_data,
376 "symbolSize": 5,
377 }],
378 "tooltip": { "trigger": "item" },
379 "grid": { "left": "10%", "right": "10%", "bottom": "10%" }
380 })
381}
382
383fn is_numeric_type(dt: &DataType) -> bool {
388 matches!(
389 dt,
390 DataType::Float16
391 | DataType::Float32
392 | DataType::Float64
393 | DataType::Int8
394 | DataType::Int16
395 | DataType::Int32
396 | DataType::Int64
397 | DataType::UInt8
398 | DataType::UInt16
399 | DataType::UInt32
400 | DataType::UInt64
401 )
402}
403
404fn format_arrow_type(dt: &DataType) -> String {
405 match dt {
406 DataType::Float32 | DataType::Float64 | DataType::Float16 => "Number".to_string(),
407 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
408 "Integer".to_string()
409 }
410 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
411 "Integer".to_string()
412 }
413 DataType::Utf8 | DataType::LargeUtf8 => "String".to_string(),
414 DataType::Boolean => "Bool".to_string(),
415 DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64 => "Timestamp".to_string(),
416 other => format!("{:?}", other),
417 }
418}
419
420fn arrow_value_to_json(array: &dyn arrow_array::Array, idx: usize) -> Value {
422 use arrow_array::*;
423
424 if array.is_null(idx) {
425 return Value::Null;
426 }
427
428 if let Some(a) = array.as_any().downcast_ref::<Float64Array>() {
429 return json!(a.value(idx));
430 }
431 if let Some(a) = array.as_any().downcast_ref::<Float32Array>() {
432 return json!(a.value(idx) as f64);
433 }
434 if let Some(a) = array.as_any().downcast_ref::<Int64Array>() {
435 return json!(a.value(idx));
436 }
437 if let Some(a) = array.as_any().downcast_ref::<Int32Array>() {
438 return json!(a.value(idx));
439 }
440 if let Some(a) = array.as_any().downcast_ref::<UInt64Array>() {
441 return json!(a.value(idx));
442 }
443 if let Some(a) = array.as_any().downcast_ref::<UInt32Array>() {
444 return json!(a.value(idx));
445 }
446 if let Some(a) = array.as_any().downcast_ref::<StringArray>() {
447 return json!(a.value(idx));
448 }
449 if let Some(a) = array.as_any().downcast_ref::<BooleanArray>() {
450 return json!(a.value(idx));
451 }
452 if let Some(a) = array.as_any().downcast_ref::<TimestampMillisecondArray>() {
453 return json!(a.value(idx));
454 }
455 if let Some(a) = array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
456 return json!(a.value(idx) / 1000); }
458 if let Some(a) = array.as_any().downcast_ref::<TimestampNanosecondArray>() {
459 return json!(a.value(idx) / 1_000_000); }
461 if let Some(a) = array.as_any().downcast_ref::<Date32Array>() {
462 return json!(a.value(idx));
463 }
464 if let Some(a) = array.as_any().downcast_ref::<Date64Array>() {
465 return json!(a.value(idx));
466 }
467
468 json!(null)
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_detect_chart_type_ohlc() {
478 let schema = Schema::new(vec![
479 arrow_schema::Field::new(
480 "timestamp",
481 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
482 false,
483 ),
484 arrow_schema::Field::new("open", DataType::Float64, false),
485 arrow_schema::Field::new("high", DataType::Float64, false),
486 arrow_schema::Field::new("low", DataType::Float64, false),
487 arrow_schema::Field::new("close", DataType::Float64, false),
488 arrow_schema::Field::new("volume", DataType::Float64, false),
489 ]);
490 assert_eq!(detect_chart_type(&schema), ChartType::Candlestick);
491 }
492
493 #[test]
494 fn test_detect_chart_type_line() {
495 let schema = Schema::new(vec![
496 arrow_schema::Field::new(
497 "time",
498 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
499 false,
500 ),
501 arrow_schema::Field::new("value", DataType::Float64, false),
502 ]);
503 assert_eq!(detect_chart_type(&schema), ChartType::Line);
504 }
505
506 #[test]
507 fn test_detect_chart_type_bar() {
508 let schema = Schema::new(vec![
509 arrow_schema::Field::new("category", DataType::Utf8, false),
510 arrow_schema::Field::new("count", DataType::Int64, false),
511 ]);
512 assert_eq!(detect_chart_type(&schema), ChartType::Bar);
513 }
514
515 #[test]
516 fn test_detect_chart_type_scatter() {
517 let schema = Schema::new(vec![
518 arrow_schema::Field::new("x", DataType::Float64, false),
519 arrow_schema::Field::new("y", DataType::Float64, false),
520 ]);
521 assert_eq!(detect_chart_type(&schema), ChartType::Scatter);
522 }
523
524 #[test]
525 fn test_extract_columns_empty() {
526 let cols = extract_columns(&[]);
527 assert!(cols.is_empty());
528 }
529
530 #[test]
531 fn test_detect_chart_empty() {
532 assert!(detect_chart(&[]).is_none());
533 }
534
535 #[test]
536 fn test_format_arrow_type() {
537 assert_eq!(format_arrow_type(&DataType::Float64), "Number");
538 assert_eq!(format_arrow_type(&DataType::Int64), "Integer");
539 assert_eq!(format_arrow_type(&DataType::Utf8), "String");
540 assert_eq!(format_arrow_type(&DataType::Boolean), "Bool");
541 }
542}