1use crate::albert::AlbertForSequenceClassification;
62use crate::bart::BartForSequenceClassification;
63use crate::bert::BertForSequenceClassification;
64use crate::common::error::RustBertError;
65use crate::deberta::DebertaForSequenceClassification;
66use crate::distilbert::DistilBertModelClassifier;
67use crate::fnet::FNetForSequenceClassification;
68use crate::longformer::LongformerForSequenceClassification;
69use crate::mobilebert::MobileBertForSequenceClassification;
70use crate::pipelines::common::{
71 cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
72};
73use crate::reformer::ReformerForSequenceClassification;
74use crate::resources::ResourceProvider;
75use crate::roberta::RobertaForSequenceClassification;
76use crate::xlnet::XLNetForSequenceClassification;
77use serde::{Deserialize, Serialize};
78use std::collections::HashMap;
79use tch::nn::VarStore;
80use tch::{no_grad, Device, Kind, Tensor};
81
82use crate::deberta_v2::DebertaV2ForSequenceClassification;
83#[cfg(feature = "onnx")]
84use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
85#[cfg(feature = "remote")]
86use crate::{
87 distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
88 resources::RemoteResource,
89};
90
91#[derive(Debug, Serialize, Deserialize, Clone)]
92pub struct Label {
94 pub text: String,
96 pub score: f64,
98 pub id: i64,
100 #[serde(default)]
102 pub sentence: usize,
103}
104
105pub struct SequenceClassificationConfig {
108 pub model_type: ModelType,
110 pub model_resource: ModelResource,
112 pub config_resource: Box<dyn ResourceProvider + Send>,
114 pub vocab_resource: Box<dyn ResourceProvider + Send>,
116 pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
118 pub lower_case: bool,
120 pub strip_accents: Option<bool>,
122 pub add_prefix_space: Option<bool>,
124 pub device: Device,
126 pub kind: Option<Kind>,
128}
129
130impl SequenceClassificationConfig {
131 pub fn new<RC, RV>(
142 model_type: ModelType,
143 model_resource: ModelResource,
144 config_resource: RC,
145 vocab_resource: RV,
146 merges_resource: Option<RV>,
147 lower_case: bool,
148 strip_accents: impl Into<Option<bool>>,
149 add_prefix_space: impl Into<Option<bool>>,
150 ) -> SequenceClassificationConfig
151 where
152 RC: ResourceProvider + Send + 'static,
153 RV: ResourceProvider + Send + 'static,
154 {
155 SequenceClassificationConfig {
156 model_type,
157 model_resource,
158 config_resource: Box::new(config_resource),
159 vocab_resource: Box::new(vocab_resource),
160 merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
161 lower_case,
162 strip_accents: strip_accents.into(),
163 add_prefix_space: add_prefix_space.into(),
164 device: Device::cuda_if_available(),
165 kind: None,
166 }
167 }
168}
169
170#[cfg(feature = "remote")]
171impl Default for SequenceClassificationConfig {
172 fn default() -> SequenceClassificationConfig {
174 SequenceClassificationConfig::new(
175 ModelType::DistilBert,
176 ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
177 DistilBertModelResources::DISTIL_BERT_SST2,
178 ))),
179 RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
180 RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
181 None,
182 true,
183 None,
184 None,
185 )
186 }
187}
188
189#[allow(clippy::large_enum_variant)]
190pub enum SequenceClassificationOption {
192 Bert(BertForSequenceClassification),
194 Deberta(DebertaForSequenceClassification),
196 DebertaV2(DebertaV2ForSequenceClassification),
198 DistilBert(DistilBertModelClassifier),
200 MobileBert(MobileBertForSequenceClassification),
202 Roberta(RobertaForSequenceClassification),
204 XLMRoberta(RobertaForSequenceClassification),
206 Albert(AlbertForSequenceClassification),
208 XLNet(XLNetForSequenceClassification),
210 Bart(BartForSequenceClassification),
212 Reformer(ReformerForSequenceClassification),
214 Longformer(LongformerForSequenceClassification),
216 FNet(FNetForSequenceClassification),
218 #[cfg(feature = "onnx")]
220 ONNX(ONNXEncoder),
221}
222
223impl SequenceClassificationOption {
224 pub fn new(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
231 match config.model_resource {
232 ModelResource::Torch(_) => Self::new_torch(config),
233 #[cfg(feature = "onnx")]
234 ModelResource::ONNX(_) => Self::new_onnx(config),
235 }
236 }
237
238 fn new_torch(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
239 let device = config.device;
240 let weights_path = config.model_resource.get_torch_local_path()?;
241 let mut var_store = VarStore::new(device);
242 let model_config =
243 &ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
244 let model_type = config.model_type;
245 let model = match model_type {
246 ModelType::Bert => {
247 if let ConfigOption::Bert(config) = model_config {
248 Ok(Self::Bert(
249 BertForSequenceClassification::new(var_store.root(), config)?,
250 ))
251 } else {
252 Err(RustBertError::InvalidConfigurationError(
253 "You can only supply a BertConfig for Bert!".to_string(),
254 ))
255 }
256 }
257 ModelType::Deberta => {
258 if let ConfigOption::Deberta(config) = model_config {
259 Ok(Self::Deberta(
260 DebertaForSequenceClassification::new(var_store.root(), config)?,
261 ))
262 } else {
263 Err(RustBertError::InvalidConfigurationError(
264 "You can only supply a DebertaConfig for DeBERTa!".to_string(),
265 ))
266 }
267 }
268 ModelType::DebertaV2 => {
269 if let ConfigOption::DebertaV2(config) = model_config {
270 Ok(Self::DebertaV2(
271 DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
272 ))
273 } else {
274 Err(RustBertError::InvalidConfigurationError(
275 "You can only supply a DebertaV2Config for DeBERTa V2!".to_string(),
276 ))
277 }
278 }
279 ModelType::DistilBert => {
280 if let ConfigOption::DistilBert(config) = model_config {
281 Ok(Self::DistilBert(
282 DistilBertModelClassifier::new(var_store.root(), config)?,
283 ))
284 } else {
285 Err(RustBertError::InvalidConfigurationError(
286 "You can only supply a DistilBertConfig for DistilBert!".to_string(),
287 ))
288 }
289 }
290 ModelType::MobileBert => {
291 if let ConfigOption::MobileBert(config) = model_config {
292 Ok(Self::MobileBert(
293 MobileBertForSequenceClassification::new(var_store.root(), config)?,
294 ))
295 } else {
296 Err(RustBertError::InvalidConfigurationError(
297 "You can only supply a MobileBertConfig for MobileBert!".to_string(),
298 ))
299 }
300 }
301 ModelType::Roberta => {
302 if let ConfigOption::Roberta(config) = model_config {
303 Ok(Self::Roberta(
304 RobertaForSequenceClassification::new(var_store.root(), config)?,
305 ))
306 } else {
307 Err(RustBertError::InvalidConfigurationError(
308 "You can only supply a RobertaConfig for Roberta!".to_string(),
309 ))
310 }
311 }
312 ModelType::XLMRoberta => {
313 if let ConfigOption::Roberta(config) = model_config {
314 Ok(Self::XLMRoberta(
315 RobertaForSequenceClassification::new(var_store.root(), config)?,
316 ))
317 } else {
318 Err(RustBertError::InvalidConfigurationError(
319 "You can only supply a RobertaConfig for Roberta!".to_string(),
320 ))
321 }
322 }
323 ModelType::Albert => {
324 if let ConfigOption::Albert(config) = model_config {
325 Ok(Self::Albert(
326 AlbertForSequenceClassification::new(var_store.root(), config)?,
327 ))
328 } else {
329 Err(RustBertError::InvalidConfigurationError(
330 "You can only supply an AlbertConfig for Albert!".to_string(),
331 ))
332 }
333 }
334 ModelType::XLNet => {
335 if let ConfigOption::XLNet(config) = model_config {
336 Ok(Self::XLNet(
337 XLNetForSequenceClassification::new(var_store.root(), config)?,
338 ))
339 } else {
340 Err(RustBertError::InvalidConfigurationError(
341 "You can only supply an XLNetConfig for XLNet!".to_string(),
342 ))
343 }
344 }
345 ModelType::Bart => {
346 if let ConfigOption::Bart(config) = model_config {
347 Ok(Self::Bart(
348 BartForSequenceClassification::new(var_store.root(), config)?,
349 ))
350 } else {
351 Err(RustBertError::InvalidConfigurationError(
352 "You can only supply a BertConfig for Bert!".to_string(),
353 ))
354 }
355 }
356 ModelType::Reformer => {
357 if let ConfigOption::Reformer(config) = model_config {
358 Ok(Self::Reformer(
359 ReformerForSequenceClassification::new(var_store.root(), config)?,
360 ))
361 } else {
362 Err(RustBertError::InvalidConfigurationError(
363 "You can only supply a ReformerConfig for Reformer!".to_string(),
364 ))
365 }
366 }
367 ModelType::Longformer => {
368 if let ConfigOption::Longformer(config) = model_config {
369 Ok(Self::Longformer(
370 LongformerForSequenceClassification::new(var_store.root(), config)?,
371 ))
372 } else {
373 Err(RustBertError::InvalidConfigurationError(
374 "You can only supply a LongformerConfig for Longformer!".to_string(),
375 ))
376 }
377 }
378 ModelType::FNet => {
379 if let ConfigOption::FNet(config) = model_config {
380 Ok(Self::FNet(
381 FNetForSequenceClassification::new(var_store.root(), config)?,
382 ))
383 } else {
384 Err(RustBertError::InvalidConfigurationError(
385 "You can only supply a FNetConfig for FNet!".to_string(),
386 ))
387 }
388 }
389 #[cfg(feature = "onnx")]
390 ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
391 "A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
392 )),
393 _ => Err(RustBertError::InvalidConfigurationError(format!(
394 "Sequence Classification not implemented for {model_type:?}!",
395 ))),
396 }?;
397 var_store.load(weights_path)?;
398 cast_var_store(&mut var_store, config.kind, device);
399 Ok(model)
400 }
401
402 #[cfg(feature = "onnx")]
403 pub fn new_onnx(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
404 let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
405 let environment = onnx_config.get_environment()?;
406 let encoder_file = config
407 .model_resource
408 .get_onnx_local_paths()?
409 .encoder_path
410 .ok_or(RustBertError::InvalidConfigurationError(
411 "An encoder file must be provided for sequence classification ONNX models."
412 .to_string(),
413 ))?;
414
415 Ok(Self::ONNX(ONNXEncoder::new(
416 encoder_file,
417 &environment,
418 &onnx_config,
419 )?))
420 }
421
422 pub fn model_type(&self) -> ModelType {
424 match *self {
425 Self::Bert(_) => ModelType::Bert,
426 Self::Deberta(_) => ModelType::Deberta,
427 Self::DebertaV2(_) => ModelType::DebertaV2,
428 Self::Roberta(_) => ModelType::Roberta,
429 Self::XLMRoberta(_) => ModelType::Roberta,
430 Self::DistilBert(_) => ModelType::DistilBert,
431 Self::MobileBert(_) => ModelType::MobileBert,
432 Self::Albert(_) => ModelType::Albert,
433 Self::XLNet(_) => ModelType::XLNet,
434 Self::Bart(_) => ModelType::Bart,
435 Self::Reformer(_) => ModelType::Reformer,
436 Self::Longformer(_) => ModelType::Longformer,
437 Self::FNet(_) => ModelType::FNet,
438 #[cfg(feature = "onnx")]
439 Self::ONNX(_) => ModelType::ONNX,
440 }
441 }
442
443 pub fn forward_t(
445 &self,
446 input_ids: Option<&Tensor>,
447 mask: Option<&Tensor>,
448 token_type_ids: Option<&Tensor>,
449 position_ids: Option<&Tensor>,
450 input_embeds: Option<&Tensor>,
451 train: bool,
452 ) -> Tensor {
453 match *self {
454 Self::Bart(ref model) => {
455 model
456 .forward_t(
457 input_ids.expect("`input_ids` must be provided for BART models"),
458 mask,
459 None,
460 None,
461 None,
462 train,
463 )
464 .decoder_output
465 }
466 Self::Bert(ref model) => {
467 model
468 .forward_t(
469 input_ids,
470 mask,
471 token_type_ids,
472 position_ids,
473 input_embeds,
474 train,
475 )
476 .logits
477 }
478 Self::Deberta(ref model) => {
479 model
480 .forward_t(
481 input_ids,
482 mask,
483 token_type_ids,
484 position_ids,
485 input_embeds,
486 train,
487 )
488 .expect("Error in Deberta forward_t")
489 .logits
490 }
491 Self::DebertaV2(ref model) => {
492 model
493 .forward_t(
494 input_ids,
495 mask,
496 token_type_ids,
497 position_ids,
498 input_embeds,
499 train,
500 )
501 .expect("Error in Deberta V2 forward_t")
502 .logits
503 }
504 Self::DistilBert(ref model) => {
505 model
506 .forward_t(input_ids, mask, input_embeds, train)
507 .expect("Error in distilbert forward_t")
508 .logits
509 }
510 Self::MobileBert(ref model) => {
511 model
512 .forward_t(input_ids, None, None, input_embeds, mask, train)
513 .expect("Error in mobilebert forward_t")
514 .logits
515 }
516 Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
517 model
518 .forward_t(
519 input_ids,
520 mask,
521 token_type_ids,
522 position_ids,
523 input_embeds,
524 train,
525 )
526 .logits
527 }
528 Self::Albert(ref model) => {
529 model
530 .forward_t(
531 input_ids,
532 mask,
533 token_type_ids,
534 position_ids,
535 input_embeds,
536 train,
537 )
538 .logits
539 }
540 Self::XLNet(ref model) => {
541 model
542 .forward_t(
543 input_ids,
544 mask,
545 None,
546 None,
547 None,
548 token_type_ids,
549 input_embeds,
550 train,
551 )
552 .logits
553 }
554 Self::Reformer(ref model) => {
555 model
556 .forward_t(input_ids, None, None, mask, None, train)
557 .expect("Error in Reformer forward pass.")
558 .logits
559 }
560 Self::Longformer(ref model) => {
561 model
562 .forward_t(
563 input_ids,
564 mask,
565 None,
566 token_type_ids,
567 position_ids,
568 input_embeds,
569 train,
570 )
571 .expect("Error in Longformer forward pass.")
572 .logits
573 }
574 Self::FNet(ref model) => {
575 model
576 .forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
577 .expect("Error in FNet forward pass.")
578 .logits
579 }
580 #[cfg(feature = "onnx")]
581 Self::ONNX(ref model) => {
582 let attention_mask = input_ids.unwrap().ones_like();
583 model
584 .forward(
585 input_ids,
586 Some(&attention_mask),
587 token_type_ids,
588 position_ids,
589 input_embeds,
590 )
591 .expect("Error in ONNX forward pass.")
592 .logits
593 .unwrap()
594 }
595 }
596 }
597}
598
599pub struct SequenceClassificationModel {
601 tokenizer: TokenizerOption,
602 sequence_classifier: SequenceClassificationOption,
603 label_mapping: HashMap<i64, String>,
604 device: Device,
605 max_length: usize,
606}
607
608impl SequenceClassificationModel {
609 pub fn new(
626 config: SequenceClassificationConfig,
627 ) -> Result<SequenceClassificationModel, RustBertError> {
628 let vocab_path = config.vocab_resource.get_local_path()?;
629 let merges_path = config
630 .merges_resource
631 .as_ref()
632 .map(|resource| resource.get_local_path())
633 .transpose()?;
634
635 let tokenizer = TokenizerOption::from_file(
636 config.model_type,
637 vocab_path.to_str().unwrap(),
638 merges_path.as_deref().map(|path| path.to_str().unwrap()),
639 config.lower_case,
640 config.strip_accents,
641 config.add_prefix_space,
642 )?;
643 Self::new_with_tokenizer(config, tokenizer)
644 }
645
646 pub fn new_with_tokenizer(
672 config: SequenceClassificationConfig,
673 tokenizer: TokenizerOption,
674 ) -> Result<SequenceClassificationModel, RustBertError> {
675 let config_path = config.config_resource.get_local_path()?;
676 let sequence_classifier = SequenceClassificationOption::new(&config)?;
677
678 let model_config = ConfigOption::from_file(config.model_type, config_path);
679 let max_length = model_config
680 .get_max_len()
681 .map(|v| v as usize)
682 .unwrap_or(usize::MAX);
683 let label_mapping = model_config.get_label_mapping().clone();
684 let device = get_device(config.model_resource, config.device);
685 Ok(SequenceClassificationModel {
686 tokenizer,
687 sequence_classifier,
688 label_mapping,
689 device,
690 max_length,
691 })
692 }
693
694 pub fn get_tokenizer(&self) -> &TokenizerOption {
696 &self.tokenizer
697 }
698
699 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
701 &mut self.tokenizer
702 }
703 pub fn predict<'a, S>(&self, input: S) -> Vec<Label>
730 where
731 S: AsRef<[&'a str]>,
732 {
733 let (input_ids, token_type_ids) =
734 self.tokenizer
735 .tokenize_and_pad(input.as_ref(), self.max_length, self.device);
736 let output = no_grad(|| {
737 let output = self.sequence_classifier.forward_t(
738 Some(&input_ids),
739 None,
740 Some(&token_type_ids),
741 None,
742 None,
743 false,
744 );
745 output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
746 });
747 let label_indices = output.as_ref().argmax(-1, true).squeeze_dim(1);
748 let scores = output
749 .gather(1, &label_indices.unsqueeze(-1), false)
750 .squeeze_dim(1);
751 let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
752 let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
753
754 let mut labels: Vec<Label> = vec![];
755 for sentence_idx in 0..label_indices.len() {
756 let label_string = self
757 .label_mapping
758 .get(&label_indices[sentence_idx])
759 .unwrap()
760 .clone();
761 let label = Label {
762 text: label_string,
763 score: scores[sentence_idx],
764 id: label_indices[sentence_idx],
765 sentence: sentence_idx,
766 };
767 labels.push(label)
768 }
769 labels
770 }
771
772 pub fn predict_multilabel(
800 &self,
801 input: &[&str],
802 threshold: f64,
803 ) -> Result<Vec<Vec<Label>>, RustBertError> {
804 let (input_ids, token_type_ids) =
805 self.tokenizer
806 .tokenize_and_pad(input.as_ref(), self.max_length, self.device);
807 let output = no_grad(|| {
808 let output = self.sequence_classifier.forward_t(
809 Some(&input_ids),
810 None,
811 Some(&token_type_ids),
812 None,
813 None,
814 false,
815 );
816 output.sigmoid().detach().to(Device::Cpu)
817 });
818 let label_indices = output.as_ref().ge(threshold).nonzero();
819
820 let mut labels: Vec<Vec<Label>> = vec![];
821 let mut sequence_labels: Vec<Label> = vec![];
822
823 for sentence_idx in 0..label_indices.size()[0] {
824 let label_index_tensor = label_indices.get(sentence_idx);
825 let sentence_label = label_index_tensor
826 .iter::<i64>()
827 .unwrap()
828 .collect::<Vec<i64>>();
829 let (sentence, id) = (sentence_label[0], sentence_label[1]);
830 if sentence as usize > labels.len() {
831 labels.push(sequence_labels);
832 sequence_labels = vec![];
833 }
834 let score = output.double_value(sentence_label.as_slice());
835 let label_string = self.label_mapping.get(&id).unwrap().to_owned();
836 let label = Label {
837 text: label_string,
838 score,
839 id,
840 sentence: sentence as usize,
841 };
842 sequence_labels.push(label);
843 }
844 if !sequence_labels.is_empty() {
845 labels.push(sequence_labels);
846 }
847 Ok(labels)
848 }
849}
850
851#[cfg(test)]
852mod test {
853 use super::*;
854
855 #[test]
856 #[ignore] fn test() {
858 let config = SequenceClassificationConfig::default();
859 let _: Box<dyn Send> = Box::new(SequenceClassificationModel::new(config));
860 }
861}