bert_example/
bert_example.rs1use ndarray::{Array, IxDyn};
2use scirs2_neural::layers::Layer;
3use scirs2_neural::models::{BertConfig, BertModel};
4
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6 println!("BERT Model Example");
7
8 println!("Creating a small BERT model...");
10
11 let config = BertConfig::custom(
12 10000, 128, 2, 2, );
17
18 let model = BertModel::<f32>::new(config)?;
19
20 let input = Array::from_shape_fn(
23 IxDyn(&[2, 16]),
24 |_| rand::random::<f32>() * 100.0, );
26
27 println!("Input shape: {:?}", input.shape());
28
29 let sequence_output = model.forward(&input)?;
31
32 println!("Sequence output shape: {:?}", sequence_output.shape());
33
34 let pooled_output = model.get_pooled_output(&input)?;
36
37 println!("Pooled output shape: {:?}", pooled_output.shape());
38
39 println!("\nCreating a BERT-Base model...");
41
42 let bert_base = BertModel::<f32>::bert_base_uncased()?;
43
44 let base_input = Array::from_shape_fn(
46 IxDyn(&[1, 64]),
47 |_| rand::random::<f32>() * 1000.0, );
49
50 println!("BERT-Base input shape: {:?}", base_input.shape());
51
52 let base_pooled_output = bert_base.get_pooled_output(&base_input)?;
54
55 println!(
56 "BERT-Base pooled output shape: {:?}",
57 base_pooled_output.shape()
58 );
59 println!(
60 "BERT-Base hidden dimension: {}",
61 base_pooled_output.shape()[1]
62 );
63
64 println!("\nBERT example completed successfully!");
65
66 Ok(())
67}