Struct Transformer

Source
pub struct Transformer<F: Float + Debug + Send + Sync> { /* private fields */ }
Expand description

Complete transformer model with encoder and decoder

This implements the full transformer architecture from “Attention Is All You Need”, combining encoder and decoder stacks with positional encoding.

Implementations§

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Transformer<F>

Source

pub fn new<R: Rng>(config: TransformerConfig, rng: &mut R) -> Result<Self>

Create a new transformer model

§Arguments
  • config - Transformer configuration
  • rng - Random number generator for weight initialization
§Returns
  • A new transformer model
Examples found in repository?
examples/transformer_example.rs (line 42)
13fn main() -> Result<(), Box<dyn std::error::Error>> {
14    println!("Transformer Model Example");
15    println!("========================");
16
17    // Create a seeded RNG for reproducibility
18    let mut rng = SmallRng::seed_from_u64(42);
19
20    // Create a small transformer configuration for demonstration
21    let config = TransformerConfig {
22        d_model: 64,                                           // Embedding dimension
23        n_encoder_layers: 2,                                   // Number of encoder layers
24        n_decoder_layers: 2,                                   // Number of decoder layers
25        n_heads: 4,                                            // Number of attention heads
26        d_ff: 128,       // Feed-forward network hidden dimension
27        max_seq_len: 50, // Maximum sequence length
28        dropout: 0.1,    // Dropout rate
29        pos_encoding_type: PositionalEncodingType::Sinusoidal, // Positional encoding type
30        epsilon: 1e-5,   // Small constant for layer normalization
31    };
32
33    println!("Creating transformer model with config:");
34    println!("  - d_model: {}", config.d_model);
35    println!("  - n_encoder_layers: {}", config.n_encoder_layers);
36    println!("  - n_decoder_layers: {}", config.n_decoder_layers);
37    println!("  - n_heads: {}", config.n_heads);
38    println!("  - d_ff: {}", config.d_ff);
39    println!("  - max_seq_len: {}", config.max_seq_len);
40
41    // Create the transformer model
42    let transformer = Transformer::<f64>::new(config, &mut rng)?;
43
44    // Create sample inputs
45    // In a real application, these would be token embeddings
46    let batch_size = 2;
47    let src_seq_len = 10;
48    let tgt_seq_len = 8;
49    let d_model = 64;
50
51    println!("\nSample dimensions:");
52    println!("  - Batch size: {}", batch_size);
53    println!("  - Source sequence length: {}", src_seq_len);
54    println!("  - Target sequence length: {}", tgt_seq_len);
55
56    // Create source and target sequence embeddings
57    // In practice, these would come from embedding layers or tokenizers
58    let src_embeddings = Array3::<f64>::from_elem((batch_size, src_seq_len, d_model), 0.1);
59    let tgt_embeddings = Array3::<f64>::from_elem((batch_size, tgt_seq_len, d_model), 0.1);
60
61    // Convert to dyn format once and reuse
62    let src_embeddings_dyn = src_embeddings.clone().into_dyn();
63    let tgt_embeddings_dyn = tgt_embeddings.clone().into_dyn();
64
65    println!("\nRunning encoder-only inference...");
66    // Run encoder-only inference (useful for tasks like classification)
67    let encoder_output = transformer.forward(&src_embeddings_dyn)?;
68    println!("Encoder output shape: {:?}", encoder_output.shape());
69
70    println!("\nRunning full transformer inference (training mode)...");
71    // Run full transformer training (teacher forcing)
72    let output_train = transformer.forward_train(&src_embeddings_dyn, &tgt_embeddings_dyn)?;
73    println!("Training output shape: {:?}", output_train.shape());
74
75    println!("\nRunning autoregressive inference (one step)...");
76    // Simulate autoregressive generation (one step)
77    // In practice, we would use a loop to generate tokens one by one
78    let first_token = Array3::<f64>::from_elem((batch_size, 1, d_model), 0.1);
79    let first_token_dyn = first_token.clone().into_dyn();
80    let output_inference = transformer.forward_inference(&src_embeddings_dyn, &first_token_dyn)?;
81    println!("Inference output shape: {:?}", output_inference.shape());
82
83    println!("\nExample completed successfully");
84    Ok(())
85}
Source

