bert_example/
bert_example.rs

1use ndarray::{Array, IxDyn};
2use scirs2_neural::layers::Layer;
3use scirs2_neural::models::{BertConfig, BertModel};
4
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6    println!("BERT Model Example");
7
8    // Create a small BERT model for demonstration
9    println!("Creating a small BERT model...");
10
11    let config = BertConfig::custom(
12        10000, // vocab_size
13        128,   // hidden_size
14        2,     // num_hidden_layers
15        2,     // num_attention_heads
16    );
17
18    let model = BertModel::<f32>::new(config)?;
19
20    // Create dummy input (batch_size=2, seq_len=16)
21    // Input tensor contains token IDs
22    let input = Array::from_shape_fn(
23        IxDyn(&[2, 16]),
24        |_| rand::random::<f32>() * 100.0, // Random token IDs between 0 and 100
25    );
26
27    println!("Input shape: {:?}", input.shape());
28
29    // Get sequence output (hidden states)
30    let sequence_output = model.forward(&input)?;
31
32    println!("Sequence output shape: {:?}", sequence_output.shape());
33
34    // Get pooled output (for classification tasks)
35    let pooled_output = model.get_pooled_output(&input)?;
36
37    println!("Pooled output shape: {:?}", pooled_output.shape());
38
39    // Let's create a BERT-Base model
40    println!("\nCreating a BERT-Base model...");
41
42    let bert_base = BertModel::<f32>::bert_base_uncased()?;
43
44    // Create dummy input for a longer sequence
45    let base_input = Array::from_shape_fn(
46        IxDyn(&[1, 64]),
47        |_| rand::random::<f32>() * 1000.0, // Random token IDs
48    );
49
50    println!("BERT-Base input shape: {:?}", base_input.shape());
51
52    // Forward pass to get pooled output
53    let base_pooled_output = bert_base.get_pooled_output(&base_input)?;
54
55    println!(
56        "BERT-Base pooled output shape: {:?}",
57        base_pooled_output.shape()
58    );
59    println!(
60        "BERT-Base hidden dimension: {}",
61        base_pooled_output.shape()[1]
62    );
63
64    println!("\nBERT example completed successfully!");
65
66    Ok(())
67}