rust_bert/pipelines/
sequence_classification.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019-2020 Guillaume Becquin
3// Copyright 2020 Maarten van Gompel
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//! # Sequence classification pipeline (e.g. Sentiment Analysis)
14//! More generic sequence classification pipeline, works with multiple models (Bert, Roberta)
15//!
16//! ```no_run
17//! use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
18//! use rust_bert::resources::{RemoteResource};
19//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
20//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
21//! use rust_bert::pipelines::common::ModelType;
22//! # fn main() -> anyhow::Result<()> {
23//!
24//! //Load a configuration
25//! use rust_bert::pipelines::common::ModelResource;
26//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
27//!    ModelResource::Torch(Box::new(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2))),
28//!    RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
29//!    RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
30//!    None, // Merge resources
31//!    true, //lowercase
32//!    None, //strip_accents
33//!    None, //add_prefix_space
34//! );
35//!
36//! //Create the model
37//! let sequence_classification_model = SequenceClassificationModel::new(config)?;
38//!
39//! let input = [
40//!     "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
41//!     "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
42//!     "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
43//! ];
44//! let output = sequence_classification_model.predict(&input);
45//! # Ok(())
46//! # }
47//! ```
48//! (Example courtesy of [IMDb](http://www.imdb.com))
49//!
50//! Output: \
51//! ```no_run
52//! # use rust_bert::pipelines::sequence_classification::Label;
53//! let output =
54//! [
55//!    Label { text: String::from("POSITIVE"), score: 0.9986, id: 1, sentence: 0},
56//!    Label { text: String::from("NEGATIVE"), score: 0.9985, id: 0, sentence: 1},
57//!    Label { text: String::from("POSITIVE"), score: 0.9988, id: 1, sentence: 12},
58//! ]
59//! # ;
60//! ```
61use crate::albert::AlbertForSequenceClassification;
62use crate::bart::BartForSequenceClassification;
63use crate::bert::BertForSequenceClassification;
64use crate::common::error::RustBertError;
65use crate::deberta::DebertaForSequenceClassification;
66use crate::distilbert::DistilBertModelClassifier;
67use crate::fnet::FNetForSequenceClassification;
68use crate::longformer::LongformerForSequenceClassification;
69use crate::mobilebert::MobileBertForSequenceClassification;
70use crate::pipelines::common::{
71    cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
72};
73use crate::reformer::ReformerForSequenceClassification;
74use crate::resources::ResourceProvider;
75use crate::roberta::RobertaForSequenceClassification;
76use crate::xlnet::XLNetForSequenceClassification;
77use serde::{Deserialize, Serialize};
78use std::collections::HashMap;
79use tch::nn::VarStore;
80use tch::{no_grad, Device, Kind, Tensor};
81
82use crate::deberta_v2::DebertaV2ForSequenceClassification;
83#[cfg(feature = "onnx")]
84use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
85#[cfg(feature = "remote")]
86use crate::{
87    distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
88    resources::RemoteResource,
89};
90
91#[derive(Debug, Serialize, Deserialize, Clone)]
92/// # Label generated by a `SequenceClassificationModel`
93pub struct Label {
94    /// Label String representation
95    pub text: String,
96    /// Confidence score
97    pub score: f64,
98    /// Label ID
99    pub id: i64,
100    /// Sentence index
101    #[serde(default)]
102    pub sentence: usize,
103}
104
105/// # Configuration for SequenceClassificationModel
106/// Contains information regarding the model to load and device to place the model on.
107pub struct SequenceClassificationConfig {
108    /// Model type
109    pub model_type: ModelType,
110    /// Model weights resource (default: pretrained BERT model on CoNLL)
111    pub model_resource: ModelResource,
112    /// Config resource (default: pretrained BERT model on CoNLL)
113    pub config_resource: Box<dyn ResourceProvider + Send>,
114    /// Vocab resource (default: pretrained BERT model on CoNLL)
115    pub vocab_resource: Box<dyn ResourceProvider + Send>,
116    /// Merges resource (default: None)
117    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
118    /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
119    pub lower_case: bool,
120    /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
121    pub strip_accents: Option<bool>,
122    /// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models)
123    pub add_prefix_space: Option<bool>,
124    /// Device to place the model on (default: CUDA/GPU when available)
125    pub device: Device,
126    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
127    pub kind: Option<Kind>,
128}
129
130impl SequenceClassificationConfig {
131    /// Instantiate a new sequence classification configuration of the supplied type.
132    ///
133    /// # Arguments
134    ///
135    /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
136    /// * model - The `ResourceProvider` pointing to the model to load (e.g.  model.ot)
137    /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
138    /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g.  vocab.txt/vocab.json)
139    /// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g.  merges.txt), needed only for Roberta.
140    /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
141    pub fn new<RC, RV>(
142        model_type: ModelType,
143        model_resource: ModelResource,
144        config_resource: RC,
145        vocab_resource: RV,
146        merges_resource: Option<RV>,
147        lower_case: bool,
148        strip_accents: impl Into<Option<bool>>,
149        add_prefix_space: impl Into<Option<bool>>,
150    ) -> SequenceClassificationConfig
151    where
152        RC: ResourceProvider + Send + 'static,
153        RV: ResourceProvider + Send + 'static,
154    {
155        SequenceClassificationConfig {
156            model_type,
157            model_resource,
158            config_resource: Box::new(config_resource),
159            vocab_resource: Box::new(vocab_resource),
160            merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
161            lower_case,
162            strip_accents: strip_accents.into(),
163            add_prefix_space: add_prefix_space.into(),
164            device: Device::cuda_if_available(),
165            kind: None,
166        }
167    }
168}
169
170#[cfg(feature = "remote")]
171impl Default for SequenceClassificationConfig {
172    /// Provides a defaultSST-2 sentiment analysis model (English)
173    fn default() -> SequenceClassificationConfig {
174        SequenceClassificationConfig::new(
175            ModelType::DistilBert,
176            ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
177                DistilBertModelResources::DISTIL_BERT_SST2,
178            ))),
179            RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
180            RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
181            None,
182            true,
183            None,
184            None,
185        )
186    }
187}
188
189#[allow(clippy::large_enum_variant)]
190/// # Abstraction that holds one particular sequence classification model, for any of the supported models
191pub enum SequenceClassificationOption {
192    /// Bert for Sequence Classification
193    Bert(BertForSequenceClassification),
194    /// DeBERTa for Sequence Classification
195    Deberta(DebertaForSequenceClassification),
196    /// DeBERTa V2 for Sequence Classification
197    DebertaV2(DebertaV2ForSequenceClassification),
198    /// DistilBert for Sequence Classification
199    DistilBert(DistilBertModelClassifier),
200    /// MobileBert for Sequence Classification
201    MobileBert(MobileBertForSequenceClassification),
202    /// Roberta for Sequence Classification
203    Roberta(RobertaForSequenceClassification),
204    /// XLMRoberta for Sequence Classification
205    XLMRoberta(RobertaForSequenceClassification),
206    /// Albert for Sequence Classification
207    Albert(AlbertForSequenceClassification),
208    /// XLNet for Sequence Classification
209    XLNet(XLNetForSequenceClassification),
210    /// Bart for Sequence Classification
211    Bart(BartForSequenceClassification),
212    /// Reformer for Sequence Classification
213    Reformer(ReformerForSequenceClassification),
214    /// Longformer for Sequence Classification
215    Longformer(LongformerForSequenceClassification),
216    /// FNet for Sequence Classification
217    FNet(FNetForSequenceClassification),
218    /// ONNX Model for Sequence Classification
219    #[cfg(feature = "onnx")]
220    ONNX(ONNXEncoder),
221}
222
223impl SequenceClassificationOption {
224    /// Instantiate a new sequence classification model of the supplied type.
225    ///
226    /// # Arguments
227    ///
228    /// * `SequenceClassificationConfig` - Sequence classification pipeline configuration. The type of model created will be inferred from the
229    ///     `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
230    pub fn new(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
231        match config.model_resource {
232            ModelResource::Torch(_) => Self::new_torch(config),
233            #[cfg(feature = "onnx")]
234            ModelResource::ONNX(_) => Self::new_onnx(config),
235        }
236    }
237
238    fn new_torch(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
239        let device = config.device;
240        let weights_path = config.model_resource.get_torch_local_path()?;
241        let mut var_store = VarStore::new(device);
242        let model_config =
243            &ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
244        let model_type = config.model_type;
245        let model = match model_type {
246            ModelType::Bert => {
247                if let ConfigOption::Bert(config) = model_config {
248                    Ok(Self::Bert(
249                        BertForSequenceClassification::new(var_store.root(), config)?,
250                    ))
251                } else {
252                    Err(RustBertError::InvalidConfigurationError(
253                        "You can only supply a BertConfig for Bert!".to_string(),
254                    ))
255                }
256            }
257            ModelType::Deberta => {
258                if let ConfigOption::Deberta(config) = model_config {
259                    Ok(Self::Deberta(
260                        DebertaForSequenceClassification::new(var_store.root(), config)?,
261                    ))
262                } else {
263                    Err(RustBertError::InvalidConfigurationError(
264                        "You can only supply a DebertaConfig for DeBERTa!".to_string(),
265                    ))
266                }
267            }
268            ModelType::DebertaV2 => {
269                if let ConfigOption::DebertaV2(config) = model_config {
270                    Ok(Self::DebertaV2(
271                        DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
272                    ))
273                } else {
274                    Err(RustBertError::InvalidConfigurationError(
275                        "You can only supply a DebertaV2Config for DeBERTa V2!".to_string(),
276                    ))
277                }
278            }
279            ModelType::DistilBert => {
280                if let ConfigOption::DistilBert(config) = model_config {
281                    Ok(Self::DistilBert(
282                        DistilBertModelClassifier::new(var_store.root(), config)?,
283                    ))
284                } else {
285                    Err(RustBertError::InvalidConfigurationError(
286                        "You can only supply a DistilBertConfig for DistilBert!".to_string(),
287                    ))
288                }
289            }
290            ModelType::MobileBert => {
291                if let ConfigOption::MobileBert(config) = model_config {
292                    Ok(Self::MobileBert(
293                        MobileBertForSequenceClassification::new(var_store.root(), config)?,
294                    ))
295                } else {
296                    Err(RustBertError::InvalidConfigurationError(
297                        "You can only supply a MobileBertConfig for MobileBert!".to_string(),
298                    ))
299                }
300            }
301            ModelType::Roberta => {
302                if let ConfigOption::Roberta(config) = model_config {
303                    Ok(Self::Roberta(
304                        RobertaForSequenceClassification::new(var_store.root(), config)?,
305                    ))
306                } else {
307                    Err(RustBertError::InvalidConfigurationError(
308                        "You can only supply a RobertaConfig for Roberta!".to_string(),
309                    ))
310                }
311            }
312            ModelType::XLMRoberta => {
313                if let ConfigOption::Roberta(config) = model_config {
314                    Ok(Self::XLMRoberta(
315                        RobertaForSequenceClassification::new(var_store.root(), config)?,
316                    ))
317                } else {
318                    Err(RustBertError::InvalidConfigurationError(
319                        "You can only supply a RobertaConfig for Roberta!".to_string(),
320                    ))
321                }
322            }
323            ModelType::Albert => {
324                if let ConfigOption::Albert(config) = model_config {
325                    Ok(Self::Albert(
326                        AlbertForSequenceClassification::new(var_store.root(), config)?,
327                    ))
328                } else {
329                    Err(RustBertError::InvalidConfigurationError(
330                        "You can only supply an AlbertConfig for Albert!".to_string(),
331                    ))
332                }
333            }
334            ModelType::XLNet => {
335                if let ConfigOption::XLNet(config) = model_config {
336                    Ok(Self::XLNet(
337                        XLNetForSequenceClassification::new(var_store.root(), config)?,
338                    ))
339                } else {
340                    Err(RustBertError::InvalidConfigurationError(
341                        "You can only supply an XLNetConfig for XLNet!".to_string(),
342                    ))
343                }
344            }
345            ModelType::Bart => {
346                if let ConfigOption::Bart(config) = model_config {
347                    Ok(Self::Bart(
348                        BartForSequenceClassification::new(var_store.root(), config)?,
349                    ))
350                } else {
351                    Err(RustBertError::InvalidConfigurationError(
352                        "You can only supply a BertConfig for Bert!".to_string(),
353                    ))
354                }
355            }
356            ModelType::Reformer => {
357                if let ConfigOption::Reformer(config) = model_config {
358                    Ok(Self::Reformer(
359                        ReformerForSequenceClassification::new(var_store.root(), config)?,
360                    ))
361                } else {
362                    Err(RustBertError::InvalidConfigurationError(
363                        "You can only supply a ReformerConfig for Reformer!".to_string(),
364                    ))
365                }
366            }
367            ModelType::Longformer => {
368                if let ConfigOption::Longformer(config) = model_config {
369                    Ok(Self::Longformer(
370                        LongformerForSequenceClassification::new(var_store.root(), config)?,
371                    ))
372                } else {
373                    Err(RustBertError::InvalidConfigurationError(
374                        "You can only supply a LongformerConfig for Longformer!".to_string(),
375                    ))
376                }
377            }
378            ModelType::FNet => {
379                if let ConfigOption::FNet(config) = model_config {
380                    Ok(Self::FNet(
381                        FNetForSequenceClassification::new(var_store.root(), config)?,
382                    ))
383                } else {
384                    Err(RustBertError::InvalidConfigurationError(
385                        "You can only supply a FNetConfig for FNet!".to_string(),
386                    ))
387                }
388            }
389            #[cfg(feature = "onnx")]
390            ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
391                "A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
392            )),
393            _ => Err(RustBertError::InvalidConfigurationError(format!(
394                "Sequence Classification not implemented for {model_type:?}!",
395            ))),
396        }?;
397        var_store.load(weights_path)?;
398        cast_var_store(&mut var_store, config.kind, device);
399        Ok(model)
400    }
401
402    #[cfg(feature = "onnx")]
403    pub fn new_onnx(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
404        let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
405        let environment = onnx_config.get_environment()?;
406        let encoder_file = config
407            .model_resource
408            .get_onnx_local_paths()?
409            .encoder_path
410            .ok_or(RustBertError::InvalidConfigurationError(
411                "An encoder file must be provided for sequence classification ONNX models."
412                    .to_string(),
413            ))?;
414
415        Ok(Self::ONNX(ONNXEncoder::new(
416            encoder_file,
417            &environment,
418            &onnx_config,
419        )?))
420    }
421
422    /// Returns the `ModelType` for this SequenceClassificationOption
423    pub fn model_type(&self) -> ModelType {
424        match *self {
425            Self::Bert(_) => ModelType::Bert,
426            Self::Deberta(_) => ModelType::Deberta,
427            Self::DebertaV2(_) => ModelType::DebertaV2,
428            Self::Roberta(_) => ModelType::Roberta,
429            Self::XLMRoberta(_) => ModelType::Roberta,
430            Self::DistilBert(_) => ModelType::DistilBert,
431            Self::MobileBert(_) => ModelType::MobileBert,
432            Self::Albert(_) => ModelType::Albert,
433            Self::XLNet(_) => ModelType::XLNet,
434            Self::Bart(_) => ModelType::Bart,
435            Self::Reformer(_) => ModelType::Reformer,
436            Self::Longformer(_) => ModelType::Longformer,
437            Self::FNet(_) => ModelType::FNet,
438            #[cfg(feature = "onnx")]
439            Self::ONNX(_) => ModelType::ONNX,
440        }
441    }
442
443    /// Interface method to forward_t() of the particular models.
444    pub fn forward_t(
445        &self,
446        input_ids: Option<&Tensor>,
447        mask: Option<&Tensor>,
448        token_type_ids: Option<&Tensor>,
449        position_ids: Option<&Tensor>,
450        input_embeds: Option<&Tensor>,
451        train: bool,
452    ) -> Tensor {
453        match *self {
454            Self::Bart(ref model) => {
455                model
456                    .forward_t(
457                        input_ids.expect("`input_ids` must be provided for BART models"),
458                        mask,
459                        None,
460                        None,
461                        None,
462                        train,
463                    )
464                    .decoder_output
465            }
466            Self::Bert(ref model) => {
467                model
468                    .forward_t(
469                        input_ids,
470                        mask,
471                        token_type_ids,
472                        position_ids,
473                        input_embeds,
474                        train,
475                    )
476                    .logits
477            }
478            Self::Deberta(ref model) => {
479                model
480                    .forward_t(
481                        input_ids,
482                        mask,
483                        token_type_ids,
484                        position_ids,
485                        input_embeds,
486                        train,
487                    )
488                    .expect("Error in Deberta forward_t")
489                    .logits
490            }
491            Self::DebertaV2(ref model) => {
492                model
493                    .forward_t(
494                        input_ids,
495                        mask,
496                        token_type_ids,
497                        position_ids,
498                        input_embeds,
499                        train,
500                    )
501                    .expect("Error in Deberta V2 forward_t")
502                    .logits
503            }
504            Self::DistilBert(ref model) => {
505                model
506                    .forward_t(input_ids, mask, input_embeds, train)
507                    .expect("Error in distilbert forward_t")
508                    .logits
509            }
510            Self::MobileBert(ref model) => {
511                model
512                    .forward_t(input_ids, None, None, input_embeds, mask, train)
513                    .expect("Error in mobilebert forward_t")
514                    .logits
515            }
516            Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
517                model
518                    .forward_t(
519                        input_ids,
520                        mask,
521                        token_type_ids,
522                        position_ids,
523                        input_embeds,
524                        train,
525                    )
526                    .logits
527            }
528            Self::Albert(ref model) => {
529                model
530                    .forward_t(
531                        input_ids,
532                        mask,
533                        token_type_ids,
534                        position_ids,
535                        input_embeds,
536                        train,
537                    )
538                    .logits
539            }
540            Self::XLNet(ref model) => {
541                model
542                    .forward_t(
543                        input_ids,
544                        mask,
545                        None,
546                        None,
547                        None,
548                        token_type_ids,
549                        input_embeds,
550                        train,
551                    )
552                    .logits
553            }
554            Self::Reformer(ref model) => {
555                model
556                    .forward_t(input_ids, None, None, mask, None, train)
557                    .expect("Error in Reformer forward pass.")
558                    .logits
559            }
560            Self::Longformer(ref model) => {
561                model
562                    .forward_t(
563                        input_ids,
564                        mask,
565                        None,
566                        token_type_ids,
567                        position_ids,
568                        input_embeds,
569                        train,
570                    )
571                    .expect("Error in Longformer forward pass.")
572                    .logits
573            }
574            Self::FNet(ref model) => {
575                model
576                    .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
577                    .expect("Error in FNet forward pass.")
578                    .logits
579            }
580            #[cfg(feature = "onnx")]
581            Self::ONNX(ref model) => {
582                let attention_mask = input_ids.unwrap().ones_like();
583                model
584                    .forward(
585                        input_ids,
586                        Some(&attention_mask),
587                        token_type_ids,
588                        position_ids,
589                        input_embeds,
590                    )
591                    .expect("Error in ONNX forward pass.")
592                    .logits
593                    .unwrap()
594            }
595        }
596    }
597}
598
599/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
600pub struct SequenceClassificationModel {
601    tokenizer: TokenizerOption,
602    sequence_classifier: SequenceClassificationOption,
603    label_mapping: HashMap<i64, String>,
604    device: Device,
605    max_length: usize,
606}
607
608impl SequenceClassificationModel {
609    /// Build a new `SequenceClassificationModel`
610    ///
611    /// # Arguments
612    ///
613    /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
614    ///
615    /// # Example
616    ///
617    /// ```no_run
618    /// # fn main() -> anyhow::Result<()> {
619    /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
620    ///
621    /// let model = SequenceClassificationModel::new(Default::default())?;
622    /// # Ok(())
623    /// # }
624    /// ```
625    pub fn new(
626        config: SequenceClassificationConfig,
627    ) -> Result<SequenceClassificationModel, RustBertError> {
628        let vocab_path = config.vocab_resource.get_local_path()?;
629        let merges_path = config
630            .merges_resource
631            .as_ref()
632            .map(|resource| resource.get_local_path())
633            .transpose()?;
634
635        let tokenizer = TokenizerOption::from_file(
636            config.model_type,
637            vocab_path.to_str().unwrap(),
638            merges_path.as_deref().map(|path| path.to_str().unwrap()),
639            config.lower_case,
640            config.strip_accents,
641            config.add_prefix_space,
642        )?;
643        Self::new_with_tokenizer(config, tokenizer)
644    }
645
646    /// Build a new `SequenceClassificationModel` with a provided tokenizer.
647    ///
648    /// # Arguments
649    ///
650    /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
651    /// * `tokenizer` - `TokenizerOption` tokenizer to use for sequence classification.
652    ///
653    /// # Example
654    ///
655    /// ```no_run
656    /// # fn main() -> anyhow::Result<()> {
657    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
658    /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
659    /// let tokenizer = TokenizerOption::from_file(
660    ///     ModelType::Bert,
661    ///     "path/to/vocab.txt",
662    ///     None,
663    ///     false,
664    ///     None,
665    ///     None,
666    /// )?;
667    /// let model = SequenceClassificationModel::new_with_tokenizer(Default::default(), tokenizer)?;
668    /// # Ok(())
669    /// # }
670    /// ```
671    pub fn new_with_tokenizer(
672        config: SequenceClassificationConfig,
673        tokenizer: TokenizerOption,
674    ) -> Result<SequenceClassificationModel, RustBertError> {
675        let config_path = config.config_resource.get_local_path()?;
676        let sequence_classifier = SequenceClassificationOption::new(&config)?;
677
678        let model_config = ConfigOption::from_file(config.model_type, config_path);
679        let max_length = model_config
680            .get_max_len()
681            .map(|v| v as usize)
682            .unwrap_or(usize::MAX);
683        let label_mapping = model_config.get_label_mapping().clone();
684        let device = get_device(config.model_resource, config.device);
685        Ok(SequenceClassificationModel {
686            tokenizer,
687            sequence_classifier,
688            label_mapping,
689            device,
690            max_length,
691        })
692    }
693
694    /// Get a reference to the model tokenizer.
695    pub fn get_tokenizer(&self) -> &TokenizerOption {
696        &self.tokenizer
697    }
698
699    /// Get a mutable reference to the model tokenizer.
700    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
701        &mut self.tokenizer
702    }
703    /// Classify texts
704    ///
705    /// # Arguments
706    ///
707    /// * `input` - `&[&str]` Array of texts to classify.
708    ///
709    /// # Returns
710    ///
711    /// * `Vec<Label>` containing labels for input texts
712    ///
713    /// # Example
714    ///
715    /// ```no_run
716    /// # fn main() -> anyhow::Result<()> {
717    /// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
718    ///
719    /// let sequence_classification_model =  SequenceClassificationModel::new(Default::default())?;
720    /// let input = [
721    ///     "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
722    ///     "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
723    ///     "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
724    /// ];
725    /// let output = sequence_classification_model.predict(&input);
726    /// # Ok(())
727    /// # }
728    /// ```
729    pub fn predict<'a, S>(&self, input: S) -> Vec<Label>
730    where
731        S: AsRef<[&'a str]>,
732    {
733        let (input_ids, token_type_ids) =
734            self.tokenizer
735                .tokenize_and_pad(input.as_ref(), self.max_length, self.device);
736        let output = no_grad(|| {
737            let output = self.sequence_classifier.forward_t(
738                Some(&input_ids),
739                None,
740                Some(&token_type_ids),
741                None,
742                None,
743                false,
744            );
745            output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
746        });
747        let label_indices = output.as_ref().argmax(-1, true).squeeze_dim(1);
748        let scores = output
749            .gather(1, &label_indices.unsqueeze(-1), false)
750            .squeeze_dim(1);
751        let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
752        let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
753
754        let mut labels: Vec<Label> = vec![];
755        for sentence_idx in 0..label_indices.len() {
756            let label_string = self
757                .label_mapping
758                .get(&label_indices[sentence_idx])
759                .unwrap()
760                .clone();
761            let label = Label {
762                text: label_string,
763                score: scores[sentence_idx],
764                id: label_indices[sentence_idx],
765                sentence: sentence_idx,
766            };
767            labels.push(label)
768        }
769        labels
770    }
771
772    /// Multi-label classification of texts
773    ///
774    /// # Arguments
775    ///
776    /// * `input` - `&[&str]` Array of texts to classify.
777    /// * `threshold` - `f64` threshold above which a label will be considered true by the classifier
778    ///
779    /// # Returns
780    ///
781    /// * `Vec<Vec<Label>>` containing a vector of true labels for each input text
782    ///
783    /// # Example
784    ///
785    /// ```no_run
786    /// # fn main() -> anyhow::Result<()> {
787    /// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
788    ///
789    /// let sequence_classification_model =  SequenceClassificationModel::new(Default::default())?;
790    /// let input = [
791    ///     "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
792    ///     "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
793    ///     "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
794    /// ];
795    /// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
796    /// # Ok(())
797    /// # }
798    /// ```
799    pub fn predict_multilabel(
800        &self,
801        input: &[&str],
802        threshold: f64,
803    ) -> Result<Vec<Vec<Label>>, RustBertError> {
804        let (input_ids, token_type_ids) =
805            self.tokenizer
806                .tokenize_and_pad(input.as_ref(), self.max_length, self.device);
807        let output = no_grad(|| {
808            let output = self.sequence_classifier.forward_t(
809                Some(&input_ids),
810                None,
811                Some(&token_type_ids),
812                None,
813                None,
814                false,
815            );
816            output.sigmoid().detach().to(Device::Cpu)
817        });
818        let label_indices = output.as_ref().ge(threshold).nonzero();
819
820        let mut labels: Vec<Vec<Label>> = vec![];
821        let mut sequence_labels: Vec<Label> = vec![];
822
823        for sentence_idx in 0..label_indices.size()[0] {
824            let label_index_tensor = label_indices.get(sentence_idx);
825            let sentence_label = label_index_tensor
826                .iter::<i64>()
827                .unwrap()
828                .collect::<Vec<i64>>();
829            let (sentence, id) = (sentence_label[0], sentence_label[1]);
830            if sentence as usize > labels.len() {
831                labels.push(sequence_labels);
832                sequence_labels = vec![];
833            }
834            let score = output.double_value(sentence_label.as_slice());
835            let label_string = self.label_mapping.get(&id).unwrap().to_owned();
836            let label = Label {
837                text: label_string,
838                score,
839                id,
840                sentence: sentence as usize,
841            };
842            sequence_labels.push(label);
843        }
844        if !sequence_labels.is_empty() {
845            labels.push(sequence_labels);
846        }
847        Ok(labels)
848    }
849}
850
851#[cfg(test)]
852mod test {
853    use super::*;
854
855    #[test]
856    #[ignore] // no need to run, compilation is enough to verify it is Send
857    fn test() {
858        let config = SequenceClassificationConfig::default();
859        let _: Box<dyn Send> = Box::new(SequenceClassificationModel::new(config));
860    }
861}