Skip to main content

scirs2_neural/models/architectures/
bert.rs

1//! BERT implementation
2//!
3//! BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based
4//! model designed to pretrain deep bidirectional representations from unlabeled text.
5//! Reference: "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding", Devlin et al. (2018)
6//! <https://arxiv.org/abs/1810.04805>
7
8use crate::error::{NeuralError, Result};
9use crate::layers::{
10    Dense, Dropout, Embedding, EmbeddingConfig, Layer, LayerNorm, MultiHeadAttention,
11};
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
14use scirs2_core::random::SeedableRng;
15use scirs2_core::simd_ops::SimdUnifiedOps;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fmt::Debug;
19
20/// Configuration for a BERT model
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BertConfig {
23    /// Vocabulary size
24    pub vocab_size: usize,
25    /// Maximum sequence length
26    pub max_position_embeddings: usize,
27    /// Hidden size
28    pub hidden_size: usize,
29    /// Number of hidden layers
30    pub num_hidden_layers: usize,
31    /// Number of attention heads
32    pub num_attention_heads: usize,
33    /// Intermediate size in feed-forward networks
34    pub intermediate_size: usize,
35    /// Hidden activation function
36    pub hidden_act: String,
37    /// Hidden dropout probability
38    pub hidden_dropout_prob: f64,
39    /// Attention dropout probability
40    pub attention_probs_dropout_prob: f64,
41    /// Type vocabulary size (usually 2 for sentence pair tasks)
42    pub type_vocab_size: usize,
43    /// Layer norm epsilon
44    pub layer_norm_eps: f64,
45    /// Initializer range
46    pub initializer_range: f64,
47}
48
49impl BertConfig {
50    /// Create a BERT-Base configuration
51    pub fn bert_base_uncased() -> Self {
52        Self {
53            vocab_size: 30522,
54            max_position_embeddings: 512,
55            hidden_size: 768,
56            num_hidden_layers: 12,
57            num_attention_heads: 12,
58            intermediate_size: 3072,
59            hidden_act: "gelu".to_string(),
60            hidden_dropout_prob: 0.1,
61            attention_probs_dropout_prob: 0.1,
62            type_vocab_size: 2,
63            layer_norm_eps: 1e-12,
64            initializer_range: 0.02,
65        }
66    }
67
68    /// Create a BERT-Large configuration
69    pub fn bert_large_uncased() -> Self {
70        Self {
71            vocab_size: 30522,
72            max_position_embeddings: 512,
73            hidden_size: 1024,
74            num_hidden_layers: 24,
75            num_attention_heads: 16,
76            intermediate_size: 4096,
77            hidden_act: "gelu".to_string(),
78            hidden_dropout_prob: 0.1,
79            attention_probs_dropout_prob: 0.1,
80            type_vocab_size: 2,
81            layer_norm_eps: 1e-12,
82            initializer_range: 0.02,
83        }
84    }
85
86    /// Create a custom BERT configuration
87    pub fn custom(
88        vocab_size: usize,
89        hidden_size: usize,
90        num_hidden_layers: usize,
91        num_attention_heads: usize,
92    ) -> Self {
93        Self {
94            vocab_size,
95            max_position_embeddings: 512,
96            hidden_size,
97            num_hidden_layers,
98            num_attention_heads,
99            intermediate_size: hidden_size * 4,
100            hidden_act: "gelu".to_string(),
101            hidden_dropout_prob: 0.1,
102            attention_probs_dropout_prob: 0.1,
103            type_vocab_size: 2,
104            layer_norm_eps: 1e-12,
105            initializer_range: 0.02,
106        }
107    }
108}
109
110/// BERT embeddings combining token, position, and token type embeddings
111struct BertEmbeddings<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static>
112where
113    F: SimdUnifiedOps,
114{
115    /// Token embeddings
116    word_embeddings: Embedding<F>,
117    /// Position embeddings
118    position_embeddings: Embedding<F>,
119    /// Token type embeddings
120    token_type_embeddings: Embedding<F>,
121    /// Layer normalization
122    layer_norm: LayerNorm<F>,
123    /// Dropout
124    dropout: Dropout<F>,
125}
126
127impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
128    for BertEmbeddings<F>
129{
130    fn clone(&self) -> Self {
131        Self {
132            word_embeddings: self.word_embeddings.clone(),
133            position_embeddings: self.position_embeddings.clone(),
134            token_type_embeddings: self.token_type_embeddings.clone(),
135            layer_norm: self.layer_norm.clone(),
136            dropout: self.dropout.clone(),
137        }
138    }
139}
140
141impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
142    BertEmbeddings<F>
143{
144    /// Create BERT embeddings
145    pub fn new(config: &BertConfig) -> Result<Self> {
146        let word_embeddings = Embedding::new(EmbeddingConfig {
147            num_embeddings: config.vocab_size,
148            embedding_dim: config.hidden_size,
149            padding_idx: None,
150            max_norm: None,
151            norm_type: 2.0,
152            scale_grad_by_freq: false,
153        })?;
154
155        let position_embeddings = Embedding::new(EmbeddingConfig {
156            num_embeddings: config.max_position_embeddings,
157            embedding_dim: config.hidden_size,
158            padding_idx: None,
159            max_norm: None,
160            norm_type: 2.0,
161            scale_grad_by_freq: false,
162        })?;
163
164        let token_type_embeddings = Embedding::new(EmbeddingConfig {
165            num_embeddings: config.type_vocab_size,
166            embedding_dim: config.hidden_size,
167            padding_idx: None,
168            max_norm: None,
169            norm_type: 2.0,
170            scale_grad_by_freq: false,
171        })?;
172
173        let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
174        let layer_norm = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng4)?;
175
176        let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
177        let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng5)?;
178
179        Ok(Self {
180            word_embeddings,
181            position_embeddings,
182            token_type_embeddings,
183            layer_norm,
184            dropout,
185        })
186    }
187}
188
189impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
190    for BertEmbeddings<F>
191{
192    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
193        // Input should be of shape [batch_size, seq_len] and contain token IDs
194        let shape = input.shape();
195        if shape.len() != 2 {
196            return Err(NeuralError::InferenceError(format!(
197                "Expected input shape [batch_size, seq_len], got {:?}",
198                shape
199            )));
200        }
201
202        let batch_size = shape[0];
203        let seq_len = shape[1];
204
205        // Get word embeddings
206        let inputs_embeds = self.word_embeddings.forward(input)?;
207
208        // Create position IDs
209        let mut position_ids = Array::zeros(IxDyn(&[batch_size, seq_len]));
210        for b in 0..batch_size {
211            for s in 0..seq_len {
212                position_ids[[b, s]] = F::from(s).expect("Failed to convert to float");
213            }
214        }
215
216        // Get position embeddings
217        let position_embeds = self.position_embeddings.forward(&position_ids)?;
218
219        // Create token type IDs (all zeros for single sequence)
220        let token_type_ids = Array::zeros(IxDyn(&[batch_size, seq_len]));
221
222        // Get token type embeddings
223        let token_type_embeds = self.token_type_embeddings.forward(&token_type_ids)?;
224
225        // Combine embeddings
226        let embeddings = &inputs_embeds + &position_embeds + &token_type_embeds;
227
228        // Apply layer normalization
229        let embeddings = self.layer_norm.forward(&embeddings)?;
230
231        // Apply dropout
232        let embeddings = self.dropout.forward(&embeddings)?;
233
234        Ok(embeddings)
235    }
236
237    fn backward(
238        &self,
239        _input: &Array<F, IxDyn>,
240        grad_output: &Array<F, IxDyn>,
241    ) -> Result<Array<F, IxDyn>> {
242        Ok(grad_output.clone())
243    }
244
245    fn update(&mut self, learning_rate: F) -> Result<()> {
246        self.word_embeddings.update(learning_rate)?;
247        self.position_embeddings.update(learning_rate)?;
248        self.token_type_embeddings.update(learning_rate)?;
249        self.layer_norm.update(learning_rate)?;
250        Ok(())
251    }
252
253    fn as_any(&self) -> &dyn std::any::Any {
254        self
255    }
256
257    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
258        self
259    }
260}
261
262/// BERT self-attention layer
263struct BertSelfAttention<
264    F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign,
265> {
266    /// Multi-head attention layer
267    attention: MultiHeadAttention<F>,
268    /// Output dropout
269    dropout: Dropout<F>,
270}
271
272impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
273    for BertSelfAttention<F>
274{
275    fn clone(&self) -> Self {
276        Self {
277            attention: self.attention.clone(),
278            dropout: self.dropout.clone(),
279        }
280    }
281}
282
283impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
284    BertSelfAttention<F>
285{
286    /// Create BERT self-attention layer
287    pub fn new(config: &BertConfig) -> Result<Self> {
288        let head_dim = config.hidden_size / config.num_attention_heads;
289        let attn_config = crate::layers::AttentionConfig {
290            num_heads: config.num_attention_heads,
291            head_dim,
292            dropout_prob: config.attention_probs_dropout_prob,
293            causal: false,
294            scale: None,
295        };
296
297        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
298        let attention = MultiHeadAttention::new(config.hidden_size, attn_config, &mut rng)?;
299
300        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
301        let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng2)?;
302
303        Ok(Self { attention, dropout })
304    }
305}
306
307impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
308    for BertSelfAttention<F>
309{
310    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
311        let attention_output = self.attention.forward(input)?;
312        let attention_output = self.dropout.forward(&attention_output)?;
313        Ok(attention_output)
314    }
315
316    fn backward(
317        &self,
318        _input: &Array<F, IxDyn>,
319        grad_output: &Array<F, IxDyn>,
320    ) -> Result<Array<F, IxDyn>> {
321        Ok(grad_output.clone())
322    }
323
324    fn update(&mut self, learning_rate: F) -> Result<()> {
325        self.attention.update(learning_rate)?;
326        Ok(())
327    }
328
329    fn as_any(&self) -> &dyn std::any::Any {
330        self
331    }
332
333    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
334        self
335    }
336}
337
338/// BERT feed-forward network (intermediate + output)
339struct BertFeedForward<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static>
340where
341    F: SimdUnifiedOps,
342{
343    /// Intermediate dense layer
344    intermediate_dense: Dense<F>,
345    /// Output dense layer
346    output_dense: Dense<F>,
347    /// Layer normalization
348    layer_norm: LayerNorm<F>,
349    /// Dropout
350    dropout: Dropout<F>,
351}
352
353impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
354    for BertFeedForward<F>
355{
356    fn clone(&self) -> Self {
357        Self {
358            intermediate_dense: self.intermediate_dense.clone(),
359            output_dense: self.output_dense.clone(),
360            layer_norm: self.layer_norm.clone(),
361            dropout: self.dropout.clone(),
362        }
363    }
364}
365
366impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
367    BertFeedForward<F>
368{
369    /// Create BERT feed-forward layer
370    pub fn new(config: &BertConfig) -> Result<Self> {
371        let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
372        let intermediate_dense = Dense::new(
373            config.hidden_size,
374            config.intermediate_size,
375            None,
376            &mut rng1,
377        )?;
378
379        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
380        let output_dense = Dense::new(
381            config.intermediate_size,
382            config.hidden_size,
383            None,
384            &mut rng2,
385        )?;
386
387        let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
388        let layer_norm = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng3)?;
389
390        let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([52; 32]);
391        let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng4)?;
392
393        Ok(Self {
394            intermediate_dense,
395            output_dense,
396            layer_norm,
397            dropout,
398        })
399    }
400}
401
402impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
403    for BertFeedForward<F>
404{
405    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
406        // Intermediate layer with GELU activation
407        let hidden = self.intermediate_dense.forward(input)?;
408        let hidden = hidden.mapv(|v: F| {
409            // GELU approximation
410            let x3 = v * v * v;
411            v * F::from(0.5).expect("Failed to convert constant to float")
412                * (F::one()
413                    + (v + F::from(0.044715).expect("Failed to convert constant to float") * x3)
414                        .tanh())
415        });
416
417        // Output layer
418        let output = self.output_dense.forward(&hidden)?;
419        let output = self.dropout.forward(&output)?;
420
421        // Add residual and layer norm
422        let output = input + &output;
423        let output = self.layer_norm.forward(&output)?;
424
425        Ok(output)
426    }
427
428    fn backward(
429        &self,
430        _input: &Array<F, IxDyn>,
431        grad_output: &Array<F, IxDyn>,
432    ) -> Result<Array<F, IxDyn>> {
433        Ok(grad_output.clone())
434    }
435
436    fn update(&mut self, learning_rate: F) -> Result<()> {
437        self.intermediate_dense.update(learning_rate)?;
438        self.output_dense.update(learning_rate)?;
439        self.layer_norm.update(learning_rate)?;
440        Ok(())
441    }
442
443    fn as_any(&self) -> &dyn std::any::Any {
444        self
445    }
446
447    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
448        self
449    }
450}
451
452/// BERT layer (attention + feed-forward)
453struct BertLayer<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign> {
454    /// Self attention
455    attention: BertSelfAttention<F>,
456    /// Attention output layer norm
457    attention_layer_norm: LayerNorm<F>,
458    /// Feed-forward network
459    feed_forward: BertFeedForward<F>,
460}
461
462impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
463    for BertLayer<F>
464{
465    fn clone(&self) -> Self {
466        Self {
467            attention: self.attention.clone(),
468            attention_layer_norm: self.attention_layer_norm.clone(),
469            feed_forward: self.feed_forward.clone(),
470        }
471    }
472}
473
474impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
475    BertLayer<F>
476{
477    /// Create BERT layer
478    pub fn new(config: &BertConfig) -> Result<Self> {
479        let attention = BertSelfAttention::new(config)?;
480
481        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([53; 32]);
482        let attention_layer_norm =
483            LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng)?;
484
485        let feed_forward = BertFeedForward::new(config)?;
486
487        Ok(Self {
488            attention,
489            attention_layer_norm,
490            feed_forward,
491        })
492    }
493}
494
495impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
496    for BertLayer<F>
497{
498    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
499        // Self-attention with residual and layer norm
500        let attention_output = self.attention.forward(input)?;
501        let attention_output = input + &attention_output;
502        let attention_output = self.attention_layer_norm.forward(&attention_output)?;
503
504        // Feed-forward with residual and layer norm
505        let layer_output = self.feed_forward.forward(&attention_output)?;
506
507        Ok(layer_output)
508    }
509
510    fn backward(
511        &self,
512        _input: &Array<F, IxDyn>,
513        grad_output: &Array<F, IxDyn>,
514    ) -> Result<Array<F, IxDyn>> {
515        Ok(grad_output.clone())
516    }
517
518    fn update(&mut self, learning_rate: F) -> Result<()> {
519        self.attention.update(learning_rate)?;
520        self.attention_layer_norm.update(learning_rate)?;
521        self.feed_forward.update(learning_rate)?;
522        Ok(())
523    }
524
525    fn as_any(&self) -> &dyn std::any::Any {
526        self
527    }
528
529    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
530        self
531    }
532}
533
534/// BERT encoder
535struct BertEncoder<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign> {
536    /// BERT layers
537    layers: Vec<BertLayer<F>>,
538}
539
540impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
541    for BertEncoder<F>
542{
543    fn clone(&self) -> Self {
544        Self {
545            layers: self.layers.clone(),
546        }
547    }
548}
549
550impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
551    BertEncoder<F>
552{
553    /// Create BERT encoder
554    pub fn new(config: &BertConfig) -> Result<Self> {
555        let mut layers = Vec::with_capacity(config.num_hidden_layers);
556        for _ in 0..config.num_hidden_layers {
557            layers.push(BertLayer::new(config)?);
558        }
559
560        Ok(Self { layers })
561    }
562}
563
564impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
565    for BertEncoder<F>
566{
567    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
568        let mut hidden_states = input.clone();
569        for layer in &self.layers {
570            hidden_states = layer.forward(&hidden_states)?;
571        }
572        Ok(hidden_states)
573    }
574
575    fn backward(
576        &self,
577        _input: &Array<F, IxDyn>,
578        grad_output: &Array<F, IxDyn>,
579    ) -> Result<Array<F, IxDyn>> {
580        Ok(grad_output.clone())
581    }
582
583    fn update(&mut self, learning_rate: F) -> Result<()> {
584        for layer in &mut self.layers {
585            layer.update(learning_rate)?;
586        }
587        Ok(())
588    }
589
590    fn as_any(&self) -> &dyn std::any::Any {
591        self
592    }
593
594    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
595        self
596    }
597}
598
599/// BERT pooler (for classification tasks)
600struct BertPooler<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
601    /// Dense layer
602    dense: Dense<F>,
603}
604
605impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
606    for BertPooler<F>
607{
608    fn clone(&self) -> Self {
609        Self {
610            dense: self.dense.clone(),
611        }
612    }
613}
614
615impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
616    BertPooler<F>
617{
618    /// Create BERT pooler
619    pub fn new(config: &BertConfig) -> Result<Self> {
620        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([54; 32]);
621        let dense = Dense::new(config.hidden_size, config.hidden_size, None, &mut rng)?;
622
623        Ok(Self { dense })
624    }
625}
626
627impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
628    for BertPooler<F>
629{
630    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
631        // Take the first token ([CLS]) representation
632        let shape = input.shape();
633        if shape.len() != 3 {
634            return Err(NeuralError::InferenceError(format!(
635                "Expected input shape [batch_size, seq_len, hidden_size], got {:?}",
636                shape
637            )));
638        }
639
640        let batch_size = shape[0];
641        let hidden_size = shape[2];
642
643        // Extract [CLS] token (first token)
644        let mut cls_tokens = Array::zeros(IxDyn(&[batch_size, hidden_size]));
645        for b in 0..batch_size {
646            for i in 0..hidden_size {
647                cls_tokens[[b, i]] = input[[b, 0, i]];
648            }
649        }
650
651        // Apply dense layer
652        let pooled_output = self.dense.forward(&cls_tokens)?;
653
654        // Apply tanh activation
655        let pooled_output = pooled_output.mapv(|x: F| x.tanh());
656
657        Ok(pooled_output)
658    }
659
660    fn backward(
661        &self,
662        _input: &Array<F, IxDyn>,
663        grad_output: &Array<F, IxDyn>,
664    ) -> Result<Array<F, IxDyn>> {
665        Ok(grad_output.clone())
666    }
667
668    fn update(&mut self, learning_rate: F) -> Result<()> {
669        self.dense.update(learning_rate)?;
670        Ok(())
671    }
672
673    fn as_any(&self) -> &dyn std::any::Any {
674        self
675    }
676
677    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
678        self
679    }
680}
681
682/// BERT model implementation
683pub struct BertModel<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign> {
684    /// Embeddings layer
685    embeddings: BertEmbeddings<F>,
686    /// Encoder
687    encoder: BertEncoder<F>,
688    /// Pooler
689    pooler: BertPooler<F>,
690    /// Model configuration
691    config: BertConfig,
692}
693
694impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
695    for BertModel<F>
696{
697    fn clone(&self) -> Self {
698        Self {
699            embeddings: self.embeddings.clone(),
700            encoder: self.encoder.clone(),
701            pooler: self.pooler.clone(),
702            config: self.config.clone(),
703        }
704    }
705}
706
707impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
708    BertModel<F>
709{
710    /// Create a new BERT model
711    pub fn new(config: BertConfig) -> Result<Self> {
712        let embeddings = BertEmbeddings::new(&config)?;
713        let encoder = BertEncoder::new(&config)?;
714        let pooler = BertPooler::new(&config)?;
715
716        Ok(Self {
717            embeddings,
718            encoder,
719            pooler,
720            config,
721        })
722    }
723
724    /// Create a BERT-Base-Uncased model
725    pub fn bert_base_uncased() -> Result<Self> {
726        let config = BertConfig::bert_base_uncased();
727        Self::new(config)
728    }
729
730    /// Create a BERT-Large-Uncased model
731    pub fn bert_large_uncased() -> Result<Self> {
732        let config = BertConfig::bert_large_uncased();
733        Self::new(config)
734    }
735
736    /// Create a custom BERT model
737    pub fn custom(
738        vocab_size: usize,
739        hidden_size: usize,
740        num_hidden_layers: usize,
741        num_attention_heads: usize,
742    ) -> Result<Self> {
743        let config = BertConfig::custom(
744            vocab_size,
745            hidden_size,
746            num_hidden_layers,
747            num_attention_heads,
748        );
749        Self::new(config)
750    }
751
752    /// Get sequence output (last layer hidden states)
753    pub fn get_sequence_output(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
754        let embedding_output = self.embeddings.forward(input)?;
755        let sequence_output = self.encoder.forward(&embedding_output)?;
756        Ok(sequence_output)
757    }
758
759    /// Get pooled output (for classification tasks)
760    pub fn get_pooled_output(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
761        let sequence_output = self.get_sequence_output(input)?;
762        let pooled_output = self.pooler.forward(&sequence_output)?;
763        Ok(pooled_output)
764    }
765
766    /// Get the model configuration
767    pub fn config(&self) -> &BertConfig {
768        &self.config
769    }
770}
771
772impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
773    for BertModel<F>
774{
775    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
776        // By default, return the full sequence output
777        self.get_sequence_output(input)
778    }
779
780    fn backward(
781        &self,
782        _input: &Array<F, IxDyn>,
783        grad_output: &Array<F, IxDyn>,
784    ) -> Result<Array<F, IxDyn>> {
785        Ok(grad_output.clone())
786    }
787
788    fn update(&mut self, learning_rate: F) -> Result<()> {
789        self.embeddings.update(learning_rate)?;
790        self.encoder.update(learning_rate)?;
791        self.pooler.update(learning_rate)?;
792        Ok(())
793    }
794
795    fn as_any(&self) -> &dyn std::any::Any {
796        self
797    }
798
799    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
800        self
801    }
802}
803
804impl<
805        F: Float
806            + Debug
807            + ScalarOperand
808            + Send
809            + Sync
810            + SimdUnifiedOps
811            + NumAssign
812            + ToPrimitive
813            + FromPrimitive
814            + 'static,
815    > BertModel<F>
816{
817    /// Extract all named parameters in HuggingFace-compatible format.
818    ///
819    /// Parameter names mirror the official HuggingFace BERT naming:
820    /// - `embeddings.word_embeddings.weight`
821    /// - `embeddings.LayerNorm.weight`, `embeddings.LayerNorm.bias`
822    /// - `encoder.layer.N.attention.self.query.weight`
823    /// - `encoder.layer.N.attention.output.dense.weight`
824    /// - `encoder.layer.N.intermediate.dense.weight`
825    /// - `encoder.layer.N.output.dense.weight`
826    /// - `pooler.dense.weight`, `pooler.dense.bias`
827    pub fn extract_named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>> {
828        let mut result = Vec::new();
829
830        // Embeddings
831        for p in self.embeddings.word_embeddings.params().iter() {
832            result.push(("embeddings.word_embeddings.weight".to_string(), p.clone()));
833        }
834        for p in self.embeddings.position_embeddings.params().iter() {
835            result.push((
836                "embeddings.position_embeddings.weight".to_string(),
837                p.clone(),
838            ));
839        }
840        for p in self.embeddings.token_type_embeddings.params().iter() {
841            result.push((
842                "embeddings.token_type_embeddings.weight".to_string(),
843                p.clone(),
844            ));
845        }
846        let ln_params = self.embeddings.layer_norm.params();
847        if !ln_params.is_empty() {
848            result.push((
849                "embeddings.LayerNorm.weight".to_string(),
850                ln_params[0].clone(),
851            ));
852        }
853        if ln_params.len() >= 2 {
854            result.push((
855                "embeddings.LayerNorm.bias".to_string(),
856                ln_params[1].clone(),
857            ));
858        }
859
860        // Encoder layers
861        for (layer_idx, bert_layer) in self.encoder.layers.iter().enumerate() {
862            let prefix = format!("encoder.layer.{layer_idx}");
863
864            // Self-attention: MultiHeadAttention has 4 params: w_query, w_key, w_value, w_output
865            let attn_params = bert_layer.attention.attention.params();
866            if attn_params.len() >= 4 {
867                result.push((
868                    format!("{prefix}.attention.self.query.weight"),
869                    attn_params[0].clone(),
870                ));
871                result.push((
872                    format!("{prefix}.attention.self.key.weight"),
873                    attn_params[1].clone(),
874                ));
875                result.push((
876                    format!("{prefix}.attention.self.value.weight"),
877                    attn_params[2].clone(),
878                ));
879                result.push((
880                    format!("{prefix}.attention.output.dense.weight"),
881                    attn_params[3].clone(),
882                ));
883            } else if attn_params.len() == 3 {
884                result.push((
885                    format!("{prefix}.attention.self.query.weight"),
886                    attn_params[0].clone(),
887                ));
888                result.push((
889                    format!("{prefix}.attention.self.key.weight"),
890                    attn_params[1].clone(),
891                ));
892                result.push((
893                    format!("{prefix}.attention.self.value.weight"),
894                    attn_params[2].clone(),
895                ));
896            }
897
898            // Attention output layer norm
899            let attn_ln_params = bert_layer.attention_layer_norm.params();
900            if !attn_ln_params.is_empty() {
901                result.push((
902                    format!("{prefix}.attention.output.LayerNorm.weight"),
903                    attn_ln_params[0].clone(),
904                ));
905            }
906            if attn_ln_params.len() >= 2 {
907                result.push((
908                    format!("{prefix}.attention.output.LayerNorm.bias"),
909                    attn_ln_params[1].clone(),
910                ));
911            }
912
913            // Feed-forward intermediate dense
914            let ff_inter_params = bert_layer.feed_forward.intermediate_dense.params();
915            if !ff_inter_params.is_empty() {
916                result.push((
917                    format!("{prefix}.intermediate.dense.weight"),
918                    ff_inter_params[0].clone(),
919                ));
920            }
921            if ff_inter_params.len() >= 2 {
922                result.push((
923                    format!("{prefix}.intermediate.dense.bias"),
924                    ff_inter_params[1].clone(),
925                ));
926            }
927
928            // Feed-forward output dense
929            let ff_out_params = bert_layer.feed_forward.output_dense.params();
930            if !ff_out_params.is_empty() {
931                result.push((
932                    format!("{prefix}.output.dense.weight"),
933                    ff_out_params[0].clone(),
934                ));
935            }
936            if ff_out_params.len() >= 2 {
937                result.push((
938                    format!("{prefix}.output.dense.bias"),
939                    ff_out_params[1].clone(),
940                ));
941            }
942
943            // Feed-forward layer norm
944            let ff_ln_params = bert_layer.feed_forward.layer_norm.params();
945            if !ff_ln_params.is_empty() {
946                result.push((
947                    format!("{prefix}.output.LayerNorm.weight"),
948                    ff_ln_params[0].clone(),
949                ));
950            }
951            if ff_ln_params.len() >= 2 {
952                result.push((
953                    format!("{prefix}.output.LayerNorm.bias"),
954                    ff_ln_params[1].clone(),
955                ));
956            }
957        }
958
959        // Pooler
960        let pooler_params = self.pooler.dense.params();
961        if !pooler_params.is_empty() {
962            result.push(("pooler.dense.weight".to_string(), pooler_params[0].clone()));
963        }
964        if pooler_params.len() >= 2 {
965            result.push(("pooler.dense.bias".to_string(), pooler_params[1].clone()));
966        }
967
968        Ok(result)
969    }
970
971    /// Load named parameters from a map (by name).
972    ///
973    /// Unknown parameter names are silently ignored, enabling graceful
974    /// forward/backward compatibility between model versions.
975    pub fn load_named_params(
976        &mut self,
977        params_map: &HashMap<String, Array<F, IxDyn>>,
978    ) -> Result<()> {
979        // Embeddings
980        if let Some(p) = params_map.get("embeddings.word_embeddings.weight") {
981            self.embeddings
982                .word_embeddings
983                .set_params(std::slice::from_ref(p))?;
984        }
985        if let Some(p) = params_map.get("embeddings.position_embeddings.weight") {
986            self.embeddings
987                .position_embeddings
988                .set_params(std::slice::from_ref(p))?;
989        }
990        if let Some(p) = params_map.get("embeddings.token_type_embeddings.weight") {
991            self.embeddings
992                .token_type_embeddings
993                .set_params(std::slice::from_ref(p))?;
994        }
995        {
996            let mut ln_ps = Vec::new();
997            if let Some(p) = params_map.get("embeddings.LayerNorm.weight") {
998                ln_ps.push(p.clone());
999            }
1000            if let Some(p) = params_map.get("embeddings.LayerNorm.bias") {
1001                ln_ps.push(p.clone());
1002            }
1003            if !ln_ps.is_empty() {
1004                self.embeddings.layer_norm.set_params(&ln_ps)?;
1005            }
1006        }
1007
1008        // Encoder layers
1009        for (layer_idx, bert_layer) in self.encoder.layers.iter_mut().enumerate() {
1010            let prefix = format!("encoder.layer.{layer_idx}");
1011
1012            // Self-attention weights
1013            let mut attn_ps = Vec::new();
1014            if let Some(p) = params_map.get(&format!("{prefix}.attention.self.query.weight")) {
1015                attn_ps.push(p.clone());
1016            }
1017            if let Some(p) = params_map.get(&format!("{prefix}.attention.self.key.weight")) {
1018                attn_ps.push(p.clone());
1019            }
1020            if let Some(p) = params_map.get(&format!("{prefix}.attention.self.value.weight")) {
1021                attn_ps.push(p.clone());
1022            }
1023            if let Some(p) = params_map.get(&format!("{prefix}.attention.output.dense.weight")) {
1024                attn_ps.push(p.clone());
1025            }
1026            if !attn_ps.is_empty() {
1027                bert_layer.attention.attention.set_params(&attn_ps)?;
1028            }
1029
1030            // Attention output layer norm
1031            {
1032                let mut ln_ps = Vec::new();
1033                if let Some(p) =
1034                    params_map.get(&format!("{prefix}.attention.output.LayerNorm.weight"))
1035                {
1036                    ln_ps.push(p.clone());
1037                }
1038                if let Some(p) =
1039                    params_map.get(&format!("{prefix}.attention.output.LayerNorm.bias"))
1040                {
1041                    ln_ps.push(p.clone());
1042                }
1043                if !ln_ps.is_empty() {
1044                    bert_layer.attention_layer_norm.set_params(&ln_ps)?;
1045                }
1046            }
1047
1048            // Feed-forward intermediate dense
1049            {
1050                let mut ff_ps = Vec::new();
1051                if let Some(p) = params_map.get(&format!("{prefix}.intermediate.dense.weight")) {
1052                    ff_ps.push(p.clone());
1053                }
1054                if let Some(p) = params_map.get(&format!("{prefix}.intermediate.dense.bias")) {
1055                    ff_ps.push(p.clone());
1056                }
1057                if !ff_ps.is_empty() {
1058                    bert_layer
1059                        .feed_forward
1060                        .intermediate_dense
1061                        .set_params(&ff_ps)?;
1062                }
1063            }
1064
1065            // Feed-forward output dense
1066            {
1067                let mut ff_ps = Vec::new();
1068                if let Some(p) = params_map.get(&format!("{prefix}.output.dense.weight")) {
1069                    ff_ps.push(p.clone());
1070                }
1071                if let Some(p) = params_map.get(&format!("{prefix}.output.dense.bias")) {
1072                    ff_ps.push(p.clone());
1073                }
1074                if !ff_ps.is_empty() {
1075                    bert_layer.feed_forward.output_dense.set_params(&ff_ps)?;
1076                }
1077            }
1078
1079            // Feed-forward layer norm
1080            {
1081                let mut ln_ps = Vec::new();
1082                if let Some(p) = params_map.get(&format!("{prefix}.output.LayerNorm.weight")) {
1083                    ln_ps.push(p.clone());
1084                }
1085                if let Some(p) = params_map.get(&format!("{prefix}.output.LayerNorm.bias")) {
1086                    ln_ps.push(p.clone());
1087                }
1088                if !ln_ps.is_empty() {
1089                    bert_layer.feed_forward.layer_norm.set_params(&ln_ps)?;
1090                }
1091            }
1092        }
1093
1094        // Pooler
1095        {
1096            let mut ps = Vec::new();
1097            if let Some(p) = params_map.get("pooler.dense.weight") {
1098                ps.push(p.clone());
1099            }
1100            if let Some(p) = params_map.get("pooler.dense.bias") {
1101                ps.push(p.clone());
1102            }
1103            if !ps.is_empty() {
1104                self.pooler.dense.set_params(&ps)?;
1105            }
1106        }
1107
1108        Ok(())
1109    }
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114    use super::*;
1115
1116    #[test]
1117    fn test_bert_config_base() {
1118        let config = BertConfig::bert_base_uncased();
1119        assert_eq!(config.vocab_size, 30522);
1120        assert_eq!(config.hidden_size, 768);
1121        assert_eq!(config.num_hidden_layers, 12);
1122        assert_eq!(config.num_attention_heads, 12);
1123    }
1124
1125    #[test]
1126    fn test_bert_config_large() {
1127        let config = BertConfig::bert_large_uncased();
1128        assert_eq!(config.hidden_size, 1024);
1129        assert_eq!(config.num_hidden_layers, 24);
1130        assert_eq!(config.num_attention_heads, 16);
1131    }
1132
1133    #[test]
1134    fn test_bert_config_custom() {
1135        let config = BertConfig::custom(10000, 256, 4, 4);
1136        assert_eq!(config.vocab_size, 10000);
1137        assert_eq!(config.hidden_size, 256);
1138        assert_eq!(config.num_hidden_layers, 4);
1139        assert_eq!(config.num_attention_heads, 4);
1140        assert_eq!(config.intermediate_size, 1024); // 256 * 4
1141    }
1142}