rust_bert/pipelines/sentiment.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13//! # Sentiment Analysis pipeline
14//! Predicts the binary sentiment for a sentence. By default, the dependencies for this
15//! model will be downloaded for a DistilBERT model finetuned on SST-2.
16//! Customized DistilBERT models can be loaded by overwriting the resources in the configuration.
17//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-sst2
18//!
19//! ```no_run
20//! use rust_bert::pipelines::sentiment::SentimentModel;
21//!
22//! # fn main() -> anyhow::Result<()> {
23//! let sentiment_classifier = SentimentModel::new(Default::default())?;
24//! let input = [
25//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
26//! "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
27//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
28//! ];
29//! let output = sentiment_classifier.predict(&input);
30//! # Ok(())
31//! # }
32//! ```
33//! (Example courtesy of [IMDb](http://www.imdb.com))
34//!
35//! Output: \
36//! ```no_run
37//! # use rust_bert::pipelines::sentiment::Sentiment;
38//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
39//! # let output =
40//! [
41//! Sentiment {
42//! polarity: Positive,
43//! score: 0.998,
44//! },
45//! Sentiment {
46//! polarity: Negative,
47//! score: 0.992,
48//! },
49//! Sentiment {
50//! polarity: Positive,
51//! score: 0.999,
52//! },
53//! ]
54//! # ;
55//! ```
56
57use crate::common::error::RustBertError;
58use crate::pipelines::common::TokenizerOption;
59use crate::pipelines::sequence_classification::{
60 SequenceClassificationConfig, SequenceClassificationModel,
61};
62use serde::{Deserialize, Serialize};
63
64#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
65/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
66pub enum SentimentPolarity {
67 Positive,
68 Negative,
69}
70
71#[derive(Debug, Serialize, Deserialize)]
72/// Sentiment returned by the model.
73pub struct Sentiment {
74 /// Polarity of the sentiment
75 pub polarity: SentimentPolarity,
76 /// Confidence score
77 pub score: f64,
78}
79
80pub type SentimentConfig = SequenceClassificationConfig;
81
82/// # SentimentClassifier to perform sentiment analysis
83pub struct SentimentModel {
84 sequence_classification_model: SequenceClassificationModel,
85}
86
87impl SentimentModel {
88 /// Build a new `SentimentModel`
89 ///
90 /// # Arguments
91 ///
92 /// * `sentiment_config` - `SentimentConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
93 ///
94 /// # Example
95 ///
96 /// ```no_run
97 /// # fn main() -> anyhow::Result<()> {
98 /// use rust_bert::pipelines::sentiment::SentimentModel;
99 ///
100 /// let sentiment_model = SentimentModel::new(Default::default())?;
101 /// # Ok(())
102 /// # }
103 /// ```
104 pub fn new(sentiment_config: SentimentConfig) -> Result<SentimentModel, RustBertError> {
105 let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
106 Ok(SentimentModel {
107 sequence_classification_model,
108 })
109 }
110
111 /// Build a new `SentimentModel` with a provided tokenizer.
112 ///
113 /// # Arguments
114 ///
115 /// * `sentiment_config` - `SentimentConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
116 /// * `tokenizer` - `TokenizerOption` tokenizer to use for sentiment classification.
117 ///
118 /// # Example
119 ///
120 /// ```no_run
121 /// # fn main() -> anyhow::Result<()> {
122 /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
123 /// use rust_bert::pipelines::sentiment::SentimentModel;
124 /// let tokenizer = TokenizerOption::from_file(
125 /// ModelType::Bert,
126 /// "path/to/vocab.txt",
127 /// None,
128 /// false,
129 /// None,
130 /// None,
131 /// )?;
132 /// let sentiment_model = SentimentModel::new_with_tokenizer(Default::default(), tokenizer)?;
133 /// # Ok(())
134 /// # }
135 /// ```
136 pub fn new_with_tokenizer(
137 sentiment_config: SentimentConfig,
138 tokenizer: TokenizerOption,
139 ) -> Result<SentimentModel, RustBertError> {
140 let sequence_classification_model =
141 SequenceClassificationModel::new_with_tokenizer(sentiment_config, tokenizer)?;
142 Ok(SentimentModel {
143 sequence_classification_model,
144 })
145 }
146
147 /// Get a reference to the model tokenizer.
148 pub fn get_tokenizer(&self) -> &TokenizerOption {
149 self.sequence_classification_model.get_tokenizer()
150 }
151
152 /// Get a mutable reference to the model tokenizer.
153 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
154 self.sequence_classification_model.get_tokenizer_mut()
155 }
156
157 /// Extract sentiment form an array of text inputs
158 ///
159 /// # Arguments
160 ///
161 /// * `input` - `&[&str]` Array of texts to extract the sentiment from.
162 ///
163 /// # Returns
164 /// * `Vec<Sentiment>` Sentiments extracted from texts.
165 ///
166 /// # Example
167 ///
168 /// ```no_run
169 /// # fn main() -> anyhow::Result<()> {
170 /// use rust_bert::pipelines::sentiment::SentimentModel;
171 ///
172 /// let sentiment_classifier = SentimentModel::new(Default::default())?;
173 ///
174 /// let input = [
175 /// "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
176 /// "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
177 /// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
178 /// ];
179 ///
180 /// let output = sentiment_classifier.predict(&input);
181 /// # Ok(())
182 /// # }
183 /// ```
184 pub fn predict<'a, S>(&self, input: S) -> Vec<Sentiment>
185 where
186 S: AsRef<[&'a str]>,
187 {
188 let labels = self.sequence_classification_model.predict(input);
189 let mut sentiments = Vec::with_capacity(labels.len());
190 for label in labels {
191 let polarity = if label.id == 1 {
192 SentimentPolarity::Positive
193 } else {
194 SentimentPolarity::Negative
195 };
196 sentiments.push(Sentiment {
197 polarity,
198 score: label.score,
199 })
200 }
201 sentiments
202 }
203}
204#[cfg(test)]
205mod test {
206 use super::*;
207
208 #[test]
209 #[ignore] // no need to run, compilation is enough to verify it is Send
210 fn test() {
211 let config = SentimentConfig::default();
212 let _: Box<dyn Send> = Box::new(SentimentModel::new(config));
213 }
214}