Skip to main content

sciforge_hub/tools/
visualization.rs

1//! SVG chart generation: line, scatter, bar, histogram, and heatmap.
2//!
3//! Each function accepts a [`ChartConfig`] to customize dimensions,
4//! margins, title, axis labels, and grid.
5
6const PALETTE: [&str; 10] = [
7    "#2196F3", "#FF5722", "#4CAF50", "#FFC107", "#9C27B0", "#00BCD4", "#E91E63", "#8BC34A",
8    "#FF9800", "#607D8B",
9];
10
11fn palette(i: usize) -> &'static str {
12    PALETTE[i % PALETTE.len()]
13}
14
15fn escape_xml(s: &str) -> String {
16    s.replace('&', "&")
17        .replace('<', "&lt;")
18        .replace('>', "&gt;")
19        .replace('"', "&quot;")
20}
21
22/// Configuration for SVG chart rendering.
23#[derive(Debug, Clone)]
24pub struct ChartConfig {
25    /// Total SVG width in pixels.
26    pub width: f64,
27    /// Total SVG height in pixels.
28    pub height: f64,
29    /// Top margin in pixels.
30    pub margin_top: f64,
31    /// Right margin in pixels.
32    pub margin_right: f64,
33    /// Bottom margin in pixels.
34    pub margin_bottom: f64,
35    /// Left margin in pixels.
36    pub margin_left: f64,
37    /// Chart title.
38    pub title: String,
39    /// X-axis label.
40    pub x_label: String,
41    /// Y-axis label.
42    pub y_label: String,
43    /// Background color (CSS).
44    pub background: String,
45    /// Whether to draw grid lines.
46    pub grid: bool,
47    /// Base font size in pixels.
48    pub font_size: f64,
49}
50
51impl Default for ChartConfig {
52    fn default() -> Self {
53        Self {
54            width: 800.0,
55            height: 500.0,
56            margin_top: 50.0,
57            margin_right: 30.0,
58            margin_bottom: 60.0,
59            margin_left: 70.0,
60            title: String::new(),
61            x_label: String::new(),
62            y_label: String::new(),
63            background: "#ffffff".into(),
64            grid: true,
65            font_size: 14.0,
66        }
67    }
68}
69
70impl ChartConfig {
71    fn plot_w(&self) -> f64 {
72        self.width - self.margin_left - self.margin_right
73    }
74    fn plot_h(&self) -> f64 {
75        self.height - self.margin_top - self.margin_bottom
76    }
77}
78
79fn nice_ticks(min: f64, max: f64, target_count: usize) -> Vec<f64> {
80    if (max - min).abs() < 1e-15 {
81        return vec![min];
82    }
83    let range = max - min;
84    let rough_step = range / target_count as f64;
85    let mag = 10f64.powf(rough_step.log10().floor());
86    let frac = rough_step / mag;
87    let nice = if frac <= 1.5 {
88        1.0
89    } else if frac <= 3.5 {
90        2.0
91    } else if frac <= 7.5 {
92        5.0
93    } else {
94        10.0
95    };
96    let step = nice * mag;
97    let lo = (min / step).floor() * step;
98    let mut ticks = Vec::new();
99    let mut v = lo;
100    while v <= max + step * 0.01 {
101        if v >= min - step * 0.01 {
102            ticks.push(v);
103        }
104        v += step;
105    }
106    ticks
107}
108
109fn format_tick(v: f64) -> String {
110    if v.abs() >= 1e6 || (v != 0.0 && v.abs() < 0.01) {
111        format!("{:.2e}", v)
112    } else if v == v.floor() {
113        format!("{:.0}", v)
114    } else {
115        format!("{:.2}", v)
116    }
117}
118
119fn data_range(data: &[f64]) -> (f64, f64) {
120    let min = data.iter().copied().fold(f64::INFINITY, f64::min);
121    let max = data.iter().copied().fold(f64::NEG_INFINITY, f64::max);
122    if (max - min).abs() < 1e-15 {
123        (min - 1.0, max + 1.0)
124    } else {
125        let pad = (max - min) * 0.05;
126        (min - pad, max + pad)
127    }
128}
129
130fn svg_header(cfg: &ChartConfig) -> String {
131    format!(
132        "<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 {} {}\" width=\"{}\" height=\"{}\">\n\
133         <rect width=\"100%\" height=\"100%\" fill=\"{}\"/>\n",
134        cfg.width, cfg.height, cfg.width, cfg.height, cfg.background,
135    )
136}
137
138fn svg_title(cfg: &ChartConfig) -> String {
139    if cfg.title.is_empty() {
140        return String::new();
141    }
142    format!(
143        "<text x=\"{}\" y=\"{}\" text-anchor=\"middle\" font-size=\"{}\" font-weight=\"bold\">{}</text>\n",
144        cfg.width / 2.0,
145        cfg.margin_top / 2.0 + 5.0,
146        cfg.font_size + 2.0,
147        escape_xml(&cfg.title),
148    )
149}
150
151fn svg_axes(cfg: &ChartConfig, x_min: f64, x_max: f64, y_min: f64, y_max: f64) -> String {
152    let pw = cfg.plot_w();
153    let ph = cfg.plot_h();
154    let ml = cfg.margin_left;
155    let mt = cfg.margin_top;
156    let mut s = String::new();
157
158    s.push_str(&format!(
159        "<line x1=\"{}\" y1=\"{}\" x2=\"{}\" y2=\"{}\" stroke=\"#333\" stroke-width=\"1.5\"/>\n",
160        ml,
161        mt,
162        ml,
163        mt + ph,
164    ));
165    s.push_str(&format!(
166        "<line x1=\"{}\" y1=\"{}\" x2=\"{}\" y2=\"{}\" stroke=\"#333\" stroke-width=\"1.5\"/>\n",
167        ml,
168        mt + ph,
169        ml + pw,
170        mt + ph,
171    ));
172
173    let xt = nice_ticks(x_min, x_max, 6);
174    for &v in &xt {
175        let x = ml + (v - x_min) / (x_max - x_min) * pw;
176        s.push_str(&format!(
177            "<text x=\"{}\" y=\"{}\" text-anchor=\"middle\" font-size=\"{}\">{}</text>\n",
178            x,
179            mt + ph + 20.0,
180            cfg.font_size - 2.0,
181            format_tick(v),
182        ));
183        if cfg.grid {
184            s.push_str(&format!(
185                "<line x1=\"{}\" y1=\"{}\" x2=\"{}\" y2=\"{}\" stroke=\"#ddd\" stroke-width=\"0.5\"/>\n",
186                x, mt, x, mt + ph,
187            ));
188        }
189    }
190
191    let yt = nice_ticks(y_min, y_max, 5);
192    for &v in &yt {
193        let y = mt + ph - (v - y_min) / (y_max - y_min) * ph;
194        s.push_str(&format!(
195            "<text x=\"{}\" y=\"{}\" text-anchor=\"end\" font-size=\"{}\">{}</text>\n",
196            ml - 8.0,
197            y + 4.0,
198            cfg.font_size - 2.0,
199            format_tick(v),
200        ));
201        if cfg.grid {
202            s.push_str(&format!(
203                "<line x1=\"{}\" y1=\"{}\" x2=\"{}\" y2=\"{}\" stroke=\"#ddd\" stroke-width=\"0.5\"/>\n",
204                ml, y, ml + pw, y,
205            ));
206        }
207    }
208
209    if !cfg.x_label.is_empty() {
210        s.push_str(&format!(
211            "<text x=\"{}\" y=\"{}\" text-anchor=\"middle\" font-size=\"{}\">{}</text>\n",
212            ml + pw / 2.0,
213            cfg.height - 10.0,
214            cfg.font_size,
215            escape_xml(&cfg.x_label),
216        ));
217    }
218    if !cfg.y_label.is_empty() {
219        s.push_str(&format!(
220            "<text x=\"{}\" y=\"{}\" text-anchor=\"middle\" font-size=\"{}\" transform=\"rotate(-90,{},{})\">{}</text>\n",
221            15.0, mt + ph / 2.0, cfg.font_size, 15.0, mt + ph / 2.0, escape_xml(&cfg.y_label),
222        ));
223    }
224
225    s
226}
227
228/// Named data series for line charts.
229#[derive(Debug, Clone)]
230pub struct Series {
231    /// Series display name (used in legend).
232    pub name: String,
233    /// X-coordinates.
234    pub x: Vec<f64>,
235    /// Y-coordinates.
236    pub y: Vec<f64>,
237}
238
239/// Generates a multi-series SVG line chart.
240pub fn line_chart(series: &[Series], cfg: &ChartConfig) -> String {
241    let all_x: Vec<f64> = series.iter().flat_map(|s| s.x.iter().copied()).collect();
242    let all_y: Vec<f64> = series.iter().flat_map(|s| s.y.iter().copied()).collect();
243    if all_x.is_empty() {
244        return String::from("<svg/>");
245    }
246
247    let (x_min, x_max) = data_range(&all_x);
248    let (y_min, y_max) = data_range(&all_y);
249    let pw = cfg.plot_w();
250    let ph = cfg.plot_h();
251    let ml = cfg.margin_left;
252    let mt = cfg.margin_top;
253
254    let mut svg = svg_header(cfg);
255    svg.push_str(&svg_title(cfg));
256    svg.push_str(&svg_axes(cfg, x_min, x_max, y_min, y_max));
257
258    for (si, s) in series.iter().enumerate() {
259        let color = palette(si);
260        let mut path = String::new();
261        for (i, (&xi, &yi)) in s.x.iter().zip(s.y.iter()).enumerate() {
262            let px = ml + (xi - x_min) / (x_max - x_min) * pw;
263            let py = mt + ph - (yi - y_min) / (y_max - y_min) * ph;
264            if i == 0 {
265                path.push_str(&format!("M{:.2},{:.2}", px, py));
266            } else {
267                path.push_str(&format!(" L{:.2},{:.2}", px, py));
268            }
269        }
270        svg.push_str(&format!(
271            "<path d=\"{}\" fill=\"none\" stroke=\"{}\" stroke-width=\"2\"/>\n",
272            path, color,
273        ));
274    }
275
276    if series.len() > 1 {
277        for (si, s) in series.iter().enumerate() {
278            let lx = ml + pw - 120.0;
279            let ly = mt + 20.0 + si as f64 * 20.0;
280            svg.push_str(&format!(
281                "<rect x=\"{}\" y=\"{}\" width=\"12\" height=\"12\" fill=\"{}\"/>\n",
282                lx,
283                ly - 10.0,
284                palette(si),
285            ));
286            svg.push_str(&format!(
287                "<text x=\"{}\" y=\"{}\" font-size=\"{}\">{}</text>\n",
288                lx + 18.0,
289                ly,
290                cfg.font_size - 2.0,
291                escape_xml(&s.name),
292            ));
293        }
294    }
295
296    svg.push_str("</svg>");
297    svg
298}
299
300/// Generates an SVG scatter plot.
301pub fn scatter_plot(x: &[f64], y: &[f64], cfg: &ChartConfig) -> String {
302    if x.is_empty() {
303        return String::from("<svg/>");
304    }
305    let (x_min, x_max) = data_range(x);
306    let (y_min, y_max) = data_range(y);
307    let pw = cfg.plot_w();
308    let ph = cfg.plot_h();
309    let ml = cfg.margin_left;
310    let mt = cfg.margin_top;
311
312    let mut svg = svg_header(cfg);
313    svg.push_str(&svg_title(cfg));
314    svg.push_str(&svg_axes(cfg, x_min, x_max, y_min, y_max));
315
316    for (&xi, &yi) in x.iter().zip(y.iter()) {
317        let px = ml + (xi - x_min) / (x_max - x_min) * pw;
318        let py = mt + ph - (yi - y_min) / (y_max - y_min) * ph;
319        svg.push_str(&format!(
320            "<circle cx=\"{:.2}\" cy=\"{:.2}\" r=\"3\" fill=\"{}\" opacity=\"0.7\"/>\n",
321            px, py, PALETTE[0],
322        ));
323    }
324
325    svg.push_str("</svg>");
326    svg
327}
328
329/// Generates an SVG bar chart from labels and values.
330pub fn bar_chart(labels: &[&str], values: &[f64], cfg: &ChartConfig) -> String {
331    if labels.is_empty() {
332        return String::from("<svg/>");
333    }
334    let n = labels.len();
335    let pw = cfg.plot_w();
336    let ph = cfg.plot_h();
337    let ml = cfg.margin_left;
338    let mt = cfg.margin_top;
339
340    let y_max = values.iter().copied().fold(0.0f64, f64::max) * 1.1;
341    let y_min = 0.0f64;
342    let bar_w = pw / n as f64 * 0.7;
343    let gap = pw / n as f64 * 0.3;
344
345    let mut svg = svg_header(cfg);
346    svg.push_str(&svg_title(cfg));
347    svg.push_str(&svg_axes(cfg, 0.0, n as f64, y_min, y_max));
348
349    for (i, (&label, &val)) in labels.iter().zip(values.iter()).enumerate() {
350        let x = ml + i as f64 * (bar_w + gap) + gap / 2.0;
351        let h = if y_max > 0.0 { val / y_max * ph } else { 0.0 };
352        let y = mt + ph - h;
353        svg.push_str(&format!(
354            "<rect x=\"{:.2}\" y=\"{:.2}\" width=\"{:.2}\" height=\"{:.2}\" fill=\"{}\" rx=\"2\"/>\n",
355            x, y, bar_w, h, palette(i),
356        ));
357        svg.push_str(&format!(
358            "<text x=\"{:.2}\" y=\"{}\" text-anchor=\"middle\" font-size=\"{}\">{}</text>\n",
359            x + bar_w / 2.0,
360            mt + ph + 18.0,
361            cfg.font_size - 3.0,
362            escape_xml(label),
363        ));
364    }
365
366    svg.push_str("</svg>");
367    svg
368}
369
370/// Generates an SVG histogram by binning the input data.
371pub fn histogram(data: &[f64], bins: usize, cfg: &ChartConfig) -> String {
372    if data.is_empty() || bins == 0 {
373        return String::from("<svg/>");
374    }
375    let (d_min, d_max) = data_range(data);
376    let bin_w = (d_max - d_min) / bins as f64;
377
378    let mut counts = vec![0usize; bins];
379    for &v in data {
380        let idx = ((v - d_min) / bin_w).floor() as usize;
381        let idx = idx.min(bins - 1);
382        counts[idx] += 1;
383    }
384
385    let max_count = *counts.iter().max().unwrap_or(&1);
386    let pw = cfg.plot_w();
387    let ph = cfg.plot_h();
388    let ml = cfg.margin_left;
389    let mt = cfg.margin_top;
390    let bar_px = pw / bins as f64;
391
392    let mut svg = svg_header(cfg);
393    svg.push_str(&svg_title(cfg));
394    svg.push_str(&svg_axes(cfg, d_min, d_max, 0.0, max_count as f64));
395
396    for (i, &c) in counts.iter().enumerate() {
397        let x = ml + i as f64 * bar_px;
398        let h = if max_count > 0 {
399            c as f64 / max_count as f64 * ph
400        } else {
401            0.0
402        };
403        let y = mt + ph - h;
404        svg.push_str(&format!(
405            "<rect x=\"{:.2}\" y=\"{:.2}\" width=\"{:.2}\" height=\"{:.2}\" fill=\"{}\" stroke=\"#fff\" stroke-width=\"0.5\"/>\n",
406            x, y, bar_px, h, PALETTE[0],
407        ));
408    }
409
410    svg.push_str("</svg>");
411    svg
412}
413
414/// Generates an SVG heatmap from a 2D matrix.
415pub fn heatmap(matrix: &[Vec<f64>], cfg: &ChartConfig) -> String {
416    if matrix.is_empty() {
417        return String::from("<svg/>");
418    }
419    let rows = matrix.len();
420    let cols = matrix[0].len();
421    let pw = cfg.plot_w();
422    let ph = cfg.plot_h();
423    let ml = cfg.margin_left;
424    let mt = cfg.margin_top;
425    let cell_w = pw / cols as f64;
426    let cell_h = ph / rows as f64;
427
428    let all: Vec<f64> = matrix.iter().flat_map(|r| r.iter().copied()).collect();
429    let v_min = all.iter().copied().fold(f64::INFINITY, f64::min);
430    let v_max = all.iter().copied().fold(f64::NEG_INFINITY, f64::max);
431    let range = if (v_max - v_min).abs() < 1e-15 {
432        1.0
433    } else {
434        v_max - v_min
435    };
436
437    let mut svg = svg_header(cfg);
438    svg.push_str(&svg_title(cfg));
439
440    for (r, row) in matrix.iter().enumerate() {
441        for (c, &val) in row.iter().enumerate() {
442            let t = ((val - v_min) / range).clamp(0.0, 1.0);
443            let red = (255.0 * t) as u8;
444            let blue = (255.0 * (1.0 - t)) as u8;
445            let x = ml + c as f64 * cell_w;
446            let y = mt + r as f64 * cell_h;
447            svg.push_str(&format!(
448                "<rect x=\"{:.1}\" y=\"{:.1}\" width=\"{:.1}\" height=\"{:.1}\" fill=\"rgb({},0,{})\"/>\n",
449                x, y, cell_w, cell_h, red, blue,
450            ));
451        }
452    }
453
454    svg.push_str("</svg>");
455    svg
456}