1use crate::error::{NeuralError, Result};
9use crate::layers::{
10 Dense, Dropout, Embedding, EmbeddingConfig, Layer, LayerNorm, MultiHeadAttention,
11};
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
14use scirs2_core::random::SeedableRng;
15use scirs2_core::simd_ops::SimdUnifiedOps;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fmt::Debug;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BertConfig {
23 pub vocab_size: usize,
25 pub max_position_embeddings: usize,
27 pub hidden_size: usize,
29 pub num_hidden_layers: usize,
31 pub num_attention_heads: usize,
33 pub intermediate_size: usize,
35 pub hidden_act: String,
37 pub hidden_dropout_prob: f64,
39 pub attention_probs_dropout_prob: f64,
41 pub type_vocab_size: usize,
43 pub layer_norm_eps: f64,
45 pub initializer_range: f64,
47}
48
49impl BertConfig {
50 pub fn bert_base_uncased() -> Self {
52 Self {
53 vocab_size: 30522,
54 max_position_embeddings: 512,
55 hidden_size: 768,
56 num_hidden_layers: 12,
57 num_attention_heads: 12,
58 intermediate_size: 3072,
59 hidden_act: "gelu".to_string(),
60 hidden_dropout_prob: 0.1,
61 attention_probs_dropout_prob: 0.1,
62 type_vocab_size: 2,
63 layer_norm_eps: 1e-12,
64 initializer_range: 0.02,
65 }
66 }
67
68 pub fn bert_large_uncased() -> Self {
70 Self {
71 vocab_size: 30522,
72 max_position_embeddings: 512,
73 hidden_size: 1024,
74 num_hidden_layers: 24,
75 num_attention_heads: 16,
76 intermediate_size: 4096,
77 hidden_act: "gelu".to_string(),
78 hidden_dropout_prob: 0.1,
79 attention_probs_dropout_prob: 0.1,
80 type_vocab_size: 2,
81 layer_norm_eps: 1e-12,
82 initializer_range: 0.02,
83 }
84 }
85
86 pub fn custom(
88 vocab_size: usize,
89 hidden_size: usize,
90 num_hidden_layers: usize,
91 num_attention_heads: usize,
92 ) -> Self {
93 Self {
94 vocab_size,
95 max_position_embeddings: 512,
96 hidden_size,
97 num_hidden_layers,
98 num_attention_heads,
99 intermediate_size: hidden_size * 4,
100 hidden_act: "gelu".to_string(),
101 hidden_dropout_prob: 0.1,
102 attention_probs_dropout_prob: 0.1,
103 type_vocab_size: 2,
104 layer_norm_eps: 1e-12,
105 initializer_range: 0.02,
106 }
107 }
108}
109
110struct BertEmbeddings<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static>
112where
113 F: SimdUnifiedOps,
114{
115 word_embeddings: Embedding<F>,
117 position_embeddings: Embedding<F>,
119 token_type_embeddings: Embedding<F>,
121 layer_norm: LayerNorm<F>,
123 dropout: Dropout<F>,
125}
126
127impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
128 for BertEmbeddings<F>
129{
130 fn clone(&self) -> Self {
131 Self {
132 word_embeddings: self.word_embeddings.clone(),
133 position_embeddings: self.position_embeddings.clone(),
134 token_type_embeddings: self.token_type_embeddings.clone(),
135 layer_norm: self.layer_norm.clone(),
136 dropout: self.dropout.clone(),
137 }
138 }
139}
140
141impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
142 BertEmbeddings<F>
143{
144 pub fn new(config: &BertConfig) -> Result<Self> {
146 let word_embeddings = Embedding::new(EmbeddingConfig {
147 num_embeddings: config.vocab_size,
148 embedding_dim: config.hidden_size,
149 padding_idx: None,
150 max_norm: None,
151 norm_type: 2.0,
152 scale_grad_by_freq: false,
153 })?;
154
155 let position_embeddings = Embedding::new(EmbeddingConfig {
156 num_embeddings: config.max_position_embeddings,
157 embedding_dim: config.hidden_size,
158 padding_idx: None,
159 max_norm: None,
160 norm_type: 2.0,
161 scale_grad_by_freq: false,
162 })?;
163
164 let token_type_embeddings = Embedding::new(EmbeddingConfig {
165 num_embeddings: config.type_vocab_size,
166 embedding_dim: config.hidden_size,
167 padding_idx: None,
168 max_norm: None,
169 norm_type: 2.0,
170 scale_grad_by_freq: false,
171 })?;
172
173 let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
174 let layer_norm = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng4)?;
175
176 let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
177 let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng5)?;
178
179 Ok(Self {
180 word_embeddings,
181 position_embeddings,
182 token_type_embeddings,
183 layer_norm,
184 dropout,
185 })
186 }
187}
188
189impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
190 for BertEmbeddings<F>
191{
192 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
193 let shape = input.shape();
195 if shape.len() != 2 {
196 return Err(NeuralError::InferenceError(format!(
197 "Expected input shape [batch_size, seq_len], got {:?}",
198 shape
199 )));
200 }
201
202 let batch_size = shape[0];
203 let seq_len = shape[1];
204
205 let inputs_embeds = self.word_embeddings.forward(input)?;
207
208 let mut position_ids = Array::zeros(IxDyn(&[batch_size, seq_len]));
210 for b in 0..batch_size {
211 for s in 0..seq_len {
212 position_ids[[b, s]] = F::from(s).expect("Failed to convert to float");
213 }
214 }
215
216 let position_embeds = self.position_embeddings.forward(&position_ids)?;
218
219 let token_type_ids = Array::zeros(IxDyn(&[batch_size, seq_len]));
221
222 let token_type_embeds = self.token_type_embeddings.forward(&token_type_ids)?;
224
225 let embeddings = &inputs_embeds + &position_embeds + &token_type_embeds;
227
228 let embeddings = self.layer_norm.forward(&embeddings)?;
230
231 let embeddings = self.dropout.forward(&embeddings)?;
233
234 Ok(embeddings)
235 }
236
237 fn backward(
238 &self,
239 _input: &Array<F, IxDyn>,
240 grad_output: &Array<F, IxDyn>,
241 ) -> Result<Array<F, IxDyn>> {
242 Ok(grad_output.clone())
243 }
244
245 fn update(&mut self, learning_rate: F) -> Result<()> {
246 self.word_embeddings.update(learning_rate)?;
247 self.position_embeddings.update(learning_rate)?;
248 self.token_type_embeddings.update(learning_rate)?;
249 self.layer_norm.update(learning_rate)?;
250 Ok(())
251 }
252
253 fn as_any(&self) -> &dyn std::any::Any {
254 self
255 }
256
257 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
258 self
259 }
260}
261
262struct BertSelfAttention<
264 F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign,
265> {
266 attention: MultiHeadAttention<F>,
268 dropout: Dropout<F>,
270}
271
272impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
273 for BertSelfAttention<F>
274{
275 fn clone(&self) -> Self {
276 Self {
277 attention: self.attention.clone(),
278 dropout: self.dropout.clone(),
279 }
280 }
281}
282
283impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
284 BertSelfAttention<F>
285{
286 pub fn new(config: &BertConfig) -> Result<Self> {
288 let head_dim = config.hidden_size / config.num_attention_heads;
289 let attn_config = crate::layers::AttentionConfig {
290 num_heads: config.num_attention_heads,
291 head_dim,
292 dropout_prob: config.attention_probs_dropout_prob,
293 causal: false,
294 scale: None,
295 };
296
297 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
298 let attention = MultiHeadAttention::new(config.hidden_size, attn_config, &mut rng)?;
299
300 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
301 let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng2)?;
302
303 Ok(Self { attention, dropout })
304 }
305}
306
307impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
308 for BertSelfAttention<F>
309{
310 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
311 let attention_output = self.attention.forward(input)?;
312 let attention_output = self.dropout.forward(&attention_output)?;
313 Ok(attention_output)
314 }
315
316 fn backward(
317 &self,
318 _input: &Array<F, IxDyn>,
319 grad_output: &Array<F, IxDyn>,
320 ) -> Result<Array<F, IxDyn>> {
321 Ok(grad_output.clone())
322 }
323
324 fn update(&mut self, learning_rate: F) -> Result<()> {
325 self.attention.update(learning_rate)?;
326 Ok(())
327 }
328
329 fn as_any(&self) -> &dyn std::any::Any {
330 self
331 }
332
333 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
334 self
335 }
336}
337
338struct BertFeedForward<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static>
340where
341 F: SimdUnifiedOps,
342{
343 intermediate_dense: Dense<F>,
345 output_dense: Dense<F>,
347 layer_norm: LayerNorm<F>,
349 dropout: Dropout<F>,
351}
352
353impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
354 for BertFeedForward<F>
355{
356 fn clone(&self) -> Self {
357 Self {
358 intermediate_dense: self.intermediate_dense.clone(),
359 output_dense: self.output_dense.clone(),
360 layer_norm: self.layer_norm.clone(),
361 dropout: self.dropout.clone(),
362 }
363 }
364}
365
366impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
367 BertFeedForward<F>
368{
369 pub fn new(config: &BertConfig) -> Result<Self> {
371 let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
372 let intermediate_dense = Dense::new(
373 config.hidden_size,
374 config.intermediate_size,
375 None,
376 &mut rng1,
377 )?;
378
379 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
380 let output_dense = Dense::new(
381 config.intermediate_size,
382 config.hidden_size,
383 None,
384 &mut rng2,
385 )?;
386
387 let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
388 let layer_norm = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng3)?;
389
390 let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([52; 32]);
391 let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng4)?;
392
393 Ok(Self {
394 intermediate_dense,
395 output_dense,
396 layer_norm,
397 dropout,
398 })
399 }
400}
401
402impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
403 for BertFeedForward<F>
404{
405 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
406 let hidden = self.intermediate_dense.forward(input)?;
408 let hidden = hidden.mapv(|v: F| {
409 let x3 = v * v * v;
411 v * F::from(0.5).expect("Failed to convert constant to float")
412 * (F::one()
413 + (v + F::from(0.044715).expect("Failed to convert constant to float") * x3)
414 .tanh())
415 });
416
417 let output = self.output_dense.forward(&hidden)?;
419 let output = self.dropout.forward(&output)?;
420
421 let output = input + &output;
423 let output = self.layer_norm.forward(&output)?;
424
425 Ok(output)
426 }
427
428 fn backward(
429 &self,
430 _input: &Array<F, IxDyn>,
431 grad_output: &Array<F, IxDyn>,
432 ) -> Result<Array<F, IxDyn>> {
433 Ok(grad_output.clone())
434 }
435
436 fn update(&mut self, learning_rate: F) -> Result<()> {
437 self.intermediate_dense.update(learning_rate)?;
438 self.output_dense.update(learning_rate)?;
439 self.layer_norm.update(learning_rate)?;
440 Ok(())
441 }
442
443 fn as_any(&self) -> &dyn std::any::Any {
444 self
445 }
446
447 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
448 self
449 }
450}
451
452struct BertLayer<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign> {
454 attention: BertSelfAttention<F>,
456 attention_layer_norm: LayerNorm<F>,
458 feed_forward: BertFeedForward<F>,
460}
461
462impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
463 for BertLayer<F>
464{
465 fn clone(&self) -> Self {
466 Self {
467 attention: self.attention.clone(),
468 attention_layer_norm: self.attention_layer_norm.clone(),
469 feed_forward: self.feed_forward.clone(),
470 }
471 }
472}
473
474impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
475 BertLayer<F>
476{
477 pub fn new(config: &BertConfig) -> Result<Self> {
479 let attention = BertSelfAttention::new(config)?;
480
481 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([53; 32]);
482 let attention_layer_norm =
483 LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng)?;
484
485 let feed_forward = BertFeedForward::new(config)?;
486
487 Ok(Self {
488 attention,
489 attention_layer_norm,
490 feed_forward,
491 })
492 }
493}
494
495impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
496 for BertLayer<F>
497{
498 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
499 let attention_output = self.attention.forward(input)?;
501 let attention_output = input + &attention_output;
502 let attention_output = self.attention_layer_norm.forward(&attention_output)?;
503
504 let layer_output = self.feed_forward.forward(&attention_output)?;
506
507 Ok(layer_output)
508 }
509
510 fn backward(
511 &self,
512 _input: &Array<F, IxDyn>,
513 grad_output: &Array<F, IxDyn>,
514 ) -> Result<Array<F, IxDyn>> {
515 Ok(grad_output.clone())
516 }
517
518 fn update(&mut self, learning_rate: F) -> Result<()> {
519 self.attention.update(learning_rate)?;
520 self.attention_layer_norm.update(learning_rate)?;
521 self.feed_forward.update(learning_rate)?;
522 Ok(())
523 }
524
525 fn as_any(&self) -> &dyn std::any::Any {
526 self
527 }
528
529 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
530 self
531 }
532}
533
534struct BertEncoder<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign> {
536 layers: Vec<BertLayer<F>>,
538}
539
540impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
541 for BertEncoder<F>
542{
543 fn clone(&self) -> Self {
544 Self {
545 layers: self.layers.clone(),
546 }
547 }
548}
549
550impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
551 BertEncoder<F>
552{
553 pub fn new(config: &BertConfig) -> Result<Self> {
555 let mut layers = Vec::with_capacity(config.num_hidden_layers);
556 for _ in 0..config.num_hidden_layers {
557 layers.push(BertLayer::new(config)?);
558 }
559
560 Ok(Self { layers })
561 }
562}
563
564impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
565 for BertEncoder<F>
566{
567 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
568 let mut hidden_states = input.clone();
569 for layer in &self.layers {
570 hidden_states = layer.forward(&hidden_states)?;
571 }
572 Ok(hidden_states)
573 }
574
575 fn backward(
576 &self,
577 _input: &Array<F, IxDyn>,
578 grad_output: &Array<F, IxDyn>,
579 ) -> Result<Array<F, IxDyn>> {
580 Ok(grad_output.clone())
581 }
582
583 fn update(&mut self, learning_rate: F) -> Result<()> {
584 for layer in &mut self.layers {
585 layer.update(learning_rate)?;
586 }
587 Ok(())
588 }
589
590 fn as_any(&self) -> &dyn std::any::Any {
591 self
592 }
593
594 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
595 self
596 }
597}
598
599struct BertPooler<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
601 dense: Dense<F>,
603}
604
605impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
606 for BertPooler<F>
607{
608 fn clone(&self) -> Self {
609 Self {
610 dense: self.dense.clone(),
611 }
612 }
613}
614
615impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
616 BertPooler<F>
617{
618 pub fn new(config: &BertConfig) -> Result<Self> {
620 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([54; 32]);
621 let dense = Dense::new(config.hidden_size, config.hidden_size, None, &mut rng)?;
622
623 Ok(Self { dense })
624 }
625}
626
627impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
628 for BertPooler<F>
629{
630 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
631 let shape = input.shape();
633 if shape.len() != 3 {
634 return Err(NeuralError::InferenceError(format!(
635 "Expected input shape [batch_size, seq_len, hidden_size], got {:?}",
636 shape
637 )));
638 }
639
640 let batch_size = shape[0];
641 let hidden_size = shape[2];
642
643 let mut cls_tokens = Array::zeros(IxDyn(&[batch_size, hidden_size]));
645 for b in 0..batch_size {
646 for i in 0..hidden_size {
647 cls_tokens[[b, i]] = input[[b, 0, i]];
648 }
649 }
650
651 let pooled_output = self.dense.forward(&cls_tokens)?;
653
654 let pooled_output = pooled_output.mapv(|x: F| x.tanh());
656
657 Ok(pooled_output)
658 }
659
660 fn backward(
661 &self,
662 _input: &Array<F, IxDyn>,
663 grad_output: &Array<F, IxDyn>,
664 ) -> Result<Array<F, IxDyn>> {
665 Ok(grad_output.clone())
666 }
667
668 fn update(&mut self, learning_rate: F) -> Result<()> {
669 self.dense.update(learning_rate)?;
670 Ok(())
671 }
672
673 fn as_any(&self) -> &dyn std::any::Any {
674 self
675 }
676
677 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
678 self
679 }
680}
681
682pub struct BertModel<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign> {
684 embeddings: BertEmbeddings<F>,
686 encoder: BertEncoder<F>,
688 pooler: BertPooler<F>,
690 config: BertConfig,
692}
693
694impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
695 for BertModel<F>
696{
697 fn clone(&self) -> Self {
698 Self {
699 embeddings: self.embeddings.clone(),
700 encoder: self.encoder.clone(),
701 pooler: self.pooler.clone(),
702 config: self.config.clone(),
703 }
704 }
705}
706
707impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
708 BertModel<F>
709{
710 pub fn new(config: BertConfig) -> Result<Self> {
712 let embeddings = BertEmbeddings::new(&config)?;
713 let encoder = BertEncoder::new(&config)?;
714 let pooler = BertPooler::new(&config)?;
715
716 Ok(Self {
717 embeddings,
718 encoder,
719 pooler,
720 config,
721 })
722 }
723
724 pub fn bert_base_uncased() -> Result<Self> {
726 let config = BertConfig::bert_base_uncased();
727 Self::new(config)
728 }
729
730 pub fn bert_large_uncased() -> Result<Self> {
732 let config = BertConfig::bert_large_uncased();
733 Self::new(config)
734 }
735
736 pub fn custom(
738 vocab_size: usize,
739 hidden_size: usize,
740 num_hidden_layers: usize,
741 num_attention_heads: usize,
742 ) -> Result<Self> {
743 let config = BertConfig::custom(
744 vocab_size,
745 hidden_size,
746 num_hidden_layers,
747 num_attention_heads,
748 );
749 Self::new(config)
750 }
751
752 pub fn get_sequence_output(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
754 let embedding_output = self.embeddings.forward(input)?;
755 let sequence_output = self.encoder.forward(&embedding_output)?;
756 Ok(sequence_output)
757 }
758
759 pub fn get_pooled_output(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
761 let sequence_output = self.get_sequence_output(input)?;
762 let pooled_output = self.pooler.forward(&sequence_output)?;
763 Ok(pooled_output)
764 }
765
766 pub fn config(&self) -> &BertConfig {
768 &self.config
769 }
770}
771
772impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
773 for BertModel<F>
774{
775 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
776 self.get_sequence_output(input)
778 }
779
780 fn backward(
781 &self,
782 _input: &Array<F, IxDyn>,
783 grad_output: &Array<F, IxDyn>,
784 ) -> Result<Array<F, IxDyn>> {
785 Ok(grad_output.clone())
786 }
787
788 fn update(&mut self, learning_rate: F) -> Result<()> {
789 self.embeddings.update(learning_rate)?;
790 self.encoder.update(learning_rate)?;
791 self.pooler.update(learning_rate)?;
792 Ok(())
793 }
794
795 fn as_any(&self) -> &dyn std::any::Any {
796 self
797 }
798
799 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
800 self
801 }
802}
803
804impl<
805 F: Float
806 + Debug
807 + ScalarOperand
808 + Send
809 + Sync
810 + SimdUnifiedOps
811 + NumAssign
812 + ToPrimitive
813 + FromPrimitive
814 + 'static,
815 > BertModel<F>
816{
817 pub fn extract_named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>> {
828 let mut result = Vec::new();
829
830 for p in self.embeddings.word_embeddings.params().iter() {
832 result.push(("embeddings.word_embeddings.weight".to_string(), p.clone()));
833 }
834 for p in self.embeddings.position_embeddings.params().iter() {
835 result.push((
836 "embeddings.position_embeddings.weight".to_string(),
837 p.clone(),
838 ));
839 }
840 for p in self.embeddings.token_type_embeddings.params().iter() {
841 result.push((
842 "embeddings.token_type_embeddings.weight".to_string(),
843 p.clone(),
844 ));
845 }
846 let ln_params = self.embeddings.layer_norm.params();
847 if !ln_params.is_empty() {
848 result.push((
849 "embeddings.LayerNorm.weight".to_string(),
850 ln_params[0].clone(),
851 ));
852 }
853 if ln_params.len() >= 2 {
854 result.push((
855 "embeddings.LayerNorm.bias".to_string(),
856 ln_params[1].clone(),
857 ));
858 }
859
860 for (layer_idx, bert_layer) in self.encoder.layers.iter().enumerate() {
862 let prefix = format!("encoder.layer.{layer_idx}");
863
864 let attn_params = bert_layer.attention.attention.params();
866 if attn_params.len() >= 4 {
867 result.push((
868 format!("{prefix}.attention.self.query.weight"),
869 attn_params[0].clone(),
870 ));
871 result.push((
872 format!("{prefix}.attention.self.key.weight"),
873 attn_params[1].clone(),
874 ));
875 result.push((
876 format!("{prefix}.attention.self.value.weight"),
877 attn_params[2].clone(),
878 ));
879 result.push((
880 format!("{prefix}.attention.output.dense.weight"),
881 attn_params[3].clone(),
882 ));
883 } else if attn_params.len() == 3 {
884 result.push((
885 format!("{prefix}.attention.self.query.weight"),
886 attn_params[0].clone(),
887 ));
888 result.push((
889 format!("{prefix}.attention.self.key.weight"),
890 attn_params[1].clone(),
891 ));
892 result.push((
893 format!("{prefix}.attention.self.value.weight"),
894 attn_params[2].clone(),
895 ));
896 }
897
898 let attn_ln_params = bert_layer.attention_layer_norm.params();
900 if !attn_ln_params.is_empty() {
901 result.push((
902 format!("{prefix}.attention.output.LayerNorm.weight"),
903 attn_ln_params[0].clone(),
904 ));
905 }
906 if attn_ln_params.len() >= 2 {
907 result.push((
908 format!("{prefix}.attention.output.LayerNorm.bias"),
909 attn_ln_params[1].clone(),
910 ));
911 }
912
913 let ff_inter_params = bert_layer.feed_forward.intermediate_dense.params();
915 if !ff_inter_params.is_empty() {
916 result.push((
917 format!("{prefix}.intermediate.dense.weight"),
918 ff_inter_params[0].clone(),
919 ));
920 }
921 if ff_inter_params.len() >= 2 {
922 result.push((
923 format!("{prefix}.intermediate.dense.bias"),
924 ff_inter_params[1].clone(),
925 ));
926 }
927
928 let ff_out_params = bert_layer.feed_forward.output_dense.params();
930 if !ff_out_params.is_empty() {
931 result.push((
932 format!("{prefix}.output.dense.weight"),
933 ff_out_params[0].clone(),
934 ));
935 }
936 if ff_out_params.len() >= 2 {
937 result.push((
938 format!("{prefix}.output.dense.bias"),
939 ff_out_params[1].clone(),
940 ));
941 }
942
943 let ff_ln_params = bert_layer.feed_forward.layer_norm.params();
945 if !ff_ln_params.is_empty() {
946 result.push((
947 format!("{prefix}.output.LayerNorm.weight"),
948 ff_ln_params[0].clone(),
949 ));
950 }
951 if ff_ln_params.len() >= 2 {
952 result.push((
953 format!("{prefix}.output.LayerNorm.bias"),
954 ff_ln_params[1].clone(),
955 ));
956 }
957 }
958
959 let pooler_params = self.pooler.dense.params();
961 if !pooler_params.is_empty() {
962 result.push(("pooler.dense.weight".to_string(), pooler_params[0].clone()));
963 }
964 if pooler_params.len() >= 2 {
965 result.push(("pooler.dense.bias".to_string(), pooler_params[1].clone()));
966 }
967
968 Ok(result)
969 }
970
971 pub fn load_named_params(
976 &mut self,
977 params_map: &HashMap<String, Array<F, IxDyn>>,
978 ) -> Result<()> {
979 if let Some(p) = params_map.get("embeddings.word_embeddings.weight") {
981 self.embeddings
982 .word_embeddings
983 .set_params(std::slice::from_ref(p))?;
984 }
985 if let Some(p) = params_map.get("embeddings.position_embeddings.weight") {
986 self.embeddings
987 .position_embeddings
988 .set_params(std::slice::from_ref(p))?;
989 }
990 if let Some(p) = params_map.get("embeddings.token_type_embeddings.weight") {
991 self.embeddings
992 .token_type_embeddings
993 .set_params(std::slice::from_ref(p))?;
994 }
995 {
996 let mut ln_ps = Vec::new();
997 if let Some(p) = params_map.get("embeddings.LayerNorm.weight") {
998 ln_ps.push(p.clone());
999 }
1000 if let Some(p) = params_map.get("embeddings.LayerNorm.bias") {
1001 ln_ps.push(p.clone());
1002 }
1003 if !ln_ps.is_empty() {
1004 self.embeddings.layer_norm.set_params(&ln_ps)?;
1005 }
1006 }
1007
1008 for (layer_idx, bert_layer) in self.encoder.layers.iter_mut().enumerate() {
1010 let prefix = format!("encoder.layer.{layer_idx}");
1011
1012 let mut attn_ps = Vec::new();
1014 if let Some(p) = params_map.get(&format!("{prefix}.attention.self.query.weight")) {
1015 attn_ps.push(p.clone());
1016 }
1017 if let Some(p) = params_map.get(&format!("{prefix}.attention.self.key.weight")) {
1018 attn_ps.push(p.clone());
1019 }
1020 if let Some(p) = params_map.get(&format!("{prefix}.attention.self.value.weight")) {
1021 attn_ps.push(p.clone());
1022 }
1023 if let Some(p) = params_map.get(&format!("{prefix}.attention.output.dense.weight")) {
1024 attn_ps.push(p.clone());
1025 }
1026 if !attn_ps.is_empty() {
1027 bert_layer.attention.attention.set_params(&attn_ps)?;
1028 }
1029
1030 {
1032 let mut ln_ps = Vec::new();
1033 if let Some(p) =
1034 params_map.get(&format!("{prefix}.attention.output.LayerNorm.weight"))
1035 {
1036 ln_ps.push(p.clone());
1037 }
1038 if let Some(p) =
1039 params_map.get(&format!("{prefix}.attention.output.LayerNorm.bias"))
1040 {
1041 ln_ps.push(p.clone());
1042 }
1043 if !ln_ps.is_empty() {
1044 bert_layer.attention_layer_norm.set_params(&ln_ps)?;
1045 }
1046 }
1047
1048 {
1050 let mut ff_ps = Vec::new();
1051 if let Some(p) = params_map.get(&format!("{prefix}.intermediate.dense.weight")) {
1052 ff_ps.push(p.clone());
1053 }
1054 if let Some(p) = params_map.get(&format!("{prefix}.intermediate.dense.bias")) {
1055 ff_ps.push(p.clone());
1056 }
1057 if !ff_ps.is_empty() {
1058 bert_layer
1059 .feed_forward
1060 .intermediate_dense
1061 .set_params(&ff_ps)?;
1062 }
1063 }
1064
1065 {
1067 let mut ff_ps = Vec::new();
1068 if let Some(p) = params_map.get(&format!("{prefix}.output.dense.weight")) {
1069 ff_ps.push(p.clone());
1070 }
1071 if let Some(p) = params_map.get(&format!("{prefix}.output.dense.bias")) {
1072 ff_ps.push(p.clone());
1073 }
1074 if !ff_ps.is_empty() {
1075 bert_layer.feed_forward.output_dense.set_params(&ff_ps)?;
1076 }
1077 }
1078
1079 {
1081 let mut ln_ps = Vec::new();
1082 if let Some(p) = params_map.get(&format!("{prefix}.output.LayerNorm.weight")) {
1083 ln_ps.push(p.clone());
1084 }
1085 if let Some(p) = params_map.get(&format!("{prefix}.output.LayerNorm.bias")) {
1086 ln_ps.push(p.clone());
1087 }
1088 if !ln_ps.is_empty() {
1089 bert_layer.feed_forward.layer_norm.set_params(&ln_ps)?;
1090 }
1091 }
1092 }
1093
1094 {
1096 let mut ps = Vec::new();
1097 if let Some(p) = params_map.get("pooler.dense.weight") {
1098 ps.push(p.clone());
1099 }
1100 if let Some(p) = params_map.get("pooler.dense.bias") {
1101 ps.push(p.clone());
1102 }
1103 if !ps.is_empty() {
1104 self.pooler.dense.set_params(&ps)?;
1105 }
1106 }
1107
1108 Ok(())
1109 }
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114 use super::*;
1115
1116 #[test]
1117 fn test_bert_config_base() {
1118 let config = BertConfig::bert_base_uncased();
1119 assert_eq!(config.vocab_size, 30522);
1120 assert_eq!(config.hidden_size, 768);
1121 assert_eq!(config.num_hidden_layers, 12);
1122 assert_eq!(config.num_attention_heads, 12);
1123 }
1124
1125 #[test]
1126 fn test_bert_config_large() {
1127 let config = BertConfig::bert_large_uncased();
1128 assert_eq!(config.hidden_size, 1024);
1129 assert_eq!(config.num_hidden_layers, 24);
1130 assert_eq!(config.num_attention_heads, 16);
1131 }
1132
1133 #[test]
1134 fn test_bert_config_custom() {
1135 let config = BertConfig::custom(10000, 256, 4, 4);
1136 assert_eq!(config.vocab_size, 10000);
1137 assert_eq!(config.hidden_size, 256);
1138 assert_eq!(config.num_hidden_layers, 4);
1139 assert_eq!(config.num_attention_heads, 4);
1140 assert_eq!(config.intermediate_size, 1024); }
1142}