1use crate::error::{Result, TextError};
8use crate::sentiment::SentimentResult;
9use crate::topic_modeling::Topic;
10use crate::vectorize::{CountVectorizer, TfidfVectorizer, Vectorizer};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
12use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15
16#[derive(Debug, Clone)]
18pub struct VisualizationConfig {
19 pub width: usize,
21 pub height: usize,
23 pub color_scheme: ColorScheme,
25 pub font_size_range: (usize, usize),
27 pub background_color: Color,
29 pub high_dpi: bool,
31}
32
33impl Default for VisualizationConfig {
34 fn default() -> Self {
35 Self {
36 width: 800,
37 height: 600,
38 color_scheme: ColorScheme::Viridis,
39 font_size_range: (10, 100),
40 background_color: Color::WHITE,
41 high_dpi: false,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub enum ColorScheme {
49 Viridis,
51 Plasma,
53 Inferno,
55 Cool,
57 Warm,
59 Custom(Vec<Color>),
61}
62
63#[derive(Debug, Clone, Copy)]
65pub struct Color {
66 pub r: u8,
68 pub g: u8,
70 pub b: u8,
72}
73
74impl Color {
75 pub const WHITE: Color = Color {
77 r: 255,
78 g: 255,
79 b: 255,
80 };
81 pub const BLACK: Color = Color { r: 0, g: 0, b: 0 };
83 pub const RED: Color = Color { r: 255, g: 0, b: 0 };
85 pub const GREEN: Color = Color { r: 0, g: 255, b: 0 };
87 pub const BLUE: Color = Color { r: 0, g: 0, b: 255 };
89
90 pub fn new(r: u8, g: u8, b: u8) -> Self {
92 Self { r, g, b }
93 }
94
95 pub fn to_hex(&self) -> String {
97 format!("#{:02x}{:02x}{:02x}", self.r, self.g, self.b)
98 }
99}
100
101pub struct WordCloud {
103 word_frequencies: HashMap<String, f64>,
105 config: VisualizationConfig,
107}
108
109impl WordCloud {
110 pub fn fromtext(text: &str, config: VisualizationConfig) -> Result<Self> {
112 let mut vectorizer = CountVectorizer::new(false);
113 let documents = vec![text];
114 let matrix = vectorizer.fit_transform(&documents)?;
115
116 let vocabulary_map = vectorizer.vocabulary_map();
117 let mut word_frequencies = HashMap::new();
118
119 for (word, &idx) in vocabulary_map.iter() {
121 if let Some(count) = vectorizer.get_feature_count(&matrix, 0, idx) {
122 if count > 0.0 {
123 word_frequencies.insert(word.clone(), count);
124 }
125 }
126 }
127
128 Ok(Self {
129 word_frequencies,
130 config,
131 })
132 }
133
134 pub fn from_tfidf(
136 vectorizer: &TfidfVectorizer,
137 matrix: &Array2<f64>,
138 document_index: usize,
139 ) -> Result<Self> {
140 let vocabulary_map = vectorizer.vocabulary_map();
141 let mut word_frequencies = HashMap::new();
142
143 for (word, &idx) in vocabulary_map.iter() {
145 if let Some(score) = vectorizer.get_feature_score(matrix, document_index, idx) {
146 if score > 0.0 {
147 word_frequencies.insert(word.clone(), score);
148 }
149 }
150 }
151
152 Ok(Self {
153 word_frequencies,
154 config: VisualizationConfig::default(),
155 })
156 }
157
158 pub fn from_frequencies(
160 frequencies: HashMap<String, f64>,
161 config: VisualizationConfig,
162 ) -> Self {
163 Self {
164 word_frequencies: frequencies,
165 config,
166 }
167 }
168
169 pub fn to_svg(&self) -> Result<String> {
171 let mut svg = String::new();
172
173 svg.push_str(&format!(
175 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
176 self.config.width, self.config.height
177 ));
178
179 svg.push_str(&format!(
181 r#"<rect width="100%" height="100%" fill="{}" />"#,
182 self.config.background_color.to_hex()
183 ));
184
185 let mut sorted_words: Vec<_> = self.word_frequencies.iter().collect();
187 sorted_words.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
188
189 let max_words = 50;
191 let top_words: Vec<_> = sorted_words.into_iter().take(max_words).collect();
192
193 if top_words.is_empty() {
194 return Ok(svg + "</svg>");
195 }
196
197 let max_freq = top_words[0].1;
199 let min_freq = top_words.last().map(|x| *x.1).unwrap_or(*max_freq);
200 let freq_range = max_freq - min_freq;
201
202 let colors = self.generate_colors(top_words.len());
204
205 let cols = (top_words.len() as f64).sqrt().ceil() as usize;
207 let rows = top_words.len().div_ceil(cols);
208 let cell_width = self.config.width / cols;
209 let cell_height = self.config.height / rows;
210
211 for (i, (word, &freq)) in top_words.iter().enumerate() {
212 let row = i / cols;
213 let col = i % cols;
214
215 let x = col * cell_width + cell_width / 2;
217 let y = row * cell_height + cell_height / 2;
218
219 let font_size = if freq_range > 0.0 {
221 let normalized = (freq - min_freq) / freq_range;
222 self.config.font_size_range.0
223 + (normalized
224 * (self.config.font_size_range.1 - self.config.font_size_range.0) as f64)
225 as usize
226 } else {
227 (self.config.font_size_range.0 + self.config.font_size_range.1) / 2
228 };
229
230 svg.push_str(&format!(
232 r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="{}"
233 fill="{}" text-anchor="middle" dominant-baseline="middle">{}</text>"#,
234 x,
235 y,
236 font_size,
237 colors[i % colors.len()].to_hex(),
238 word
239 ));
240 }
241
242 svg.push_str("</svg>");
243 Ok(svg)
244 }
245
246 fn generate_colors(&self, count: usize) -> Vec<Color> {
248 match &self.config.color_scheme {
249 ColorScheme::Viridis => self.generate_viridis_colors(count),
250 ColorScheme::Plasma => self.generate_plasma_colors(count),
251 ColorScheme::Inferno => self.generate_inferno_colors(count),
252 ColorScheme::Cool => self.generate_cool_colors(count),
253 ColorScheme::Warm => self.generate_warm_colors(count),
254 ColorScheme::Custom(colors) => colors.clone(),
255 }
256 }
257
258 fn generate_viridis_colors(&self, count: usize) -> Vec<Color> {
260 let mut colors = Vec::new();
261 for i in 0..count {
262 let t = i as f64 / (count - 1).max(1) as f64;
263 let r = (68.0 + t * (253.0 - 68.0)) as u8;
265 let g = (1.0 + t * (231.0 - 1.0)) as u8;
266 let b = (84.0 + t * (37.0 - 84.0)) as u8;
267 colors.push(Color::new(r, g, b));
268 }
269 colors
270 }
271
272 fn generate_plasma_colors(&self, count: usize) -> Vec<Color> {
274 let mut colors = Vec::new();
275 for i in 0..count {
276 let t = i as f64 / (count - 1).max(1) as f64;
277 let r = (13.0 + t * (240.0 - 13.0)) as u8;
279 let g = (8.0 + t * (249.0 - 8.0)) as u8;
280 let b = (135.0 + t * (33.0 - 135.0)) as u8;
281 colors.push(Color::new(r, g, b));
282 }
283 colors
284 }
285
286 fn generate_inferno_colors(&self, count: usize) -> Vec<Color> {
288 let mut colors = Vec::new();
289 for i in 0..count {
290 let t = i as f64 / (count - 1).max(1) as f64;
291 let r = (0.0 + t * (252.0 - 0.0)) as u8;
293 let g = (0.0 + t * (255.0 - 0.0)) as u8;
294 let b = (4.0 + t * (164.0 - 4.0)) as u8;
295 colors.push(Color::new(r, g, b));
296 }
297 colors
298 }
299
300 fn generate_cool_colors(&self, count: usize) -> Vec<Color> {
302 let mut colors = Vec::new();
303 for i in 0..count {
304 let t = i as f64 / (count - 1).max(1) as f64;
305 let r = (0.0 + t * (255.0 - 0.0)) as u8;
306 let g = (255.0 - t * (255.0 - 0.0)) as u8;
307 let b = 255;
308 colors.push(Color::new(r, g, b));
309 }
310 colors
311 }
312
313 fn generate_warm_colors(&self, count: usize) -> Vec<Color> {
315 let mut colors = Vec::new();
316 for i in 0..count {
317 let t = i as f64 / (count - 1).max(1) as f64;
318 let r = 255;
319 let g = (255.0 - t * (255.0 - 0.0)) as u8;
320 let b = (0.0 + t * (255.0 - 0.0)) as u8;
321 colors.push(Color::new(r, g, b));
322 }
323 colors
324 }
325
326 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
328 let svg_content = self.to_svg()?;
329 fs::write(path, svg_content)
330 .map_err(|e| TextError::IoError(format!("Failed to save word cloud: {e}")))?;
331 Ok(())
332 }
333}
334
335pub struct AttentionVisualizer {
337 config: VisualizationConfig,
338}
339
340impl AttentionVisualizer {
341 pub fn new(config: VisualizationConfig) -> Self {
343 Self { config }
344 }
345
346 pub fn attention_heatmap(
348 &self,
349 attentionweights: ArrayView2<f64>,
350 source_tokens: &[String],
351 target_tokens: &[String],
352 ) -> Result<String> {
353 let (n_target, n_source) = attentionweights.dim();
354
355 if source_tokens.len() != n_source || target_tokens.len() != n_target {
356 return Err(TextError::InvalidInput(
357 "Token count doesn't match attention matrix dimensions".to_string(),
358 ));
359 }
360
361 let mut svg = String::new();
362
363 let cell_width = self.config.width / (n_source + 1);
365 let cell_height = self.config.height / (n_target + 1);
366 let _matrix_width = n_source * cell_width;
367 let matrix_height = n_target * cell_height;
368
369 svg.push_str(&format!(
371 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
372 self.config.width + 100,
373 self.config.height + 100
374 ));
375
376 svg.push_str(&format!(
378 r#"<rect width="100%" height="100%" fill="{}" />"#,
379 self.config.background_color.to_hex()
380 ));
381
382 let mut min_val = f64::INFINITY;
384 let mut max_val = f64::NEG_INFINITY;
385 for &val in attentionweights.iter() {
386 min_val = min_val.min(val);
387 max_val = max_val.max(val);
388 }
389
390 for i in 0..n_target {
392 for j in 0..n_source {
393 let attention = attentionweights[[i, j]];
394 let normalized = if max_val > min_val {
395 (attention - min_val) / (max_val - min_val)
396 } else {
397 0.5
398 };
399
400 let intensity = (normalized * 255.0) as u8;
402 let color = Color::new(255 - intensity, 255 - intensity, 255);
403
404 let x = 50 + j * cell_width;
405 let y = 50 + i * cell_height;
406
407 svg.push_str(&format!(
408 r#"<rect x="{}" y="{}" width="{}" height="{}" fill="{}" stroke="black" stroke-width="1" />"#,
409 x, y, cell_width, cell_height, color.to_hex()
410 ));
411
412 svg.push_str(&format!(
414 r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="10"
415 text-anchor="middle" dominant-baseline="middle">{:.2}</text>"#,
416 x + cell_width / 2,
417 y + cell_height / 2,
418 attention
419 ));
420 }
421 }
422
423 for (j, token) in source_tokens.iter().enumerate() {
425 let x = 50 + j * cell_width + cell_width / 2;
426 let y = 50 + matrix_height + 20;
427
428 svg.push_str(&format!(
429 r#"<text x="{x}" y="{y}" font-family="Arial, sans-serif" font-size="12"
430 text-anchor="middle" dominant-baseline="middle" transform="rotate(-45 {x} {y})">{token}</text>"#
431 ));
432 }
433
434 for (i, token) in target_tokens.iter().enumerate() {
436 let x = 30;
437 let y = 50 + i * cell_height + cell_height / 2;
438
439 svg.push_str(&format!(
440 r#"<text x="{x}" y="{y}" font-family="Arial, sans-serif" font-size="12"
441 text-anchor="end" dominant-baseline="middle">{token}</text>"#
442 ));
443 }
444
445 svg.push_str("</svg>");
446 Ok(svg)
447 }
448
449 pub fn save_attention_heatmap<P: AsRef<Path>>(
451 &self,
452 attentionweights: ArrayView2<f64>,
453 source_tokens: &[String],
454 target_tokens: &[String],
455 path: P,
456 ) -> Result<()> {
457 let svg_content = self.attention_heatmap(attentionweights, source_tokens, target_tokens)?;
458 fs::write(path, svg_content)
459 .map_err(|e| TextError::IoError(format!("Failed to save attention heatmap: {e}")))?;
460 Ok(())
461 }
462}
463
464pub struct EmbeddingVisualizer {
466 config: VisualizationConfig,
467}
468
469impl EmbeddingVisualizer {
470 pub fn new(config: VisualizationConfig) -> Self {
472 Self { config }
473 }
474
475 pub fn visualize_embeddings(
477 &self,
478 word_vectors: &HashMap<String, Array1<f64>>,
479 words_to_plot: Option<&[String]>,
480 ) -> Result<String> {
481 let words: Vec<String> = if let Some(selected_words) = words_to_plot {
482 selected_words.to_vec()
483 } else {
484 word_vectors.keys().take(100).cloned().collect()
485 };
486
487 if words.is_empty() {
488 return Err(TextError::InvalidInput("No words to visualize".to_string()));
489 }
490
491 let mut embeddings = Vec::new();
493 let mut valid_words = Vec::new();
494
495 for word in &words {
496 if let Some(vector) = word_vectors.get(word) {
497 embeddings.push(vector.clone());
498 valid_words.push(word.clone());
499 }
500 }
501
502 if embeddings.is_empty() {
503 return Err(TextError::InvalidInput(
504 "No valid embeddings found".to_string(),
505 ));
506 }
507
508 let projected_points = self.simple_pca_2d(&embeddings)?;
510
511 let mut svg = String::new();
513
514 svg.push_str(&format!(
516 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
517 self.config.width, self.config.height
518 ));
519
520 svg.push_str(&format!(
522 r#"<rect width="100%" height="100%" fill="{}" />"#,
523 self.config.background_color.to_hex()
524 ));
525
526 let mut min_x = f64::INFINITY;
528 let mut max_x = f64::NEG_INFINITY;
529 let mut min_y = f64::INFINITY;
530 let mut max_y = f64::NEG_INFINITY;
531
532 for &(x, y) in &projected_points {
533 min_x = min_x.min(x);
534 max_x = max_x.max(x);
535 min_y = min_y.min(y);
536 max_y = max_y.max(y);
537 }
538
539 let margin = 50.0;
540 let plot_width = self.config.width as f64 - 2.0 * margin;
541 let plot_height = self.config.height as f64 - 2.0 * margin;
542
543 let colors = self.generate_colors(valid_words.len());
545
546 for (i, ((x, y), word)) in projected_points.iter().zip(&valid_words).enumerate() {
548 let scaled_x = margin + (x - min_x) / (max_x - min_x) * plot_width;
550 let scaled_y = margin + (y - min_y) / (max_y - min_y) * plot_height;
551
552 svg.push_str(&format!(
554 r#"<circle cx="{}" cy="{}" r="3" fill="{}" />"#,
555 scaled_x,
556 scaled_y,
557 colors[i % colors.len()].to_hex()
558 ));
559
560 svg.push_str(&format!(
562 r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="10"
563 text-anchor="start" dominant-baseline="middle">{}</text>"#,
564 scaled_x + 5.0,
565 scaled_y,
566 word
567 ));
568 }
569
570 svg.push_str("</svg>");
571 Ok(svg)
572 }
573
574 fn simple_pca_2d(&self, embeddings: &[Array1<f64>]) -> Result<Vec<(f64, f64)>> {
576 if embeddings.is_empty() {
577 return Ok(Vec::new());
578 }
579
580 let n_samples = embeddings.len();
581 let n_features = embeddings[0].len();
582
583 let mut data_matrix = Array2::zeros((n_samples, n_features));
585 for (i, embedding) in embeddings.iter().enumerate() {
586 data_matrix.row_mut(i).assign(embedding);
587 }
588
589 let mean = data_matrix.mean_axis(Axis(0)).unwrap();
591 for mut row in data_matrix.rows_mut() {
592 row -= &mean;
593 }
594
595 let _cov_matrix = data_matrix.t().dot(&data_matrix) / (n_samples - 1) as f64;
597
598 let mut pc1 = Array1::zeros(n_features);
601 let mut pc2 = Array1::zeros(n_features);
602
603 for i in 0..n_features {
605 pc1[i] = (i as f64).sin();
606 pc2[i] = (i as f64).cos();
607 }
608
609 pc1 /= pc1.dot(&pc1).sqrt();
611 pc2 /= pc2.dot(&pc2).sqrt();
612
613 let mut projected = Vec::new();
615 for row in data_matrix.rows() {
616 let x = row.dot(&pc1);
617 let y = row.dot(&pc2);
618 projected.push((x, y));
619 }
620
621 Ok(projected)
622 }
623
624 fn generate_colors(&self, count: usize) -> Vec<Color> {
626 let mut colors = Vec::new();
627 for i in 0..count {
628 let hue = (i as f64 / count as f64) * 360.0;
629 let color = self.hsv_to_rgb(hue, 0.8, 0.9);
630 colors.push(color);
631 }
632 colors
633 }
634
635 fn hsv_to_rgb(&self, h: f64, s: f64, v: f64) -> Color {
637 let c = v * s;
638 let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
639 let m = v - c;
640
641 let (r_prime, g_prime, b_prime) = match h as i32 / 60 {
642 0 => (c, x, 0.0),
643 1 => (x, c, 0.0),
644 2 => (0.0, c, x),
645 3 => (0.0, x, c),
646 4 => (x, 0.0, c),
647 _ => (c, 0.0, x),
648 };
649
650 Color::new(
651 ((r_prime + m) * 255.0) as u8,
652 ((g_prime + m) * 255.0) as u8,
653 ((b_prime + m) * 255.0) as u8,
654 )
655 }
656
657 pub fn save_embeddings<P: AsRef<Path>>(
659 &self,
660 word_vectors: &HashMap<String, Array1<f64>>,
661 words_to_plot: Option<&[String]>,
662 path: P,
663 ) -> Result<()> {
664 let svg_content = self.visualize_embeddings(word_vectors, words_to_plot)?;
665 fs::write(path, svg_content).map_err(|e| {
666 TextError::IoError(format!("Failed to save embedding visualization: {e}"))
667 })?;
668 Ok(())
669 }
670}
671
672pub struct SentimentVisualizer {
674 config: VisualizationConfig,
675}
676
677impl SentimentVisualizer {
678 pub fn new(config: VisualizationConfig) -> Self {
680 Self { config }
681 }
682
683 pub fn sentiment_distribution(
685 &self,
686 sentiment_results: &[SentimentResult],
687 labels: &[String],
688 ) -> Result<String> {
689 if sentiment_results.len() != labels.len() {
690 return Err(TextError::InvalidInput(
691 "Number of sentiment _results must match number of labels".to_string(),
692 ));
693 }
694
695 let mut positive_count = 0;
697 let mut negative_count = 0;
698 let mut neutral_count = 0;
699
700 for result in sentiment_results {
701 match result.sentiment {
702 crate::sentiment::Sentiment::Positive => positive_count += 1,
703 crate::sentiment::Sentiment::Negative => negative_count += 1,
704 crate::sentiment::Sentiment::Neutral => neutral_count += 1,
705 }
706 }
707
708 let total = sentiment_results.len();
709 if total == 0 {
710 return Err(TextError::InvalidInput(
711 "No sentiment data to visualize".to_string(),
712 ));
713 }
714
715 let mut svg = String::new();
716
717 svg.push_str(&format!(
719 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
720 self.config.width, self.config.height
721 ));
722
723 svg.push_str(&format!(
725 r#"<rect width="100%" height="100%" fill="{}" />"#,
726 self.config.background_color.to_hex()
727 ));
728
729 let center_x = self.config.width as f64 / 2.0;
731 let center_y = self.config.height as f64 / 2.0;
732 let radius = (self.config.width.min(self.config.height) as f64 / 2.0 - 50.0).max(50.0);
733
734 let positive_angle = (positive_count as f64 / total as f64) * 2.0 * std::f64::consts::PI;
735 let negative_angle = (negative_count as f64 / total as f64) * 2.0 * std::f64::consts::PI;
736 let neutral_angle = (neutral_count as f64 / total as f64) * 2.0 * std::f64::consts::PI;
737
738 let mut start_angle = 0.0;
739
740 if positive_count > 0 {
742 let end_angle = start_angle + positive_angle;
743 svg.push_str(&self.create_pie_segment(
744 center_x,
745 center_y,
746 radius,
747 start_angle,
748 end_angle,
749 Color::GREEN,
750 "Positive",
751 positive_count,
752 total,
753 ));
754 start_angle = end_angle;
755 }
756
757 if negative_count > 0 {
759 let end_angle = start_angle + negative_angle;
760 svg.push_str(&self.create_pie_segment(
761 center_x,
762 center_y,
763 radius,
764 start_angle,
765 end_angle,
766 Color::RED,
767 "Negative",
768 negative_count,
769 total,
770 ));
771 start_angle = end_angle;
772 }
773
774 if neutral_count > 0 {
776 let end_angle = start_angle + neutral_angle;
777 svg.push_str(&self.create_pie_segment(
778 center_x,
779 center_y,
780 radius,
781 start_angle,
782 end_angle,
783 Color::new(128, 128, 128),
784 "Neutral",
785 neutral_count,
786 total,
787 ));
788 }
789
790 svg.push_str(r#"<text x="20" y="30" font-family="Arial, sans-serif" font-size="16" font-weight="bold">Sentiment Distribution</text>"#);
792
793 let legend_y = 60;
794 svg.push_str(&format!(
795 r#"<rect x="20" y="{}" width="15" height="15" fill="{}" />"#,
796 legend_y,
797 Color::GREEN.to_hex()
798 ));
799 svg.push_str(&format!(
800 r#"<text x="40" y="{}" font-family="Arial, sans-serif" font-size="12">Positive: {} ({:.1}%)</text>"#,
801 legend_y + 12, positive_count, (positive_count as f64 / total as f64) * 100.0
802 ));
803
804 svg.push_str(&format!(
805 r#"<rect x="20" y="{}" width="15" height="15" fill="{}" />"#,
806 legend_y + 25,
807 Color::RED.to_hex()
808 ));
809 svg.push_str(&format!(
810 r#"<text x="40" y="{}" font-family="Arial, sans-serif" font-size="12">Negative: {} ({:.1}%)</text>"#,
811 legend_y + 37, negative_count, (negative_count as f64 / total as f64) * 100.0
812 ));
813
814 svg.push_str(&format!(
815 r#"<rect x="20" y="{}" width="15" height="15" fill="{}" />"#,
816 legend_y + 50,
817 Color::new(128, 128, 128).to_hex()
818 ));
819 svg.push_str(&format!(
820 r#"<text x="40" y="{}" font-family="Arial, sans-serif" font-size="12">Neutral: {} ({:.1}%)</text>"#,
821 legend_y + 62, neutral_count, (neutral_count as f64 / total as f64) * 100.0
822 ));
823
824 svg.push_str("</svg>");
825 Ok(svg)
826 }
827
828 fn create_pie_segment(
830 &self,
831 center_x: f64,
832 center_y: f64,
833 radius: f64,
834 start_angle: f64,
835 end_angle: f64,
836 color: Color,
837 label: &str,
838 _count: usize,
839 total: usize,
840 ) -> String {
841 let x1 = center_x + radius * start_angle.cos();
842 let y1 = center_y + radius * start_angle.sin();
843 let x2 = center_x + radius * end_angle.cos();
844 let y2 = center_y + radius * end_angle.sin();
845
846 let large_arc = if end_angle - start_angle > std::f64::consts::PI {
847 1
848 } else {
849 0
850 };
851
852 format!(
853 r#"<path d="M {} {} L {} {} A {} {} 0 {} 1 {} {} Z" fill="{}" stroke="white" stroke-width="2" />"#,
854 center_x,
855 center_y,
856 x1,
857 y1,
858 radius,
859 radius,
860 large_arc,
861 x2,
862 y2,
863 color.to_hex()
864 )
865 }
866
867 pub fn save_sentiment_distribution<P: AsRef<Path>>(
869 &self,
870 sentiment_results: &[SentimentResult],
871 labels: &[String],
872 path: P,
873 ) -> Result<()> {
874 let svg_content = self.sentiment_distribution(sentiment_results, labels)?;
875 fs::write(path, svg_content).map_err(|e| {
876 TextError::IoError(format!("Failed to save sentiment visualization: {e}"))
877 })?;
878 Ok(())
879 }
880}
881
882pub struct TopicVisualizer {
884 config: VisualizationConfig,
885}
886
887impl TopicVisualizer {
888 pub fn new(config: VisualizationConfig) -> Self {
890 Self { config }
891 }
892
893 pub fn topic_words_chart(&self, topics: &[Topic], topn: usize) -> Result<String> {
895 if topics.is_empty() {
896 return Err(TextError::InvalidInput(
897 "No topics to visualize".to_string(),
898 ));
899 }
900
901 let mut svg = String::new();
902
903 let chart_width = self.config.width;
905 let chart_height = self.config.height;
906 let margin = 50;
907 let topic_height = (chart_height - 2 * margin) / topics.len();
908
909 svg.push_str(&format!(
911 r#"<svg width="{chart_width}" height="{chart_height}" xmlns="http://www.w3.org/2000/svg">"#
912 ));
913
914 svg.push_str(&format!(
916 r#"<rect width="100%" height="100%" fill="{}" />"#,
917 self.config.background_color.to_hex()
918 ));
919
920 svg.push_str(&format!(
922 r#"<text x="{}" y="30" font-family="Arial, sans-serif" font-size="18" font-weight="bold" text-anchor="middle">Topic Word Distributions</text>"#,
923 chart_width / 2
924 ));
925
926 let colors = self.generate_topic_colors(topics.len());
928
929 for (topic_idx, topic) in topics.iter().enumerate() {
930 let y_offset = margin + topic_idx * topic_height;
931
932 svg.push_str(&format!(
934 r#"<text x="20" y="{}" font-family="Arial, sans-serif" font-size="14" font-weight="bold">Topic {}</text>"#,
935 y_offset + 20, topic_idx
936 ));
937
938 let mut topic_words: Vec<_> = topic.top_words.iter().collect();
940 topic_words.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
941 let top_words: Vec<_> = topic_words.into_iter().take(topn).collect();
942
943 if !top_words.is_empty() {
944 let max_prob = top_words[0].1;
945 let bar_area_width = chart_width - 200;
946
947 for (word_idx, (word, prob)) in top_words.iter().enumerate() {
949 let bar_y = y_offset + 30 + word_idx * 15;
950 let bar_width = (*prob / max_prob * bar_area_width as f64) as usize;
951
952 svg.push_str(&format!(
954 r#"<rect x="120" y="{}" width="{}" height="12" fill="{}" />"#,
955 bar_y,
956 bar_width,
957 colors[topic_idx % colors.len()].to_hex()
958 ));
959
960 svg.push_str(&format!(
962 r#"<text x="115" y="{}" font-family="Arial, sans-serif" font-size="10" text-anchor="end">{}</text>"#,
963 bar_y + 9, word
964 ));
965
966 svg.push_str(&format!(
968 r#"<text x="{}" y="{}" font-family="Arial, sans-serif" font-size="10">{:.3}</text>"#,
969 125 + bar_width, bar_y + 9, *prob
970 ));
971 }
972 }
973 }
974
975 svg.push_str("</svg>");
976 Ok(svg)
977 }
978
979 fn generate_topic_colors(&self, count: usize) -> Vec<Color> {
981 let mut colors = Vec::new();
982 for i in 0..count {
983 let hue = (i as f64 / count as f64) * 360.0;
984 let color = self.hsv_to_rgb(hue, 0.7, 0.8);
985 colors.push(color);
986 }
987 colors
988 }
989
990 fn hsv_to_rgb(&self, h: f64, s: f64, v: f64) -> Color {
992 let c = v * s;
993 let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
994 let m = v - c;
995
996 let (r_prime, g_prime, b_prime) = match h as i32 / 60 {
997 0 => (c, x, 0.0),
998 1 => (x, c, 0.0),
999 2 => (0.0, c, x),
1000 3 => (0.0, x, c),
1001 4 => (x, 0.0, c),
1002 _ => (c, 0.0, x),
1003 };
1004
1005 Color::new(
1006 ((r_prime + m) * 255.0) as u8,
1007 ((g_prime + m) * 255.0) as u8,
1008 ((b_prime + m) * 255.0) as u8,
1009 )
1010 }
1011
1012 pub fn save_topic_words<P: AsRef<Path>>(
1014 &self,
1015 topics: &[Topic],
1016 topn: usize,
1017 path: P,
1018 ) -> Result<()> {
1019 let svg_content = self.topic_words_chart(topics, topn)?;
1020 fs::write(path, svg_content)
1021 .map_err(|e| TextError::IoError(format!("Failed to save topic visualization: {e}")))?;
1022 Ok(())
1023 }
1024}
1025
1026pub struct TextAnalyticsDashboard {
1028 config: VisualizationConfig,
1029}
1030
1031impl TextAnalyticsDashboard {
1032 pub fn new(config: VisualizationConfig) -> Self {
1034 Self { config }
1035 }
1036
1037 pub fn generate_dashboard(
1039 &self,
1040 text_data: &[String],
1041 sentiment_results: &[SentimentResult],
1042 topics: &[Topic],
1043 word_frequencies: &HashMap<String, f64>,
1044 ) -> Result<String> {
1045 let mut html = String::new();
1046
1047 html.push_str(r#"<!DOCTYPE html>
1049<html>
1050<head>
1051 <title>Text Analytics Dashboard</title>
1052 <style>
1053 body { font-family: Arial, sans-serif; margin: 20px; background-color: #f5f5f5; }
1054 .dashboard { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
1055 .widget { background: white; padding: 20px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
1056 .widget h3 { margin-top: 0; color: #333; }
1057 .full-width { grid-column: 1 / -1; }
1058 .stats { display: flex; justify-content: space-around; }
1059 .stat { text-align: center; }
1060 .stat-value { font-size: 2em; font-weight: bold; color: #2196F3; }
1061 .stat-label { color: #666; }
1062 </style>
1063</head>
1064<body>
1065 <h1>Text Analytics Dashboard</h1>
1066 <div class="dashboard">
1067"#);
1068
1069 html.push_str(
1071 r#"<div class="widget">
1072 <h3>Text Statistics</h3>
1073 <div class="stats">
1074"#,
1075 );
1076
1077 let total_docs = text_data.len();
1078 let total_words: usize = text_data
1079 .iter()
1080 .map(|text| text.split_whitespace().count())
1081 .sum();
1082 let avg_words = if total_docs > 0 {
1083 total_words / total_docs
1084 } else {
1085 0
1086 };
1087 let unique_words = word_frequencies.len();
1088
1089 html.push_str(&format!(
1090 r#"
1091 <div class="stat">
1092 <div class="stat-value">{total_docs}</div>
1093 <div class="stat-label">Documents</div>
1094 </div>
1095 <div class="stat">
1096 <div class="stat-value">{total_words}</div>
1097 <div class="stat-label">Total Words</div>
1098 </div>
1099 <div class="stat">
1100 <div class="stat-value">{avg_words}</div>
1101 <div class="stat-label">Avg Words/Doc</div>
1102 </div>
1103 <div class="stat">
1104 <div class="stat-value">{unique_words}</div>
1105 <div class="stat-label">Unique Words</div>
1106 </div>
1107"#
1108 ));
1109
1110 html.push_str("</div></div>");
1111
1112 let word_cloud = WordCloud::from_frequencies(word_frequencies.clone(), self.config.clone());
1114 let word_cloud_svg = word_cloud.to_svg()?;
1115
1116 html.push_str(&format!(
1117 r#"<div class="widget">
1118 <h3>Word Cloud</h3>
1119 {word_cloud_svg}
1120 </div>"#
1121 ));
1122
1123 let sentiment_viz = SentimentVisualizer::new(self.config.clone());
1125 let labels: Vec<String> = (0..sentiment_results.len())
1126 .map(|i| {
1127 let doc_num = i + 1;
1128 format!("Doc {doc_num}")
1129 })
1130 .collect();
1131 let sentiment_svg = sentiment_viz.sentiment_distribution(sentiment_results, &labels)?;
1132
1133 html.push_str(&format!(
1134 r#"<div class="widget">
1135 <h3>Sentiment Distribution</h3>
1136 {sentiment_svg}
1137 </div>"#
1138 ));
1139
1140 if !topics.is_empty() {
1142 let topic_viz = TopicVisualizer::new(self.config.clone());
1143 let topic_svg = topic_viz.topic_words_chart(topics, 5)?;
1144
1145 html.push_str(&format!(
1146 r#"<div class="widget full-width">
1147 <h3>Topic Analysis</h3>
1148 {topic_svg}
1149 </div>"#
1150 ));
1151 }
1152
1153 html.push_str(
1155 r#"
1156 </div>
1157</body>
1158</html>"#,
1159 );
1160
1161 Ok(html)
1162 }
1163
1164 pub fn save_dashboard<P: AsRef<Path>>(
1166 &self,
1167 text_data: &[String],
1168 sentiment_results: &[SentimentResult],
1169 topics: &[Topic],
1170 word_frequencies: &HashMap<String, f64>,
1171 path: P,
1172 ) -> Result<()> {
1173 let html_content =
1174 self.generate_dashboard(text_data, sentiment_results, topics, word_frequencies)?;
1175 fs::write(path, html_content)
1176 .map_err(|e| TextError::IoError(format!("Failed to save dashboard: {e}")))?;
1177 Ok(())
1178 }
1179}
1180
1181#[cfg(test)]
1182mod tests {
1183 use super::*;
1184 use crate::sentiment::{Sentiment, SentimentResult};
1185 use std::collections::HashMap;
1186
1187 #[test]
1188 fn test_word_cloud_creation() {
1189 let mut frequencies = HashMap::new();
1190 frequencies.insert("hello".to_string(), 10.0);
1191 frequencies.insert("world".to_string(), 8.0);
1192 frequencies.insert("test".to_string(), 5.0);
1193
1194 let config = VisualizationConfig::default();
1195 let word_cloud = WordCloud::from_frequencies(frequencies, config);
1196
1197 let svg = word_cloud.to_svg().unwrap();
1198 assert!(svg.contains("<svg"));
1199 assert!(svg.contains("hello"));
1200 assert!(svg.contains("world"));
1201 assert!(svg.contains("test"));
1202 }
1203
1204 #[test]
1205 fn test_sentiment_visualization() {
1206 let sentiment_results = vec![
1207 SentimentResult {
1208 sentiment: Sentiment::Positive,
1209 confidence: 0.8,
1210 score: 0.8,
1211 word_counts: crate::sentiment::SentimentWordCounts::default(),
1212 },
1213 SentimentResult {
1214 sentiment: Sentiment::Negative,
1215 confidence: 0.7,
1216 score: -0.7,
1217 word_counts: crate::sentiment::SentimentWordCounts::default(),
1218 },
1219 SentimentResult {
1220 sentiment: Sentiment::Neutral,
1221 confidence: 0.6,
1222 score: 0.0,
1223 word_counts: crate::sentiment::SentimentWordCounts::default(),
1224 },
1225 ];
1226
1227 let labels = vec!["Doc1".to_string(), "Doc2".to_string(), "Doc3".to_string()];
1228 let config = VisualizationConfig::default();
1229 let viz = SentimentVisualizer::new(config);
1230
1231 let svg = viz
1232 .sentiment_distribution(&sentiment_results, &labels)
1233 .unwrap();
1234 assert!(svg.contains("<svg"));
1235 assert!(svg.contains("Positive"));
1236 assert!(svg.contains("Negative"));
1237 assert!(svg.contains("Neutral"));
1238 }
1239
1240 #[test]
1241 fn test_color_generation() {
1242 let config = VisualizationConfig::default();
1243 let word_cloud = WordCloud::from_frequencies(HashMap::new(), config);
1244
1245 let colors = word_cloud.generate_viridis_colors(5);
1246 assert_eq!(colors.len(), 5);
1247
1248 let colors = word_cloud.generate_plasma_colors(3);
1249 assert_eq!(colors.len(), 3);
1250 }
1251
1252 #[test]
1253 fn test_hsv_to_rgb_conversion() {
1254 let viz = EmbeddingVisualizer::new(VisualizationConfig::default());
1255
1256 let red = viz.hsv_to_rgb(0.0, 1.0, 1.0);
1258 assert_eq!(red.r, 255);
1259 assert_eq!(red.g, 0);
1260 assert_eq!(red.b, 0);
1261
1262 let green = viz.hsv_to_rgb(120.0, 1.0, 1.0);
1264 assert_eq!(green.r, 0);
1265 assert_eq!(green.g, 255);
1266 assert_eq!(green.b, 0);
1267 }
1268}