Trait Layer

Source
pub trait Layer<F: Float + Debug + ScalarOperand> {
Show 14 methods // Required methods fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>; fn backward( &self, input: &Array<F, IxDyn>, grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>; fn update(&mut self, learning_rate: F) -> Result<()>; fn as_any(&self) -> &dyn Any; fn as_any_mut(&mut self) -> &mut dyn Any; // Provided methods fn params(&self) -> Vec<Array<F, IxDyn>> { ... } fn gradients(&self) -> Vec<Array<F, IxDyn>> { ... } fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()> { ... } fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()> { ... } fn set_training(&mut self, _training: bool) { ... } fn is_training(&self) -> bool { ... } fn layer_type(&self) -> &str { ... } fn parameter_count(&self) -> usize { ... } fn layer_description(&self) -> String { ... }
}
Expand description

Base trait for neural network layers

This trait defines the core interface that all neural network layers must implement. It supports forward propagation, backpropagation, parameter management, and training/evaluation mode switching.

§Core Methods

  • forward: Compute layer output given input
  • backward: Compute gradients for backpropagation
  • update: Apply parameter updates using computed gradients
  • set_training/is_training: Control training vs evaluation behavior

§Examples

use scirs2_neural::layers::{Layer, Dense};
use ndarray::Array;
use rand::rngs::SmallRng;
use rand::SeedableRng;

let mut rng = SmallRng::seed_from_u64(42);
let mut layer = Dense::<f64>::new(10, 5, None, &mut rng)?;

let input = Array::zeros((2, 10)).into_dyn();
let output = layer.forward(&input)?;
assert_eq!(output.shape(), &[2, 5]);

// Check layer properties
println!("Layer type: {}", layer.layer_type());
println!("Parameter count: {}", layer.parameter_count());
println!("Training mode: {}", layer.is_training());

Required Methods§

Source

fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>

Forward pass of the layer

Computes the output of the layer given an input tensor. This method applies the layer’s transformation (e.g., linear transformation, convolution, activation function) to the input.

§Arguments
  • input - Input tensor with arbitrary dimensions
§Returns

Output tensor after applying the layer’s transformation

§Examples
use scirs2_neural::layers::{Layer, Dense};
use ndarray::Array;
use rand::rngs::SmallRng;
use rand::SeedableRng;

let mut rng = SmallRng::seed_from_u64(42);
let layer = Dense::<f64>::new(3, 2, Some("relu"), &mut rng)?;

let input = Array::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0])?.into_dyn();
let output = layer.forward(&input)?;
assert_eq!(output.shape(), &[1, 2]);
Source

