rust_bert/pipelines/
token_classification.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019-2020 Guillaume Becquin
3// Copyright 2020 Maarten van Gompel
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Token classification pipeline (Named Entity Recognition, Part-of-Speech tagging)
15//! More generic token classification pipeline, works with multiple models (Bert, Roberta)
16//!
17//! ```no_run
18//! use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
19//! use rust_bert::resources::RemoteResource;
20//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
21//! use rust_bert::pipelines::common::ModelType;
22//! # fn main() -> anyhow::Result<()> {
23//!
24//! use rust_bert::pipelines::common::ModelResource;
25//! //Load a configuration
26//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
27//! let config = TokenClassificationConfig::new(
28//!    ModelType::Bert,
29//!    ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT_NER))),
30//!    RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
31//!    RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
32//!    None, //merges resource only relevant with ModelType::Roberta
33//!    false, //lowercase
34//!    None, //strip_accents
35//!    None, //add_prefix_space
36//!    LabelAggregationOption::Mode
37//! );
38//!
39//! //Create the model
40//! let token_classification_model = TokenClassificationModel::new(config)?;
41//!
42//! let input = [
43//!     "My name is Amy. I live in Paris.",
44//!     "Paris is a city in France."
45//! ];
46//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
47//! # Ok(())
48//! # }
49//! ```
50//! Output: \
51//! ```no_run
52//! # use rust_bert::pipelines::token_classification::Token;
53//! use rust_tokenizers::{Mask, Offset};
54//! # let output =
55//! [
56//!     Token {
57//!         text: String::from("[CLS]"),
58//!         score: 0.9995001554489136,
59//!         label: String::from("O"),
60//!         label_index: 0,
61//!         sentence: 0,
62//!         index: 0,
63//!         word_index: 0,
64//!         offset: None,
65//!         mask: Mask::Special,
66//!     },
67//!     Token {
68//!         text: String::from("My"),
69//!         score: 0.9980450868606567,
70//!         label: String::from("O"),
71//!         label_index: 0,
72//!         sentence: 0,
73//!         index: 1,
74//!         word_index: 1,
75//!         offset: Some(Offset { begin: 0, end: 2 }),
76//!         mask: Mask::None,
77//!     },
78//!     Token {
79//!         text: String::from("name"),
80//!         score: 0.9995062351226807,
81//!         label: String::from("O"),
82//!         label_index: 0,
83//!         sentence: 0,
84//!         index: 2,
85//!         word_index: 2,
86//!         offset: Some(Offset { begin: 3, end: 7 }),
87//!         mask: Mask::None,
88//!     },
89//!     Token {
90//!         text: String::from("is"),
91//!         score: 0.9997343420982361,
92//!         label: String::from("O"),
93//!         label_index: 0,
94//!         sentence: 0,
95//!         index: 3,
96//!         word_index: 3,
97//!         offset: Some(Offset { begin: 8, end: 10 }),
98//!         mask: Mask::None,
99//!     },
100//!     Token {
101//!         text: String::from("Amélie"),
102//!         score: 0.9913727683112525,
103//!         label: String::from("I-PER"),
104//!         label_index: 4,
105//!         sentence: 0,
106//!         index: 4,
107//!         word_index: 4,
108//!         offset: Some(Offset { begin: 11, end: 17 }),
109//!         mask: Mask::None,
110//!     }, // ...
111//! ]
112//! # ;
113//! ```
114
115use crate::albert::AlbertForTokenClassification;
116use crate::bert::BertForTokenClassification;
117use crate::common::error::RustBertError;
118use crate::deberta::DebertaForTokenClassification;
119use crate::distilbert::DistilBertForTokenClassification;
120use crate::electra::ElectraForTokenClassification;
121use crate::fnet::FNetForTokenClassification;
122use crate::longformer::LongformerForTokenClassification;
123use crate::mobilebert::MobileBertForTokenClassification;
124use crate::pipelines::common::{
125    cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
126};
127use crate::resources::ResourceProvider;
128use crate::roberta::RobertaForTokenClassification;
129use crate::xlnet::XLNetForTokenClassification;
130use ordered_float::OrderedFloat;
131use rust_tokenizers::{
132    ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait,
133    TokenizedInput,
134};
135use serde::{Deserialize, Serialize};
136use std::cmp::min;
137use std::collections::HashMap;
138use tch::nn::VarStore;
139use tch::{no_grad, Device, Kind, Tensor};
140
141use crate::deberta_v2::DebertaV2ForTokenClassification;
142#[cfg(feature = "onnx")]
143use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
144#[cfg(feature = "remote")]
145use crate::{
146    bert::{BertConfigResources, BertModelResources, BertVocabResources},
147    resources::RemoteResource,
148};
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151/// # Token generated by a `TokenClassificationModel`
152pub struct Token {
153    /// String representation of the Token
154    pub text: String,
155    /// Confidence score
156    pub score: f64,
157    /// Token label (e.g. ORG, LOC in case of NER)
158    pub label: String,
159    /// Label index
160    pub label_index: i64,
161    /// Sentence index
162    pub sentence: usize,
163    /// Token position index
164    pub index: u16,
165    /// Token word position index
166    pub word_index: u16,
167    /// Token offsets
168    pub offset: Option<Offset>,
169    /// Token mask
170    pub mask: Mask,
171}
172
173impl TokenTrait for Token {
174    fn offset(&self) -> Option<Offset> {
175        self.offset
176    }
177
178    fn mask(&self) -> Mask {
179        self.mask
180    }
181
182    fn as_str(&self) -> &str {
183        self.text.as_str()
184    }
185}
186
187impl ConsolidatableTokens<Token> for Vec<Token> {
188    fn iter_consolidate_tokens(&self) -> ConsolidatedTokenIterator<Token> {
189        ConsolidatedTokenIterator::new(self)
190    }
191}
192
193#[derive(Debug)]
194struct InputFeature {
195    /// Encoded input ids
196    input_ids: Vec<i64>,
197    /// Offsets reference to the original string
198    offsets: Vec<Option<Offset>>,
199    /// Token category (mask)
200    mask: Vec<Mask>,
201    /// Token type ids (mask)
202    token_type_ids: Vec<i64>,
203    /// per-token flag indicating if this feature carries the output label for this token
204    reference_feature: Vec<bool>,
205    /// Reference example index (long inputs may be broken into multiple input features)
206    example_index: usize,
207}
208
209type LabelAggregationFunction = Box<fn(&[Token]) -> (i64, String)>;
210
211/// # Enum defining the label aggregation method for sub tokens
212/// Defines the behaviour for labels aggregation if the consolidation of sub-tokens is enabled.
213pub enum LabelAggregationOption {
214    /// The label of the first sub token is assigned to the entire token
215    First,
216    /// The label of the last sub token is assigned to the entire token
217    Last,
218    /// The most frequent sub- token is  assigned to the entire token
219    Mode,
220    /// The user can provide a function mapping a `&Vec<Token>` to a `(i64, String)` tuple corresponding to the label index, label String to return
221    Custom(LabelAggregationFunction),
222}
223
224/// # Configuration for TokenClassificationModel
225/// Contains information regarding the model to load and device to place the model on.
226pub struct TokenClassificationConfig {
227    /// Model type
228    pub model_type: ModelType,
229    /// Model weights resource (default: pretrained BERT model on CoNLL)
230    pub model_resource: ModelResource,
231    /// Config resource (default: pretrained BERT model on CoNLL)
232    pub config_resource: Box<dyn ResourceProvider + Send>,
233    /// Vocab resource (default: pretrained BERT model on CoNLL)
234    pub vocab_resource: Box<dyn ResourceProvider + Send>,
235    /// Merges resource (default: pretrained BERT model on CoNLL)
236    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
237    /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
238    pub lower_case: bool,
239    /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
240    pub strip_accents: Option<bool>,
241    /// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models)
242    pub add_prefix_space: Option<bool>,
243    /// Device to place the model on (default: CUDA/GPU when available)
244    pub device: Device,
245    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
246    pub kind: Option<Kind>,
247    /// Sub-tokens aggregation method (default: `LabelAggregationOption::First`)
248    pub label_aggregation_function: LabelAggregationOption,
249    /// Batch size for predictions
250    pub batch_size: usize,
251}
252
253impl TokenClassificationConfig {
254    /// Instantiate a new token classification configuration of the supplied type.
255    ///
256    /// # Arguments
257    ///
258    /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
259    /// * model - The `ResourceProvider` pointing to the model to load (e.g.  model.ot)
260    /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
261    /// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g.  vocab.txt/vocab.json)
262    /// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g.  merges.txt), needed only for Roberta.
263    /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
264    pub fn new<RC, RV>(
265        model_type: ModelType,
266        model_resource: ModelResource,
267        config_resource: RC,
268        vocab_resource: RV,
269        merges_resource: Option<RV>,
270        lower_case: bool,
271        strip_accents: impl Into<Option<bool>>,
272        add_prefix_space: impl Into<Option<bool>>,
273        label_aggregation_function: LabelAggregationOption,
274    ) -> TokenClassificationConfig
275    where
276        RC: ResourceProvider + Send + 'static,
277        RV: ResourceProvider + Send + 'static,
278    {
279        TokenClassificationConfig {
280            model_type,
281            model_resource,
282            config_resource: Box::new(config_resource),
283            vocab_resource: Box::new(vocab_resource),
284            merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
285            lower_case,
286            strip_accents: strip_accents.into(),
287            add_prefix_space: add_prefix_space.into(),
288            device: Device::cuda_if_available(),
289            kind: None,
290            label_aggregation_function,
291            batch_size: 64,
292        }
293    }
294}
295
296#[cfg(feature = "remote")]
297impl Default for TokenClassificationConfig {
298    /// Provides a default CoNLL-2003 NER model (English)
299    fn default() -> TokenClassificationConfig {
300        TokenClassificationConfig::new(
301            ModelType::Bert,
302            ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
303                BertModelResources::BERT_NER,
304            ))),
305            RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
306            RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
307            None,
308            false,
309            None,
310            None,
311            LabelAggregationOption::First,
312        )
313    }
314}
315
316#[allow(clippy::large_enum_variant)]
317/// # Abstraction that holds one particular token sequence classifier model, for any of the supported models
318pub enum TokenClassificationOption {
319    /// Bert for Token Classification
320    Bert(BertForTokenClassification),
321    /// DeBERTa for Token Classification
322    Deberta(DebertaForTokenClassification),
323    /// DeBERTa V2 for Token Classification
324    DebertaV2(DebertaV2ForTokenClassification),
325    /// DistilBert for Token Classification
326    DistilBert(DistilBertForTokenClassification),
327    /// MobileBert for Token Classification
328    MobileBert(MobileBertForTokenClassification),
329    /// Roberta for Token Classification
330    Roberta(RobertaForTokenClassification),
331    /// XLM Roberta for Token Classification
332    XLMRoberta(RobertaForTokenClassification),
333    /// Electra for Token Classification
334    Electra(ElectraForTokenClassification),
335    /// Albert for Token Classification
336    Albert(AlbertForTokenClassification),
337    /// XLNet for Token Classification
338    XLNet(XLNetForTokenClassification),
339    /// Longformer for Token Classification
340    Longformer(LongformerForTokenClassification),
341    /// FNet for Token Classification
342    FNet(FNetForTokenClassification),
343    /// ONNX model for Token Classification
344    #[cfg(feature = "onnx")]
345    ONNX(ONNXEncoder),
346}
347
348impl TokenClassificationOption {
349    /// Instantiate a new sequence classification model of the supplied type.
350    ///
351    /// # Arguments
352    ///
353    /// * `TokenClassificationConfig` - Token classification pipeline configuration. The type of model created will be inferred from the
354    ///     `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
355    pub fn new(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
356        match config.model_resource {
357            ModelResource::Torch(_) => Self::new_torch(config),
358            #[cfg(feature = "onnx")]
359            ModelResource::ONNX(_) => Self::new_onnx(config),
360        }
361    }
362
363    fn new_torch(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
364        let device = config.device;
365        let weights_path = config.model_resource.get_torch_local_path()?;
366        let mut var_store = VarStore::new(device);
367        let model_config =
368            &ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
369        let model_type = config.model_type;
370        let model = match model_type {
371            ModelType::Bert => {
372                if let ConfigOption::Bert(config) = model_config {
373                    Ok(Self::Bert(
374                        BertForTokenClassification::new(var_store.root(), config)?,
375                    ))
376                } else {
377                    Err(RustBertError::InvalidConfigurationError(
378                        "You can only supply a BertConfig for Bert!".to_string(),
379                    ))
380                }
381            }
382            ModelType::Deberta => {
383                if let ConfigOption::Deberta(config) = model_config {
384                    Ok(Self::Deberta(
385                        DebertaForTokenClassification::new(var_store.root(), config)?,
386                    ))
387                } else {
388                    Err(RustBertError::InvalidConfigurationError(
389                        "You can only supply a DebertaConfig for DeBERTa!".to_string(),
390                    ))
391                }
392            }
393            ModelType::DebertaV2 => {
394                if let ConfigOption::DebertaV2(config) = model_config {
395                    Ok(Self::DebertaV2(
396                        DebertaV2ForTokenClassification::new(var_store.root(), config)?,
397                    ))
398                } else {
399                    Err(RustBertError::InvalidConfigurationError(
400                        "You can only supply a DebertaConfig for DeBERTa V2!".to_string(),
401                    ))
402                }
403            }
404            ModelType::DistilBert => {
405                if let ConfigOption::DistilBert(config) = model_config {
406                    Ok(Self::DistilBert(
407                        DistilBertForTokenClassification::new(var_store.root(), config)?,
408                    ))
409                } else {
410                    Err(RustBertError::InvalidConfigurationError(
411                        "You can only supply a DistilBertConfig for DistilBert!".to_string(),
412                    ))
413                }
414            }
415            ModelType::MobileBert => {
416                if let ConfigOption::MobileBert(config) = model_config {
417                    Ok(Self::MobileBert(
418                        MobileBertForTokenClassification::new(var_store.root(), config)?,
419                    ))
420                } else {
421                    Err(RustBertError::InvalidConfigurationError(
422                        "You can only supply a MobileBertConfig for MobileBert!".to_string(),
423                    ))
424                }
425            }
426            ModelType::Roberta => {
427                if let ConfigOption::Roberta(config) = model_config {
428                    Ok(Self::Roberta(
429                        RobertaForTokenClassification::new(var_store.root(), config)?,
430                    ))
431                } else {
432                    Err(RustBertError::InvalidConfigurationError(
433                        "You can only supply a RobertaConfig for Roberta!".to_string(),
434                    ))
435                }
436            }
437            ModelType::XLMRoberta => {
438                if let ConfigOption::Roberta(config) = model_config {
439                    Ok(Self::XLMRoberta(
440                        RobertaForTokenClassification::new(var_store.root(), config)?,
441                    ))
442                } else {
443                    Err(RustBertError::InvalidConfigurationError(
444                        "You can only supply a RobertaConfig for XLMRoberta!".to_string(),
445                    ))
446                }
447            }
448            ModelType::Electra => {
449                if let ConfigOption::Electra(config) = model_config {
450                    Ok(Self::Electra(
451                        ElectraForTokenClassification::new(var_store.root(), config)?,
452                    ))
453                } else {
454                    Err(RustBertError::InvalidConfigurationError(
455                        "You can only supply a BertConfig for Roberta!".to_string(),
456                    ))
457                }
458            }
459            ModelType::Albert => {
460                if let ConfigOption::Albert(config) = model_config {
461                    Ok(Self::Albert(
462                        AlbertForTokenClassification::new(var_store.root(), config)?,
463                    ))
464                } else {
465                    Err(RustBertError::InvalidConfigurationError(
466                        "You can only supply an AlbertConfig for Albert!".to_string(),
467                    ))
468                }
469            }
470            ModelType::XLNet => {
471                if let ConfigOption::XLNet(config) = model_config {
472                    Ok(Self::XLNet(
473                        XLNetForTokenClassification::new(var_store.root(), config)?,
474                    ))
475                } else {
476                    Err(RustBertError::InvalidConfigurationError(
477                        "You can only supply an AlbertConfig for Albert!".to_string(),
478                    ))
479                }
480            }
481            ModelType::Longformer => {
482                if let ConfigOption::Longformer(config) = model_config {
483                    Ok(Self::Longformer(
484                        LongformerForTokenClassification::new(var_store.root(), config)?,
485                    ))
486                } else {
487                    Err(RustBertError::InvalidConfigurationError(
488                        "You can only supply a LongformerConfig for Longformer!".to_string(),
489                    ))
490                }
491            }
492            ModelType::FNet => {
493                if let ConfigOption::FNet(config) = model_config {
494                    Ok(Self::FNet(
495                        FNetForTokenClassification::new(var_store.root(), config)?,
496                    ))
497                } else {
498                    Err(RustBertError::InvalidConfigurationError(
499                        "You can only supply an FNetConfig for FNet!".to_string(),
500                    ))
501                }
502            }
503            #[cfg(feature = "onnx")]
504            ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
505                "A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
506            )),
507            _ => Err(RustBertError::InvalidConfigurationError(format!(
508                "Token classification not implemented for {model_type:?}!"
509            ))),
510        }?;
511        var_store.load(weights_path)?;
512        cast_var_store(&mut var_store, config.kind, device);
513        Ok(model)
514    }
515
516    #[cfg(feature = "onnx")]
517    pub fn new_onnx(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
518        let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
519        let environment = onnx_config.get_environment()?;
520        let encoder_file = config
521            .model_resource
522            .get_onnx_local_paths()?
523            .encoder_path
524            .ok_or(RustBertError::InvalidConfigurationError(
525                "An encoder file must be provided for token classification ONNX models."
526                    .to_string(),
527            ))?;
528
529        Ok(Self::ONNX(ONNXEncoder::new(
530            encoder_file,
531            &environment,
532            &onnx_config,
533        )?))
534    }
535
536    /// Returns the `ModelType` for this TokenClassificationOption
537    pub fn model_type(&self) -> ModelType {
538        match *self {
539            Self::Bert(_) => ModelType::Bert,
540            Self::Deberta(_) => ModelType::Deberta,
541            Self::DebertaV2(_) => ModelType::DebertaV2,
542            Self::Roberta(_) => ModelType::Roberta,
543            Self::XLMRoberta(_) => ModelType::XLMRoberta,
544            Self::DistilBert(_) => ModelType::DistilBert,
545            Self::MobileBert(_) => ModelType::MobileBert,
546            Self::Electra(_) => ModelType::Electra,
547            Self::Albert(_) => ModelType::Albert,
548            Self::XLNet(_) => ModelType::XLNet,
549            Self::Longformer(_) => ModelType::Longformer,
550            Self::FNet(_) => ModelType::FNet,
551            #[cfg(feature = "onnx")]
552            Self::ONNX(_) => ModelType::ONNX,
553        }
554    }
555
556    fn forward_t(
557        &self,
558        input_ids: Option<&Tensor>,
559        mask: Option<&Tensor>,
560        token_type_ids: Option<&Tensor>,
561        position_ids: Option<&Tensor>,
562        input_embeds: Option<&Tensor>,
563        train: bool,
564    ) -> Tensor {
565        match *self {
566            Self::Bert(ref model) => {
567                model
568                    .forward_t(
569                        input_ids,
570                        mask,
571                        token_type_ids,
572                        position_ids,
573                        input_embeds,
574                        train,
575                    )
576                    .logits
577            }
578            Self::Deberta(ref model) => {
579                model
580                    .forward_t(
581                        input_ids,
582                        mask,
583                        token_type_ids,
584                        position_ids,
585                        input_embeds,
586                        train,
587                    )
588                    .expect("Error in DeBERTa forward_t")
589                    .logits
590            }
591            Self::DebertaV2(ref model) => {
592                model
593                    .forward_t(
594                        input_ids,
595                        mask,
596                        token_type_ids,
597                        position_ids,
598                        input_embeds,
599                        train,
600                    )
601                    .expect("Error in DeBERTa V2 forward_t")
602                    .logits
603            }
604            Self::DistilBert(ref model) => {
605                model
606                    .forward_t(input_ids, mask, input_embeds, train)
607                    .expect("Error in distilbert forward_t")
608                    .logits
609            }
610            Self::MobileBert(ref model) => {
611                model
612                    .forward_t(input_ids, None, None, input_embeds, mask, train)
613                    .expect("Error in mobilebert forward_t")
614                    .logits
615            }
616            Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
617                model
618                    .forward_t(
619                        input_ids,
620                        mask,
621                        token_type_ids,
622                        position_ids,
623                        input_embeds,
624                        train,
625                    )
626                    .logits
627            }
628            Self::Electra(ref model) => {
629                model
630                    .forward_t(
631                        input_ids,
632                        mask,
633                        token_type_ids,
634                        position_ids,
635                        input_embeds,
636                        train,
637                    )
638                    .logits
639            }
640            Self::Albert(ref model) => {
641                model
642                    .forward_t(
643                        input_ids,
644                        mask,
645                        token_type_ids,
646                        position_ids,
647                        input_embeds,
648                        train,
649                    )
650                    .logits
651            }
652            Self::XLNet(ref model) => {
653                model
654                    .forward_t(
655                        input_ids,
656                        mask,
657                        None,
658                        None,
659                        None,
660                        token_type_ids,
661                        input_embeds,
662                        train,
663                    )
664                    .logits
665            }
666            Self::Longformer(ref model) => {
667                model
668                    .forward_t(
669                        input_ids,
670                        mask,
671                        None,
672                        token_type_ids,
673                        position_ids,
674                        input_embeds,
675                        train,
676                    )
677                    .expect("Error in longformer forward_t")
678                    .logits
679            }
680            Self::FNet(ref model) => {
681                model
682                    .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
683                    .expect("Error in fnet forward_t")
684                    .logits
685            }
686            #[cfg(feature = "onnx")]
687            Self::ONNX(ref model) => model
688                .forward(input_ids, mask, token_type_ids, position_ids, input_embeds)
689                .expect("Error in ONNX forward pass.")
690                .logits
691                .unwrap(),
692        }
693    }
694}
695
696/// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
697pub struct TokenClassificationModel {
698    tokenizer: TokenizerOption,
699    token_sequence_classifier: TokenClassificationOption,
700    label_mapping: HashMap<i64, String>,
701    device: Device,
702    label_aggregation_function: LabelAggregationOption,
703    max_length: usize,
704    batch_size: usize,
705}
706
707impl TokenClassificationModel {
708    /// Build a new `TokenClassificationModel`
709    ///
710    /// # Arguments
711    ///
712    /// * `config` - `TokenClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
713    ///
714    /// # Example
715    ///
716    /// ```no_run
717    /// # fn main() -> anyhow::Result<()> {
718    /// use rust_bert::pipelines::token_classification::TokenClassificationModel;
719    ///
720    /// let model = TokenClassificationModel::new(Default::default())?;
721    /// # Ok(())
722    /// # }
723    /// ```
724    pub fn new(
725        config: TokenClassificationConfig,
726    ) -> Result<TokenClassificationModel, RustBertError> {
727        let vocab_path = config.vocab_resource.get_local_path()?;
728        let merges_path = config
729            .merges_resource
730            .as_ref()
731            .map(|resource| resource.get_local_path())
732            .transpose()?;
733
734        let tokenizer = TokenizerOption::from_file(
735            config.model_type,
736            vocab_path.to_str().unwrap(),
737            merges_path.as_deref().map(|path| path.to_str().unwrap()),
738            config.lower_case,
739            config.strip_accents,
740            config.add_prefix_space,
741        )?;
742        Self::new_with_tokenizer(config, tokenizer)
743    }
744
745    /// Build a new `TokenClassificationModel` with a provided tokenizer.
746    ///
747    /// # Arguments
748    ///
749    /// * `config` - `TokenClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
750    /// * `tokenizer` - `TokenizerOption` tokenizer to use for token classification
751    ///
752    /// # Example
753    ///
754    /// ```no_run
755    /// # fn main() -> anyhow::Result<()> {
756    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
757    /// use rust_bert::pipelines::token_classification::TokenClassificationModel;
758    /// let tokenizer = TokenizerOption::from_file(
759    ///     ModelType::Bert,
760    ///     "path/to/vocab.txt",
761    ///     None,
762    ///     false,
763    ///     None,
764    ///     None,
765    /// )?;
766    /// let model = TokenClassificationModel::new_with_tokenizer(Default::default(), tokenizer)?;
767    /// # Ok(())
768    /// # }
769    /// ```
770    pub fn new_with_tokenizer(
771        config: TokenClassificationConfig,
772        tokenizer: TokenizerOption,
773    ) -> Result<TokenClassificationModel, RustBertError> {
774        let config_path = config.config_resource.get_local_path()?;
775        let token_sequence_classifier = TokenClassificationOption::new(&config)?;
776
777        let label_aggregation_function = config.label_aggregation_function;
778
779        let model_config = ConfigOption::from_file(config.model_type, config_path);
780        let max_length = model_config
781            .get_max_len()
782            .map(|v| v as usize)
783            .unwrap_or(usize::MAX);
784        let label_mapping = model_config.get_label_mapping().clone();
785        let batch_size = config.batch_size;
786        let device = get_device(config.model_resource, config.device);
787        Ok(TokenClassificationModel {
788            tokenizer,
789            token_sequence_classifier,
790            label_mapping,
791            device,
792            label_aggregation_function,
793            max_length,
794            batch_size,
795        })
796    }
797
798    /// Get a reference to the model tokenizer.
799    pub fn get_tokenizer(&self) -> &TokenizerOption {
800        &self.tokenizer
801    }
802
803    /// Get a mutable reference to the model tokenizer.
804    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
805        &mut self.tokenizer
806    }
807
808    fn generate_features<S>(&self, input: S, example_index: usize) -> Vec<InputFeature>
809    where
810        S: AsRef<str>,
811    {
812        let tokenized_input = self.tokenizer.tokenize_with_offsets(input.as_ref());
813        let encoded_input = TokenIdsWithOffsets {
814            ids: self
815                .tokenizer
816                .convert_tokens_to_ids(&tokenized_input.tokens),
817            offsets: tokenized_input.offsets,
818            reference_offsets: tokenized_input.reference_offsets,
819            masks: tokenized_input.masks,
820        };
821
822        let sequence_added_tokens = self
823            .tokenizer
824            .build_input_with_special_tokens(
825                TokenIdsWithOffsets {
826                    ids: vec![],
827                    offsets: vec![],
828                    reference_offsets: vec![],
829                    masks: vec![],
830                },
831                None,
832            )
833            .token_ids
834            .len();
835
836        let max_content_length = self.max_length - sequence_added_tokens;
837        let doc_stride = self.max_length / 4;
838
839        let mut spans: Vec<InputFeature> = vec![];
840        let mut start_token = 0_usize;
841        let total_length = encoded_input.ids.len();
842
843        while (spans.len() * doc_stride) < encoded_input.ids.len() {
844            let end_token = min(start_token + max_content_length, total_length);
845            let sub_encoded_input = TokenIdsWithOffsets {
846                ids: encoded_input.ids[start_token..end_token].to_vec(),
847                offsets: encoded_input.offsets[start_token..end_token].to_vec(),
848                reference_offsets: encoded_input.reference_offsets[start_token..end_token].to_vec(),
849                masks: encoded_input.masks[start_token..end_token].to_vec(),
850            };
851
852            let encoded_span = self
853                .tokenizer
854                .build_input_with_special_tokens(sub_encoded_input, None);
855
856            let reference_feature = self.get_reference_feature_flag(
857                start_token,
858                end_token,
859                total_length,
860                doc_stride,
861                &encoded_span,
862            );
863
864            let feature = InputFeature {
865                input_ids: encoded_span.token_ids,
866                offsets: encoded_span.token_offsets,
867                mask: encoded_span.mask,
868                token_type_ids: encoded_span
869                    .segment_ids
870                    .into_iter()
871                    .map(|segment_id| segment_id as i64)
872                    .collect(),
873                reference_feature,
874                example_index,
875            };
876            spans.push(feature);
877            if end_token == encoded_input.ids.len() {
878                break;
879            }
880            start_token = end_token - doc_stride;
881        }
882        spans
883    }
884
885    fn get_reference_feature_flag(
886        &self,
887        start_token: usize,
888        end_token: usize,
889        total_length: usize,
890        doc_stride: usize,
891        encoded_span: &TokenizedInput,
892    ) -> Vec<bool> {
893        // set halfway through the doc_stride to be false if the feature is not the first/last
894        let start_cutoff = if start_token > 0 {
895            let leading_special_tokens = {
896                let mut counter = 0;
897                let mut masks = encoded_span.mask.iter();
898                while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
899                    counter += 1;
900                }
901                counter
902            };
903            doc_stride / 2 + leading_special_tokens
904        } else {
905            0
906        };
907        let end_cutoff = if end_token < total_length {
908            let trailing_special_tokens = {
909                let mut counter = 0;
910                let mut masks = encoded_span.mask.iter().rev();
911                while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
912                    counter += 1;
913                }
914                counter
915            };
916            encoded_span.token_ids.len() - doc_stride / 2 - trailing_special_tokens
917        } else {
918            encoded_span.token_ids.len()
919        };
920        let mut reference_feature = vec![true; encoded_span.token_ids.len()];
921        reference_feature[..start_cutoff]
922            .iter_mut()
923            .for_each(|v| *v = false);
924        reference_feature[end_cutoff..]
925            .iter_mut()
926            .for_each(|v| *v = false);
927        reference_feature
928    }
929
930    /// Classify tokens in a text sequence
931    ///
932    /// # Arguments
933    ///
934    /// * `input` - `&[&str]` Array of texts to extract entities from.
935    /// * `consolidate_subtokens` - bool flag indicating if subtokens should be consolidated at the token level
936    /// * `return_special` - bool flag indicating if labels for special tokens should be returned
937    ///
938    /// # Returns
939    ///
940    /// * `Vec<Vec<Token>>` containing Tokens with associated labels (for example POS tags) for each input provided
941    ///
942    /// # Example
943    ///
944    /// ```no_run
945    /// # fn main() -> anyhow::Result<()> {
946    /// # use rust_bert::pipelines::token_classification::TokenClassificationModel;
947    ///
948    /// let ner_model = TokenClassificationModel::new(Default::default())?;
949    /// let input = [
950    ///     "My name is Amy. I live in Paris.",
951    ///     "Paris is a city in France.",
952    /// ];
953    /// let output = ner_model.predict(&input, true, true);
954    /// # Ok(())
955    /// # }
956    /// ```
957    pub fn predict<S>(
958        &self,
959        input: &[S],
960        consolidate_sub_tokens: bool,
961        return_special: bool,
962    ) -> Vec<Vec<Token>>
963    where
964        S: AsRef<str>,
965    {
966        let mut features: Vec<InputFeature> = input
967            .iter()
968            .enumerate()
969            .flat_map(|(example_index, example)| self.generate_features(example, example_index))
970            .collect();
971
972        let mut example_tokens_map: Vec<Vec<Token>> = vec![Vec::new(); input.len()];
973        let mut start = 0usize;
974        let len_features = features.len();
975
976        while start < len_features {
977            let end = start + min(len_features - start, self.batch_size);
978
979            no_grad(|| {
980                let batch_features = &mut features[start..end];
981                let (input_ids, attention_masks, token_type_ids) =
982                    self.pad_features(batch_features);
983                let output = self.token_sequence_classifier.forward_t(
984                    Some(&input_ids),
985                    Some(&attention_masks),
986                    Some(&token_type_ids),
987                    None,
988                    None,
989                    false,
990                );
991                let score = output.exp()
992                    / output
993                        .exp()
994                        .sum_dim_intlist([-1].as_slice(), true, Kind::Float);
995                let label_indices = score.argmax(-1, true);
996                for sentence_idx in 0..label_indices.size()[0] {
997                    let labels = label_indices.get(sentence_idx);
998                    let feature = &features[sentence_idx as usize];
999                    let sentence_reference_flag = &feature.reference_feature;
1000                    let original_chars = input[feature.example_index]
1001                        .as_ref()
1002                        .chars()
1003                        .collect::<Vec<char>>();
1004                    let mut word_idx: u16 = 0;
1005                    for position_idx in sentence_reference_flag
1006                        .iter()
1007                        .enumerate()
1008                        .filter(|(_, flag)| **flag)
1009                        .map(|(pos, _)| pos)
1010                    {
1011                        let mask = feature.mask[position_idx];
1012                        if (mask == Mask::Special) & (!return_special) {
1013                            continue;
1014                        }
1015                        if !(mask == Mask::Continuation) {
1016                            word_idx += 1;
1017                        }
1018                        let token = {
1019                            self.decode_token(
1020                                &original_chars,
1021                                feature,
1022                                &input_ids,
1023                                &labels,
1024                                &score,
1025                                sentence_idx,
1026                                position_idx as i64,
1027                                word_idx,
1028                            )
1029                        };
1030                        example_tokens_map[feature.example_index].push(token);
1031                    }
1032                }
1033            });
1034            start = end;
1035        }
1036        let mut tokens = example_tokens_map;
1037
1038        if consolidate_sub_tokens {
1039            self.consolidate_tokens(&mut tokens, &self.label_aggregation_function);
1040        }
1041        tokens
1042    }
1043
1044    fn pad_features(&self, features: &mut [InputFeature]) -> (Tensor, Tensor, Tensor) {
1045        let max_len = features
1046            .iter()
1047            .map(|feature| feature.input_ids.len())
1048            .max()
1049            .unwrap();
1050
1051        let attention_masks = features
1052            .iter()
1053            .map(|feature| &feature.input_ids)
1054            .map(|input| {
1055                let mut attention_mask = Vec::with_capacity(max_len);
1056                attention_mask.resize(input.len(), 1i64);
1057                attention_mask.resize(max_len, 0i64);
1058                attention_mask
1059            })
1060            .map(|input| Tensor::from_slice(&(input)))
1061            .collect::<Vec<_>>();
1062
1063        let padding_index = self
1064            .tokenizer
1065            .get_pad_id()
1066            .expect("Only tokenizers with a padding index can be used for token classification");
1067        for feature in features.iter_mut() {
1068            feature.input_ids.resize(max_len, padding_index);
1069            feature.offsets.resize(max_len, None);
1070            feature
1071                .token_type_ids
1072                .resize(max_len, *feature.token_type_ids.last().unwrap_or(&0));
1073            feature.reference_feature.resize(max_len, false);
1074        }
1075
1076        let padded_input_ids = features
1077            .iter()
1078            .map(|input| Tensor::from_slice(input.input_ids.as_slice()))
1079            .collect::<Vec<_>>();
1080
1081        let padded_token_type_ids = features
1082            .iter()
1083            .map(|input| Tensor::from_slice(input.token_type_ids.as_slice()))
1084            .collect::<Vec<_>>();
1085
1086        let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.device);
1087        let attention_masks = Tensor::stack(&attention_masks, 0).to(self.device);
1088        let token_type_ids = Tensor::stack(&padded_token_type_ids, 0).to(self.device);
1089        (input_ids, attention_masks, token_type_ids)
1090    }
1091
1092    fn decode_token(
1093        &self,
1094        original_sentence_chars: &[char],
1095        sentence_tokens: &InputFeature,
1096        input_tensor: &Tensor,
1097        labels: &Tensor,
1098        score: &Tensor,
1099        sentence_idx: i64,
1100        position_idx: i64,
1101        word_index: u16,
1102    ) -> Token {
1103        let label_id = labels.int64_value(&[position_idx]);
1104        let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
1105
1106        let offsets = &sentence_tokens.offsets[position_idx as usize];
1107
1108        let text = match offsets {
1109            None => self.tokenizer.decode(&[token_id], false, false),
1110            Some(offsets) => {
1111                let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
1112                let end_char = min(end_char, original_sentence_chars.len());
1113                let text = original_sentence_chars[start_char..end_char]
1114                    .iter()
1115                    .collect();
1116                text
1117            }
1118        };
1119
1120        Token {
1121            text,
1122            score: score.double_value(&[sentence_idx, position_idx, label_id]),
1123            label: self
1124                .label_mapping
1125                .get(&label_id)
1126                .expect("Index out of vocabulary bounds.")
1127                .to_owned(),
1128            label_index: label_id,
1129            sentence: sentence_idx as usize,
1130            index: position_idx as u16,
1131            word_index,
1132            offset: offsets.to_owned(),
1133            mask: sentence_tokens.mask[position_idx as usize],
1134        }
1135    }
1136
1137    fn consolidate_tokens(
1138        &self,
1139        tokens: &mut Vec<Vec<Token>>,
1140        label_aggregation_function: &LabelAggregationOption,
1141    ) {
1142        for sequence_tokens in tokens {
1143            let mut tokens_to_replace = vec![];
1144            let token_iter = sequence_tokens.iter_consolidate_tokens();
1145            let mut cursor = 0;
1146
1147            for sub_tokens in token_iter {
1148                if sub_tokens.len() > 1 {
1149                    let (label_index, label) =
1150                        self.consolidate_labels(sub_tokens, label_aggregation_function);
1151                    let sentence = (sub_tokens[0]).sentence;
1152                    let index = (sub_tokens[0]).index;
1153                    let word_index = (sub_tokens[0]).word_index;
1154                    let offset_start = sub_tokens
1155                        .first()
1156                        .unwrap()
1157                        .offset
1158                        .as_ref()
1159                        .map(|offset| offset.begin);
1160                    let offset_end = sub_tokens
1161                        .last()
1162                        .unwrap()
1163                        .offset
1164                        .as_ref()
1165                        .map(|offset| offset.end);
1166                    let offset = if let (Some(offset_start), Some(offset_end)) =
1167                        (offset_start, offset_end)
1168                    {
1169                        Some(Offset::new(offset_start, offset_end))
1170                    } else {
1171                        None
1172                    };
1173                    let mut text = String::new();
1174                    let mut score = 1f64;
1175                    for current_sub_token in sub_tokens.iter() {
1176                        text.push_str(current_sub_token.text.as_str());
1177                        score *= if current_sub_token.label_index == label_index {
1178                            current_sub_token.score
1179                        } else {
1180                            1.0 - current_sub_token.score
1181                        };
1182                    }
1183                    let token = Token {
1184                        text,
1185                        score,
1186                        label,
1187                        label_index,
1188                        sentence,
1189                        index,
1190                        word_index,
1191                        offset,
1192                        mask: Default::default(),
1193                    };
1194                    tokens_to_replace.push(((cursor, cursor + sub_tokens.len()), token));
1195                }
1196                cursor += sub_tokens.len();
1197            }
1198            for ((start, end), token) in tokens_to_replace.into_iter().rev() {
1199                sequence_tokens.splice(start..end, [token].iter().cloned());
1200            }
1201        }
1202    }
1203
1204    fn consolidate_labels(
1205        &self,
1206        tokens: &[Token],
1207        aggregation: &LabelAggregationOption,
1208    ) -> (i64, String) {
1209        match aggregation {
1210            LabelAggregationOption::First => {
1211                let token = tokens.first().unwrap();
1212                (token.label_index, token.label.clone())
1213            }
1214            LabelAggregationOption::Last => {
1215                let token = tokens.last().unwrap();
1216                (token.label_index, token.label.clone())
1217            }
1218            LabelAggregationOption::Mode => {
1219                let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
1220                    let (ref mut count, ref mut score) = m
1221                        .entry((c.label_index, c.label.as_str()))
1222                        .or_insert((0, 0.0_f64));
1223                    *count += 1;
1224                    *score = score.max(c.score);
1225                    m
1226                });
1227                counts
1228                    .into_iter()
1229                    .max_by_key(|&(_, (count, score))| (count, OrderedFloat(score)))
1230                    .map(|((label_index, label), _)| (label_index, label.to_owned()))
1231                    .unwrap()
1232            }
1233            LabelAggregationOption::Custom(function) => function(tokens),
1234        }
1235    }
1236}