pub fn forward_train( &self, src: &Array<F, IxDyn>, tgt: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Forward pass with encoder and decoder

§Arguments
  • src - Source sequences [batch, src_len, d_model]
  • tgt - Target sequences [batch, tgt_len, d_model]
§Returns
  • Output tensor [batch, tgt_len, d_model]
Examples found in repository?
examples/transformer_example.rs (line 72)
13fn main() -> Result<(), Box<dyn std::error::Error>> {
14    println!("Transformer Model Example");
15    println!("========================");
16
17    // Create a seeded RNG for reproducibility
18    let mut rng = SmallRng::seed_from_u64(42);
19
20    // Create a small transformer configuration for demonstration
21    let config = TransformerConfig {
22        d_model: 64,                                           // Embedding dimension
23        n_encoder_layers: 2,                                   // Number of encoder layers
24        n_decoder_layers: 2,                                   // Number of decoder layers
25        n_heads: 4,                                            // Number of attention heads
26        d_ff: 128,       // Feed-forward network hidden dimension
27        max_seq_len: 50, // Maximum sequence length
28        dropout: 0.1,    // Dropout rate
29        pos_encoding_type: PositionalEncodingType::Sinusoidal, // Positional encoding type
30        epsilon: 1e-5,   // Small constant for layer normalization
31    };
32
33    println!("Creating transformer model with config:");
34    println!("  - d_model: {}", config.d_model);
35    println!("  - n_encoder_layers: {}", config.n_encoder_layers);
36    println!("  - n_decoder_layers: {}", config.n_decoder_layers);
37    println!("  - n_heads: {}", config.n_heads);
38    println!("  - d_ff: {}", config.d_ff);
39    println!("  - max_seq_len: {}", config.max_seq_len);
40
41    // Create the transformer model
42    let transformer = Transformer::<f64>::new(config, &mut rng)?;
43
44    // Create sample inputs
45    // In a real application, these would be token embeddings
46    let batch_size = 2;
47    let src_seq_len = 10;
48    let tgt_seq_len = 8;
49    let d_model = 64;
50
51    println!("\nSample dimensions:");
52    println!("  - Batch size: {}", batch_size);
53    println!("  - Source sequence length: {}", src_seq_len);
54    println!("  - Target sequence length: {}", tgt_seq_len);
55
56    // Create source and target sequence embeddings
57    // In practice, these would come from embedding layers or tokenizers
58    let src_embeddings = Array3::<f64>::from_elem((batch_size, src_seq_len, d_model), 0.1);
59    let tgt_embeddings = Array3::<f64>::from_elem((batch_size, tgt_seq_len, d_model), 0.1);
60
61    // Convert to dyn format once and reuse
62    let src_embeddings_dyn = src_embeddings.clone().into_dyn();
63    let tgt_embeddings_dyn = tgt_embeddings.clone().into_dyn();
64
65    println!("\nRunning encoder-only inference...");
66    // Run encoder-only inference (useful for tasks like classification)
67    let encoder_output = transformer.forward(&src_embeddings_dyn)?;
68    println!("Encoder output shape: {:?}", encoder_output.shape());
69
70    println!("\nRunning full transformer inference (training mode)...");
71    // Run full transformer training (teacher forcing)
72    let output_train = transformer.forward_train(&src_embeddings_dyn, &tgt_embeddings_dyn)?;
73    println!("Training output shape: {:?}", output_train.shape());
74
75    println!("\nRunning autoregressive inference (one step)...");
76    // Simulate autoregressive generation (one step)
77    // In practice, we would use a loop to generate tokens one by one
78    let first_token = Array3::<f64>::from_elem((batch_size, 1, d_model), 0.1);
79    let first_token_dyn = first_token.clone().into_dyn();
80    let output_inference = transformer.forward_inference(&src_embeddings_dyn, &first_token_dyn)?;
81    println!("Inference output shape: {:?}", output_inference.shape());
82
83    println!("\nExample completed successfully");
84    Ok(())
85}
Source

pub fn forward_inference( &self, src: &Array<F, IxDyn>, tgt: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Forward pass for inference (without target)

§Arguments
  • src - Source sequences [batch, src_len, d_model]
  • tgt - Target sequences so far [batch, tgt_len, d_model]
§Returns
  • Output tensor [batch, tgt_len, d_model]
Examples found in repository?
examples/transformer_example.rs (line 80)
13fn main() -> Result<(), Box<dyn std::error::Error>> {
14    println!("Transformer Model Example");
15    println!("========================");
16
17    // Create a seeded RNG for reproducibility
18    let mut rng = SmallRng::seed_from_u64(42);
19
20    // Create a small transformer configuration for demonstration
21    let config = TransformerConfig {
22        d_model: 64,                                           // Embedding dimension
23        n_encoder_layers: 2,                                   // Number of encoder layers
24        n_decoder_layers: 2,                                   // Number of decoder layers
25        n_heads: 4,                                            // Number of attention heads
26        d_ff: 128,       // Feed-forward network hidden dimension
27        max_seq_len: 50, // Maximum sequence length
28        dropout: 0.1,    // Dropout rate
29        pos_encoding_type: PositionalEncodingType::Sinusoidal, // Positional encoding type
30        epsilon: 1e-5,   // Small constant for layer normalization
31    };
32
33    println!("Creating transformer model with config:");
34    println!("  - d_model: {}", config.d_model);
35    println!("  - n_encoder_layers: {}", config.n_encoder_layers);
36    println!("  - n_decoder_layers: {}", config.n_decoder_layers);
37    println!("  - n_heads: {}", config.n_heads);
38    println!("  - d_ff: {}", config.d_ff);
39    println!("  - max_seq_len: {}", config.max_seq_len);
40
41    // Create the transformer model
42    let transformer = Transformer::<f64>::new(config, &mut rng)?;
43
44    // Create sample inputs
45    // In a real application, these would be token embeddings
46    let batch_size = 2;
47    let src_seq_len = 10;
48    let tgt_seq_len = 8;
49    let d_model = 64;
50
51    println!("\nSample dimensions:");
52    println!("  - Batch size: {}", batch_size);
53    println!("  - Source sequence length: {}", src_seq_len);
54    println!("  - Target sequence length: {}", tgt_seq_len);
55
56    // Create source and target sequence embeddings
57    // In practice, these would come from embedding layers or tokenizers
58    let src_embeddings = Array3::<f64>::from_elem((batch_size, src_seq_len, d_model), 0.1);
59    let tgt_embeddings = Array3::<f64>::from_elem((batch_size, tgt_seq_len, d_model), 0.1);
60
61    // Convert to dyn format once and reuse
62    let src_embeddings_dyn = src_embeddings.clone().into_dyn();
63    let tgt_embeddings_dyn = tgt_embeddings.clone().into_dyn();
64
65    println!("\nRunning encoder-only inference...");
66    // Run encoder-only inference (useful for tasks like classification)
67    let encoder_output = transformer.forward(&src_embeddings_dyn)?;
68    println!("Encoder output shape: {:?}", encoder_output.shape());
69
70    println!("\nRunning full transformer inference (training mode)...");
71    // Run full transformer training (teacher forcing)
72    let output_train = transformer.forward_train(&src_embeddings_dyn, &tgt_embeddings_dyn)?;
73    println!("Training output shape: {:?}", output_train.shape());
74
75    println!("\nRunning autoregressive inference (one step)...");
76    // Simulate autoregressive generation (one step)
77    // In practice, we would use a loop to generate tokens one by one
78    let first_token = Array3::<f64>::from_elem((batch_size, 1, d_model), 0.1);
79    let first_token_dyn = first_token.clone().into_dyn();
80    let output_inference = transformer.forward_inference(&src_embeddings_dyn, &first_token_dyn)?;
81    println!("Inference output shape: {:?}", output_inference.shape());
82
83    println!("\nExample completed successfully");
84    Ok(())
85}

Trait Implementations§

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Transformer<F>

Source§

fn as_any(&self) -> &dyn Any

Get the layer as a dyn Any for downcasting Read more
Source§

fn as_any_mut(&mut self) -> &mut dyn Any

Get the layer as a mutable dyn Any for downcasting Read more
Source§

fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>

Forward pass of the layer Read more
Source§

fn backward( &self, input: &Array<F, IxDyn>, _grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Backward pass of the layer to compute gradients Read more
Source§

fn update(&mut self, learning_rate: F) -> Result<()>

Update the layer parameters with the given gradients Read more
Source§

fn params(&self) -> Vec<Array<F, IxDyn>>

Get the parameters of the layer Read more
Source§

fn gradients(&self) -> Vec<Array<F, IxDyn>>

Get the gradients of the layer parameters Read more
Source§

fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()>

Set the gradients of the layer parameters Read more
Source§

fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()>

Set the parameters of the layer Read more
Source§

fn set_training(&mut self, _training: bool)

Set the layer to training mode (true) or evaluation mode (false) Read more
Source§

fn is_training(&self) -> bool

Get the current training mode Read more
Source§

fn layer_type(&self) -> &str

Get the type of the layer (e.g., “Dense”, “Conv2D”) Read more
Source§

fn parameter_count(&self) -> usize

Get the number of trainable parameters in this layer Read more
Source§

fn layer_description(&self) -> String

Get a detailed description of this layer Read more

Auto Trait Implementations§

§

impl<F> !Freeze for Transformer<F>

§

impl<F> !RefUnwindSafe for Transformer<F>

§

impl<F> !Send for Transformer<F>

§

impl<F> !Sync for Transformer<F>

§

impl<F> Unpin for Transformer<F>
where F: Unpin,

§

impl<F> !UnwindSafe for Transformer<F>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V