fn backward( &self, input: &Array<F, IxDyn>, grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Backward pass of the layer to compute gradients

Computes gradients with respect to the layer’s input, which is needed for backpropagation. This method also typically updates the layer’s internal parameter gradients.

§Arguments
  • input - Original input to the forward pass
  • grad_output - Gradient of loss with respect to this layer’s output
§Returns

Gradient of loss with respect to this layer’s input

§Examples
use scirs2_neural::layers::{Layer, Dense};
use ndarray::Array;
use rand::rngs::SmallRng;
use rand::SeedableRng;

let mut rng = SmallRng::seed_from_u64(42);
let layer = Dense::<f64>::new(3, 2, None, &mut rng)?;

let input = Array::zeros((1, 3)).into_dyn();
let grad_output = Array::ones((1, 2)).into_dyn();

let grad_input = layer.backward(&input, &grad_output)?;
assert_eq!(grad_input.shape(), input.shape());
Source

fn update(&mut self, learning_rate: F) -> Result<()>

Update the layer parameters with the given gradients

Applies parameter updates using the provided learning rate and the gradients computed during the backward pass. This is typically called by optimizers.

§Arguments
  • learning_rate - Step size for parameter updates
§Examples
use scirs2_neural::layers::{Layer, Dense};
use ndarray::Array;
use rand::rngs::SmallRng;
use rand::SeedableRng;

let mut rng = SmallRng::seed_from_u64(42);
let mut layer = Dense::<f64>::new(3, 2, None, &mut rng)?;

// Simulate forward/backward pass
let input = Array::zeros((1, 3)).into_dyn();
let output = layer.forward(&input)?;
let grad_output = Array::ones((1, 2)).into_dyn();
let _grad_input = layer.backward(&input, &grad_output)?;

// Update parameters
layer.update(0.01)?; // learning rate = 0.01
Source

fn as_any(&self) -> &dyn Any

Get the layer as a dyn Any for downcasting

This method enables runtime type checking and downcasting to specific layer types when needed.

Source

fn as_any_mut(&mut self) -> &mut dyn Any

Get the layer as a mutable dyn Any for downcasting

This method enables runtime type checking and downcasting to specific layer types when mutable access is needed.

Provided Methods§

Source

fn params(&self) -> Vec<Array<F, IxDyn>>

Get the parameters of the layer

Returns all trainable parameters (weights, biases) as a vector of arrays. Default implementation returns empty vector for parameterless layers.

§Examples
use scirs2_neural::layers::{Layer, Dense};
use rand::rngs::SmallRng;
use rand::SeedableRng;

let mut rng = SmallRng::seed_from_u64(42);
let layer = Dense::<f64>::new(3, 2, None, &mut rng)?;

let params = layer.params();
// Dense layer has weights and biases
assert_eq!(params.len(), 2);
Examples found in repository?
examples/seq2seq_example.rs (line 118)
9fn main() -> Result<()> {
10    println!("Sequence-to-Sequence (Seq2Seq) Model Example");
11    println!("--------------------------------------------");
12
13    // Define vocabulary sizes
14    let src_vocab_size = 10000; // Source language vocabulary size
15    let tgt_vocab_size = 8000; // Target language vocabulary size
16
17    // Create random input sequences (batch_size=2, sequence_length=10)
18    let input_shape = [2, 10];
19    let mut input_seq = Array::<f32, _>::zeros(input_shape).into_dyn();
20
21    // Fill with random token IDs (between 0 and src_vocab_size-1)
22    let mut rng = rand::rng();
23    for elem in input_seq.iter_mut() {
24        *elem = (rng.random_range(0.0..1.0) * (src_vocab_size as f32 - 1.0)).floor();
25    }
26
27    // Create random target sequences for teacher forcing (batch_size=2, sequence_length=8)
28    let target_shape = [2, 8];
29    let mut target_seq = Array::<f32, _>::zeros(target_shape).into_dyn();
30
31    // Fill with random token IDs (between 0 and tgt_vocab_size-1)
32    for elem in target_seq.iter_mut() {
33        *elem = (rng.random_range(0.0..1.0) * (tgt_vocab_size as f32 - 1.0)).floor();
34    }
35
36    // 1. Create a basic translation model
37    println!("\nCreating Basic Translation Model...");
38    let mut translation_model = Seq2Seq::create_translation_model(
39        src_vocab_size,
40        tgt_vocab_size,
41        256, // Hidden dimension
42    )?;
43
44    // Run forward pass with teacher forcing
45    println!("Running forward pass with teacher forcing...");
46    let train_output = translation_model.forward_train(&input_seq, &target_seq)?;
47    println!("Training output shape: {:?}", train_output.shape());
48
49    // Generate sequences
50    println!("\nGenerating sequences...");
51    let generated = translation_model.generate(
52        &input_seq,
53        Some(15), // Maximum length
54        1,        // Start token ID (usually 1 for <START>)
55        Some(2),  // End token ID (usually 2 for <END>)
56    )?;
57    println!("Generated sequence shape: {:?}", generated.shape());
58
59    // Print generated sequences (token IDs)
60    println!("Generated sequences (token IDs):");
61    for b in 0..generated.shape()[0] {
62        print!("  Sequence {}: ", b);
63        for t in 0..generated.shape()[1] {
64            if generated[[b, t]] > 0.0 {
65                print!("{} ", generated[[b, t]]);
66            }
67        }
68        println!();
69    }
70
71    // 2. Create a custom Seq2Seq model with different configuration
72    println!("\nCreating Custom Seq2Seq Model...");
73    let custom_config = Seq2SeqConfig {
74        input_vocab_size: src_vocab_size,
75        output_vocab_size: tgt_vocab_size,
76        embedding_dim: 128,
77        hidden_dim: 256,
78        num_layers: 2,
79        encoder_cell_type: RNNCellType::GRU,
80        decoder_cell_type: RNNCellType::LSTM, // Mixing cell types
81        bidirectional_encoder: true,
82        use_attention: true,
83        dropout_rate: 0.2,
84        max_seq_len: 50,
85    };
86
87    let custom_model = Seq2Seq::<f32>::new(custom_config)?;
88    println!("Custom model created successfully.");
89
90    // 3. Creating a small and fast model for quick experimentation
91    println!("\nCreating Small Seq2Seq Model...");
92    let small_model = Seq2Seq::create_small_model(src_vocab_size, tgt_vocab_size)?;
93
94    let small_generated = small_model.generate(&input_seq, Some(10), 1, Some(2))?;
95    println!(
96        "Small model generated sequence shape: {:?}",
97        small_generated.shape()
98    );
99
100    // 4. Demonstrate switching between training and inference modes
101    println!("\nDemonstrating Training/Inference Mode Switching:");
102
103    // Set to training mode
104    translation_model.set_training(true);
105    println!("Is in training mode: {}", translation_model.is_training());
106
107    // Set to inference mode
108    translation_model.set_training(false);
109    println!(
110        "Is in training mode after switching: {}",
111        translation_model.is_training()
112    );
113
114    // 5. Example of model parameter count
115    println!("\nModel Parameter Counts:");
116    println!(
117        "Translation model parameters: {}",
118        translation_model.params().len()
119    );
120    println!("Custom model parameters: {}", custom_model.params().len());
121    println!("Small model parameters: {}", small_model.params().len());
122
123    println!("\nSeq2Seq Example Completed Successfully!");
124
125    Ok(())
126}
More examples
Hide additional examples
examples/text_classification_complete.rs (line 509)
458fn train_text_classifier() -> StdResult<()> {
459    println!("📝 Starting Text Classification Training Example");
460    println!("{}", "=".repeat(60));
461
462    let mut rng = SmallRng::seed_from_u64(42);
463
464    // Dataset parameters
465    let num_samples = 800;
466    let num_classes = 3;
467    let max_length = 20;
468    let embedding_dim = 64;
469    let hidden_dim = 128;
470
471    println!("📊 Dataset Configuration:");
472    println!("   - Samples: {}", num_samples);
473    println!(
474        "   - Classes: {} (Positive, Negative, Neutral)",
475        num_classes
476    );
477    println!("   - Max sequence length: {}", max_length);
478    println!("   - Embedding dimension: {}", embedding_dim);
479
480    // Create synthetic text dataset
481    println!("\n🔄 Creating synthetic text dataset...");
482    let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485    println!("   - Vocabulary size: {}", dataset.vocab.vocab_size);
486    println!("   - Training samples: {}", train_dataset.len());
487    println!("   - Validation samples: {}", val_dataset.len());
488
489    // Show some example texts
490    println!("\n📄 Sample texts:");
491    for i in 0..3.min(train_dataset.texts.len()) {
492        println!(
493            "   [Class {}]: {}",
494            train_dataset.labels[i], train_dataset.texts[i]
495        );
496    }
497
498    // Build model
499    println!("\n🏗️  Building text classification model...");
500    let model = build_text_model(
501        dataset.vocab.vocab_size,
502        embedding_dim,
503        hidden_dim,
504        num_classes,
505        max_length,
506        &mut rng,
507    )?;
508
509    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510    println!("   - Model layers: {}", model.len());
511    println!("   - Total parameters: {}", total_params);
512
513    // Training configuration
514    let config = TrainingConfig {
515        batch_size: 16,
516        epochs: 30,
517        learning_rate: 0.001,
518        shuffle: true,
519        verbose: 1,
520        validation: Some(ValidationSettings {
521            enabled: true,
522            validation_split: 0.2,
523            batch_size: 32,
524            num_workers: 0,
525        }),
526        gradient_accumulation: None,
527        mixed_precision: None,
528        num_workers: 0,
529    };
530
531    println!("\n⚙️  Training Configuration:");
532    println!("   - Batch size: {}", config.batch_size);
533    println!("   - Learning rate: {}", config.learning_rate);
534    println!("   - Epochs: {}", config.epochs);
535
536    // Set up training
537    let loss_fn = CrossEntropyLoss::new(1e-7);
538    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542    // Train the model
543    println!("\n🏋️  Starting training...");
544    println!("{}", "-".repeat(40));
545
546    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548    println!("\n✅ Training completed!");
549    println!("   - Epochs trained: {}", training_session.epochs_trained);
550
551    // Evaluate model
552    println!("\n📊 Final Evaluation:");
553    let val_metrics = trainer.validate(&val_dataset)?;
554
555    for (metric, value) in &val_metrics {
556        println!("   - {}: {:.4}", metric, value);
557    }
558
559    // Test on sample texts
560    println!("\n🔍 Sample Predictions:");
561    let sample_indices = vec![0, 1, 2, 3, 4];
562
563    // Manually collect batch since get_batch is not part of Dataset trait
564    let mut batch_tokens = Vec::new();
565    let mut batch_targets = Vec::new();
566
567    for &idx in &sample_indices {
568        let (tokens, targets) = val_dataset.get(idx)?;
569        batch_tokens.push(tokens);
570        batch_targets.push(targets);
571    }
572
573    // Concatenate into batch arrays
574    let sample_tokens = ndarray::concatenate(
575        ndarray::Axis(0),
576        &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577    )?;
578    let sample_targets = ndarray::concatenate(
579        ndarray::Axis(0),
580        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581    )?;
582
583    let model = trainer.get_model();
584    let predictions = model.forward(&sample_tokens)?;
585
586    let class_names = ["Positive", "Negative", "Neutral"];
587
588    for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589        let pred_row = predictions.slice(s![i, ..]);
590        let target_row = sample_targets.slice(s![i, ..]);
591
592        let pred_class = pred_row
593            .iter()
594            .enumerate()
595            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596            .map(|(i, _)| i)
597            .unwrap_or(0);
598
599        let true_class = target_row
600            .iter()
601            .enumerate()
602            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603            .map(|(i, _)| i)
604            .unwrap_or(0);
605
606        let confidence = pred_row[pred_class];
607
608        if sample_indices[i] < val_dataset.texts.len() {
609            println!("   Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610            println!(
611                "   Predicted: {} (confidence: {:.3})",
612                class_names[pred_class], confidence
613            );
614            println!("   Actual: {}", class_names[true_class]);
615            println!();
616        }
617    }
618
619    // Calculate detailed metrics
620    let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621    println!("📈 Detailed Metrics:");
622    for (metric, value) in &detailed_metrics {
623        println!("   - {}: {:.4}", metric, value);
624    }
625
626    Ok(())
627}
628
629/// Demonstrate advanced text model with attention
630fn demonstrate_attention_model() -> StdResult<()> {
631    println!("\n🎯 Attention-Based Model Demo:");
632    println!("{}", "-".repeat(40));
633
634    let mut rng = SmallRng::seed_from_u64(123);
635
636    // Create attention model
637    let model = build_attention_text_model(1000, 128, 256, 3, 20, &mut rng)?;
638
639    println!("   - Attention model created");
640    println!(
641        "   - Parameters: {}",
642        model.params().iter().map(|p| p.len()).sum::<usize>()
643    );
644    println!("   ✅ Attention mechanism simulation completed");
645
646    Ok(())
647}
examples/image_classification_complete.rs (line 265)
233fn train_image_classifier() -> Result<()> {
234    println!("🚀 Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("📊 Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n🔄 Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n🏗️  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\n⚙️  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("🔄 Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n🏋️  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\n✅ Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n📊 Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n🔍 Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n🎯 Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n📋 Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n🎉 Image classification example completed successfully!");
399
400    Ok(())
401}
402
403/// Demonstrate data augmentation techniques
404fn demonstrate_augmentation() -> Result<()> {
405    println!("\n🔄 Data Augmentation Demo:");
406    println!("{}", "-".repeat(30));
407
408    // Create augmentation manager
409    let _aug_manager: AugmentationManager<f32> = AugmentationManager::new(Some(42));
410
411    // Note: Augmentation API is being updated
412    // For now, demonstrate basic concept
413    println!("   - Augmentation manager created with seed 42");
414    println!("   - Basic augmentations (rotation, flipping, etc.) would be applied here");
415
416    // Create sample image
417    let sample_image = Array4::<f32>::ones((1, 3, 32, 32));
418    println!("   - Sample image shape: {:?}", sample_image.shape());
419
420    // Note: Apply augmentation when API is stabilized
421    // let augmented = aug_manager.apply(&sample_image)?;
422
423    println!("   - Original shape: {:?}", sample_image.shape());
424    println!("   - Augmentation functionality available (API being finalized)");
425    println!("   ✅ Augmentation framework initialized successfully");
426
427    Ok(())
428}
429
430/// Demonstrate model saving and loading
431fn demonstrate_model_persistence() -> Result<()> {
432    println!("\n💾 Model Persistence Demo:");
433    println!("{}", "-".repeat(30));
434
435    let mut rng = SmallRng::seed_from_u64(123);
436
437    // Create a simple model
438    let model = build_cnn_model(3, 5, &mut rng)?;
439
440    // Save model (would save to file in real scenario)
441    println!(
442        "   - Model created with {} parameters",
443        model.params().iter().map(|p| p.len()).sum::<usize>()
444    );
445    println!("   ✅ Model persistence simulation completed");
446
447    Ok(())
448}
Source

fn gradients(&self) -> Vec<Array<F, IxDyn>>

Get the gradients of the layer parameters

Returns gradients for all trainable parameters. Must be called after backward pass to get meaningful values.

Source

fn set_gradients(&mut self, _gradients: &[Array<F, IxDyn>]) -> Result<()>

Set the gradients of the layer parameters

Used by optimizers to set computed gradients. Default implementation does nothing for parameterless layers.

Source

fn set_params(&mut self, _params: &[Array<F, IxDyn>]) -> Result<()>

Set the parameters of the layer

Used for loading pre-trained weights or applying parameter updates. Default implementation does nothing for parameterless layers.

Source

fn set_training(&mut self, _training: bool)

Set the layer to training mode (true) or evaluation mode (false)

Training mode enables features like dropout and batch normalization parameter updates. Evaluation mode disables these features for inference.

§Examples
use scirs2_neural::layers::{Layer, Dropout};
use rand::rngs::SmallRng;
use rand::SeedableRng;

let mut rng = SmallRng::seed_from_u64(42);
let mut dropout = Dropout::<f32>::new(0.5, &mut rng).unwrap();
assert!(dropout.is_training()); // Default is training mode

dropout.set_training(false); // Switch to evaluation
assert!(!dropout.is_training());
Examples found in repository?
examples/seq2seq_example.rs (line 104)
9fn main() -> Result<()> {
10    println!("Sequence-to-Sequence (Seq2Seq) Model Example");
11    println!("--------------------------------------------");
12
13    // Define vocabulary sizes
14    let src_vocab_size = 10000; // Source language vocabulary size
15    let tgt_vocab_size = 8000; // Target language vocabulary size
16
17    // Create random input sequences (batch_size=2, sequence_length=10)
18    let input_shape = [2, 10];
19    let mut input_seq = Array::<f32, _>::zeros(input_shape).into_dyn();
20
21    // Fill with random token IDs (between 0 and src_vocab_size-1)
22    let mut rng = rand::rng();
23    for elem in input_seq.iter_mut() {
24        *elem = (rng.random_range(0.0..1.0) * (src_vocab_size as f32 - 1.0)).floor();
25    }
26
27    // Create random target sequences for teacher forcing (batch_size=2, sequence_length=8)
28    let target_shape = [2, 8];
29    let mut target_seq = Array::<f32, _>::zeros(target_shape).into_dyn();
30
31    // Fill with random token IDs (between 0 and tgt_vocab_size-1)
32    for elem in target_seq.iter_mut() {
33        *elem = (rng.random_range(0.0..1.0) * (tgt_vocab_size as f32 - 1.0)).floor();
34    }
35
36    // 1. Create a basic translation model
37    println!("\nCreating Basic Translation Model...");
38    let mut translation_model = Seq2Seq::create_translation_model(
39        src_vocab_size,
40        tgt_vocab_size,
41        256, // Hidden dimension
42    )?;
43
44    // Run forward pass with teacher forcing
45    println!("Running forward pass with teacher forcing...");
46    let train_output = translation_model.forward_train(&input_seq, &target_seq)?;
47    println!("Training output shape: {:?}", train_output.shape());
48
49    // Generate sequences
50    println!("\nGenerating sequences...");
51    let generated = translation_model.generate(
52        &input_seq,
53        Some(15), // Maximum length
54        1,        // Start token ID (usually 1 for <START>)
55        Some(2),  // End token ID (usually 2 for <END>)
56    )?;
57    println!("Generated sequence shape: {:?}", generated.shape());
58
59    // Print generated sequences (token IDs)
60    println!("Generated sequences (token IDs):");
61    for b in 0..generated.shape()[0] {
62        print!("  Sequence {}: ", b);
63        for t in 0..generated.shape()[1] {
64            if generated[[b, t]] > 0.0 {
65                print!("{} ", generated[[b, t]]);
66            }
67        }
68        println!();
69    }
70
71    // 2. Create a custom Seq2Seq model with different configuration
72    println!("\nCreating Custom Seq2Seq Model...");
73    let custom_config = Seq2SeqConfig {
74        input_vocab_size: src_vocab_size,
75        output_vocab_size: tgt_vocab_size,
76        embedding_dim: 128,
77        hidden_dim: 256,
78        num_layers: 2,
79        encoder_cell_type: RNNCellType::GRU,
80        decoder_cell_type: RNNCellType::LSTM, // Mixing cell types
81        bidirectional_encoder: true,
82        use_attention: true,
83        dropout_rate: 0.2,
84        max_seq_len: 50,
85    };
86
87    let custom_model = Seq2Seq::<f32>::new(custom_config)?;
88    println!("Custom model created successfully.");
89
90    // 3. Creating a small and fast model for quick experimentation
91    println!("\nCreating Small Seq2Seq Model...");
92    let small_model = Seq2Seq::create_small_model(src_vocab_size, tgt_vocab_size)?;
93
94    let small_generated = small_model.generate(&input_seq, Some(10), 1, Some(2))?;
95    println!(
96        "Small model generated sequence shape: {:?}",
97        small_generated.shape()
98    );
99
100    // 4. Demonstrate switching between training and inference modes
101    println!("\nDemonstrating Training/Inference Mode Switching:");
102
103    // Set to training mode
104    translation_model.set_training(true);
105    println!("Is in training mode: {}", translation_model.is_training());
106
107    // Set to inference mode
108    translation_model.set_training(false);
109    println!(
110        "Is in training mode after switching: {}",
111        translation_model.is_training()
112    );
113
114    // 5. Example of model parameter count
115    println!("\nModel Parameter Counts:");
116    println!(
117        "Translation model parameters: {}",
118        translation_model.params().len()
119    );
120    println!("Custom model parameters: {}", custom_model.params().len());
121    println!("Small model parameters: {}", small_model.params().len());
122
123    println!("\nSeq2Seq Example Completed Successfully!");
124
125    Ok(())
126}
Source

fn is_training(&self) -> bool

Get the current training mode

Returns true if layer is in training mode, false if in evaluation mode.

Examples found in repository?
examples/seq2seq_example.rs (line 105)
9fn main() -> Result<()> {
10    println!("Sequence-to-Sequence (Seq2Seq) Model Example");
11    println!("--------------------------------------------");
12
13    // Define vocabulary sizes
14    let src_vocab_size = 10000; // Source language vocabulary size
15    let tgt_vocab_size = 8000; // Target language vocabulary size
16
17    // Create random input sequences (batch_size=2, sequence_length=10)
18    let input_shape = [2, 10];
19    let mut input_seq = Array::<f32, _>::zeros(input_shape).into_dyn();
20
21    // Fill with random token IDs (between 0 and src_vocab_size-1)
22    let mut rng = rand::rng();
23    for elem in input_seq.iter_mut() {
24        *elem = (rng.random_range(0.0..1.0) * (src_vocab_size as f32 - 1.0)).floor();
25    }
26
27    // Create random target sequences for teacher forcing (batch_size=2, sequence_length=8)
28    let target_shape = [2, 8];
29    let mut target_seq = Array::<f32, _>::zeros(target_shape).into_dyn();
30
31    // Fill with random token IDs (between 0 and tgt_vocab_size-1)
32    for elem in target_seq.iter_mut() {
33        *elem = (rng.random_range(0.0..1.0) * (tgt_vocab_size as f32 - 1.0)).floor();
34    }
35
36    // 1. Create a basic translation model
37    println!("\nCreating Basic Translation Model...");
38    let mut translation_model = Seq2Seq::create_translation_model(
39        src_vocab_size,
40        tgt_vocab_size,
41        256, // Hidden dimension
42    )?;
43
44    // Run forward pass with teacher forcing
45    println!("Running forward pass with teacher forcing...");
46    let train_output = translation_model.forward_train(&input_seq, &target_seq)?;
47    println!("Training output shape: {:?}", train_output.shape());
48
49    // Generate sequences
50    println!("\nGenerating sequences...");
51    let generated = translation_model.generate(
52        &input_seq,
53        Some(15), // Maximum length
54        1,        // Start token ID (usually 1 for <START>)
55        Some(2),  // End token ID (usually 2 for <END>)
56    )?;
57    println!("Generated sequence shape: {:?}", generated.shape());
58
59    // Print generated sequences (token IDs)
60    println!("Generated sequences (token IDs):");
61    for b in 0..generated.shape()[0] {
62        print!("  Sequence {}: ", b);
63        for t in 0..generated.shape()[1] {
64            if generated[[b, t]] > 0.0 {
65                print!("{} ", generated[[b, t]]);
66            }
67        }
68        println!();
69    }
70
71    // 2. Create a custom Seq2Seq model with different configuration
72    println!("\nCreating Custom Seq2Seq Model...");
73    let custom_config = Seq2SeqConfig {
74        input_vocab_size: src_vocab_size,
75        output_vocab_size: tgt_vocab_size,
76        embedding_dim: 128,
77        hidden_dim: 256,
78        num_layers: 2,
79        encoder_cell_type: RNNCellType::GRU,
80        decoder_cell_type: RNNCellType::LSTM, // Mixing cell types
81        bidirectional_encoder: true,
82        use_attention: true,
83        dropout_rate: 0.2,
84        max_seq_len: 50,
85    };
86
87    let custom_model = Seq2Seq::<f32>::new(custom_config)?;
88    println!("Custom model created successfully.");
89
90    // 3. Creating a small and fast model for quick experimentation
91    println!("\nCreating Small Seq2Seq Model...");
92    let small_model = Seq2Seq::create_small_model(src_vocab_size, tgt_vocab_size)?;
93
94    let small_generated = small_model.generate(&input_seq, Some(10), 1, Some(2))?;
95    println!(
96        "Small model generated sequence shape: {:?}",
97        small_generated.shape()
98    );
99
100    // 4. Demonstrate switching between training and inference modes
101    println!("\nDemonstrating Training/Inference Mode Switching:");
102
103    // Set to training mode
104    translation_model.set_training(true);
105    println!("Is in training mode: {}", translation_model.is_training());
106
107    // Set to inference mode
108    translation_model.set_training(false);
109    println!(
110        "Is in training mode after switching: {}",
111        translation_model.is_training()
112    );
113
114    // 5. Example of model parameter count
115    println!("\nModel Parameter Counts:");
116    println!(
117        "Translation model parameters: {}",
118        translation_model.params().len()
119    );
120    println!("Custom model parameters: {}", custom_model.params().len());
121    println!("Small model parameters: {}", small_model.params().len());
122
123    println!("\nSeq2Seq Example Completed Successfully!");
124
125    Ok(())
126}
Source

fn layer_type(&self) -> &str

Get the type of the layer (e.g., “Dense”, “Conv2D”)

Returns a string identifier for the layer type, useful for debugging and model introspection.

Source

fn parameter_count(&self) -> usize

Get the number of trainable parameters in this layer

Returns the total count of all trainable parameters (weights, biases, etc.). Useful for model analysis and memory estimation.

Source

fn layer_description(&self) -> String

Get a detailed description of this layer

Returns a human-readable description including layer type and key properties. Can be overridden for more detailed layer-specific information.

Examples found in repository?
examples/new_features_showcase.rs (line 48)
33fn demonstrate_adaptive_pooling() -> Result<(), Box<dyn std::error::Error>> {
34    println!("🔧 Adaptive Pooling Layers Demonstration");
35    println!("========================================\n");
36
37    // Create input tensor: batch_size=2, channels=3, height=32, width=32
38    let input = Array4::<f64>::from_elem((2, 3, 32, 32), 1.5);
39    println!("Input shape: {:?}", input.shape());
40
41    // Adaptive Average Pooling to 7x7
42    println!("\n1. Adaptive Average Pooling (32x32 → 7x7):");
43    let adaptive_avg_pool = AdaptiveAvgPool2D::new((7, 7), Some("adaptive_avg_7x7"))?;
44    let avg_output = adaptive_avg_pool.forward(&input.clone().into_dyn())?;
45    println!("   Output shape: {:?}", avg_output.shape());
46    println!(
47        "   Layer description: {}",
48        adaptive_avg_pool.layer_description()
49    );
50
51    // Adaptive Max Pooling to 4x4
52    println!("\n2. Adaptive Max Pooling (32x32 → 4x4):");
53    let adaptive_max_pool = AdaptiveMaxPool2D::new((4, 4), Some("adaptive_max_4x4"))?;
54    let max_output = adaptive_max_pool.forward(&input.into_dyn())?;
55    println!("   Output shape: {:?}", max_output.shape());
56    println!(
57        "   Layer description: {}",
58        adaptive_max_pool.layer_description()
59    );
60
61    // Non-square adaptive pooling
62    println!("\n3. Non-square Adaptive Pooling (32x32 → 3x5):");
63    let non_square_pool = AdaptiveAvgPool2D::new((3, 5), Some("non_square"))?;
64    let non_square_output =
65        non_square_pool.forward(&Array4::<f64>::from_elem((1, 2, 16, 20), 2.0).into_dyn())?;
66    println!("   Input shape: [1, 2, 16, 20]");
67    println!("   Output shape: {:?}", non_square_output.shape());
68
69    println!("✅ Adaptive pooling demonstration completed!\n");
70    Ok(())
71}
72
73fn demonstrate_activity_regularization() -> Result<(), Box<dyn std::error::Error>> {
74    println!("🎯 Activity Regularization Demonstration");
75    println!("=======================================\n");
76
77    // Create some activations to regularize
78    let activations =
79        Array::from_shape_vec((2, 4), vec![1.5, -2.0, 0.5, 3.0, -1.0, 0.0, 2.5, -0.5])?.into_dyn();
80
81    println!("Input activations:");
82    println!("{:?}\n", activations);
83
84    // 1. L1 Activity Regularization
85    println!("1. L1 Activity Regularization (factor=0.1):");
86    let l1_reg = L1ActivityRegularization::new(0.1, Some("l1_regularizer"))?;
87    let l1_output = l1_reg.forward(&activations)?;
88    let l1_loss = l1_reg.get_activity_loss()?;
89    println!("   Output (unchanged): {:?}", l1_output.shape());
90    println!("   L1 activity loss: {:.4}", l1_loss);
91    println!("   Layer description: {}", l1_reg.layer_description());
92
93    // 2. L2 Activity Regularization
94    println!("\n2. L2 Activity Regularization (factor=0.05):");
95    let l2_reg = L2ActivityRegularization::new(0.05, Some("l2_regularizer"))?;
96    let l2_output = l2_reg.forward(&activations)?;
97    let l2_loss = l2_reg.get_activity_loss()?;
98    println!("   Output (unchanged): {:?}", l2_output.shape());
99    println!("   L2 activity loss: {:.4}", l2_loss);
100    println!("   Layer description: {}", l2_reg.layer_description());
101
102    // 3. Combined L1 + L2 Activity Regularization
103    println!("\n3. Combined L1+L2 Activity Regularization:");
104    let combined_reg = ActivityRegularization::new(Some(0.1), Some(0.05), Some("combined_reg"))?;
105    let combined_output = combined_reg.forward(&activations)?;
106    let combined_loss = combined_reg.get_activity_loss()?;
107    println!("   Output (unchanged): {:?}", combined_output.shape());
108    println!("   Combined activity loss: {:.4}", combined_loss);
109    println!("   Layer description: {}", combined_reg.layer_description());
110
111    // Demonstrate backward pass
112    println!("\n4. Backward Pass with Gradient Modification:");
113    let grad_output = Array::ones(activations.raw_dim());
114    let grad_input = combined_reg.backward(&activations, &grad_output)?;
115    println!("   Gradient input shape: {:?}", grad_input.shape());
116    println!(
117        "   Sample gradient values: [{:.3}, {:.3}, {:.3}, {:.3}]",
118        grad_input[[0, 0]],
119        grad_input[[0, 1]],
120        grad_input[[0, 2]],
121        grad_input[[0, 3]]
122    );
123
124    println!("✅ Activity regularization demonstration completed!\n");
125    Ok(())
126}

Implementors§

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for Dense<F>

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for Bidirectional<F>

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for GRU<F>

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for LSTM<F>

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for LayerNorm2D<F>

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for LayerNorm<F>

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for MultiHeadAttention<F>

Source§

impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for SelfAttention<F>

Source§

impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + 'static> Layer<F> for Conv2D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync> Layer<F> for VisionTransformer<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for CLIP<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for CLIPTextEncoder<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for CLIPVisionEncoder<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for FeedForward<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Transformer<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for TransformerDecoder<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for TransformerDecoderLayer<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for TransformerEncoder<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for TransformerEncoderLayer<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for RNN<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for ActivityRegularization<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for AdaptiveAvgPool1D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for AdaptiveAvgPool2D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for AdaptiveAvgPool3D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for AdaptiveMaxPool1D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for AdaptiveMaxPool2D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for AdaptiveMaxPool3D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for BatchNorm<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Dropout<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for GlobalAvgPool2D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for L1ActivityRegularization<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for L2ActivityRegularization<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for MaxPool2D<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for ThreadSafeBidirectional<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for ThreadSafeRNN<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for BertModel<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for ConvNeXt<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for ConvNeXtBlock<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for ConvNeXtDownsample<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for ConvNeXtStage<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for EfficientNet<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for BilinearFusion<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for CrossModalAttention<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for FeatureAlignment<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for FeatureFusion<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for FiLMModule<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for GPTModel<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for MobileNet<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for ResNet<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for Attention<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for Seq2Seq<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for Seq2SeqDecoder<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for Seq2SeqEncoder<F>

Source§

impl<F: Float + Debug + ScalarOperand + Send + Sync> Layer<F> for PatchEmbedding<F>

Source§

impl<F: Float + Debug + ScalarOperand> Layer<F> for GELU

Source§

impl<F: Float + Debug + ScalarOperand> Layer<F> for Embedding<F>

Source§

impl<F: Float + Debug + ScalarOperand> Layer<F> for PositionalEmbedding<F>

Source§

impl<F: Float + Debug + ScalarOperand> Layer<F> for Sequential<F>