rust_bert/pipelines/
sentiment.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13//! # Sentiment Analysis pipeline
14//! Predicts the binary sentiment for a sentence. By default, the dependencies for this
15//! model will be downloaded for a DistilBERT model finetuned on SST-2.
16//! Customized DistilBERT models can be loaded by overwriting the resources in the configuration.
17//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-sst2
18//!
19//! ```no_run
20//! use rust_bert::pipelines::sentiment::SentimentModel;
21//!
22//! # fn main() -> anyhow::Result<()> {
23//! let sentiment_classifier = SentimentModel::new(Default::default())?;
24//! let input = [
25//!     "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
26//!     "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
27//!     "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
28//! ];
29//! let output = sentiment_classifier.predict(&input);
30//! # Ok(())
31//! # }
32//! ```
33//! (Example courtesy of [IMDb](http://www.imdb.com))
34//!
35//! Output: \
36//! ```no_run
37//! # use rust_bert::pipelines::sentiment::Sentiment;
38//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
39//! # let output =
40//! [
41//!     Sentiment {
42//!         polarity: Positive,
43//!         score: 0.998,
44//!     },
45//!     Sentiment {
46//!         polarity: Negative,
47//!         score: 0.992,
48//!     },
49//!     Sentiment {
50//!         polarity: Positive,
51//!         score: 0.999,
52//!     },
53//! ]
54//! # ;
55//! ```
56
57use crate::common::error::RustBertError;
58use crate::pipelines::common::TokenizerOption;
59use crate::pipelines::sequence_classification::{
60    SequenceClassificationConfig, SequenceClassificationModel,
61};
62use serde::{Deserialize, Serialize};
63
64#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
65/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
66pub enum SentimentPolarity {
67    Positive,
68    Negative,
69}
70
71#[derive(Debug, Serialize, Deserialize)]
72/// Sentiment returned by the model.
73pub struct Sentiment {
74    /// Polarity of the sentiment
75    pub polarity: SentimentPolarity,
76    /// Confidence score
77    pub score: f64,
78}
79
80pub type SentimentConfig = SequenceClassificationConfig;
81
82/// # SentimentClassifier to perform sentiment analysis
83pub struct SentimentModel {
84    sequence_classification_model: SequenceClassificationModel,
85}
86
87impl SentimentModel {
88    /// Build a new `SentimentModel`
89    ///
90    /// # Arguments
91    ///
92    /// * `sentiment_config` - `SentimentConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
93    ///
94    /// # Example
95    ///
96    /// ```no_run
97    /// # fn main() -> anyhow::Result<()> {
98    /// use rust_bert::pipelines::sentiment::SentimentModel;
99    ///
100    /// let sentiment_model = SentimentModel::new(Default::default())?;
101    /// # Ok(())
102    /// # }
103    /// ```
104    pub fn new(sentiment_config: SentimentConfig) -> Result<SentimentModel, RustBertError> {
105        let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
106        Ok(SentimentModel {
107            sequence_classification_model,
108        })
109    }
110
111    /// Build a new `SentimentModel` with a provided tokenizer.
112    ///
113    /// # Arguments
114    ///
115    /// * `sentiment_config` - `SentimentConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
116    /// * `tokenizer` - `TokenizerOption` tokenizer to use for sentiment classification.
117    ///
118    /// # Example
119    ///
120    /// ```no_run
121    /// # fn main() -> anyhow::Result<()> {
122    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
123    /// use rust_bert::pipelines::sentiment::SentimentModel;
124    /// let tokenizer = TokenizerOption::from_file(
125    ///     ModelType::Bert,
126    ///     "path/to/vocab.txt",
127    ///     None,
128    ///     false,
129    ///     None,
130    ///     None,
131    /// )?;
132    /// let sentiment_model = SentimentModel::new_with_tokenizer(Default::default(), tokenizer)?;
133    /// # Ok(())
134    /// # }
135    /// ```
136    pub fn new_with_tokenizer(
137        sentiment_config: SentimentConfig,
138        tokenizer: TokenizerOption,
139    ) -> Result<SentimentModel, RustBertError> {
140        let sequence_classification_model =
141            SequenceClassificationModel::new_with_tokenizer(sentiment_config, tokenizer)?;
142        Ok(SentimentModel {
143            sequence_classification_model,
144        })
145    }
146
147    /// Get a reference to the model tokenizer.
148    pub fn get_tokenizer(&self) -> &TokenizerOption {
149        self.sequence_classification_model.get_tokenizer()
150    }
151
152    /// Get a mutable reference to the model tokenizer.
153    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
154        self.sequence_classification_model.get_tokenizer_mut()
155    }
156
157    /// Extract sentiment form an array of text inputs
158    ///
159    /// # Arguments
160    ///
161    /// * `input` - `&[&str]` Array of texts to extract the sentiment from.
162    ///
163    /// # Returns
164    /// * `Vec<Sentiment>` Sentiments extracted from texts.
165    ///
166    /// # Example
167    ///
168    /// ```no_run
169    /// # fn main() -> anyhow::Result<()> {
170    /// use rust_bert::pipelines::sentiment::SentimentModel;
171    ///
172    /// let sentiment_classifier =  SentimentModel::new(Default::default())?;
173    ///
174    /// let input = [
175    ///     "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
176    ///     "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
177    ///     "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
178    /// ];
179    ///
180    /// let output = sentiment_classifier.predict(&input);
181    /// # Ok(())
182    /// # }
183    /// ```
184    pub fn predict<'a, S>(&self, input: S) -> Vec<Sentiment>
185    where
186        S: AsRef<[&'a str]>,
187    {
188        let labels = self.sequence_classification_model.predict(input);
189        let mut sentiments = Vec::with_capacity(labels.len());
190        for label in labels {
191            let polarity = if label.id == 1 {
192                SentimentPolarity::Positive
193            } else {
194                SentimentPolarity::Negative
195            };
196            sentiments.push(Sentiment {
197                polarity,
198                score: label.score,
199            })
200        }
201        sentiments
202    }
203}
204#[cfg(test)]
205mod test {
206    use super::*;
207
208    #[test]
209    #[ignore] // no need to run, compilation is enough to verify it is Send
210    fn test() {
211        let config = SentimentConfig::default();
212        let _: Box<dyn Send> = Box::new(SentimentModel::new(config));
213    }
214}