Skip to main content

trustformers_models/bert/
tasks.rs

1#![allow(dead_code)]
2
3use crate::bert::config::BertConfig;
4use crate::bert::model::BertModel;
5use std::io::Read;
6use trustformers_core::device::Device;
7use trustformers_core::errors::Result;
8use trustformers_core::layers::Linear;
9use trustformers_core::tensor::Tensor;
10use trustformers_core::traits::{Layer, Model, TokenizedInput};
11
12#[derive(Debug, Clone)]
13#[allow(dead_code)]
14pub struct BertForSequenceClassification {
15    bert: BertModel,
16    classifier: Linear,
17    #[allow(dead_code)]
18    num_labels: usize,
19    device: Device,
20}
21
22impl BertForSequenceClassification {
23    pub fn new(config: BertConfig, num_labels: usize) -> Result<Self> {
24        Self::new_with_device(config, num_labels, Device::CPU)
25    }
26
27    pub fn new_with_device(config: BertConfig, num_labels: usize, device: Device) -> Result<Self> {
28        let bert = BertModel::new_with_device(config.clone(), device)?;
29        let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
30
31        Ok(Self {
32            bert,
33            classifier,
34            num_labels,
35            device,
36        })
37    }
38
39    pub fn device(&self) -> Device {
40        self.device
41    }
42}
43
44#[derive(Debug)]
45pub struct SequenceClassifierOutput {
46    pub logits: Tensor,
47    pub hidden_states: Option<Tensor>,
48}
49
50impl Model for BertForSequenceClassification {
51    type Config = BertConfig;
52    type Input = TokenizedInput;
53    type Output = SequenceClassifierOutput;
54
55    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
56        let bert_output = self.bert.forward(input)?;
57
58        let pooled_output = bert_output.pooler_output.ok_or_else(|| {
59            trustformers_core::errors::TrustformersError::model_error(
60                "BertForSequenceClassification requires pooler output".to_string(),
61            )
62        })?;
63
64        let logits = self.classifier.forward(pooled_output)?;
65
66        Ok(SequenceClassifierOutput {
67            logits,
68            hidden_states: Some(bert_output.last_hidden_state),
69        })
70    }
71
72    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
73        self.bert.load_pretrained(reader)
74    }
75
76    fn get_config(&self) -> &Self::Config {
77        self.bert.get_config()
78    }
79
80    fn num_parameters(&self) -> usize {
81        self.bert.num_parameters() + self.classifier.parameter_count()
82    }
83}
84
85#[derive(Debug, Clone)]
86pub struct BertForMaskedLM {
87    bert: BertModel,
88    cls: BertLMHead,
89    device: Device,
90}
91
92impl BertForMaskedLM {
93    pub fn new(config: BertConfig) -> Result<Self> {
94        Self::new_with_device(config, Device::CPU)
95    }
96
97    pub fn new_with_device(config: BertConfig, device: Device) -> Result<Self> {
98        let bert = BertModel::new_with_device(config.clone(), device)?;
99        let cls = BertLMHead::new_with_device(&config, device)?;
100
101        Ok(Self { bert, cls, device })
102    }
103
104    pub fn device(&self) -> Device {
105        self.device
106    }
107}
108
109#[derive(Debug, Clone)]
110struct BertLMHead {
111    dense: Linear,
112    layer_norm: trustformers_core::layers::LayerNorm,
113    decoder: Linear,
114    device: Device,
115}
116
117impl BertLMHead {
118    fn new(config: &BertConfig) -> Result<Self> {
119        Self::new_with_device(config, Device::CPU)
120    }
121
122    fn new_with_device(config: &BertConfig, device: Device) -> Result<Self> {
123        Ok(Self {
124            dense: Linear::new_with_device(config.hidden_size, config.hidden_size, true, device),
125            layer_norm: trustformers_core::layers::LayerNorm::new_with_device(
126                vec![config.hidden_size],
127                config.layer_norm_eps,
128                device,
129            )?,
130            decoder: Linear::new_with_device(config.hidden_size, config.vocab_size, true, device),
131            device,
132        })
133    }
134
135    fn device(&self) -> Device {
136        self.device
137    }
138
139    fn forward(&self, hidden_states: Tensor) -> Result<Tensor> {
140        let hidden_states = self.dense.forward(hidden_states)?;
141        let hidden_states = trustformers_core::ops::activations::gelu(&hidden_states)?;
142        let hidden_states = self.layer_norm.forward(hidden_states)?;
143        self.decoder.forward(hidden_states)
144    }
145
146    fn parameter_count(&self) -> usize {
147        self.dense.parameter_count()
148            + self.layer_norm.parameter_count()
149            + self.decoder.parameter_count()
150    }
151}
152
153#[derive(Debug)]
154pub struct MaskedLMOutput {
155    pub logits: Tensor,
156    pub hidden_states: Option<Tensor>,
157}
158
159impl Model for BertForMaskedLM {
160    type Config = BertConfig;
161    type Input = TokenizedInput;
162    type Output = MaskedLMOutput;
163
164    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
165        let bert_output = self.bert.forward(input)?;
166        let prediction_scores = self.cls.forward(bert_output.last_hidden_state.clone())?;
167
168        Ok(MaskedLMOutput {
169            logits: prediction_scores,
170            hidden_states: Some(bert_output.last_hidden_state),
171        })
172    }
173
174    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
175        self.bert.load_pretrained(reader)
176    }
177
178    fn get_config(&self) -> &Self::Config {
179        self.bert.get_config()
180    }
181
182    fn num_parameters(&self) -> usize {
183        self.bert.num_parameters() + self.cls.parameter_count()
184    }
185}
186
187#[derive(Debug, Clone)]
188pub struct BertForTokenClassification {
189    bert: BertModel,
190    classifier: Linear,
191    #[allow(dead_code)]
192    num_labels: usize,
193    device: Device,
194}
195
196impl BertForTokenClassification {
197    pub fn new(config: BertConfig, num_labels: usize) -> Result<Self> {
198        Self::new_with_device(config, num_labels, Device::CPU)
199    }
200
201    pub fn new_with_device(config: BertConfig, num_labels: usize, device: Device) -> Result<Self> {
202        let bert = BertModel::new_with_device(config.clone(), device)?;
203        let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
204
205        Ok(Self {
206            bert,
207            classifier,
208            num_labels,
209            device,
210        })
211    }
212
213    pub fn device(&self) -> Device {
214        self.device
215    }
216}
217
218#[derive(Debug)]
219pub struct TokenClassifierOutput {
220    pub logits: Tensor,
221    pub hidden_states: Option<Tensor>,
222}
223
224impl Model for BertForTokenClassification {
225    type Config = BertConfig;
226    type Input = TokenizedInput;
227    type Output = TokenClassifierOutput;
228
229    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
230        let bert_output = self.bert.forward(input)?;
231        let sequence_output = bert_output.last_hidden_state;
232
233        let logits = self.classifier.forward(sequence_output.clone())?;
234
235        Ok(TokenClassifierOutput {
236            logits,
237            hidden_states: Some(sequence_output),
238        })
239    }
240
241    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
242        self.bert.load_pretrained(reader)
243    }
244
245    fn get_config(&self) -> &Self::Config {
246        self.bert.get_config()
247    }
248
249    fn num_parameters(&self) -> usize {
250        self.bert.num_parameters() + self.classifier.parameter_count()
251    }
252}
253
254#[derive(Debug, Clone)]
255pub struct BertForQuestionAnswering {
256    bert: BertModel,
257    qa_outputs: Linear,
258    device: Device,
259}
260
261impl BertForQuestionAnswering {
262    pub fn new(config: BertConfig) -> Result<Self> {
263        Self::new_with_device(config, Device::CPU)
264    }
265
266    pub fn new_with_device(config: BertConfig, device: Device) -> Result<Self> {
267        let bert = BertModel::new_with_device(config.clone(), device)?;
268        // QA outputs has 2 classes: start and end positions
269        let qa_outputs = Linear::new_with_device(config.hidden_size, 2, true, device);
270
271        Ok(Self {
272            bert,
273            qa_outputs,
274            device,
275        })
276    }
277
278    pub fn device(&self) -> Device {
279        self.device
280    }
281}
282
283#[derive(Debug)]
284pub struct QuestionAnsweringOutput {
285    pub start_logits: Tensor,
286    pub end_logits: Tensor,
287    pub hidden_states: Option<Tensor>,
288}
289
290impl Model for BertForQuestionAnswering {
291    type Config = BertConfig;
292    type Input = TokenizedInput;
293    type Output = QuestionAnsweringOutput;
294
295    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
296        let bert_output = self.bert.forward(input)?;
297        let sequence_output = bert_output.last_hidden_state;
298
299        let logits = self.qa_outputs.forward(sequence_output.clone())?;
300
301        // Split logits into start and end logits along the last dimension (dimension with size 2)
302        let split_logits = logits.split(logits.shape().len() - 1, 1)?;
303        if split_logits.len() != 2 {
304            return Err(trustformers_core::errors::TrustformersError::model_error(
305                "Expected 2 QA outputs (start and end), got different number".to_string(),
306            ));
307        }
308
309        let start_logits = split_logits[0].clone();
310        let end_logits = split_logits[1].clone();
311
312        Ok(QuestionAnsweringOutput {
313            start_logits,
314            end_logits,
315            hidden_states: Some(sequence_output),
316        })
317    }
318
319    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
320        self.bert.load_pretrained(reader)
321    }
322
323    fn get_config(&self) -> &Self::Config {
324        self.bert.get_config()
325    }
326
327    fn num_parameters(&self) -> usize {
328        self.bert.num_parameters() + self.qa_outputs.parameter_count()
329    }
330}