scirs2_text/huggingface_compat/pipelines/
classification.rs

1//! Text classification pipeline implementations
2//!
3//! This module provides text classification pipelines including
4//! standard binary/multi-class classification and zero-shot classification.
5
6use super::ClassificationResult;
7use crate::error::Result;
8
9/// Text classification pipeline
10#[derive(Debug)]
11pub struct TextClassificationPipeline {
12    /// Labels for classification
13    #[allow(dead_code)]
14    labels: Vec<String>,
15}
16
17impl Default for TextClassificationPipeline {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl TextClassificationPipeline {
24    /// Create new text classification pipeline
25    pub fn new() -> Self {
26        Self {
27            labels: vec!["NEGATIVE".to_string(), "POSITIVE".to_string()],
28        }
29    }
30
31    /// Run classification on text
32    pub fn predict(&self, text: &str) -> Result<Vec<ClassificationResult>> {
33        // Use the existing sentiment analysis functionality for more realistic predictions
34        use crate::sentiment::{LexiconSentimentAnalyzer, Sentiment, SentimentLexicon};
35
36        let analyzer = LexiconSentimentAnalyzer::new(SentimentLexicon::with_basiclexicon());
37        let sentiment_result = analyzer.analyze(text)?;
38
39        // Convert sentiment result to classification format
40        let (label, confidence) = match sentiment_result.sentiment {
41            Sentiment::Positive => ("POSITIVE", sentiment_result.confidence),
42            Sentiment::Negative => ("NEGATIVE", sentiment_result.confidence),
43            Sentiment::Neutral => {
44                // For binary classification, lean towards positive for neutral based on word counts
45                let positive_ratio = sentiment_result.word_counts.positive_words as f64
46                    / (sentiment_result.word_counts.total_words as f64).max(1.0);
47                let negative_ratio = sentiment_result.word_counts.negative_words as f64
48                    / (sentiment_result.word_counts.total_words as f64).max(1.0);
49
50                if positive_ratio >= negative_ratio {
51                    ("POSITIVE", 0.5 + (positive_ratio - negative_ratio) / 2.0)
52                } else {
53                    ("NEGATIVE", 0.5 + (negative_ratio - positive_ratio) / 2.0)
54                }
55            }
56        };
57
58        // Also provide the alternative label with lower confidence
59        let alternative_label = if label == "POSITIVE" {
60            "NEGATIVE"
61        } else {
62            "POSITIVE"
63        };
64        let alternative_confidence = 1.0 - confidence;
65
66        Ok(vec![
67            ClassificationResult {
68                label: label.to_string(),
69                score: confidence,
70            },
71            ClassificationResult {
72                label: alternative_label.to_string(),
73                score: alternative_confidence,
74            },
75        ])
76    }
77}
78
79/// Zero-shot classification pipeline
80#[derive(Debug)]
81pub struct ZeroShotClassificationPipeline {
82    /// Hypothesis template
83    hypothesis_template: String,
84}
85
86impl Default for ZeroShotClassificationPipeline {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl ZeroShotClassificationPipeline {
93    /// Create new zero-shot classification pipeline
94    pub fn new() -> Self {
95        Self {
96            hypothesis_template: "This example is {}.".to_string(),
97        }
98    }
99
100    /// Classify text against multiple labels
101    pub fn classify(
102        &self,
103        text: &str,
104        candidate_labels: &[&str],
105    ) -> Result<Vec<ClassificationResult>> {
106        let mut results = Vec::new();
107
108        // Enhanced zero-shot classification using text similarity and keyword matching
109        use crate::distance::cosine_similarity;
110        use crate::tokenize::WhitespaceTokenizer;
111        use crate::vectorize::{CountVectorizer, Vectorizer};
112
113        let tokenizer = WhitespaceTokenizer::new();
114        let mut vectorizer = CountVectorizer::with_tokenizer(Box::new(tokenizer), false);
115
116        // Create corpus with text and hypotheses for each label
117        let mut corpus = vec![text];
118        let hypotheses: Vec<String> = candidate_labels
119            .iter()
120            .map(|label| self.hypothesis_template.replace("{}", label))
121            .collect();
122        corpus.extend(hypotheses.iter().map(|h| h.as_str()));
123
124        // Vectorize the corpus
125        if let Ok(vectors) = vectorizer.fit_transform(&corpus) {
126            let text_vector = vectors.row(0);
127
128            for (i, &label) in candidate_labels.iter().enumerate() {
129                let hypothesis_vector = vectors.row(i + 1);
130
131                // Calculate cosine similarity between text and hypothesis
132                let similarity = cosine_similarity(text_vector, hypothesis_vector).unwrap_or(0.0);
133
134                // Enhance with keyword matching
135                let text_lower = text.to_lowercase();
136                let label_lower = label.to_lowercase();
137                let keyword_bonus = if text_lower.contains(&label_lower) {
138                    0.2
139                } else {
140                    0.0
141                };
142
143                let score = (similarity + keyword_bonus).clamp(0.0, 1.0);
144
145                results.push(ClassificationResult {
146                    label: label.to_string(),
147                    score,
148                });
149            }
150        } else {
151            // Fallback to simple keyword matching if vectorization fails
152            for &label in candidate_labels {
153                let text_lower = text.to_lowercase();
154                let label_lower = label.to_lowercase();
155
156                let score = if text_lower.contains(&label_lower) {
157                    0.8
158                } else {
159                    // Basic similarity based on common words
160                    let text_words: std::collections::HashSet<_> =
161                        text_lower.split_whitespace().collect();
162                    let label_words: std::collections::HashSet<_> =
163                        label_lower.split_whitespace().collect();
164                    let common_words = text_words.intersection(&label_words).count();
165                    let total_words = text_words.union(&label_words).count();
166
167                    if total_words > 0 {
168                        common_words as f64 / total_words as f64
169                    } else {
170                        0.1
171                    }
172                };
173
174                results.push(ClassificationResult {
175                    label: label.to_string(),
176                    score,
177                });
178            }
179        }
180
181        // Sort by score descending
182        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
183
184        Ok(results)
185    }
186
187    /// Set hypothesis template
188    pub fn set_hypothesis_template(&mut self, template: String) {
189        self.hypothesis_template = template;
190    }
191}