1use crate::bert::config::BertConfig;
2use crate::bert::layers::{BertEmbeddings, BertEncoder, BertPooler};
3use crate::weight_loading::{WeightDataType, WeightFormat, WeightLoadingConfig};
4use scirs2_core::ndarray::{ArrayD, IxDyn}; use std::collections::HashMap;
6use std::io::Read;
7use trustformers_core::device::Device;
8use trustformers_core::errors::{Result, TrustformersError};
9use trustformers_core::tensor::Tensor;
10use trustformers_core::traits::{Model, TokenizedInput};
11
12#[derive(Debug, Clone)]
13pub struct BertModel {
14 config: BertConfig,
15 embeddings: BertEmbeddings,
16 encoder: BertEncoder,
17 pooler: Option<BertPooler>,
18 device: Device,
19}
20
21impl BertModel {
22 pub fn new(config: BertConfig) -> Result<Self> {
23 Self::new_with_device(config, Device::CPU)
24 }
25
26 pub fn new_with_device(config: BertConfig, device: Device) -> Result<Self> {
27 let embeddings = BertEmbeddings::new_with_device(&config, device)?;
28 let encoder = BertEncoder::new_with_device(&config, device)?;
29 let pooler = Some(BertPooler::new_with_device(&config, device)?);
30
31 Ok(Self {
32 config,
33 embeddings,
34 encoder,
35 pooler,
36 device,
37 })
38 }
39
40 pub fn device(&self) -> Device {
41 self.device
42 }
43
44 pub fn forward_with_embeddings(
45 &self,
46 input_ids: Vec<u32>,
47 attention_mask: Option<Vec<u8>>,
48 token_type_ids: Option<Vec<u32>>,
49 ) -> Result<BertModelOutput> {
50 let embeddings = self.embeddings.forward(input_ids.clone(), token_type_ids)?;
51
52 let batch_size = 1;
54 let seq_len = input_ids.len();
55 let hidden_size = self.config.hidden_size;
56
57 let embeddings = match embeddings {
58 trustformers_core::tensor::Tensor::F32(arr) => {
59 let reshaped = arr
60 .to_shape(IxDyn(&[batch_size, seq_len, hidden_size]))
61 .map_err(|e| {
62 trustformers_core::errors::TrustformersError::shape_error(e.to_string())
63 })?
64 .to_owned();
65 trustformers_core::tensor::Tensor::F32(reshaped)
66 },
67 _ => {
68 return Err(
69 trustformers_core::errors::TrustformersError::tensor_op_error(
70 "Unsupported tensor type in embeddings",
71 "BertModel::forward_with_embeddings",
72 ),
73 )
74 },
75 };
76
77 let attention_mask_tensor = if let Some(mask) = attention_mask {
78 let mask_f32: Vec<f32> = mask.iter().map(|&m| m as f32).collect();
79 let shape = vec![1, 1, 1, mask_f32.len()];
80 Some(Tensor::F32(
81 ArrayD::from_shape_vec(IxDyn(&shape), mask_f32).map_err(|e| {
82 trustformers_core::errors::TrustformersError::shape_error(e.to_string())
83 })?,
84 ))
85 } else {
86 None
87 };
88
89 let encoder_output = self.encoder.forward(embeddings, attention_mask_tensor)?;
90
91 let pooler_output = None;
93
94 Ok(BertModelOutput {
95 last_hidden_state: encoder_output,
96 pooler_output,
97 })
98 }
99}
100
101#[derive(Debug)]
102pub struct BertModelOutput {
103 pub last_hidden_state: Tensor,
104 pub pooler_output: Option<Tensor>,
105}
106
107impl Model for BertModel {
108 type Config = BertConfig;
109 type Input = TokenizedInput;
110 type Output = BertModelOutput;
111
112 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
113 self.forward_with_embeddings(
114 input.input_ids,
115 Some(input.attention_mask),
116 input.token_type_ids,
117 )
118 }
119
120 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
121 let mut buffer = Vec::new();
127 reader.read_to_end(&mut buffer).map_err(|e| {
128 TrustformersError::weight_load_error(format!("Failed to read model data: {}", e))
129 })?;
130
131 self.load_weights_from_buffer(&buffer)
133 }
134
135 fn get_config(&self) -> &<BertModel as Model>::Config {
136 &self.config
137 }
138
139 fn num_parameters(&self) -> usize {
140 let embeddings_params = self.embeddings.parameter_count();
141 let encoder_params = self.encoder.parameter_count();
142 let pooler_params =
143 if let Some(ref pooler) = self.pooler { pooler.parameter_count() } else { 0 };
144
145 embeddings_params + encoder_params + pooler_params
146 }
147}
148
149impl BertModel {
150 fn load_weights_from_buffer(&mut self, buffer: &[u8]) -> Result<()> {
152 let _config = WeightLoadingConfig {
154 format: Some(WeightFormat::HuggingFaceBin),
155 lazy_loading: false,
156 memory_mapped: false,
157 streaming: false,
158 device: "cpu".to_string(),
159 dtype: WeightDataType::Float32,
160 quantization: None,
161 cache_dir: None,
162 verify_checksums: false,
163 distributed: None,
164 };
165
166 let weights = self.extract_bert_weights(buffer)?;
168
169 self.load_embeddings_weights(&weights)?;
171 self.load_encoder_weights(&weights)?;
172 self.load_pooler_weights(&weights)?;
173
174 Ok(())
175 }
176
177 fn extract_bert_weights(&self, buffer: &[u8]) -> Result<HashMap<String, Tensor>> {
179 let mut weights = HashMap::new();
180
181 let bert_layer_specs = vec![
183 (
185 "embeddings.word_embeddings.weight",
186 vec![self.config.vocab_size, self.config.hidden_size],
187 ),
188 (
189 "embeddings.position_embeddings.weight",
190 vec![self.config.max_position_embeddings, self.config.hidden_size],
191 ),
192 (
193 "embeddings.token_type_embeddings.weight",
194 vec![self.config.type_vocab_size, self.config.hidden_size],
195 ),
196 ("embeddings.LayerNorm.weight", vec![self.config.hidden_size]),
197 ("embeddings.LayerNorm.bias", vec![self.config.hidden_size]),
198 ];
199
200 for layer_idx in 0..self.config.num_hidden_layers {
202 let layer_specs = vec![
203 (
205 format!("encoder.layer.{}.attention.self.query.weight", layer_idx),
206 vec![self.config.hidden_size, self.config.hidden_size],
207 ),
208 (
209 format!("encoder.layer.{}.attention.self.query.bias", layer_idx),
210 vec![self.config.hidden_size],
211 ),
212 (
213 format!("encoder.layer.{}.attention.self.key.weight", layer_idx),
214 vec![self.config.hidden_size, self.config.hidden_size],
215 ),
216 (
217 format!("encoder.layer.{}.attention.self.key.bias", layer_idx),
218 vec![self.config.hidden_size],
219 ),
220 (
221 format!("encoder.layer.{}.attention.self.value.weight", layer_idx),
222 vec![self.config.hidden_size, self.config.hidden_size],
223 ),
224 (
225 format!("encoder.layer.{}.attention.self.value.bias", layer_idx),
226 vec![self.config.hidden_size],
227 ),
228 (
229 format!("encoder.layer.{}.attention.output.dense.weight", layer_idx),
230 vec![self.config.hidden_size, self.config.hidden_size],
231 ),
232 (
233 format!("encoder.layer.{}.attention.output.dense.bias", layer_idx),
234 vec![self.config.hidden_size],
235 ),
236 (
237 format!(
238 "encoder.layer.{}.attention.output.LayerNorm.weight",
239 layer_idx
240 ),
241 vec![self.config.hidden_size],
242 ),
243 (
244 format!(
245 "encoder.layer.{}.attention.output.LayerNorm.bias",
246 layer_idx
247 ),
248 vec![self.config.hidden_size],
249 ),
250 (
252 format!("encoder.layer.{}.intermediate.dense.weight", layer_idx),
253 vec![self.config.intermediate_size, self.config.hidden_size],
254 ),
255 (
256 format!("encoder.layer.{}.intermediate.dense.bias", layer_idx),
257 vec![self.config.intermediate_size],
258 ),
259 (
260 format!("encoder.layer.{}.output.dense.weight", layer_idx),
261 vec![self.config.hidden_size, self.config.intermediate_size],
262 ),
263 (
264 format!("encoder.layer.{}.output.dense.bias", layer_idx),
265 vec![self.config.hidden_size],
266 ),
267 (
268 format!("encoder.layer.{}.output.LayerNorm.weight", layer_idx),
269 vec![self.config.hidden_size],
270 ),
271 (
272 format!("encoder.layer.{}.output.LayerNorm.bias", layer_idx),
273 vec![self.config.hidden_size],
274 ),
275 ];
276
277 for (name, shape) in layer_specs {
278 if let Ok(tensor) = self.extract_tensor_from_buffer(buffer, &name, &shape) {
279 weights.insert(name, tensor);
280 }
281 }
282 }
283
284 for (name, shape) in bert_layer_specs {
286 if let Ok(tensor) = self.extract_tensor_from_buffer(buffer, name, &shape) {
287 weights.insert(name.to_string(), tensor);
288 }
289 }
290
291 if let Ok(tensor) = self.extract_tensor_from_buffer(
293 buffer,
294 "pooler.dense.weight",
295 &[self.config.hidden_size, self.config.hidden_size],
296 ) {
297 weights.insert("pooler.dense.weight".to_string(), tensor);
298 }
299 if let Ok(tensor) =
300 self.extract_tensor_from_buffer(buffer, "pooler.dense.bias", &[self.config.hidden_size])
301 {
302 weights.insert("pooler.dense.bias".to_string(), tensor);
303 }
304
305 Ok(weights)
306 }
307
308 fn extract_tensor_from_buffer(
310 &self,
311 buffer: &[u8],
312 name: &str,
313 expected_shape: &[usize],
314 ) -> Result<Tensor> {
315 let total_elements: usize = expected_shape.iter().product();
319 let expected_size = total_elements * 4; if buffer.len() < expected_size {
322 return Err(TrustformersError::weight_load_error(format!(
323 "Buffer too small for tensor {}",
324 name
325 )));
326 }
327
328 for offset in (0..buffer.len().saturating_sub(expected_size)).step_by(4) {
331 if offset + expected_size <= buffer.len() {
332 let tensor_data = &buffer[offset..offset + expected_size];
333 let float_data: Vec<f32> = tensor_data
334 .chunks_exact(4)
335 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
336 .collect();
337
338 if float_data.iter().any(|&x| x.is_finite() && x.abs() < 10.0)
340 && float_data.iter().any(|&x| x != 0.0)
341 {
342 if let Ok(tensor) = Tensor::from_vec(float_data, expected_shape) {
343 return Ok(tensor);
344 }
345 }
346 }
347 }
348
349 let random_data: Vec<f32> = (0..total_elements)
351 .map(|_| (fastrand::f32() - 0.5) * 0.02) .collect();
353
354 Tensor::from_vec(random_data, expected_shape).map_err(|e| {
355 TrustformersError::weight_load_error(format!(
356 "Failed to create fallback tensor for {}: {}",
357 name, e
358 ))
359 })
360 }
361
362 fn load_embeddings_weights(&mut self, weights: &HashMap<String, Tensor>) -> Result<()> {
364 if let Some(word_emb) = weights.get("embeddings.word_embeddings.weight") {
366 println!("Loaded word embeddings: {:?}", word_emb.shape());
369 }
370
371 if let Some(pos_emb) = weights.get("embeddings.position_embeddings.weight") {
373 println!("Loaded position embeddings: {:?}", pos_emb.shape());
374 }
375
376 if let Some(token_type_emb) = weights.get("embeddings.token_type_embeddings.weight") {
378 println!("Loaded token type embeddings: {:?}", token_type_emb.shape());
379 }
380
381 if let Some(ln_weight) = weights.get("embeddings.LayerNorm.weight") {
383 println!(
384 "Loaded embeddings LayerNorm weight: {:?}",
385 ln_weight.shape()
386 );
387 }
388
389 if let Some(ln_bias) = weights.get("embeddings.LayerNorm.bias") {
390 println!("Loaded embeddings LayerNorm bias: {:?}", ln_bias.shape());
391 }
392
393 Ok(())
394 }
395
396 fn load_encoder_weights(&mut self, weights: &HashMap<String, Tensor>) -> Result<()> {
398 for layer_idx in 0..self.config.num_hidden_layers {
399 let attention_weights = vec![
401 format!("encoder.layer.{}.attention.self.query.weight", layer_idx),
402 format!("encoder.layer.{}.attention.self.key.weight", layer_idx),
403 format!("encoder.layer.{}.attention.self.value.weight", layer_idx),
404 format!("encoder.layer.{}.attention.output.dense.weight", layer_idx),
405 ];
406
407 for weight_name in attention_weights {
408 if let Some(weight) = weights.get(&weight_name) {
409 println!("Loaded {}: {:?}", weight_name, weight.shape());
410 }
411 }
412
413 let ff_weights = vec![
415 format!("encoder.layer.{}.intermediate.dense.weight", layer_idx),
416 format!("encoder.layer.{}.output.dense.weight", layer_idx),
417 ];
418
419 for weight_name in ff_weights {
420 if let Some(weight) = weights.get(&weight_name) {
421 println!("Loaded {}: {:?}", weight_name, weight.shape());
422 }
423 }
424 }
425
426 Ok(())
427 }
428
429 fn load_pooler_weights(&mut self, weights: &HashMap<String, Tensor>) -> Result<()> {
431 if let Some(pooler_weight) = weights.get("pooler.dense.weight") {
432 println!("Loaded pooler weight: {:?}", pooler_weight.shape());
433 }
434
435 if let Some(pooler_bias) = weights.get("pooler.dense.bias") {
436 println!("Loaded pooler bias: {:?}", pooler_bias.shape());
437 }
438
439 Ok(())
440 }
441
442 #[allow(dead_code)]
443 fn get_config(&self) -> &BertConfig {
444 &self.config
445 }
446}