scirs2_text/huggingface_compat/pipelines/
classification.rs1use super::ClassificationResult;
7use crate::error::Result;
8
9#[derive(Debug)]
11pub struct TextClassificationPipeline {
12 #[allow(dead_code)]
14 labels: Vec<String>,
15}
16
17impl Default for TextClassificationPipeline {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl TextClassificationPipeline {
24 pub fn new() -> Self {
26 Self {
27 labels: vec!["NEGATIVE".to_string(), "POSITIVE".to_string()],
28 }
29 }
30
31 pub fn predict(&self, text: &str) -> Result<Vec<ClassificationResult>> {
33 use crate::sentiment::{LexiconSentimentAnalyzer, Sentiment, SentimentLexicon};
35
36 let analyzer = LexiconSentimentAnalyzer::new(SentimentLexicon::with_basiclexicon());
37 let sentiment_result = analyzer.analyze(text)?;
38
39 let (label, confidence) = match sentiment_result.sentiment {
41 Sentiment::Positive => ("POSITIVE", sentiment_result.confidence),
42 Sentiment::Negative => ("NEGATIVE", sentiment_result.confidence),
43 Sentiment::Neutral => {
44 let positive_ratio = sentiment_result.word_counts.positive_words as f64
46 / (sentiment_result.word_counts.total_words as f64).max(1.0);
47 let negative_ratio = sentiment_result.word_counts.negative_words as f64
48 / (sentiment_result.word_counts.total_words as f64).max(1.0);
49
50 if positive_ratio >= negative_ratio {
51 ("POSITIVE", 0.5 + (positive_ratio - negative_ratio) / 2.0)
52 } else {
53 ("NEGATIVE", 0.5 + (negative_ratio - positive_ratio) / 2.0)
54 }
55 }
56 };
57
58 let alternative_label = if label == "POSITIVE" {
60 "NEGATIVE"
61 } else {
62 "POSITIVE"
63 };
64 let alternative_confidence = 1.0 - confidence;
65
66 Ok(vec![
67 ClassificationResult {
68 label: label.to_string(),
69 score: confidence,
70 },
71 ClassificationResult {
72 label: alternative_label.to_string(),
73 score: alternative_confidence,
74 },
75 ])
76 }
77}
78
79#[derive(Debug)]
81pub struct ZeroShotClassificationPipeline {
82 hypothesis_template: String,
84}
85
86impl Default for ZeroShotClassificationPipeline {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl ZeroShotClassificationPipeline {
93 pub fn new() -> Self {
95 Self {
96 hypothesis_template: "This example is {}.".to_string(),
97 }
98 }
99
100 pub fn classify(
102 &self,
103 text: &str,
104 candidate_labels: &[&str],
105 ) -> Result<Vec<ClassificationResult>> {
106 let mut results = Vec::new();
107
108 use crate::distance::cosine_similarity;
110 use crate::tokenize::WhitespaceTokenizer;
111 use crate::vectorize::{CountVectorizer, Vectorizer};
112
113 let tokenizer = WhitespaceTokenizer::new();
114 let mut vectorizer = CountVectorizer::with_tokenizer(Box::new(tokenizer), false);
115
116 let mut corpus = vec![text];
118 let hypotheses: Vec<String> = candidate_labels
119 .iter()
120 .map(|label| self.hypothesis_template.replace("{}", label))
121 .collect();
122 corpus.extend(hypotheses.iter().map(|h| h.as_str()));
123
124 if let Ok(vectors) = vectorizer.fit_transform(&corpus) {
126 let text_vector = vectors.row(0);
127
128 for (i, &label) in candidate_labels.iter().enumerate() {
129 let hypothesis_vector = vectors.row(i + 1);
130
131 let similarity = cosine_similarity(text_vector, hypothesis_vector).unwrap_or(0.0);
133
134 let text_lower = text.to_lowercase();
136 let label_lower = label.to_lowercase();
137 let keyword_bonus = if text_lower.contains(&label_lower) {
138 0.2
139 } else {
140 0.0
141 };
142
143 let score = (similarity + keyword_bonus).clamp(0.0, 1.0);
144
145 results.push(ClassificationResult {
146 label: label.to_string(),
147 score,
148 });
149 }
150 } else {
151 for &label in candidate_labels {
153 let text_lower = text.to_lowercase();
154 let label_lower = label.to_lowercase();
155
156 let score = if text_lower.contains(&label_lower) {
157 0.8
158 } else {
159 let text_words: std::collections::HashSet<_> =
161 text_lower.split_whitespace().collect();
162 let label_words: std::collections::HashSet<_> =
163 label_lower.split_whitespace().collect();
164 let common_words = text_words.intersection(&label_words).count();
165 let total_words = text_words.union(&label_words).count();
166
167 if total_words > 0 {
168 common_words as f64 / total_words as f64
169 } else {
170 0.1
171 }
172 };
173
174 results.push(ClassificationResult {
175 label: label.to_string(),
176 score,
177 });
178 }
179 }
180
181 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
183
184 Ok(results)
185 }
186
187 pub fn set_hypothesis_template(&mut self, template: String) {
189 self.hypothesis_template = template;
190 }
191}