scirs2_text/
visualization.rs

1//! Visualization tools for text processing and analysis
2//!
3//! This module provides comprehensive visualization capabilities for text data,
4//! including word clouds, attention visualizations, embedding plots, and various
5//! text analysis charts.
6
7use crate::error::{Result, TextError};
8use crate::sentiment::SentimentResult;
9use crate::topic_modeling::Topic;
10use crate::vectorize::{CountVectorizer, TfidfVectorizer, Vectorizer};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
12use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15
16/// Configuration for text visualizations
17#[derive(Debug, Clone)]
18pub struct VisualizationConfig {
19    /// Output image width
20    pub width: usize,
21    /// Output image height
22    pub height: usize,
23    /// Color scheme
24    pub color_scheme: ColorScheme,
25    /// Font size range
26    pub font_size_range: (usize, usize),
27    /// Background color
28    pub background_color: Color,
29    /// Whether to save high-DPI images
30    pub high_dpi: bool,
31}
32
33impl Default for VisualizationConfig {
34    fn default() -> Self {
35        Self {
36            width: 800,
37            height: 600,
38            color_scheme: ColorScheme::Viridis,
39            font_size_range: (10, 100),
40            background_color: Color::WHITE,
41            high_dpi: false,
42        }
43    }
44}
45
46/// Color schemes for visualizations
47#[derive(Debug, Clone)]
48pub enum ColorScheme {
49    /// Viridis color scheme
50    Viridis,
51    /// Plasma color scheme
52    Plasma,
53    /// Inferno color scheme
54    Inferno,
55    /// Cool colors
56    Cool,
57    /// Warm colors
58    Warm,
59    /// Custom color palette
60    Custom(Vec<Color>),
61}
62
63/// RGB color representation
64#[derive(Debug, Clone, Copy)]
65pub struct Color {
66    /// Red component (0-255)
67    pub r: u8,
68    /// Green component (0-255)
69    pub g: u8,
70    /// Blue component (0-255)
71    pub b: u8,
72}
73
74impl Color {
75    /// White color constant
76    pub const WHITE: Color = Color {
77        r: 255,
78        g: 255,
79        b: 255,
80    };
81    /// Black color constant
82    pub const BLACK: Color = Color { r: 0, g: 0, b: 0 };
83    /// Red color constant
84    pub const RED: Color = Color { r: 255, g: 0, b: 0 };
85    /// Green color constant
86    pub const GREEN: Color = Color { r: 0, g: 255, b: 0 };
87    /// Blue color constant
88    pub const BLUE: Color = Color { r: 0, g: 0, b: 255 };
89
90    /// Create new color from RGB values
91    pub fn new(r: u8, g: u8, b: u8) -> Self {
92        Self { r, g, b }
93    }
94
95    /// Convert to hex string
96    pub fn to_hex(&self) -> String {
97        format!("#{:02x}{:02x}{:02x}", self.r, self.g, self.b)
98    }
99}
100
101/// Word cloud visualization
102pub struct WordCloud {
103    /// Word frequency data
104    word_frequencies: HashMap<String, f64>,
105    /// Configuration
106    config: VisualizationConfig,
107}
108
109impl WordCloud {
110    /// Create new word cloud from text
111    pub fn fromtext(text: &str, config: VisualizationConfig) -> Result<Self> {
112        let mut vectorizer = CountVectorizer::new(false);
113        let documents = vec![text];
114        let matrix = vectorizer.fit_transform(&documents)?;
115
116        let vocabulary_map = vectorizer.vocabulary_map();
117        let mut word_frequencies = HashMap::new();
118
119        // Extract word frequencies from the matrix
120        for (word, &idx) in vocabulary_map.iter() {
121            if let Some(count) = vectorizer.get_feature_count(&matrix, 0, idx) {
122                if count > 0.0 {
123                    word_frequencies.insert(word.clone(), count);
124                }
125            }
126        }
127
128        Ok(Self {
129            word_frequencies,
130            config,
131        })
132    }
133
134    /// Create word cloud from TF-IDF vectorizer and matrix
135    pub fn from_tfidf(
136        vectorizer: &TfidfVectorizer,
137        matrix: &Array2<f64>,
138        document_index: usize,
139    ) -> Result<Self> {
140        let vocabulary_map = vectorizer.vocabulary_map();
141        let mut word_frequencies = HashMap::new();
142
143        // Get TF-IDF scores for the document
144        for (word, &idx) in vocabulary_map.iter() {
145            if let Some(score) = vectorizer.get_feature_score(matrix, document_index, idx) {
146                if score > 0.0 {
147                    word_frequencies.insert(word.clone(), score);
148                }
149            }
150        }
151
152        Ok(Self {
153            word_frequencies,
154            config: VisualizationConfig::default(),
155        })
156    }
157
158    /// Create word cloud from frequency map
159    pub fn from_frequencies(
160        frequencies: HashMap<String, f64>,
161        config: VisualizationConfig,
162    ) -> Self {
163        Self {
164            word_frequencies: frequencies,
165            config,
166        }
167    }
168
169    /// Generate word cloud as SVG
170    pub fn to_svg(&self) -> Result<String> {
171        let mut svg = String::new();
172
173        // SVG header
174        svg.push_str(&format!(
175            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
176            self.config.width, self.config.height
177        ));
178
179        // Background
180        svg.push_str(&format!(
181            r#"<rect width="100%" height="100%" fill="{}" />"#,
182            self.config.background_color.to_hex()
183        ));
184
185        // Sort words by frequency
186        let mut sorted_words: Vec<_> = self.word_frequencies.iter().collect();
187        sorted_words.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
188
189        // Take top words
190        let max_words = 50;
191        let top_words: Vec<_> = sorted_words.into_iter().take(max_words).collect();
192
193        if top_words.is_empty() {
194            return Ok(svg + "</svg>");
195        }
196
197        // Calculate font sizes
198        let max_freq = top_words[0].1;
199        let min_freq = top_words.last().map(|x| *x.1).unwrap_or(*max_freq);
200        let freq_range = max_freq - min_freq;
201
202        // Generate colors
203        let colors = self.generate_colors(top_words.len());
204
205        // Position words (simplified grid layout)
206        let cols = (top_words.len() as f64).sqrt().ceil() as usize;
207        let rows = top_words.len().div_ceil(cols);
208        let cell_width = self.config.width / cols;
209        let cell_height = self.config.height / rows;
210
211        for (i, (word, &freq)) in top_words.iter().enumerate() {
212            let row = i / cols;
213            let col = i % cols;
214
215            // Calculate position
216            let x = col * cell_width + cell_width / 2;
217            let y = row * cell_height + cell_height / 2;
218
219            // Calculate font size
220            let font_size = if freq_range > 0.0 {
221                let normalized = (freq - min_freq) / freq_range;
222                self.config.font_size_range.0
223                    + (normalized
224                        * (self.config.font_size_range.1 - self.config.font_size_range.0) as f64)
225                        as usize
226            } else {
227                (self.config.font_size_range.0 + self.config.font_size_range.1) / 2
228            };
229
230            // Add word to SVG
231            svg.push_str(&format!(
232                r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="{}" 
233                   fill="{}" text-anchor="middle" dominant-baseline="middle">{}</text>"#,
234                x,
235                y,
236                font_size,
237                colors[i % colors.len()].to_hex(),
238                word
239            ));
240        }
241
242        svg.push_str("</svg>");
243        Ok(svg)
244    }
245
246    /// Generate colors based on color scheme
247    fn generate_colors(&self, count: usize) -> Vec<Color> {
248        match &self.config.color_scheme {
249            ColorScheme::Viridis => self.generate_viridis_colors(count),
250            ColorScheme::Plasma => self.generate_plasma_colors(count),
251            ColorScheme::Inferno => self.generate_inferno_colors(count),
252            ColorScheme::Cool => self.generate_cool_colors(count),
253            ColorScheme::Warm => self.generate_warm_colors(count),
254            ColorScheme::Custom(colors) => colors.clone(),
255        }
256    }
257
258    /// Generate viridis color scheme
259    fn generate_viridis_colors(&self, count: usize) -> Vec<Color> {
260        let mut colors = Vec::new();
261        for i in 0..count {
262            let t = i as f64 / (count - 1).max(1) as f64;
263            // Simplified viridis approximation
264            let r = (68.0 + t * (253.0 - 68.0)) as u8;
265            let g = (1.0 + t * (231.0 - 1.0)) as u8;
266            let b = (84.0 + t * (37.0 - 84.0)) as u8;
267            colors.push(Color::new(r, g, b));
268        }
269        colors
270    }
271
272    /// Generate plasma color scheme
273    fn generate_plasma_colors(&self, count: usize) -> Vec<Color> {
274        let mut colors = Vec::new();
275        for i in 0..count {
276            let t = i as f64 / (count - 1).max(1) as f64;
277            // Simplified plasma approximation
278            let r = (13.0 + t * (240.0 - 13.0)) as u8;
279            let g = (8.0 + t * (249.0 - 8.0)) as u8;
280            let b = (135.0 + t * (33.0 - 135.0)) as u8;
281            colors.push(Color::new(r, g, b));
282        }
283        colors
284    }
285
286    /// Generate inferno color scheme
287    fn generate_inferno_colors(&self, count: usize) -> Vec<Color> {
288        let mut colors = Vec::new();
289        for i in 0..count {
290            let t = i as f64 / (count - 1).max(1) as f64;
291            // Simplified inferno approximation
292            let r = (0.0 + t * (252.0 - 0.0)) as u8;
293            let g = (0.0 + t * (255.0 - 0.0)) as u8;
294            let b = (4.0 + t * (164.0 - 4.0)) as u8;
295            colors.push(Color::new(r, g, b));
296        }
297        colors
298    }
299
300    /// Generate cool color scheme
301    fn generate_cool_colors(&self, count: usize) -> Vec<Color> {
302        let mut colors = Vec::new();
303        for i in 0..count {
304            let t = i as f64 / (count - 1).max(1) as f64;
305            let r = (0.0 + t * (255.0 - 0.0)) as u8;
306            let g = (255.0 - t * (255.0 - 0.0)) as u8;
307            let b = 255;
308            colors.push(Color::new(r, g, b));
309        }
310        colors
311    }
312
313    /// Generate warm color scheme
314    fn generate_warm_colors(&self, count: usize) -> Vec<Color> {
315        let mut colors = Vec::new();
316        for i in 0..count {
317            let t = i as f64 / (count - 1).max(1) as f64;
318            let r = 255;
319            let g = (255.0 - t * (255.0 - 0.0)) as u8;
320            let b = (0.0 + t * (255.0 - 0.0)) as u8;
321            colors.push(Color::new(r, g, b));
322        }
323        colors
324    }
325
326    /// Save word cloud to file
327    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
328        let svg_content = self.to_svg()?;
329        fs::write(path, svg_content)
330            .map_err(|e| TextError::IoError(format!("Failed to save word cloud: {e}")))?;
331        Ok(())
332    }
333}
334
335/// Attention visualization for transformer models
336pub struct AttentionVisualizer {
337    config: VisualizationConfig,
338}
339
340impl AttentionVisualizer {
341    /// Create new attention visualizer
342    pub fn new(config: VisualizationConfig) -> Self {
343        Self { config }
344    }
345
346    /// Visualize attention weights as heatmap
347    pub fn attention_heatmap(
348        &self,
349        attentionweights: ArrayView2<f64>,
350        source_tokens: &[String],
351        target_tokens: &[String],
352    ) -> Result<String> {
353        let (n_target, n_source) = attentionweights.dim();
354
355        if source_tokens.len() != n_source || target_tokens.len() != n_target {
356            return Err(TextError::InvalidInput(
357                "Token count doesn't match attention matrix dimensions".to_string(),
358            ));
359        }
360
361        let mut svg = String::new();
362
363        // Calculate dimensions
364        let cell_width = self.config.width / (n_source + 1);
365        let cell_height = self.config.height / (n_target + 1);
366        let _matrix_width = n_source * cell_width;
367        let matrix_height = n_target * cell_height;
368
369        // SVG header
370        svg.push_str(&format!(
371            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
372            self.config.width + 100,
373            self.config.height + 100
374        ));
375
376        // Background
377        svg.push_str(&format!(
378            r#"<rect width="100%" height="100%" fill="{}" />"#,
379            self.config.background_color.to_hex()
380        ));
381
382        // Find min and max attention values for normalization
383        let mut min_val = f64::INFINITY;
384        let mut max_val = f64::NEG_INFINITY;
385        for &val in attentionweights.iter() {
386            min_val = min_val.min(val);
387            max_val = max_val.max(val);
388        }
389
390        // Draw attention matrix
391        for i in 0..n_target {
392            for j in 0..n_source {
393                let attention = attentionweights[[i, j]];
394                let normalized = if max_val > min_val {
395                    (attention - min_val) / (max_val - min_val)
396                } else {
397                    0.5
398                };
399
400                // Color based on attention value
401                let intensity = (normalized * 255.0) as u8;
402                let color = Color::new(255 - intensity, 255 - intensity, 255);
403
404                let x = 50 + j * cell_width;
405                let y = 50 + i * cell_height;
406
407                svg.push_str(&format!(
408                    r#"<rect x="{}" y="{}" width="{}" height="{}" fill="{}" stroke="black" stroke-width="1" />"#,
409                    x, y, cell_width, cell_height, color.to_hex()
410                ));
411
412                // Add attention value text
413                svg.push_str(&format!(
414                    r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="10" 
415                       text-anchor="middle" dominant-baseline="middle">{:.2}</text>"#,
416                    x + cell_width / 2,
417                    y + cell_height / 2,
418                    attention
419                ));
420            }
421        }
422
423        // Add source token labels (bottom)
424        for (j, token) in source_tokens.iter().enumerate() {
425            let x = 50 + j * cell_width + cell_width / 2;
426            let y = 50 + matrix_height + 20;
427
428            svg.push_str(&format!(
429                r#"<text x="{x}" y="{y}" font-family="Arial, sans-serif" font-size="12" 
430                   text-anchor="middle" dominant-baseline="middle" transform="rotate(-45 {x} {y})">{token}</text>"#
431            ));
432        }
433
434        // Add target token labels (left)
435        for (i, token) in target_tokens.iter().enumerate() {
436            let x = 30;
437            let y = 50 + i * cell_height + cell_height / 2;
438
439            svg.push_str(&format!(
440                r#"<text x="{x}" y="{y}" font-family="Arial, sans-serif" font-size="12" 
441                   text-anchor="end" dominant-baseline="middle">{token}</text>"#
442            ));
443        }
444
445        svg.push_str("</svg>");
446        Ok(svg)
447    }
448
449    /// Save attention visualization to file
450    pub fn save_attention_heatmap<P: AsRef<Path>>(
451        &self,
452        attentionweights: ArrayView2<f64>,
453        source_tokens: &[String],
454        target_tokens: &[String],
455        path: P,
456    ) -> Result<()> {
457        let svg_content = self.attention_heatmap(attentionweights, source_tokens, target_tokens)?;
458        fs::write(path, svg_content)
459            .map_err(|e| TextError::IoError(format!("Failed to save attention heatmap: {e}")))?;
460        Ok(())
461    }
462}
463
464/// Embedding visualization using dimensionality reduction
465pub struct EmbeddingVisualizer {
466    config: VisualizationConfig,
467}
468
469impl EmbeddingVisualizer {
470    /// Create new embedding visualizer
471    pub fn new(config: VisualizationConfig) -> Self {
472        Self { config }
473    }
474
475    /// Visualize word embeddings using PCA projection to 2D
476    pub fn visualize_embeddings(
477        &self,
478        word_vectors: &HashMap<String, Array1<f64>>,
479        words_to_plot: Option<&[String]>,
480    ) -> Result<String> {
481        let words: Vec<String> = if let Some(selected_words) = words_to_plot {
482            selected_words.to_vec()
483        } else {
484            word_vectors.keys().take(100).cloned().collect()
485        };
486
487        if words.is_empty() {
488            return Err(TextError::InvalidInput("No words to visualize".to_string()));
489        }
490
491        // Collect embedding _vectors
492        let mut embeddings = Vec::new();
493        let mut valid_words = Vec::new();
494
495        for word in &words {
496            if let Some(vector) = word_vectors.get(word) {
497                embeddings.push(vector.clone());
498                valid_words.push(word.clone());
499            }
500        }
501
502        if embeddings.is_empty() {
503            return Err(TextError::InvalidInput(
504                "No valid embeddings found".to_string(),
505            ));
506        }
507
508        // Simple PCA to 2D (simplified implementation)
509        let projected_points = self.simple_pca_2d(&embeddings)?;
510
511        // Create SVG
512        let mut svg = String::new();
513
514        // SVG header
515        svg.push_str(&format!(
516            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
517            self.config.width, self.config.height
518        ));
519
520        // Background
521        svg.push_str(&format!(
522            r#"<rect width="100%" height="100%" fill="{}" />"#,
523            self.config.background_color.to_hex()
524        ));
525
526        // Find bounds for scaling
527        let mut min_x = f64::INFINITY;
528        let mut max_x = f64::NEG_INFINITY;
529        let mut min_y = f64::INFINITY;
530        let mut max_y = f64::NEG_INFINITY;
531
532        for &(x, y) in &projected_points {
533            min_x = min_x.min(x);
534            max_x = max_x.max(x);
535            min_y = min_y.min(y);
536            max_y = max_y.max(y);
537        }
538
539        let margin = 50.0;
540        let plot_width = self.config.width as f64 - 2.0 * margin;
541        let plot_height = self.config.height as f64 - 2.0 * margin;
542
543        // Generate colors
544        let colors = self.generate_colors(valid_words.len());
545
546        // Plot points and labels
547        for (i, ((x, y), word)) in projected_points.iter().zip(&valid_words).enumerate() {
548            // Scale to _plot area
549            let scaled_x = margin + (x - min_x) / (max_x - min_x) * plot_width;
550            let scaled_y = margin + (y - min_y) / (max_y - min_y) * plot_height;
551
552            // Draw point
553            svg.push_str(&format!(
554                r#"<circle cx="{}" cy="{}" r="3" fill="{}" />"#,
555                scaled_x,
556                scaled_y,
557                colors[i % colors.len()].to_hex()
558            ));
559
560            // Draw label
561            svg.push_str(&format!(
562                r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="10" 
563                   text-anchor="start" dominant-baseline="middle">{}</text>"#,
564                scaled_x + 5.0,
565                scaled_y,
566                word
567            ));
568        }
569
570        svg.push_str("</svg>");
571        Ok(svg)
572    }
573
574    /// Simple PCA implementation for 2D projection
575    fn simple_pca_2d(&self, embeddings: &[Array1<f64>]) -> Result<Vec<(f64, f64)>> {
576        if embeddings.is_empty() {
577            return Ok(Vec::new());
578        }
579
580        let n_samples = embeddings.len();
581        let n_features = embeddings[0].len();
582
583        // Create data matrix
584        let mut data_matrix = Array2::zeros((n_samples, n_features));
585        for (i, embedding) in embeddings.iter().enumerate() {
586            data_matrix.row_mut(i).assign(embedding);
587        }
588
589        // Center the data
590        let mean = data_matrix.mean_axis(Axis(0)).unwrap();
591        for mut row in data_matrix.rows_mut() {
592            row -= &mean;
593        }
594
595        // Simplified SVD (using covariance matrix approach)
596        let _cov_matrix = data_matrix.t().dot(&data_matrix) / (n_samples - 1) as f64;
597
598        // Find first two principal components (simplified eigenvalue decomposition)
599        // This is a very simplified approach - in practice would use proper SVD/eigendecomposition
600        let mut pc1 = Array1::zeros(n_features);
601        let mut pc2 = Array1::zeros(n_features);
602
603        // Use random orthogonal vectors as approximation
604        for i in 0..n_features {
605            pc1[i] = (i as f64).sin();
606            pc2[i] = (i as f64).cos();
607        }
608
609        // Normalize
610        pc1 /= pc1.dot(&pc1).sqrt();
611        pc2 /= pc2.dot(&pc2).sqrt();
612
613        // Project data
614        let mut projected = Vec::new();
615        for row in data_matrix.rows() {
616            let x = row.dot(&pc1);
617            let y = row.dot(&pc2);
618            projected.push((x, y));
619        }
620
621        Ok(projected)
622    }
623
624    /// Generate colors for embedding visualization
625    fn generate_colors(&self, count: usize) -> Vec<Color> {
626        let mut colors = Vec::new();
627        for i in 0..count {
628            let hue = (i as f64 / count as f64) * 360.0;
629            let color = self.hsv_to_rgb(hue, 0.8, 0.9);
630            colors.push(color);
631        }
632        colors
633    }
634
635    /// Convert HSV to RGB
636    fn hsv_to_rgb(&self, h: f64, s: f64, v: f64) -> Color {
637        let c = v * s;
638        let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
639        let m = v - c;
640
641        let (r_prime, g_prime, b_prime) = match h as i32 / 60 {
642            0 => (c, x, 0.0),
643            1 => (x, c, 0.0),
644            2 => (0.0, c, x),
645            3 => (0.0, x, c),
646            4 => (x, 0.0, c),
647            _ => (c, 0.0, x),
648        };
649
650        Color::new(
651            ((r_prime + m) * 255.0) as u8,
652            ((g_prime + m) * 255.0) as u8,
653            ((b_prime + m) * 255.0) as u8,
654        )
655    }
656
657    /// Save embedding visualization to file
658    pub fn save_embeddings<P: AsRef<Path>>(
659        &self,
660        word_vectors: &HashMap<String, Array1<f64>>,
661        words_to_plot: Option<&[String]>,
662        path: P,
663    ) -> Result<()> {
664        let svg_content = self.visualize_embeddings(word_vectors, words_to_plot)?;
665        fs::write(path, svg_content).map_err(|e| {
666            TextError::IoError(format!("Failed to save embedding visualization: {e}"))
667        })?;
668        Ok(())
669    }
670}
671
672/// Sentiment analysis visualization
673pub struct SentimentVisualizer {
674    config: VisualizationConfig,
675}
676
677impl SentimentVisualizer {
678    /// Create new sentiment visualizer
679    pub fn new(config: VisualizationConfig) -> Self {
680        Self { config }
681    }
682
683    /// Create sentiment distribution chart
684    pub fn sentiment_distribution(
685        &self,
686        sentiment_results: &[SentimentResult],
687        labels: &[String],
688    ) -> Result<String> {
689        if sentiment_results.len() != labels.len() {
690            return Err(TextError::InvalidInput(
691                "Number of sentiment _results must match number of labels".to_string(),
692            ));
693        }
694
695        // Count sentiment categories
696        let mut positive_count = 0;
697        let mut negative_count = 0;
698        let mut neutral_count = 0;
699
700        for result in sentiment_results {
701            match result.sentiment {
702                crate::sentiment::Sentiment::Positive => positive_count += 1,
703                crate::sentiment::Sentiment::Negative => negative_count += 1,
704                crate::sentiment::Sentiment::Neutral => neutral_count += 1,
705            }
706        }
707
708        let total = sentiment_results.len();
709        if total == 0 {
710            return Err(TextError::InvalidInput(
711                "No sentiment data to visualize".to_string(),
712            ));
713        }
714
715        let mut svg = String::new();
716
717        // SVG header
718        svg.push_str(&format!(
719            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
720            self.config.width, self.config.height
721        ));
722
723        // Background
724        svg.push_str(&format!(
725            r#"<rect width="100%" height="100%" fill="{}" />"#,
726            self.config.background_color.to_hex()
727        ));
728
729        // Create pie chart
730        let center_x = self.config.width as f64 / 2.0;
731        let center_y = self.config.height as f64 / 2.0;
732        let radius = (self.config.width.min(self.config.height) as f64 / 2.0 - 50.0).max(50.0);
733
734        let positive_angle = (positive_count as f64 / total as f64) * 2.0 * std::f64::consts::PI;
735        let negative_angle = (negative_count as f64 / total as f64) * 2.0 * std::f64::consts::PI;
736        let neutral_angle = (neutral_count as f64 / total as f64) * 2.0 * std::f64::consts::PI;
737
738        let mut start_angle = 0.0;
739
740        // Positive segment
741        if positive_count > 0 {
742            let end_angle = start_angle + positive_angle;
743            svg.push_str(&self.create_pie_segment(
744                center_x,
745                center_y,
746                radius,
747                start_angle,
748                end_angle,
749                Color::GREEN,
750                "Positive",
751                positive_count,
752                total,
753            ));
754            start_angle = end_angle;
755        }
756
757        // Negative segment
758        if negative_count > 0 {
759            let end_angle = start_angle + negative_angle;
760            svg.push_str(&self.create_pie_segment(
761                center_x,
762                center_y,
763                radius,
764                start_angle,
765                end_angle,
766                Color::RED,
767                "Negative",
768                negative_count,
769                total,
770            ));
771            start_angle = end_angle;
772        }
773
774        // Neutral segment
775        if neutral_count > 0 {
776            let end_angle = start_angle + neutral_angle;
777            svg.push_str(&self.create_pie_segment(
778                center_x,
779                center_y,
780                radius,
781                start_angle,
782                end_angle,
783                Color::new(128, 128, 128),
784                "Neutral",
785                neutral_count,
786                total,
787            ));
788        }
789
790        // Add legend
791        svg.push_str(r#"<text x="20" y="30" font-family="Arial, sans-serif" font-size="16" font-weight="bold">Sentiment Distribution</text>"#);
792
793        let legend_y = 60;
794        svg.push_str(&format!(
795            r#"<rect x="20" y="{}" width="15" height="15" fill="{}" />"#,
796            legend_y,
797            Color::GREEN.to_hex()
798        ));
799        svg.push_str(&format!(
800            r#"<text x="40" y="{}" font-family="Arial, sans-serif" font-size="12">Positive: {} ({:.1}%)</text>"#,
801            legend_y + 12, positive_count, (positive_count as f64 / total as f64) * 100.0
802        ));
803
804        svg.push_str(&format!(
805            r#"<rect x="20" y="{}" width="15" height="15" fill="{}" />"#,
806            legend_y + 25,
807            Color::RED.to_hex()
808        ));
809        svg.push_str(&format!(
810            r#"<text x="40" y="{}" font-family="Arial, sans-serif" font-size="12">Negative: {} ({:.1}%)</text>"#,
811            legend_y + 37, negative_count, (negative_count as f64 / total as f64) * 100.0
812        ));
813
814        svg.push_str(&format!(
815            r#"<rect x="20" y="{}" width="15" height="15" fill="{}" />"#,
816            legend_y + 50,
817            Color::new(128, 128, 128).to_hex()
818        ));
819        svg.push_str(&format!(
820            r#"<text x="40" y="{}" font-family="Arial, sans-serif" font-size="12">Neutral: {} ({:.1}%)</text>"#,
821            legend_y + 62, neutral_count, (neutral_count as f64 / total as f64) * 100.0
822        ));
823
824        svg.push_str("</svg>");
825        Ok(svg)
826    }
827
828    /// Create pie chart segment
829    fn create_pie_segment(
830        &self,
831        center_x: f64,
832        center_y: f64,
833        radius: f64,
834        start_angle: f64,
835        end_angle: f64,
836        color: Color,
837        label: &str,
838        _count: usize,
839        total: usize,
840    ) -> String {
841        let x1 = center_x + radius * start_angle.cos();
842        let y1 = center_y + radius * start_angle.sin();
843        let x2 = center_x + radius * end_angle.cos();
844        let y2 = center_y + radius * end_angle.sin();
845
846        let large_arc = if end_angle - start_angle > std::f64::consts::PI {
847            1
848        } else {
849            0
850        };
851
852        format!(
853            r#"<path d="M {} {} L {} {} A {} {} 0 {} 1 {} {} Z" fill="{}" stroke="white" stroke-width="2" />"#,
854            center_x,
855            center_y,
856            x1,
857            y1,
858            radius,
859            radius,
860            large_arc,
861            x2,
862            y2,
863            color.to_hex()
864        )
865    }
866
867    /// Save sentiment visualization to file
868    pub fn save_sentiment_distribution<P: AsRef<Path>>(
869        &self,
870        sentiment_results: &[SentimentResult],
871        labels: &[String],
872        path: P,
873    ) -> Result<()> {
874        let svg_content = self.sentiment_distribution(sentiment_results, labels)?;
875        fs::write(path, svg_content).map_err(|e| {
876            TextError::IoError(format!("Failed to save sentiment visualization: {e}"))
877        })?;
878        Ok(())
879    }
880}
881
882/// Topic modeling visualization
883pub struct TopicVisualizer {
884    config: VisualizationConfig,
885}
886
887impl TopicVisualizer {
888    /// Create new topic visualizer
889    pub fn new(config: VisualizationConfig) -> Self {
890        Self { config }
891    }
892
893    /// Visualize topic word distributions
894    pub fn topic_words_chart(&self, topics: &[Topic], topn: usize) -> Result<String> {
895        if topics.is_empty() {
896            return Err(TextError::InvalidInput(
897                "No topics to visualize".to_string(),
898            ));
899        }
900
901        let mut svg = String::new();
902
903        // Calculate dimensions
904        let chart_width = self.config.width;
905        let chart_height = self.config.height;
906        let margin = 50;
907        let topic_height = (chart_height - 2 * margin) / topics.len();
908
909        // SVG header
910        svg.push_str(&format!(
911            r#"<svg width="{chart_width}" height="{chart_height}" xmlns="http://www.w3.org/2000/svg">"#
912        ));
913
914        // Background
915        svg.push_str(&format!(
916            r#"<rect width="100%" height="100%" fill="{}" />"#,
917            self.config.background_color.to_hex()
918        ));
919
920        // Title
921        svg.push_str(&format!(
922            r#"<text x="{}" y="30" font-family="Arial, sans-serif" font-size="18" font-weight="bold" text-anchor="middle">Topic Word Distributions</text>"#,
923            chart_width / 2
924        ));
925
926        // Generate colors for topics
927        let colors = self.generate_topic_colors(topics.len());
928
929        for (topic_idx, topic) in topics.iter().enumerate() {
930            let y_offset = margin + topic_idx * topic_height;
931
932            // Topic label
933            svg.push_str(&format!(
934                r#"<text x="20" y="{}" font-family="Arial, sans-serif" font-size="14" font-weight="bold">Topic {}</text>"#,
935                y_offset + 20, topic_idx
936            ));
937
938            // Get top words for this topic
939            let mut topic_words: Vec<_> = topic.top_words.iter().collect();
940            topic_words.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
941            let top_words: Vec<_> = topic_words.into_iter().take(topn).collect();
942
943            if !top_words.is_empty() {
944                let max_prob = top_words[0].1;
945                let bar_area_width = chart_width - 200;
946
947                // Draw bars for top words
948                for (word_idx, (word, prob)) in top_words.iter().enumerate() {
949                    let bar_y = y_offset + 30 + word_idx * 15;
950                    let bar_width = (*prob / max_prob * bar_area_width as f64) as usize;
951
952                    // Word bar
953                    svg.push_str(&format!(
954                        r#"<rect x="120" y="{}" width="{}" height="12" fill="{}" />"#,
955                        bar_y,
956                        bar_width,
957                        colors[topic_idx % colors.len()].to_hex()
958                    ));
959
960                    // Word label
961                    svg.push_str(&format!(
962                        r#"<text x="115" y="{}" font-family="Arial, sans-serif" font-size="10" text-anchor="end">{}</text>"#,
963                        bar_y + 9, word
964                    ));
965
966                    // Probability value
967                    svg.push_str(&format!(
968                        r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="10">{:.3}</text>"#,
969                        125 + bar_width, bar_y + 9, *prob
970                    ));
971                }
972            }
973        }
974
975        svg.push_str("</svg>");
976        Ok(svg)
977    }
978
979    /// Generate colors for topics
980    fn generate_topic_colors(&self, count: usize) -> Vec<Color> {
981        let mut colors = Vec::new();
982        for i in 0..count {
983            let hue = (i as f64 / count as f64) * 360.0;
984            let color = self.hsv_to_rgb(hue, 0.7, 0.8);
985            colors.push(color);
986        }
987        colors
988    }
989
990    /// Convert HSV to RGB
991    fn hsv_to_rgb(&self, h: f64, s: f64, v: f64) -> Color {
992        let c = v * s;
993        let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
994        let m = v - c;
995
996        let (r_prime, g_prime, b_prime) = match h as i32 / 60 {
997            0 => (c, x, 0.0),
998            1 => (x, c, 0.0),
999            2 => (0.0, c, x),
1000            3 => (0.0, x, c),
1001            4 => (x, 0.0, c),
1002            _ => (c, 0.0, x),
1003        };
1004
1005        Color::new(
1006            ((r_prime + m) * 255.0) as u8,
1007            ((g_prime + m) * 255.0) as u8,
1008            ((b_prime + m) * 255.0) as u8,
1009        )
1010    }
1011
1012    /// Save topic visualization to file
1013    pub fn save_topic_words<P: AsRef<Path>>(
1014        &self,
1015        topics: &[Topic],
1016        topn: usize,
1017        path: P,
1018    ) -> Result<()> {
1019        let svg_content = self.topic_words_chart(topics, topn)?;
1020        fs::write(path, svg_content)
1021            .map_err(|e| TextError::IoError(format!("Failed to save topic visualization: {e}")))?;
1022        Ok(())
1023    }
1024}
1025
1026/// Text analysis dashboard generator
1027pub struct TextAnalyticsDashboard {
1028    config: VisualizationConfig,
1029}
1030
1031impl TextAnalyticsDashboard {
1032    /// Create new analytics dashboard
1033    pub fn new(config: VisualizationConfig) -> Self {
1034        Self { config }
1035    }
1036
1037    /// Generate complete text analytics dashboard
1038    pub fn generate_dashboard(
1039        &self,
1040        text_data: &[String],
1041        sentiment_results: &[SentimentResult],
1042        topics: &[Topic],
1043        word_frequencies: &HashMap<String, f64>,
1044    ) -> Result<String> {
1045        let mut html = String::new();
1046
1047        // HTML header
1048        html.push_str(r#"<!DOCTYPE html>
1049<html>
1050<head>
1051    <title>Text Analytics Dashboard</title>
1052    <style>
1053        body { font-family: Arial, sans-serif; margin: 20px; background-color: #f5f5f5; }
1054        .dashboard { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
1055        .widget { background: white; padding: 20px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
1056        .widget h3 { margin-top: 0; color: #333; }
1057        .full-width { grid-column: 1 / -1; }
1058        .stats { display: flex; justify-content: space-around; }
1059        .stat { text-align: center; }
1060        .stat-value { font-size: 2em; font-weight: bold; color: #2196F3; }
1061        .stat-label { color: #666; }
1062    </style>
1063</head>
1064<body>
1065    <h1>Text Analytics Dashboard</h1>
1066    <div class="dashboard">
1067"#);
1068
1069        // Text statistics widget
1070        html.push_str(
1071            r#"<div class="widget">
1072            <h3>Text Statistics</h3>
1073            <div class="stats">
1074"#,
1075        );
1076
1077        let total_docs = text_data.len();
1078        let total_words: usize = text_data
1079            .iter()
1080            .map(|text| text.split_whitespace().count())
1081            .sum();
1082        let avg_words = if total_docs > 0 {
1083            total_words / total_docs
1084        } else {
1085            0
1086        };
1087        let unique_words = word_frequencies.len();
1088
1089        html.push_str(&format!(
1090            r#"
1091                <div class="stat">
1092                    <div class="stat-value">{total_docs}</div>
1093                    <div class="stat-label">Documents</div>
1094                </div>
1095                <div class="stat">
1096                    <div class="stat-value">{total_words}</div>
1097                    <div class="stat-label">Total Words</div>
1098                </div>
1099                <div class="stat">
1100                    <div class="stat-value">{avg_words}</div>
1101                    <div class="stat-label">Avg Words/Doc</div>
1102                </div>
1103                <div class="stat">
1104                    <div class="stat-value">{unique_words}</div>
1105                    <div class="stat-label">Unique Words</div>
1106                </div>
1107"#
1108        ));
1109
1110        html.push_str("</div></div>");
1111
1112        // Word cloud widget
1113        let word_cloud = WordCloud::from_frequencies(word_frequencies.clone(), self.config.clone());
1114        let word_cloud_svg = word_cloud.to_svg()?;
1115
1116        html.push_str(&format!(
1117            r#"<div class="widget">
1118            <h3>Word Cloud</h3>
1119            {word_cloud_svg}
1120        </div>"#
1121        ));
1122
1123        // Sentiment analysis widget
1124        let sentiment_viz = SentimentVisualizer::new(self.config.clone());
1125        let labels: Vec<String> = (0..sentiment_results.len())
1126            .map(|i| {
1127                let doc_num = i + 1;
1128                format!("Doc {doc_num}")
1129            })
1130            .collect();
1131        let sentiment_svg = sentiment_viz.sentiment_distribution(sentiment_results, &labels)?;
1132
1133        html.push_str(&format!(
1134            r#"<div class="widget">
1135            <h3>Sentiment Distribution</h3>
1136            {sentiment_svg}
1137        </div>"#
1138        ));
1139
1140        // Topic modeling widget
1141        if !topics.is_empty() {
1142            let topic_viz = TopicVisualizer::new(self.config.clone());
1143            let topic_svg = topic_viz.topic_words_chart(topics, 5)?;
1144
1145            html.push_str(&format!(
1146                r#"<div class="widget full-width">
1147                <h3>Topic Analysis</h3>
1148                {topic_svg}
1149            </div>"#
1150            ));
1151        }
1152
1153        // HTML footer
1154        html.push_str(
1155            r#"
1156    </div>
1157</body>
1158</html>"#,
1159        );
1160
1161        Ok(html)
1162    }
1163
1164    /// Save dashboard to HTML file
1165    pub fn save_dashboard<P: AsRef<Path>>(
1166        &self,
1167        text_data: &[String],
1168        sentiment_results: &[SentimentResult],
1169        topics: &[Topic],
1170        word_frequencies: &HashMap<String, f64>,
1171        path: P,
1172    ) -> Result<()> {
1173        let html_content =
1174            self.generate_dashboard(text_data, sentiment_results, topics, word_frequencies)?;
1175        fs::write(path, html_content)
1176            .map_err(|e| TextError::IoError(format!("Failed to save dashboard: {e}")))?;
1177        Ok(())
1178    }
1179}
1180
1181#[cfg(test)]
1182mod tests {
1183    use super::*;
1184    use crate::sentiment::{Sentiment, SentimentResult};
1185    use std::collections::HashMap;
1186
1187    #[test]
1188    fn test_word_cloud_creation() {
1189        let mut frequencies = HashMap::new();
1190        frequencies.insert("hello".to_string(), 10.0);
1191        frequencies.insert("world".to_string(), 8.0);
1192        frequencies.insert("test".to_string(), 5.0);
1193
1194        let config = VisualizationConfig::default();
1195        let word_cloud = WordCloud::from_frequencies(frequencies, config);
1196
1197        let svg = word_cloud.to_svg().unwrap();
1198        assert!(svg.contains("<svg"));
1199        assert!(svg.contains("hello"));
1200        assert!(svg.contains("world"));
1201        assert!(svg.contains("test"));
1202    }
1203
1204    #[test]
1205    fn test_sentiment_visualization() {
1206        let sentiment_results = vec![
1207            SentimentResult {
1208                sentiment: Sentiment::Positive,
1209                confidence: 0.8,
1210                score: 0.8,
1211                word_counts: crate::sentiment::SentimentWordCounts::default(),
1212            },
1213            SentimentResult {
1214                sentiment: Sentiment::Negative,
1215                confidence: 0.7,
1216                score: -0.7,
1217                word_counts: crate::sentiment::SentimentWordCounts::default(),
1218            },
1219            SentimentResult {
1220                sentiment: Sentiment::Neutral,
1221                confidence: 0.6,
1222                score: 0.0,
1223                word_counts: crate::sentiment::SentimentWordCounts::default(),
1224            },
1225        ];
1226
1227        let labels = vec!["Doc1".to_string(), "Doc2".to_string(), "Doc3".to_string()];
1228        let config = VisualizationConfig::default();
1229        let viz = SentimentVisualizer::new(config);
1230
1231        let svg = viz
1232            .sentiment_distribution(&sentiment_results, &labels)
1233            .unwrap();
1234        assert!(svg.contains("<svg"));
1235        assert!(svg.contains("Positive"));
1236        assert!(svg.contains("Negative"));
1237        assert!(svg.contains("Neutral"));
1238    }
1239
1240    #[test]
1241    fn test_color_generation() {
1242        let config = VisualizationConfig::default();
1243        let word_cloud = WordCloud::from_frequencies(HashMap::new(), config);
1244
1245        let colors = word_cloud.generate_viridis_colors(5);
1246        assert_eq!(colors.len(), 5);
1247
1248        let colors = word_cloud.generate_plasma_colors(3);
1249        assert_eq!(colors.len(), 3);
1250    }
1251
1252    #[test]
1253    fn test_hsv_to_rgb_conversion() {
1254        let viz = EmbeddingVisualizer::new(VisualizationConfig::default());
1255
1256        // Test red (0 degrees)
1257        let red = viz.hsv_to_rgb(0.0, 1.0, 1.0);
1258        assert_eq!(red.r, 255);
1259        assert_eq!(red.g, 0);
1260        assert_eq!(red.b, 0);
1261
1262        // Test green (120 degrees)
1263        let green = viz.hsv_to_rgb(120.0, 1.0, 1.0);
1264        assert_eq!(green.r, 0);
1265        assert_eq!(green.g, 255);
1266        assert_eq!(green.b, 0);
1267    }
1268}