trustformers_models/bert/
tasks.rs1#![allow(dead_code)]
2
3use crate::bert::config::BertConfig;
4use crate::bert::model::BertModel;
5use std::io::Read;
6use trustformers_core::device::Device;
7use trustformers_core::errors::Result;
8use trustformers_core::layers::Linear;
9use trustformers_core::tensor::Tensor;
10use trustformers_core::traits::{Layer, Model, TokenizedInput};
11
12#[derive(Debug, Clone)]
13#[allow(dead_code)]
14pub struct BertForSequenceClassification {
15 bert: BertModel,
16 classifier: Linear,
17 #[allow(dead_code)]
18 num_labels: usize,
19 device: Device,
20}
21
22impl BertForSequenceClassification {
23 pub fn new(config: BertConfig, num_labels: usize) -> Result<Self> {
24 Self::new_with_device(config, num_labels, Device::CPU)
25 }
26
27 pub fn new_with_device(config: BertConfig, num_labels: usize, device: Device) -> Result<Self> {
28 let bert = BertModel::new_with_device(config.clone(), device)?;
29 let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
30
31 Ok(Self {
32 bert,
33 classifier,
34 num_labels,
35 device,
36 })
37 }
38
39 pub fn device(&self) -> Device {
40 self.device
41 }
42}
43
44#[derive(Debug)]
45pub struct SequenceClassifierOutput {
46 pub logits: Tensor,
47 pub hidden_states: Option<Tensor>,
48}
49
50impl Model for BertForSequenceClassification {
51 type Config = BertConfig;
52 type Input = TokenizedInput;
53 type Output = SequenceClassifierOutput;
54
55 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
56 let bert_output = self.bert.forward(input)?;
57
58 let pooled_output = bert_output.pooler_output.ok_or_else(|| {
59 trustformers_core::errors::TrustformersError::model_error(
60 "BertForSequenceClassification requires pooler output".to_string(),
61 )
62 })?;
63
64 let logits = self.classifier.forward(pooled_output)?;
65
66 Ok(SequenceClassifierOutput {
67 logits,
68 hidden_states: Some(bert_output.last_hidden_state),
69 })
70 }
71
72 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
73 self.bert.load_pretrained(reader)
74 }
75
76 fn get_config(&self) -> &Self::Config {
77 self.bert.get_config()
78 }
79
80 fn num_parameters(&self) -> usize {
81 self.bert.num_parameters() + self.classifier.parameter_count()
82 }
83}
84
85#[derive(Debug, Clone)]
86pub struct BertForMaskedLM {
87 bert: BertModel,
88 cls: BertLMHead,
89 device: Device,
90}
91
92impl BertForMaskedLM {
93 pub fn new(config: BertConfig) -> Result<Self> {
94 Self::new_with_device(config, Device::CPU)
95 }
96
97 pub fn new_with_device(config: BertConfig, device: Device) -> Result<Self> {
98 let bert = BertModel::new_with_device(config.clone(), device)?;
99 let cls = BertLMHead::new_with_device(&config, device)?;
100
101 Ok(Self { bert, cls, device })
102 }
103
104 pub fn device(&self) -> Device {
105 self.device
106 }
107}
108
109#[derive(Debug, Clone)]
110struct BertLMHead {
111 dense: Linear,
112 layer_norm: trustformers_core::layers::LayerNorm,
113 decoder: Linear,
114 device: Device,
115}
116
117impl BertLMHead {
118 fn new(config: &BertConfig) -> Result<Self> {
119 Self::new_with_device(config, Device::CPU)
120 }
121
122 fn new_with_device(config: &BertConfig, device: Device) -> Result<Self> {
123 Ok(Self {
124 dense: Linear::new_with_device(config.hidden_size, config.hidden_size, true, device),
125 layer_norm: trustformers_core::layers::LayerNorm::new_with_device(
126 vec![config.hidden_size],
127 config.layer_norm_eps,
128 device,
129 )?,
130 decoder: Linear::new_with_device(config.hidden_size, config.vocab_size, true, device),
131 device,
132 })
133 }
134
135 fn device(&self) -> Device {
136 self.device
137 }
138
139 fn forward(&self, hidden_states: Tensor) -> Result<Tensor> {
140 let hidden_states = self.dense.forward(hidden_states)?;
141 let hidden_states = trustformers_core::ops::activations::gelu(&hidden_states)?;
142 let hidden_states = self.layer_norm.forward(hidden_states)?;
143 self.decoder.forward(hidden_states)
144 }
145
146 fn parameter_count(&self) -> usize {
147 self.dense.parameter_count()
148 + self.layer_norm.parameter_count()
149 + self.decoder.parameter_count()
150 }
151}
152
153#[derive(Debug)]
154pub struct MaskedLMOutput {
155 pub logits: Tensor,
156 pub hidden_states: Option<Tensor>,
157}
158
159impl Model for BertForMaskedLM {
160 type Config = BertConfig;
161 type Input = TokenizedInput;
162 type Output = MaskedLMOutput;
163
164 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
165 let bert_output = self.bert.forward(input)?;
166 let prediction_scores = self.cls.forward(bert_output.last_hidden_state.clone())?;
167
168 Ok(MaskedLMOutput {
169 logits: prediction_scores,
170 hidden_states: Some(bert_output.last_hidden_state),
171 })
172 }
173
174 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
175 self.bert.load_pretrained(reader)
176 }
177
178 fn get_config(&self) -> &Self::Config {
179 self.bert.get_config()
180 }
181
182 fn num_parameters(&self) -> usize {
183 self.bert.num_parameters() + self.cls.parameter_count()
184 }
185}
186
187#[derive(Debug, Clone)]
188pub struct BertForTokenClassification {
189 bert: BertModel,
190 classifier: Linear,
191 #[allow(dead_code)]
192 num_labels: usize,
193 device: Device,
194}
195
196impl BertForTokenClassification {
197 pub fn new(config: BertConfig, num_labels: usize) -> Result<Self> {
198 Self::new_with_device(config, num_labels, Device::CPU)
199 }
200
201 pub fn new_with_device(config: BertConfig, num_labels: usize, device: Device) -> Result<Self> {
202 let bert = BertModel::new_with_device(config.clone(), device)?;
203 let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
204
205 Ok(Self {
206 bert,
207 classifier,
208 num_labels,
209 device,
210 })
211 }
212
213 pub fn device(&self) -> Device {
214 self.device
215 }
216}
217
218#[derive(Debug)]
219pub struct TokenClassifierOutput {
220 pub logits: Tensor,
221 pub hidden_states: Option<Tensor>,
222}
223
224impl Model for BertForTokenClassification {
225 type Config = BertConfig;
226 type Input = TokenizedInput;
227 type Output = TokenClassifierOutput;
228
229 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
230 let bert_output = self.bert.forward(input)?;
231 let sequence_output = bert_output.last_hidden_state;
232
233 let logits = self.classifier.forward(sequence_output.clone())?;
234
235 Ok(TokenClassifierOutput {
236 logits,
237 hidden_states: Some(sequence_output),
238 })
239 }
240
241 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
242 self.bert.load_pretrained(reader)
243 }
244
245 fn get_config(&self) -> &Self::Config {
246 self.bert.get_config()
247 }
248
249 fn num_parameters(&self) -> usize {
250 self.bert.num_parameters() + self.classifier.parameter_count()
251 }
252}
253
254#[derive(Debug, Clone)]
255pub struct BertForQuestionAnswering {
256 bert: BertModel,
257 qa_outputs: Linear,
258 device: Device,
259}
260
261impl BertForQuestionAnswering {
262 pub fn new(config: BertConfig) -> Result<Self> {
263 Self::new_with_device(config, Device::CPU)
264 }
265
266 pub fn new_with_device(config: BertConfig, device: Device) -> Result<Self> {
267 let bert = BertModel::new_with_device(config.clone(), device)?;
268 let qa_outputs = Linear::new_with_device(config.hidden_size, 2, true, device);
270
271 Ok(Self {
272 bert,
273 qa_outputs,
274 device,
275 })
276 }
277
278 pub fn device(&self) -> Device {
279 self.device
280 }
281}
282
283#[derive(Debug)]
284pub struct QuestionAnsweringOutput {
285 pub start_logits: Tensor,
286 pub end_logits: Tensor,
287 pub hidden_states: Option<Tensor>,
288}
289
290impl Model for BertForQuestionAnswering {
291 type Config = BertConfig;
292 type Input = TokenizedInput;
293 type Output = QuestionAnsweringOutput;
294
295 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
296 let bert_output = self.bert.forward(input)?;
297 let sequence_output = bert_output.last_hidden_state;
298
299 let logits = self.qa_outputs.forward(sequence_output.clone())?;
300
301 let split_logits = logits.split(logits.shape().len() - 1, 1)?;
303 if split_logits.len() != 2 {
304 return Err(trustformers_core::errors::TrustformersError::model_error(
305 "Expected 2 QA outputs (start and end), got different number".to_string(),
306 ));
307 }
308
309 let start_logits = split_logits[0].clone();
310 let end_logits = split_logits[1].clone();
311
312 Ok(QuestionAnsweringOutput {
313 start_logits,
314 end_logits,
315 hidden_states: Some(sequence_output),
316 })
317 }
318
319 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
320 self.bert.load_pretrained(reader)
321 }
322
323 fn get_config(&self) -> &Self::Config {
324 self.bert.get_config()
325 }
326
327 fn num_parameters(&self) -> usize {
328 self.bert.num_parameters() + self.qa_outputs.parameter_count()
329 }
330}