Skip to main content

shape_viz_core/
chart.rs

1//! Main chart API and configuration
2
3use crate::data::ChartData;
4use crate::error::{ChartError, Result};
5use crate::layers::{Layer, LayerBuilder, LayerManager};
6use crate::renderer::{GpuRenderer, RenderContext};
7use crate::style::ChartStyle;
8use crate::theme::ChartTheme;
9use crate::viewport::{ChartBounds, Rect, Viewport};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12
13/// Configuration for chart creation and behavior
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ChartConfig {
16    /// Chart dimensions in pixels
17    pub width: u32,
18    pub height: u32,
19
20    /// Chart theme
21    pub theme: ChartTheme,
22
23    /// Chart styling parameters
24    pub style: ChartStyle,
25
26    /// Auto-fit data to viewport
27    pub auto_fit: bool,
28
29    /// Enable anti-aliasing
30    pub anti_aliasing: bool,
31
32    /// Maximum frames per second for animations
33    pub max_fps: f32,
34
35    /// Padding around chart content (percentage of viewport)
36    pub padding: f32,
37
38    /// Enable GPU acceleration
39    pub gpu_acceleration: bool,
40}
41
42impl Default for ChartConfig {
43    fn default() -> Self {
44        Self {
45            width: 800,
46            height: 600,
47            theme: ChartTheme::default(),
48            style: crate::style::ChartStyle::default(),
49            auto_fit: true,
50            anti_aliasing: true,
51            max_fps: 60.0,
52            padding: 0.02,
53            gpu_acceleration: true,
54        }
55    }
56}
57
58/// Main chart instance that manages rendering and data
59pub struct Chart {
60    config: ChartConfig,
61    renderer: Option<GpuRenderer>,
62    render_context: Option<RenderContext>,
63    layer_manager: LayerManager,
64    data: Option<ChartData>,
65    viewport: Viewport,
66    dirty: bool,
67}
68
69impl Chart {
70    /// Create a new chart with the given configuration
71    pub async fn new(config: ChartConfig) -> Result<Self> {
72        // Create viewport that represents the full rendering area
73        // The chart content area will be inset to make room for axes
74        let full_rect = Rect::new(0.0, 0.0, config.width as f32, config.height as f32);
75
76        let default_bounds = ChartBounds::new(
77            Utc::now() - chrono::Duration::hours(24),
78            Utc::now(),
79            0.0,
80            100.0,
81        )?;
82        let viewport = Viewport::new(full_rect, default_bounds, config.style.layout.clone());
83
84        // Initialize GPU renderer if enabled
85        let renderer = if config.gpu_acceleration {
86            Some(GpuRenderer::new_offscreen(config.width, config.height).await?)
87        } else {
88            None
89        };
90
91        // Create render context
92        let render_context = if let Some(ref renderer) = renderer {
93            let (device, queue) = renderer.device_and_queue();
94            Some(RenderContext::new(
95                device,
96                queue,
97                viewport.clone(),
98                config.theme.clone(),
99                config.style.clone(),
100            ))
101        } else {
102            None
103        };
104
105        Ok(Self {
106            config,
107            renderer,
108            render_context,
109            layer_manager: LayerManager::new(),
110            data: None,
111            viewport,
112            dirty: true,
113        })
114    }
115
116    /// Create a basic financial chart with default layers
117    pub async fn new_financial(config: ChartConfig) -> Result<Self> {
118        let mut chart = Self::new(config).await?;
119        chart.layer_manager = LayerManager::basic_financial_chart();
120        chart.dirty = true;
121        Ok(chart)
122    }
123
124    /// Set chart data
125    pub fn set_data(&mut self, data: ChartData) -> Result<()> {
126        // Auto-fit viewport to data if enabled
127        if self.config.auto_fit {
128            if let Some(time_range) = data.time_range() {
129                if let Some((min_price, max_price)) = data.y_bounds() {
130                    // Just use tight bounds without forcing to nice numbers
131                    let price_range = max_price - min_price;
132                    let padding = price_range * self.config.padding as f64;
133
134                    // Simple padded bounds
135                    let padded_min = min_price - padding;
136                    let padded_max = max_price + padding;
137
138                    let chart_bounds =
139                        ChartBounds::new(time_range.start, time_range.end, padded_min, padded_max)?;
140
141                    self.viewport.set_chart_bounds(chart_bounds);
142                }
143            }
144        }
145
146        self.data = Some(data);
147        self.dirty = true;
148        Ok(())
149    }
150
151    /// Get current chart data
152    pub fn data(&self) -> Option<&ChartData> {
153        self.data.as_ref()
154    }
155
156    /// Add a layer to the chart
157    pub fn add_layer(&mut self, layer: Box<dyn Layer>) {
158        self.layer_manager.add_layer(layer);
159        self.dirty = true;
160    }
161
162    /// Remove a layer by name
163    pub fn remove_layer(&mut self, name: &str) -> Option<Box<dyn Layer>> {
164        self.dirty = true;
165        self.layer_manager.remove_layer(name)
166    }
167
168    /// Get a mutable reference to a layer
169    pub fn get_layer_mut(&mut self, name: &str) -> Option<&mut (dyn Layer + '_)> {
170        self.layer_manager.get_layer_mut(name)
171    }
172
173    /// Enable or disable a layer
174    pub fn set_layer_enabled(&mut self, name: &str, enabled: bool) -> bool {
175        let result = self.layer_manager.set_layer_enabled(name, enabled);
176        if result {
177            self.dirty = true;
178        }
179        result
180    }
181
182    /// Get current viewport
183    pub fn viewport(&self) -> &Viewport {
184        &self.viewport
185    }
186
187    /// Set viewport bounds
188    pub fn set_viewport_bounds(&mut self, bounds: ChartBounds) {
189        self.viewport.set_chart_bounds(bounds);
190        self.dirty = true;
191    }
192
193    /// Pan the viewport by screen pixels
194    pub fn pan(&mut self, delta_x: f32, delta_y: f32) {
195        self.viewport.pan(glam::Vec2::new(delta_x, delta_y));
196        self.dirty = true;
197    }
198
199    /// Zoom the viewport around a center point
200    pub fn zoom(&mut self, center_x: f32, center_y: f32, zoom_factor: f32) {
201        self.viewport
202            .zoom(glam::Vec2::new(center_x, center_y), zoom_factor);
203        self.dirty = true;
204    }
205
206    /// Reset viewport to fit all data
207    pub fn fit_to_data(&mut self) -> Result<()> {
208        if let Some(ref data) = self.data {
209            if let Some(time_range) = data.time_range() {
210                if let Some((min_price, max_price)) = data.y_bounds() {
211                    let price_range = max_price - min_price;
212                    let price_padding = price_range * self.config.padding as f64;
213
214                    let chart_bounds = ChartBounds::new(
215                        time_range.start,
216                        time_range.end,
217                        min_price - price_padding,
218                        max_price + price_padding,
219                    )?;
220
221                    self.viewport.set_chart_bounds(chart_bounds);
222                    self.dirty = true;
223                }
224            }
225        }
226        Ok(())
227    }
228
229    /// Update chart configuration
230    pub fn set_config(&mut self, config: ChartConfig) {
231        self.viewport.set_layout_style(config.style.layout.clone());
232        self.config = config;
233        self.dirty = true;
234    }
235
236    /// Get current configuration
237    pub fn config(&self) -> &ChartConfig {
238        &self.config
239    }
240
241    /// Set chart theme
242    pub fn set_theme(&mut self, theme: ChartTheme) {
243        self.config.theme = theme;
244        self.dirty = true;
245    }
246
247    /// Check if chart needs to be re-rendered
248    pub fn needs_render(&self) -> bool {
249        self.dirty || self.layer_manager.needs_render()
250    }
251
252    /// Render the chart and return RGBA image data
253    pub async fn render(&mut self) -> Result<Vec<u8>> {
254        // Check if we have necessary components
255        let renderer = self
256            .renderer
257            .as_ref()
258            .ok_or_else(|| ChartError::internal("No renderer available"))?;
259
260        let render_context = self
261            .render_context
262            .as_mut()
263            .ok_or_else(|| ChartError::internal("No render context available"))?;
264
265        // Update render context
266        render_context.update(
267            self.viewport.clone(),
268            self.config.theme.clone(),
269            self.config.style.clone(),
270        );
271
272        // Update all layers if we have data
273        if let Some(ref data) = self.data {
274            self.layer_manager.update_all(
275                data,
276                &self.viewport,
277                &self.config.theme,
278                &self.config.style,
279            );
280        }
281
282        // Clear previous frame
283        render_context.clear();
284
285        // Execute GPU render
286        let image_data = renderer
287            .render(
288                render_context,
289                &self.layer_manager,
290                self.config.theme.colors.background,
291            )
292            .await?;
293
294        // Mark as clean
295        self.dirty = false;
296        self.layer_manager.mark_clean();
297
298        Ok(image_data)
299    }
300
301    /// Get chart dimensions
302    pub fn dimensions(&self) -> (u32, u32) {
303        (self.config.width, self.config.height)
304    }
305
306    /// Resize the chart
307    pub async fn resize(&mut self, width: u32, height: u32) -> Result<()> {
308        if width != self.config.width || height != self.config.height {
309            self.config.width = width;
310            self.config.height = height;
311
312            // Update viewport screen rect to full size
313            let full_rect = Rect::new(0.0, 0.0, width as f32, height as f32);
314            self.viewport.set_screen_rect(full_rect);
315
316            // Recreate renderer with new dimensions
317            if self.config.gpu_acceleration {
318                self.renderer = Some(GpuRenderer::new_offscreen(width, height).await?);
319
320                // Update render context
321                if let Some(ref renderer) = self.renderer {
322                    let (device, queue) = renderer.device_and_queue();
323                    self.render_context = Some(RenderContext::new(
324                        device,
325                        queue,
326                        self.viewport.clone(),
327                        self.config.theme.clone(),
328                        self.config.style.clone(),
329                    ));
330                }
331            }
332
333            self.dirty = true;
334        }
335        Ok(())
336    }
337
338    /// Convert screen coordinates to chart coordinates
339    pub fn screen_to_chart(&self, screen_x: f32, screen_y: f32) -> glam::Vec2 {
340        self.viewport
341            .screen_to_chart(glam::Vec2::new(screen_x, screen_y))
342    }
343
344    /// Convert chart coordinates to screen coordinates
345    pub fn chart_to_screen(&self, chart_x: f32, chart_y: f32) -> glam::Vec2 {
346        self.viewport
347            .chart_to_screen(glam::Vec2::new(chart_x, chart_y))
348    }
349
350    /// Find data point at screen coordinates
351    /// Returns: (index, timestamp, start, max, min, end, auxiliary)
352    pub fn hit_test(
353        &self,
354        screen_x: f32,
355        screen_y: f32,
356    ) -> Option<(usize, DateTime<Utc>, f64, f64, f64, f64, f64)> {
357        let data = self.data.as_ref()?;
358        let chart_pos = self.screen_to_chart(screen_x, screen_y);
359
360        // Convert chart X coordinate (timestamp) to find nearest data point
361        let index = data.main_series.find_index(chart_pos.x as f64)?;
362        let time_val = data.main_series.get_x(index);
363        let time = DateTime::from_timestamp(time_val as i64, 0)?;
364
365        // Use Y value for simple series
366        let val = data.main_series.get_y(index);
367        Some((index, time, val, val, val, val, 0.0))
368    }
369
370    /// Get visible time range as timestamps
371    pub fn visible_time_range(&self) -> (i64, i64) {
372        self.viewport.visible_time_range()
373    }
374
375    /// Get visible price range
376    pub fn visible_price_range(&self) -> (f64, f64) {
377        self.viewport.visible_price_range()
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::data::{ChartData, RangeSeries, Series};
385    use chrono::TimeZone;
386    use std::any::Any;
387
388    fn should_skip_gpu(err: &ChartError) -> bool {
389        matches!(err, ChartError::Internal(message)
390            if message.contains("No suitable graphics adapter"))
391    }
392
393    /// Mock range series for testing (no external dependencies)
394    #[derive(Debug, Clone)]
395    struct MockRangeSeries {
396        name: String,
397        timestamps: Vec<i64>,
398        ranges: Vec<(f64, f64, f64, f64)>, // (start, max, min, end)
399        auxiliary: Vec<f64>,
400    }
401
402    impl MockRangeSeries {
403        fn new(name: &str, count: usize) -> Self {
404            let base_time = chrono::Utc
405                .with_ymd_and_hms(2024, 1, 1, 9, 0, 0)
406                .unwrap()
407                .timestamp();
408            let mut timestamps = Vec::with_capacity(count);
409            let mut ranges = Vec::with_capacity(count);
410            let mut auxiliary = Vec::with_capacity(count);
411
412            let mut last_end = 100.0;
413            for i in 0..count {
414                timestamps.push(base_time + (i as i64 * 3600)); // hourly
415                let start = last_end;
416                let movement = (i as f64 * 0.1).sin() * 5.0;
417                let end = start + movement;
418                let max = start.max(end) + 2.0;
419                let min = start.min(end) - 2.0;
420                ranges.push((start, max, min, end));
421                auxiliary.push(1000.0 + (i as f64 * 100.0));
422                last_end = end;
423            }
424
425            Self {
426                name: name.to_string(),
427                timestamps,
428                ranges,
429                auxiliary,
430            }
431        }
432    }
433
434    impl Series for MockRangeSeries {
435        fn name(&self) -> &str {
436            &self.name
437        }
438
439        fn len(&self) -> usize {
440            self.timestamps.len()
441        }
442
443        fn get_x(&self, index: usize) -> f64 {
444            self.timestamps[index] as f64
445        }
446
447        fn get_y(&self, index: usize) -> f64 {
448            self.ranges[index].3 // end value
449        }
450
451        fn get_x_range(&self) -> (f64, f64) {
452            if self.timestamps.is_empty() {
453                return (0.0, 1.0);
454            }
455            (
456                self.timestamps[0] as f64,
457                self.timestamps[self.timestamps.len() - 1] as f64,
458            )
459        }
460
461        fn get_y_range(&self, _x_min: f64, _x_max: f64) -> (f64, f64) {
462            let mut min = f64::INFINITY;
463            let mut max = f64::NEG_INFINITY;
464            for (_, hi, lo, _) in &self.ranges {
465                min = min.min(*lo);
466                max = max.max(*hi);
467            }
468            if min.is_infinite() {
469                (0.0, 100.0)
470            } else {
471                (min, max)
472            }
473        }
474
475        fn find_index(&self, x: f64) -> Option<usize> {
476            let target = x as i64;
477            match self.timestamps.binary_search(&target) {
478                Ok(idx) => Some(idx),
479                Err(idx) => Some(idx.min(self.len().saturating_sub(1))),
480            }
481        }
482
483        fn as_any(&self) -> &dyn Any {
484            self
485        }
486    }
487
488    impl RangeSeries for MockRangeSeries {
489        fn get_range(&self, index: usize) -> (f64, f64, f64, f64) {
490            self.ranges[index]
491        }
492
493        fn get_auxiliary(&self, index: usize) -> Option<f64> {
494            Some(self.auxiliary[index])
495        }
496    }
497
498    #[tokio::test]
499    async fn test_chart_creation() {
500        let config = ChartConfig::default();
501        let chart = match Chart::new(config).await {
502            Ok(chart) => chart,
503            Err(err) if should_skip_gpu(&err) => {
504                eprintln!("skipping test_chart_creation: {err}");
505                return;
506            }
507            Err(err) => panic!("chart creation failed: {err}"),
508        };
509        assert_eq!(chart.dimensions(), (800, 600));
510        assert!(chart.needs_render()); // Should be dirty initially
511    }
512
513    #[tokio::test]
514    async fn test_chart_with_data() {
515        let config = ChartConfig::default();
516        let mut chart = match Chart::new(config).await {
517            Ok(chart) => chart,
518            Err(err) if should_skip_gpu(&err) => {
519                eprintln!("skipping test_chart_with_data: {err}");
520                return;
521            }
522            Err(err) => panic!("chart creation failed: {err}"),
523        };
524
525        // Create test data using mock series
526        let series = MockRangeSeries::new("TEST", 10);
527        let chart_data = ChartData::new(Box::new(series));
528        let result = chart.set_data(chart_data);
529        assert!(result.is_ok());
530
531        assert!(chart.data().is_some());
532        assert!(chart.needs_render());
533    }
534
535    #[tokio::test]
536    async fn test_chart_viewport_operations() {
537        let config = ChartConfig::default();
538        let mut chart = match Chart::new(config).await {
539            Ok(chart) => chart,
540            Err(err) if should_skip_gpu(&err) => {
541                eprintln!("skipping test_chart_viewport_operations: {err}");
542                return;
543            }
544            Err(err) => panic!("chart creation failed: {err}"),
545        };
546
547        let original_bounds = chart.viewport().chart_bounds;
548
549        // Test pan
550        chart.pan(10.0, 20.0);
551        assert!(chart.needs_render());
552
553        // Test zoom
554        chart.zoom(400.0, 300.0, 2.0);
555        assert!(chart.needs_render());
556
557        // Bounds should have changed
558        assert_ne!(
559            chart.viewport().chart_bounds.time_start,
560            original_bounds.time_start
561        );
562    }
563
564    #[test]
565    fn test_config_defaults() {
566        let config = ChartConfig::default();
567        assert_eq!(config.width, 800);
568        assert_eq!(config.height, 600);
569        assert!(config.auto_fit);
570        assert!(config.anti_aliasing);
571        assert!(config.gpu_acceleration);
572        assert_eq!(config.max_fps, 60.0);
573    }
574
575    #[tokio::test]
576    async fn test_full_chart_render() {
577        let config = ChartConfig::default();
578        let mut chart = match Chart::new_financial(config).await {
579            Ok(chart) => chart,
580            Err(err) if should_skip_gpu(&err) => {
581                eprintln!("skipping test_full_chart_render: {err}");
582                return;
583            }
584            Err(err) => panic!("chart creation failed: {err}"),
585        };
586
587        // Create test data using mock series
588        let series = MockRangeSeries::new("TEST/USD", 50);
589        let chart_data = ChartData::new(Box::new(series));
590        chart.set_data(chart_data).unwrap();
591
592        // Verify chart is ready to render
593        assert!(chart.needs_render());
594        assert!(chart.data().is_some());
595
596        // Render the chart
597        let image_data = chart.render().await.unwrap();
598
599        // Verify image data is correct size
600        assert_eq!(image_data.len(), 800 * 600 * 4); // RGBA
601
602        // Verify it's not just a blank image (should have some color)
603        let non_zero_pixels = image_data.iter().any(|&byte| byte != 0);
604        assert!(non_zero_pixels, "Chart should contain rendered pixels");
605
606        println!(
607            "Chart rendered successfully with {} bytes of RGBA data",
608            image_data.len()
609        );
610    }
611}