1use crate::albert::AlbertForTokenClassification;
116use crate::bert::BertForTokenClassification;
117use crate::common::error::RustBertError;
118use crate::deberta::DebertaForTokenClassification;
119use crate::distilbert::DistilBertForTokenClassification;
120use crate::electra::ElectraForTokenClassification;
121use crate::fnet::FNetForTokenClassification;
122use crate::longformer::LongformerForTokenClassification;
123use crate::mobilebert::MobileBertForTokenClassification;
124use crate::pipelines::common::{
125 cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
126};
127use crate::resources::ResourceProvider;
128use crate::roberta::RobertaForTokenClassification;
129use crate::xlnet::XLNetForTokenClassification;
130use ordered_float::OrderedFloat;
131use rust_tokenizers::{
132 ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait,
133 TokenizedInput,
134};
135use serde::{Deserialize, Serialize};
136use std::cmp::min;
137use std::collections::HashMap;
138use tch::nn::VarStore;
139use tch::{no_grad, Device, Kind, Tensor};
140
141use crate::deberta_v2::DebertaV2ForTokenClassification;
142#[cfg(feature = "onnx")]
143use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
144#[cfg(feature = "remote")]
145use crate::{
146 bert::{BertConfigResources, BertModelResources, BertVocabResources},
147 resources::RemoteResource,
148};
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct Token {
153 pub text: String,
155 pub score: f64,
157 pub label: String,
159 pub label_index: i64,
161 pub sentence: usize,
163 pub index: u16,
165 pub word_index: u16,
167 pub offset: Option<Offset>,
169 pub mask: Mask,
171}
172
173impl TokenTrait for Token {
174 fn offset(&self) -> Option<Offset> {
175 self.offset
176 }
177
178 fn mask(&self) -> Mask {
179 self.mask
180 }
181
182 fn as_str(&self) -> &str {
183 self.text.as_str()
184 }
185}
186
187impl ConsolidatableTokens<Token> for Vec<Token> {
188 fn iter_consolidate_tokens(&self) -> ConsolidatedTokenIterator<Token> {
189 ConsolidatedTokenIterator::new(self)
190 }
191}
192
193#[derive(Debug)]
194struct InputFeature {
195 input_ids: Vec<i64>,
197 offsets: Vec<Option<Offset>>,
199 mask: Vec<Mask>,
201 token_type_ids: Vec<i64>,
203 reference_feature: Vec<bool>,
205 example_index: usize,
207}
208
209type LabelAggregationFunction = Box<fn(&[Token]) -> (i64, String)>;
210
211pub enum LabelAggregationOption {
214 First,
216 Last,
218 Mode,
220 Custom(LabelAggregationFunction),
222}
223
224pub struct TokenClassificationConfig {
227 pub model_type: ModelType,
229 pub model_resource: ModelResource,
231 pub config_resource: Box<dyn ResourceProvider + Send>,
233 pub vocab_resource: Box<dyn ResourceProvider + Send>,
235 pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
237 pub lower_case: bool,
239 pub strip_accents: Option<bool>,
241 pub add_prefix_space: Option<bool>,
243 pub device: Device,
245 pub kind: Option<Kind>,
247 pub label_aggregation_function: LabelAggregationOption,
249 pub batch_size: usize,
251}
252
253impl TokenClassificationConfig {
254 pub fn new<RC, RV>(
265 model_type: ModelType,
266 model_resource: ModelResource,
267 config_resource: RC,
268 vocab_resource: RV,
269 merges_resource: Option<RV>,
270 lower_case: bool,
271 strip_accents: impl Into<Option<bool>>,
272 add_prefix_space: impl Into<Option<bool>>,
273 label_aggregation_function: LabelAggregationOption,
274 ) -> TokenClassificationConfig
275 where
276 RC: ResourceProvider + Send + 'static,
277 RV: ResourceProvider + Send + 'static,
278 {
279 TokenClassificationConfig {
280 model_type,
281 model_resource,
282 config_resource: Box::new(config_resource),
283 vocab_resource: Box::new(vocab_resource),
284 merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
285 lower_case,
286 strip_accents: strip_accents.into(),
287 add_prefix_space: add_prefix_space.into(),
288 device: Device::cuda_if_available(),
289 kind: None,
290 label_aggregation_function,
291 batch_size: 64,
292 }
293 }
294}
295
296#[cfg(feature = "remote")]
297impl Default for TokenClassificationConfig {
298 fn default() -> TokenClassificationConfig {
300 TokenClassificationConfig::new(
301 ModelType::Bert,
302 ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
303 BertModelResources::BERT_NER,
304 ))),
305 RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
306 RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
307 None,
308 false,
309 None,
310 None,
311 LabelAggregationOption::First,
312 )
313 }
314}
315
316#[allow(clippy::large_enum_variant)]
317pub enum TokenClassificationOption {
319 Bert(BertForTokenClassification),
321 Deberta(DebertaForTokenClassification),
323 DebertaV2(DebertaV2ForTokenClassification),
325 DistilBert(DistilBertForTokenClassification),
327 MobileBert(MobileBertForTokenClassification),
329 Roberta(RobertaForTokenClassification),
331 XLMRoberta(RobertaForTokenClassification),
333 Electra(ElectraForTokenClassification),
335 Albert(AlbertForTokenClassification),
337 XLNet(XLNetForTokenClassification),
339 Longformer(LongformerForTokenClassification),
341 FNet(FNetForTokenClassification),
343 #[cfg(feature = "onnx")]
345 ONNX(ONNXEncoder),
346}
347
348impl TokenClassificationOption {
349 pub fn new(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
356 match config.model_resource {
357 ModelResource::Torch(_) => Self::new_torch(config),
358 #[cfg(feature = "onnx")]
359 ModelResource::ONNX(_) => Self::new_onnx(config),
360 }
361 }
362
363 fn new_torch(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
364 let device = config.device;
365 let weights_path = config.model_resource.get_torch_local_path()?;
366 let mut var_store = VarStore::new(device);
367 let model_config =
368 &ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
369 let model_type = config.model_type;
370 let model = match model_type {
371 ModelType::Bert => {
372 if let ConfigOption::Bert(config) = model_config {
373 Ok(Self::Bert(
374 BertForTokenClassification::new(var_store.root(), config)?,
375 ))
376 } else {
377 Err(RustBertError::InvalidConfigurationError(
378 "You can only supply a BertConfig for Bert!".to_string(),
379 ))
380 }
381 }
382 ModelType::Deberta => {
383 if let ConfigOption::Deberta(config) = model_config {
384 Ok(Self::Deberta(
385 DebertaForTokenClassification::new(var_store.root(), config)?,
386 ))
387 } else {
388 Err(RustBertError::InvalidConfigurationError(
389 "You can only supply a DebertaConfig for DeBERTa!".to_string(),
390 ))
391 }
392 }
393 ModelType::DebertaV2 => {
394 if let ConfigOption::DebertaV2(config) = model_config {
395 Ok(Self::DebertaV2(
396 DebertaV2ForTokenClassification::new(var_store.root(), config)?,
397 ))
398 } else {
399 Err(RustBertError::InvalidConfigurationError(
400 "You can only supply a DebertaConfig for DeBERTa V2!".to_string(),
401 ))
402 }
403 }
404 ModelType::DistilBert => {
405 if let ConfigOption::DistilBert(config) = model_config {
406 Ok(Self::DistilBert(
407 DistilBertForTokenClassification::new(var_store.root(), config)?,
408 ))
409 } else {
410 Err(RustBertError::InvalidConfigurationError(
411 "You can only supply a DistilBertConfig for DistilBert!".to_string(),
412 ))
413 }
414 }
415 ModelType::MobileBert => {
416 if let ConfigOption::MobileBert(config) = model_config {
417 Ok(Self::MobileBert(
418 MobileBertForTokenClassification::new(var_store.root(), config)?,
419 ))
420 } else {
421 Err(RustBertError::InvalidConfigurationError(
422 "You can only supply a MobileBertConfig for MobileBert!".to_string(),
423 ))
424 }
425 }
426 ModelType::Roberta => {
427 if let ConfigOption::Roberta(config) = model_config {
428 Ok(Self::Roberta(
429 RobertaForTokenClassification::new(var_store.root(), config)?,
430 ))
431 } else {
432 Err(RustBertError::InvalidConfigurationError(
433 "You can only supply a RobertaConfig for Roberta!".to_string(),
434 ))
435 }
436 }
437 ModelType::XLMRoberta => {
438 if let ConfigOption::Roberta(config) = model_config {
439 Ok(Self::XLMRoberta(
440 RobertaForTokenClassification::new(var_store.root(), config)?,
441 ))
442 } else {
443 Err(RustBertError::InvalidConfigurationError(
444 "You can only supply a RobertaConfig for XLMRoberta!".to_string(),
445 ))
446 }
447 }
448 ModelType::Electra => {
449 if let ConfigOption::Electra(config) = model_config {
450 Ok(Self::Electra(
451 ElectraForTokenClassification::new(var_store.root(), config)?,
452 ))
453 } else {
454 Err(RustBertError::InvalidConfigurationError(
455 "You can only supply a BertConfig for Roberta!".to_string(),
456 ))
457 }
458 }
459 ModelType::Albert => {
460 if let ConfigOption::Albert(config) = model_config {
461 Ok(Self::Albert(
462 AlbertForTokenClassification::new(var_store.root(), config)?,
463 ))
464 } else {
465 Err(RustBertError::InvalidConfigurationError(
466 "You can only supply an AlbertConfig for Albert!".to_string(),
467 ))
468 }
469 }
470 ModelType::XLNet => {
471 if let ConfigOption::XLNet(config) = model_config {
472 Ok(Self::XLNet(
473 XLNetForTokenClassification::new(var_store.root(), config)?,
474 ))
475 } else {
476 Err(RustBertError::InvalidConfigurationError(
477 "You can only supply an AlbertConfig for Albert!".to_string(),
478 ))
479 }
480 }
481 ModelType::Longformer => {
482 if let ConfigOption::Longformer(config) = model_config {
483 Ok(Self::Longformer(
484 LongformerForTokenClassification::new(var_store.root(), config)?,
485 ))
486 } else {
487 Err(RustBertError::InvalidConfigurationError(
488 "You can only supply a LongformerConfig for Longformer!".to_string(),
489 ))
490 }
491 }
492 ModelType::FNet => {
493 if let ConfigOption::FNet(config) = model_config {
494 Ok(Self::FNet(
495 FNetForTokenClassification::new(var_store.root(), config)?,
496 ))
497 } else {
498 Err(RustBertError::InvalidConfigurationError(
499 "You can only supply an FNetConfig for FNet!".to_string(),
500 ))
501 }
502 }
503 #[cfg(feature = "onnx")]
504 ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
505 "A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
506 )),
507 _ => Err(RustBertError::InvalidConfigurationError(format!(
508 "Token classification not implemented for {model_type:?}!"
509 ))),
510 }?;
511 var_store.load(weights_path)?;
512 cast_var_store(&mut var_store, config.kind, device);
513 Ok(model)
514 }
515
516 #[cfg(feature = "onnx")]
517 pub fn new_onnx(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
518 let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
519 let environment = onnx_config.get_environment()?;
520 let encoder_file = config
521 .model_resource
522 .get_onnx_local_paths()?
523 .encoder_path
524 .ok_or(RustBertError::InvalidConfigurationError(
525 "An encoder file must be provided for token classification ONNX models."
526 .to_string(),
527 ))?;
528
529 Ok(Self::ONNX(ONNXEncoder::new(
530 encoder_file,
531 &environment,
532 &onnx_config,
533 )?))
534 }
535
536 pub fn model_type(&self) -> ModelType {
538 match *self {
539 Self::Bert(_) => ModelType::Bert,
540 Self::Deberta(_) => ModelType::Deberta,
541 Self::DebertaV2(_) => ModelType::DebertaV2,
542 Self::Roberta(_) => ModelType::Roberta,
543 Self::XLMRoberta(_) => ModelType::XLMRoberta,
544 Self::DistilBert(_) => ModelType::DistilBert,
545 Self::MobileBert(_) => ModelType::MobileBert,
546 Self::Electra(_) => ModelType::Electra,
547 Self::Albert(_) => ModelType::Albert,
548 Self::XLNet(_) => ModelType::XLNet,
549 Self::Longformer(_) => ModelType::Longformer,
550 Self::FNet(_) => ModelType::FNet,
551 #[cfg(feature = "onnx")]
552 Self::ONNX(_) => ModelType::ONNX,
553 }
554 }
555
556 fn forward_t(
557 &self,
558 input_ids: Option<&Tensor>,
559 mask: Option<&Tensor>,
560 token_type_ids: Option<&Tensor>,
561 position_ids: Option<&Tensor>,
562 input_embeds: Option<&Tensor>,
563 train: bool,
564 ) -> Tensor {
565 match *self {
566 Self::Bert(ref model) => {
567 model
568 .forward_t(
569 input_ids,
570 mask,
571 token_type_ids,
572 position_ids,
573 input_embeds,
574 train,
575 )
576 .logits
577 }
578 Self::Deberta(ref model) => {
579 model
580 .forward_t(
581 input_ids,
582 mask,
583 token_type_ids,
584 position_ids,
585 input_embeds,
586 train,
587 )
588 .expect("Error in DeBERTa forward_t")
589 .logits
590 }
591 Self::DebertaV2(ref model) => {
592 model
593 .forward_t(
594 input_ids,
595 mask,
596 token_type_ids,
597 position_ids,
598 input_embeds,
599 train,
600 )
601 .expect("Error in DeBERTa V2 forward_t")
602 .logits
603 }
604 Self::DistilBert(ref model) => {
605 model
606 .forward_t(input_ids, mask, input_embeds, train)
607 .expect("Error in distilbert forward_t")
608 .logits
609 }
610 Self::MobileBert(ref model) => {
611 model
612 .forward_t(input_ids, None, None, input_embeds, mask, train)
613 .expect("Error in mobilebert forward_t")
614 .logits
615 }
616 Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
617 model
618 .forward_t(
619 input_ids,
620 mask,
621 token_type_ids,
622 position_ids,
623 input_embeds,
624 train,
625 )
626 .logits
627 }
628 Self::Electra(ref model) => {
629 model
630 .forward_t(
631 input_ids,
632 mask,
633 token_type_ids,
634 position_ids,
635 input_embeds,
636 train,
637 )
638 .logits
639 }
640 Self::Albert(ref model) => {
641 model
642 .forward_t(
643 input_ids,
644 mask,
645 token_type_ids,
646 position_ids,
647 input_embeds,
648 train,
649 )
650 .logits
651 }
652 Self::XLNet(ref model) => {
653 model
654 .forward_t(
655 input_ids,
656 mask,
657 None,
658 None,
659 None,
660 token_type_ids,
661 input_embeds,
662 train,
663 )
664 .logits
665 }
666 Self::Longformer(ref model) => {
667 model
668 .forward_t(
669 input_ids,
670 mask,
671 None,
672 token_type_ids,
673 position_ids,
674 input_embeds,
675 train,
676 )
677 .expect("Error in longformer forward_t")
678 .logits
679 }
680 Self::FNet(ref model) => {
681 model
682 .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
683 .expect("Error in fnet forward_t")
684 .logits
685 }
686 #[cfg(feature = "onnx")]
687 Self::ONNX(ref model) => model
688 .forward(input_ids, mask, token_type_ids, position_ids, input_embeds)
689 .expect("Error in ONNX forward pass.")
690 .logits
691 .unwrap(),
692 }
693 }
694}
695
696pub struct TokenClassificationModel {
698 tokenizer: TokenizerOption,
699 token_sequence_classifier: TokenClassificationOption,
700 label_mapping: HashMap<i64, String>,
701 device: Device,
702 label_aggregation_function: LabelAggregationOption,
703 max_length: usize,
704 batch_size: usize,
705}
706
707impl TokenClassificationModel {
708 pub fn new(
725 config: TokenClassificationConfig,
726 ) -> Result<TokenClassificationModel, RustBertError> {
727 let vocab_path = config.vocab_resource.get_local_path()?;
728 let merges_path = config
729 .merges_resource
730 .as_ref()
731 .map(|resource| resource.get_local_path())
732 .transpose()?;
733
734 let tokenizer = TokenizerOption::from_file(
735 config.model_type,
736 vocab_path.to_str().unwrap(),
737 merges_path.as_deref().map(|path| path.to_str().unwrap()),
738 config.lower_case,
739 config.strip_accents,
740 config.add_prefix_space,
741 )?;
742 Self::new_with_tokenizer(config, tokenizer)
743 }
744
745 pub fn new_with_tokenizer(
771 config: TokenClassificationConfig,
772 tokenizer: TokenizerOption,
773 ) -> Result<TokenClassificationModel, RustBertError> {
774 let config_path = config.config_resource.get_local_path()?;
775 let token_sequence_classifier = TokenClassificationOption::new(&config)?;
776
777 let label_aggregation_function = config.label_aggregation_function;
778
779 let model_config = ConfigOption::from_file(config.model_type, config_path);
780 let max_length = model_config
781 .get_max_len()
782 .map(|v| v as usize)
783 .unwrap_or(usize::MAX);
784 let label_mapping = model_config.get_label_mapping().clone();
785 let batch_size = config.batch_size;
786 let device = get_device(config.model_resource, config.device);
787 Ok(TokenClassificationModel {
788 tokenizer,
789 token_sequence_classifier,
790 label_mapping,
791 device,
792 label_aggregation_function,
793 max_length,
794 batch_size,
795 })
796 }
797
798 pub fn get_tokenizer(&self) -> &TokenizerOption {
800 &self.tokenizer
801 }
802
803 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
805 &mut self.tokenizer
806 }
807
808 fn generate_features<S>(&self, input: S, example_index: usize) -> Vec<InputFeature>
809 where
810 S: AsRef<str>,
811 {
812 let tokenized_input = self.tokenizer.tokenize_with_offsets(input.as_ref());
813 let encoded_input = TokenIdsWithOffsets {
814 ids: self
815 .tokenizer
816 .convert_tokens_to_ids(&tokenized_input.tokens),
817 offsets: tokenized_input.offsets,
818 reference_offsets: tokenized_input.reference_offsets,
819 masks: tokenized_input.masks,
820 };
821
822 let sequence_added_tokens = self
823 .tokenizer
824 .build_input_with_special_tokens(
825 TokenIdsWithOffsets {
826 ids: vec![],
827 offsets: vec![],
828 reference_offsets: vec![],
829 masks: vec![],
830 },
831 None,
832 )
833 .token_ids
834 .len();
835
836 let max_content_length = self.max_length - sequence_added_tokens;
837 let doc_stride = self.max_length / 4;
838
839 let mut spans: Vec<InputFeature> = vec![];
840 let mut start_token = 0_usize;
841 let total_length = encoded_input.ids.len();
842
843 while (spans.len() * doc_stride) < encoded_input.ids.len() {
844 let end_token = min(start_token + max_content_length, total_length);
845 let sub_encoded_input = TokenIdsWithOffsets {
846 ids: encoded_input.ids[start_token..end_token].to_vec(),
847 offsets: encoded_input.offsets[start_token..end_token].to_vec(),
848 reference_offsets: encoded_input.reference_offsets[start_token..end_token].to_vec(),
849 masks: encoded_input.masks[start_token..end_token].to_vec(),
850 };
851
852 let encoded_span = self
853 .tokenizer
854 .build_input_with_special_tokens(sub_encoded_input, None);
855
856 let reference_feature = self.get_reference_feature_flag(
857 start_token,
858 end_token,
859 total_length,
860 doc_stride,
861 &encoded_span,
862 );
863
864 let feature = InputFeature {
865 input_ids: encoded_span.token_ids,
866 offsets: encoded_span.token_offsets,
867 mask: encoded_span.mask,
868 token_type_ids: encoded_span
869 .segment_ids
870 .into_iter()
871 .map(|segment_id| segment_id as i64)
872 .collect(),
873 reference_feature,
874 example_index,
875 };
876 spans.push(feature);
877 if end_token == encoded_input.ids.len() {
878 break;
879 }
880 start_token = end_token - doc_stride;
881 }
882 spans
883 }
884
885 fn get_reference_feature_flag(
886 &self,
887 start_token: usize,
888 end_token: usize,
889 total_length: usize,
890 doc_stride: usize,
891 encoded_span: &TokenizedInput,
892 ) -> Vec<bool> {
893 let start_cutoff = if start_token > 0 {
895 let leading_special_tokens = {
896 let mut counter = 0;
897 let mut masks = encoded_span.mask.iter();
898 while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
899 counter += 1;
900 }
901 counter
902 };
903 doc_stride / 2 + leading_special_tokens
904 } else {
905 0
906 };
907 let end_cutoff = if end_token < total_length {
908 let trailing_special_tokens = {
909 let mut counter = 0;
910 let mut masks = encoded_span.mask.iter().rev();
911 while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
912 counter += 1;
913 }
914 counter
915 };
916 encoded_span.token_ids.len() - doc_stride / 2 - trailing_special_tokens
917 } else {
918 encoded_span.token_ids.len()
919 };
920 let mut reference_feature = vec![true; encoded_span.token_ids.len()];
921 reference_feature[..start_cutoff]
922 .iter_mut()
923 .for_each(|v| *v = false);
924 reference_feature[end_cutoff..]
925 .iter_mut()
926 .for_each(|v| *v = false);
927 reference_feature
928 }
929
930 pub fn predict<S>(
958 &self,
959 input: &[S],
960 consolidate_sub_tokens: bool,
961 return_special: bool,
962 ) -> Vec<Vec<Token>>
963 where
964 S: AsRef<str>,
965 {
966 let mut features: Vec<InputFeature> = input
967 .iter()
968 .enumerate()
969 .flat_map(|(example_index, example)| self.generate_features(example, example_index))
970 .collect();
971
972 let mut example_tokens_map: Vec<Vec<Token>> = vec![Vec::new(); input.len()];
973 let mut start = 0usize;
974 let len_features = features.len();
975
976 while start < len_features {
977 let end = start + min(len_features - start, self.batch_size);
978
979 no_grad(|| {
980 let batch_features = &mut features[start..end];
981 let (input_ids, attention_masks, token_type_ids) =
982 self.pad_features(batch_features);
983 let output = self.token_sequence_classifier.forward_t(
984 Some(&input_ids),
985 Some(&attention_masks),
986 Some(&token_type_ids),
987 None,
988 None,
989 false,
990 );
991 let score = output.exp()
992 / output
993 .exp()
994 .sum_dim_intlist([-1].as_slice(), true, Kind::Float);
995 let label_indices = score.argmax(-1, true);
996 for sentence_idx in 0..label_indices.size()[0] {
997 let labels = label_indices.get(sentence_idx);
998 let feature = &features[sentence_idx as usize];
999 let sentence_reference_flag = &feature.reference_feature;
1000 let original_chars = input[feature.example_index]
1001 .as_ref()
1002 .chars()
1003 .collect::<Vec<char>>();
1004 let mut word_idx: u16 = 0;
1005 for position_idx in sentence_reference_flag
1006 .iter()
1007 .enumerate()
1008 .filter(|(_, flag)| **flag)
1009 .map(|(pos, _)| pos)
1010 {
1011 let mask = feature.mask[position_idx];
1012 if (mask == Mask::Special) & (!return_special) {
1013 continue;
1014 }
1015 if !(mask == Mask::Continuation) {
1016 word_idx += 1;
1017 }
1018 let token = {
1019 self.decode_token(
1020 &original_chars,
1021 feature,
1022 &input_ids,
1023 &labels,
1024 &score,
1025 sentence_idx,
1026 position_idx as i64,
1027 word_idx,
1028 )
1029 };
1030 example_tokens_map[feature.example_index].push(token);
1031 }
1032 }
1033 });
1034 start = end;
1035 }
1036 let mut tokens = example_tokens_map;
1037
1038 if consolidate_sub_tokens {
1039 self.consolidate_tokens(&mut tokens, &self.label_aggregation_function);
1040 }
1041 tokens
1042 }
1043
1044 fn pad_features(&self, features: &mut [InputFeature]) -> (Tensor, Tensor, Tensor) {
1045 let max_len = features
1046 .iter()
1047 .map(|feature| feature.input_ids.len())
1048 .max()
1049 .unwrap();
1050
1051 let attention_masks = features
1052 .iter()
1053 .map(|feature| &feature.input_ids)
1054 .map(|input| {
1055 let mut attention_mask = Vec::with_capacity(max_len);
1056 attention_mask.resize(input.len(), 1i64);
1057 attention_mask.resize(max_len, 0i64);
1058 attention_mask
1059 })
1060 .map(|input| Tensor::from_slice(&(input)))
1061 .collect::<Vec<_>>();
1062
1063 let padding_index = self
1064 .tokenizer
1065 .get_pad_id()
1066 .expect("Only tokenizers with a padding index can be used for token classification");
1067 for feature in features.iter_mut() {
1068 feature.input_ids.resize(max_len, padding_index);
1069 feature.offsets.resize(max_len, None);
1070 feature
1071 .token_type_ids
1072 .resize(max_len, *feature.token_type_ids.last().unwrap_or(&0));
1073 feature.reference_feature.resize(max_len, false);
1074 }
1075
1076 let padded_input_ids = features
1077 .iter()
1078 .map(|input| Tensor::from_slice(input.input_ids.as_slice()))
1079 .collect::<Vec<_>>();
1080
1081 let padded_token_type_ids = features
1082 .iter()
1083 .map(|input| Tensor::from_slice(input.token_type_ids.as_slice()))
1084 .collect::<Vec<_>>();
1085
1086 let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.device);
1087 let attention_masks = Tensor::stack(&attention_masks, 0).to(self.device);
1088 let token_type_ids = Tensor::stack(&padded_token_type_ids, 0).to(self.device);
1089 (input_ids, attention_masks, token_type_ids)
1090 }
1091
1092 fn decode_token(
1093 &self,
1094 original_sentence_chars: &[char],
1095 sentence_tokens: &InputFeature,
1096 input_tensor: &Tensor,
1097 labels: &Tensor,
1098 score: &Tensor,
1099 sentence_idx: i64,
1100 position_idx: i64,
1101 word_index: u16,
1102 ) -> Token {
1103 let label_id = labels.int64_value(&[position_idx]);
1104 let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
1105
1106 let offsets = &sentence_tokens.offsets[position_idx as usize];
1107
1108 let text = match offsets {
1109 None => self.tokenizer.decode(&[token_id], false, false),
1110 Some(offsets) => {
1111 let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
1112 let end_char = min(end_char, original_sentence_chars.len());
1113 let text = original_sentence_chars[start_char..end_char]
1114 .iter()
1115 .collect();
1116 text
1117 }
1118 };
1119
1120 Token {
1121 text,
1122 score: score.double_value(&[sentence_idx, position_idx, label_id]),
1123 label: self
1124 .label_mapping
1125 .get(&label_id)
1126 .expect("Index out of vocabulary bounds.")
1127 .to_owned(),
1128 label_index: label_id,
1129 sentence: sentence_idx as usize,
1130 index: position_idx as u16,
1131 word_index,
1132 offset: offsets.to_owned(),
1133 mask: sentence_tokens.mask[position_idx as usize],
1134 }
1135 }
1136
1137 fn consolidate_tokens(
1138 &self,
1139 tokens: &mut Vec<Vec<Token>>,
1140 label_aggregation_function: &LabelAggregationOption,
1141 ) {
1142 for sequence_tokens in tokens {
1143 let mut tokens_to_replace = vec![];
1144 let token_iter = sequence_tokens.iter_consolidate_tokens();
1145 let mut cursor = 0;
1146
1147 for sub_tokens in token_iter {
1148 if sub_tokens.len() > 1 {
1149 let (label_index, label) =
1150 self.consolidate_labels(sub_tokens, label_aggregation_function);
1151 let sentence = (sub_tokens[0]).sentence;
1152 let index = (sub_tokens[0]).index;
1153 let word_index = (sub_tokens[0]).word_index;
1154 let offset_start = sub_tokens
1155 .first()
1156 .unwrap()
1157 .offset
1158 .as_ref()
1159 .map(|offset| offset.begin);
1160 let offset_end = sub_tokens
1161 .last()
1162 .unwrap()
1163 .offset
1164 .as_ref()
1165 .map(|offset| offset.end);
1166 let offset = if let (Some(offset_start), Some(offset_end)) =
1167 (offset_start, offset_end)
1168 {
1169 Some(Offset::new(offset_start, offset_end))
1170 } else {
1171 None
1172 };
1173 let mut text = String::new();
1174 let mut score = 1f64;
1175 for current_sub_token in sub_tokens.iter() {
1176 text.push_str(current_sub_token.text.as_str());
1177 score *= if current_sub_token.label_index == label_index {
1178 current_sub_token.score
1179 } else {
1180 1.0 - current_sub_token.score
1181 };
1182 }
1183 let token = Token {
1184 text,
1185 score,
1186 label,
1187 label_index,
1188 sentence,
1189 index,
1190 word_index,
1191 offset,
1192 mask: Default::default(),
1193 };
1194 tokens_to_replace.push(((cursor, cursor + sub_tokens.len()), token));
1195 }
1196 cursor += sub_tokens.len();
1197 }
1198 for ((start, end), token) in tokens_to_replace.into_iter().rev() {
1199 sequence_tokens.splice(start..end, [token].iter().cloned());
1200 }
1201 }
1202 }
1203
1204 fn consolidate_labels(
1205 &self,
1206 tokens: &[Token],
1207 aggregation: &LabelAggregationOption,
1208 ) -> (i64, String) {
1209 match aggregation {
1210 LabelAggregationOption::First => {
1211 let token = tokens.first().unwrap();
1212 (token.label_index, token.label.clone())
1213 }
1214 LabelAggregationOption::Last => {
1215 let token = tokens.last().unwrap();
1216 (token.label_index, token.label.clone())
1217 }
1218 LabelAggregationOption::Mode => {
1219 let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
1220 let (ref mut count, ref mut score) = m
1221 .entry((c.label_index, c.label.as_str()))
1222 .or_insert((0, 0.0_f64));
1223 *count += 1;
1224 *score = score.max(c.score);
1225 m
1226 });
1227 counts
1228 .into_iter()
1229 .max_by_key(|&(_, (count, score))| (count, OrderedFloat(score)))
1230 .map(|((label_index, label), _)| (label_index, label.to_owned()))
1231 .unwrap()
1232 }
1233 LabelAggregationOption::Custom(function) => function(tokens),
1234 }
1235 }
1236}