trustformers_core/traits.rs
1//! Core traits defining the fundamental abstractions of TrustformeRS.
2//!
3//! This module contains the essential traits that form the foundation of the TrustformeRS
4//! transformer library. These traits define the interfaces for models, layers, configuration,
5//! tokenization, optimization, and parameter initialization.
6//!
7//! # Overview
8//!
9//! The traits in this module establish a consistent API across all transformer implementations:
10//!
11//! - [`Model`]: The main trait for transformer models with forward pass and loading capabilities
12//! - [`Layer`]: Building blocks for neural network architectures
13//! - [`Config`]: Configuration management for models and components
14//! - [`WeightReader`]: Interface for loading pretrained model weights
15//! - [`Tokenizer`]: Text tokenization and encoding/decoding
16//! - [`Optimizer`]: Parameter optimization algorithms
17//! - [`ParameterInit`]: Weight initialization strategies
18//!
19//! # Examples
20//!
21//! ```no_run
22//! use trustformers_core::traits::{Model, Config};
23//! use trustformers_core::tensor::Tensor;
24//!
25//! // Example model implementation
26//! struct MyModel {
27//! config: MyConfig,
28//! // ... model layers
29//! }
30//!
31//! impl Model for MyModel {
32//! type Config = MyConfig;
33//! type Input = Tensor;
34//! type Output = Tensor;
35//!
36//! fn forward(&self, input: Self::Input) -> Result<Self::Output> {
37//! // Model forward pass implementation
38//! Ok(input)
39//! }
40//!
41//! fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
42//! // Load pretrained weights
43//! Ok(())
44//! }
45//!
46//! fn get_config(&self) -> &Self::Config {
47//! &self.config
48//! }
49//! }
50//! ```
51
52use crate::errors::Result;
53use crate::tensor::Tensor;
54use serde::{Deserialize, Serialize};
55use std::io::Read;
56
57/// The main trait for transformer models.
58///
59/// This trait defines the interface that all transformer models must implement,
60/// providing a consistent API for forward passes, weight loading, and configuration access.
61///
62/// # Type Parameters
63///
64/// - `Config`: The configuration type for this model, must implement [`Config`]
65/// - `Input`: The input type for the model's forward pass
66/// - `Output`: The output type produced by the model
67///
68/// # Thread Safety
69///
70/// Models must be `Send + Sync` to support multi-threaded inference and training.
71///
72/// # Example
73///
74/// ```no_run
75/// use trustformers_core::traits::{Model, Config};
76/// use trustformers_core::tensor::Tensor;
77/// use trustformers_core::error::Result;
78/// use std::io::Read;
79/// use serde::{Deserialize, Serialize};
80///
81/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
82/// #[derive(Deserialize, Serialize)]
83/// struct BertConfig {
84/// hidden_size: usize,
85/// num_attention_heads: usize,
86/// // ... other config fields
87/// }
88///
89/// impl Config for BertConfig {
90/// fn architecture(&self) -> &'static str {
91/// "bert"
92/// }
93/// }
94///
95/// struct BertModel {
96/// config: BertConfig,
97/// // ... model layers
98/// }
99///
100/// impl Model for BertModel {
101/// type Config = BertConfig;
102/// type Input = Tensor;
103/// type Output = Tensor;
104///
105/// fn forward(&self, input: Self::Input) -> Result<Self::Output> {
106/// // BERT forward pass implementation
107/// Ok(input)
108/// }
109///
110/// fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
111/// // Load BERT weights from reader
112/// Ok(())
113/// }
114///
115/// fn get_config(&self) -> &Self::Config {
116/// &self.config
117/// }
118/// }
119/// # Ok(())
120/// # }
121/// ```
122pub trait Model: Send + Sync {
123 type Config: Config;
124 type Input;
125 type Output;
126
127 /// Performs a forward pass through the model.
128 ///
129 /// # Arguments
130 ///
131 /// * `input` - The input data for the model
132 ///
133 /// # Returns
134 ///
135 /// Returns `Ok(output)` on success, or an error if the forward pass fails.
136 ///
137 /// # Errors
138 ///
139 /// May return errors for:
140 /// - Invalid input dimensions
141 /// - Numerical computation errors
142 /// - Out of memory conditions
143 fn forward(&self, input: Self::Input) -> Result<Self::Output>;
144
145 /// Loads pretrained weights into the model.
146 ///
147 /// This method reads model weights from a reader (typically a file or network stream)
148 /// and updates the model's parameters accordingly.
149 ///
150 /// # Arguments
151 ///
152 /// * `reader` - A reader providing access to the pretrained weight data
153 ///
154 /// # Returns
155 ///
156 /// Returns `Ok(())` on successful loading, or an error if loading fails.
157 ///
158 /// # Errors
159 ///
160 /// May return errors for:
161 /// - IO errors while reading
162 /// - Incompatible weight formats
163 /// - Mismatched tensor shapes
164 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()>;
165
166 /// Returns a reference to the model's configuration.
167 ///
168 /// # Returns
169 ///
170 /// A reference to the model's configuration object.
171 fn get_config(&self) -> &Self::Config;
172
173 /// Returns the total number of parameters in the model.
174 ///
175 /// This method calculates the total number of trainable parameters
176 /// across all layers in the model. It's useful for compression metrics,
177 /// model size analysis, and memory usage calculations.
178 ///
179 /// # Returns
180 ///
181 /// The total number of parameters as a `usize`.
182 ///
183 /// # Example
184 ///
185 /// ```no_run
186 /// use trustformers_core::traits::Model;
187 ///
188 /// fn analyze_model_size<M: Model>(model: &M) {
189 /// let params = model.num_parameters();
190 /// let memory_mb = params * 4 / (1024 * 1024); // Assuming f32 weights
191 /// println!("Model has {} parameters ({} MB)", params, memory_mb);
192 /// }
193 /// ```
194 fn num_parameters(&self) -> usize;
195}
196
197/// A building block for neural network architectures.
198///
199/// The `Layer` trait represents a single computational unit in a neural network,
200/// such as a linear transformation, attention mechanism, or normalization layer.
201/// Layers can be composed together to build complete models.
202///
203/// # Type Parameters
204///
205/// - `Input`: The input type accepted by this layer
206/// - `Output`: The output type produced by this layer
207///
208/// # Thread Safety
209///
210/// Layers must be `Send + Sync` to support parallel computation.
211///
212/// # Example
213///
214/// ```no_run
215/// use trustformers_core::traits::Layer;
216/// use trustformers_core::tensor::Tensor;
217/// use trustformers_core::error::Result;
218///
219/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
220/// struct LinearLayer {
221/// weight: Tensor,
222/// bias: Option<Tensor>,
223/// }
224///
225/// impl Layer for LinearLayer {
226/// type Input = Tensor;
227/// type Output = Tensor;
228///
229/// fn forward(&self, input: Self::Input) -> Result<Self::Output> {
230/// // Compute linear transformation: y = xW^T + b
231/// let output = input.matmul(&self.weight.transpose()?)?;
232/// if let Some(bias) = &self.bias {
233/// output.add(bias)
234/// } else {
235/// Ok(output)
236/// }
237/// }
238/// }
239/// # Ok(())
240/// # }
241/// ```
242pub trait Layer: Send + Sync {
243 type Input;
244 type Output;
245
246 /// Performs the forward computation of this layer.
247 ///
248 /// # Arguments
249 ///
250 /// * `input` - The input data to process
251 ///
252 /// # Returns
253 ///
254 /// Returns `Ok(output)` containing the layer's output, or an error if computation fails.
255 ///
256 /// # Errors
257 ///
258 /// May return errors for:
259 /// - Invalid input dimensions
260 /// - Numerical errors during computation
261 /// - Resource allocation failures
262 fn forward(&self, input: Self::Input) -> Result<Self::Output>;
263}
264
265/// Configuration trait for models and components.
266///
267/// This trait provides a standardized interface for configuration objects
268/// that can be serialized, deserialized, and validated. All model configurations
269/// must implement this trait to ensure compatibility with the TrustformeRS ecosystem.
270///
271/// # Requirements
272///
273/// Implementing types must be serializable and deserializable using serde.
274///
275/// # Example
276///
277/// ```no_run
278/// use trustformers_core::traits::Config;
279/// use trustformers_core::error::Result;
280/// use serde::{Deserialize, Serialize};
281///
282/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
283/// #[derive(Debug, Clone, Deserialize, Serialize)]
284/// struct GPT2Config {
285/// vocab_size: usize,
286/// hidden_size: usize,
287/// num_layers: usize,
288/// num_heads: usize,
289/// }
290///
291/// impl Config for GPT2Config {
292/// fn validate(&self) -> Result<()> {
293/// if self.hidden_size % self.num_heads != 0 {
294/// return Err(anyhow::anyhow!(
295/// "hidden_size must be divisible by num_heads"
296/// ));
297/// }
298/// Ok(())
299/// }
300///
301/// fn architecture(&self) -> &'static str {
302/// "gpt2"
303/// }
304/// }
305/// # Ok(())
306/// # }
307/// ```
308pub trait Config: for<'de> Deserialize<'de> + Serialize {
309 /// Validates the configuration for correctness.
310 ///
311 /// This method should check that all configuration parameters are valid
312 /// and compatible with each other. The default implementation accepts
313 /// all configurations as valid.
314 ///
315 /// # Returns
316 ///
317 /// Returns `Ok(())` if the configuration is valid, or an error describing
318 /// the validation failure.
319 ///
320 /// # Example
321 ///
322 /// Common validations include:
323 /// - Checking that dimensions are compatible
324 /// - Verifying that values are within acceptable ranges
325 /// - Ensuring required fields are properly set
326 fn validate(&self) -> Result<()> {
327 Ok(())
328 }
329
330 /// Returns the architecture name for this configuration.
331 ///
332 /// This should return a static string identifying the model architecture,
333 /// such as "bert", "gpt2", "t5", etc. This is used for model registration
334 /// and automatic model selection.
335 ///
336 /// # Returns
337 ///
338 /// A static string slice containing the architecture name.
339 fn architecture(&self) -> &'static str;
340}
341
342/// Interface for reading model weights from various sources.
343///
344/// `WeightReader` provides an abstraction over different weight storage formats,
345/// allowing models to load pretrained parameters from files, network sources,
346/// or other storage backends.
347///
348/// # Supported Formats
349///
350/// Implementations may support various formats including:
351/// - SafeTensors (.safetensors)
352/// - PyTorch checkpoints (.pt, .bin)
353/// - NumPy arrays (.npz)
354/// - Custom formats
355///
356/// # Example
357///
358/// ```no_run
359/// use trustformers_core::traits::WeightReader;
360/// use trustformers_core::tensor::Tensor;
361/// use trustformers_core::error::Result;
362///
363/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
364/// struct SafeTensorsReader {
365/// // ... implementation details
366/// }
367///
368/// impl WeightReader for SafeTensorsReader {
369/// fn read_tensor(&mut self, name: &str) -> Result<Tensor> {
370/// // Read tensor from SafeTensors file
371/// // ...
372/// Tensor::zeros(&[768, 768])
373/// }
374///
375/// fn list_tensors(&self) -> Vec<String> {
376/// vec![
377/// "bert.embeddings.word_embeddings.weight".to_string(),
378/// "bert.encoder.layer.0.attention.self.query.weight".to_string(),
379/// // ... more tensor names
380/// ]
381/// }
382/// }
383/// # Ok(())
384/// # }
385/// ```
386pub trait WeightReader {
387 /// Reads a tensor by name from the weight source.
388 ///
389 /// # Arguments
390 ///
391 /// * `name` - The name/key of the tensor to read (e.g., "encoder.layer.0.weight")
392 ///
393 /// # Returns
394 ///
395 /// Returns `Ok(tensor)` containing the requested tensor, or an error if the
396 /// tensor cannot be found or loaded.
397 ///
398 /// # Errors
399 ///
400 /// May return errors for:
401 /// - Tensor not found with the given name
402 /// - IO errors while reading
403 /// - Corrupted or invalid tensor data
404 /// - Unsupported tensor format
405 fn read_tensor(&mut self, name: &str) -> Result<Tensor>;
406
407 /// Lists all available tensor names in the weight source.
408 ///
409 /// This method is useful for debugging and for discovering the structure
410 /// of saved model weights.
411 ///
412 /// # Returns
413 ///
414 /// A vector containing the names of all available tensors.
415 fn list_tensors(&self) -> Vec<String>;
416}
417
418/// Text tokenization interface for transformer models.
419///
420/// The `Tokenizer` trait provides methods for converting between text and token IDs,
421/// which is essential for preparing input data for transformer models. Implementations
422/// may use various tokenization algorithms such as WordPiece, BPE, or SentencePiece.
423///
424/// # Thread Safety
425///
426/// Tokenizers must be `Send + Sync` to support concurrent tokenization.
427///
428/// # Example
429///
430/// ```no_run
431/// use trustformers_core::traits::{Tokenizer, TokenizedInput};
432/// use trustformers_core::error::Result;
433///
434/// struct BertTokenizer {
435/// vocab: std::collections::HashMap<String, u32>,
436/// // ... other fields
437/// }
438///
439/// impl Tokenizer for BertTokenizer {
440/// fn encode(&self, text: &str) -> Result<TokenizedInput> {
441/// // Tokenize text into subwords
442/// let tokens = vec![101, 2023, 2003, 1037, 3231, 102]; // [CLS] this is a test [SEP]
443/// Ok(TokenizedInput {
444/// input_ids: tokens,
445/// attention_mask: vec![1; 6],
446/// token_type_ids: Some(vec![0; 6]),
447/// })
448/// }
449///
450/// fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
451/// // Encode two texts for tasks like question answering
452/// let tokens1 = self.encode(text)?;
453/// let tokens2 = self.encode(text2)?;
454/// // Combine tokens with separator
455/// let mut combined_tokens = tokens1.token_ids;
456/// combined_tokens.extend_from_slice(&tokens2.token_ids);
457/// Ok(TokenizedInput {
458/// token_ids: combined_tokens,
459/// attention_mask: vec![1; combined_tokens.len()],
460/// token_type_ids: Some(vec![0; tokens1.token_ids.len()].into_iter()
461/// .chain(vec![1; tokens2.token_ids.len()]).collect()),
462/// })
463/// }
464///
465/// fn decode(&self, ids: &[u32]) -> Result<String> {
466/// // Convert token IDs back to text
467/// Ok("this is a test".to_string())
468/// }
469///
470/// fn vocab_size(&self) -> usize {
471/// 30522 // BERT base vocabulary size
472/// }
473/// }
474/// ```
475pub trait Tokenizer: Send + Sync {
476 /// Encodes a single text string into tokens.
477 ///
478 /// # Arguments
479 ///
480 /// * `text` - The input text to tokenize
481 ///
482 /// # Returns
483 ///
484 /// Returns a `TokenizedInput` containing:
485 /// - `input_ids`: The token IDs
486 /// - `attention_mask`: Binary mask indicating real vs padding tokens
487 /// - `token_type_ids`: Optional segment IDs for models like BERT
488 ///
489 /// # Errors
490 ///
491 /// May return errors for:
492 /// - Invalid UTF-8 sequences
493 /// - Text exceeding maximum length
494 /// - Unknown tokens that cannot be handled
495 fn encode(&self, text: &str) -> Result<TokenizedInput>;
496
497 /// Encodes a pair of texts for sequence-pair tasks.
498 ///
499 /// This method is used for tasks that require two input sequences,
500 /// such as question answering, textual entailment, or sequence classification.
501 ///
502 /// # Arguments
503 ///
504 /// * `text` - The first text sequence
505 /// * `text2` - The second text sequence
506 ///
507 /// # Returns
508 ///
509 /// Returns a `TokenizedInput` with both sequences encoded and separated
510 /// by appropriate special tokens (e.g., [SEP] for BERT).
511 ///
512 /// # Errors
513 ///
514 /// May return errors for the same reasons as `encode()`.
515 fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput>;
516
517 /// Decodes token IDs back into text.
518 ///
519 /// # Arguments
520 ///
521 /// * `ids` - The token IDs to decode
522 ///
523 /// # Returns
524 ///
525 /// Returns the decoded text string. Special tokens may be included
526 /// or excluded depending on the implementation.
527 ///
528 /// # Errors
529 ///
530 /// May return errors for:
531 /// - Invalid token IDs
532 /// - Decoding errors
533 fn decode(&self, ids: &[u32]) -> Result<String>;
534
535 /// Returns the size of the tokenizer's vocabulary.
536 ///
537 /// # Returns
538 ///
539 /// The total number of tokens in the vocabulary.
540 fn vocab_size(&self) -> usize;
541
542 /// Returns a copy of the vocabulary as a mapping from tokens to IDs.
543 ///
544 /// # Returns
545 ///
546 /// A HashMap containing the vocabulary mapping.
547 fn get_vocab(&self) -> std::collections::HashMap<String, u32>;
548
549 /// Converts a token string to its corresponding ID.
550 ///
551 /// # Arguments
552 ///
553 /// * `token` - The token string to convert
554 ///
555 /// # Returns
556 ///
557 /// The token ID if the token exists in the vocabulary, None otherwise.
558 fn token_to_id(&self, token: &str) -> Option<u32>;
559
560 /// Converts a token ID to its corresponding token string.
561 ///
562 /// # Arguments
563 ///
564 /// * `id` - The token ID to convert
565 ///
566 /// # Returns
567 ///
568 /// The token string if the ID exists in the vocabulary, None otherwise.
569 fn id_to_token(&self, id: u32) -> Option<String>;
570}
571
572/// Represents tokenized input ready for model consumption.
573///
574/// `TokenizedInput` contains all the necessary components for feeding
575/// text data into a transformer model after tokenization.
576///
577/// # Fields
578///
579/// * `input_ids` - The token IDs representing the input text
580/// * `attention_mask` - Binary mask (0 or 1) indicating which tokens are real vs padding
581/// * `token_type_ids` - Optional segment IDs for models that use them (e.g., BERT)
582///
583/// # Example
584///
585/// ```no_run
586/// use trustformers_core::traits::TokenizedInput;
587///
588/// let input = TokenizedInput {
589/// input_ids: vec![101, 2023, 2003, 1037, 3231, 102], // [CLS] this is a test [SEP]
590/// attention_mask: vec![1, 1, 1, 1, 1, 1], // All tokens are real (not padding)
591/// token_type_ids: Some(vec![0, 0, 0, 0, 0, 0]), // All tokens from first segment
592/// };
593/// ```
594#[derive(Debug, Clone, Default)]
595pub struct TokenizedInput {
596 /// Token IDs representing the encoded text.
597 /// These correspond to entries in the tokenizer's vocabulary.
598 pub input_ids: Vec<u32>,
599
600 /// Binary attention mask indicating real tokens (1) vs padding tokens (0).
601 /// This prevents the model from attending to padding tokens.
602 pub attention_mask: Vec<u8>,
603
604 /// Optional token type IDs for distinguishing between different segments.
605 /// Used by models like BERT for tasks involving multiple sequences.
606 /// Typically 0 for the first sequence and 1 for the second sequence.
607 pub token_type_ids: Option<Vec<u32>>,
608
609 /// Optional special tokens mask indicating special tokens (1) vs regular tokens (0).
610 /// Used to identify tokens like [CLS], [SEP], [PAD] etc.
611 pub special_tokens_mask: Option<Vec<u8>>,
612
613 /// Optional offset mapping showing character positions of tokens in original text.
614 /// Each tuple contains (start_pos, end_pos) character offsets.
615 pub offset_mapping: Option<Vec<(usize, usize)>>,
616
617 /// Optional overflowing tokens when text exceeds max length.
618 /// Contains tokens that were truncated from the input.
619 pub overflowing_tokens: Option<Vec<u32>>,
620}
621
622impl TokenizedInput {
623 /// Create a new TokenizedInput with minimal required fields
624 pub fn new(input_ids: Vec<u32>, attention_mask: Vec<u8>) -> Self {
625 Self {
626 input_ids,
627 attention_mask,
628 token_type_ids: None,
629 special_tokens_mask: None,
630 offset_mapping: None,
631 overflowing_tokens: None,
632 }
633 }
634
635 /// Create a new TokenizedInput with token type IDs
636 pub fn with_token_type_ids(
637 input_ids: Vec<u32>,
638 attention_mask: Vec<u8>,
639 token_type_ids: Option<Vec<u32>>,
640 ) -> Self {
641 Self {
642 input_ids,
643 attention_mask,
644 token_type_ids,
645 special_tokens_mask: None,
646 offset_mapping: None,
647 overflowing_tokens: None,
648 }
649 }
650}
651
652/// Parameter optimization algorithms for training neural networks.
653///
654/// The `Optimizer` trait defines the interface for gradient-based optimization
655/// algorithms such as SGD, Adam, AdamW, etc. Optimizers update model parameters
656/// based on computed gradients to minimize the loss function.
657///
658/// # Thread Safety
659///
660/// Optimizers must be `Send + Sync` to support distributed training.
661///
662/// # Example
663///
664/// ```no_run
665/// use trustformers_core::traits::Optimizer;
666/// use trustformers_core::tensor::Tensor;
667/// use trustformers_core::error::Result;
668///
669/// struct SGD {
670/// learning_rate: f32,
671/// momentum: f32,
672/// velocity: std::collections::HashMap<String, Tensor>,
673/// }
674///
675/// impl Optimizer for SGD {
676/// fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
677/// // SGD with momentum: v = momentum * v - lr * grad
678/// // parameter += v
679/// Ok(())
680/// }
681///
682/// fn zero_grad(&mut self) {
683/// // Clear accumulated gradients
684/// }
685///
686/// fn step(&mut self) {
687/// // Apply updates to all parameters
688/// }
689///
690/// fn get_lr(&self) -> f32 {
691/// self.learning_rate
692/// }
693///
694/// fn set_lr(&mut self, lr: f32) {
695/// self.learning_rate = lr;
696/// }
697/// }
698/// ```
699pub trait Optimizer: Send + Sync {
700 /// Updates a parameter based on its gradient.
701 ///
702 /// # Arguments
703 ///
704 /// * `parameter` - The parameter tensor to update
705 /// * `grad` - The gradient tensor for this parameter
706 ///
707 /// # Returns
708 ///
709 /// Returns `Ok(())` on successful update, or an error if the update fails.
710 ///
711 /// # Errors
712 ///
713 /// May return errors for:
714 /// - Mismatched tensor shapes
715 /// - Numerical errors (e.g., NaN or Inf values)
716 /// - Memory allocation failures
717 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()>;
718
719 /// Clears all accumulated gradients.
720 ///
721 /// This should be called before each backward pass to ensure
722 /// gradients don't accumulate across batches (unless gradient
723 /// accumulation is intentionally being used).
724 fn zero_grad(&mut self);
725
726 /// Performs a single optimization step.
727 ///
728 /// This method applies all pending parameter updates. It should be
729 /// called after gradients have been computed for all parameters.
730 fn step(&mut self);
731
732 /// Gets the current learning rate.
733 ///
734 /// # Returns
735 ///
736 /// The current learning rate value.
737 fn get_lr(&self) -> f32;
738
739 /// Sets a new learning rate.
740 ///
741 /// # Arguments
742 ///
743 /// * `lr` - The new learning rate value
744 ///
745 /// # Note
746 ///
747 /// This is useful for implementing learning rate schedules.
748 fn set_lr(&mut self, lr: f32);
749
750 /// Accumulates gradients for gradient accumulation.
751 ///
752 /// This method is used when training with gradient accumulation,
753 /// where gradients from multiple batches are accumulated before
754 /// performing an update step.
755 ///
756 /// # Arguments
757 ///
758 /// * `parameter` - The parameter tensor
759 /// * `grad` - The gradient to accumulate
760 ///
761 /// # Returns
762 ///
763 /// Returns `Ok(())` on success, or an error if accumulation fails.
764 ///
765 /// # Default Implementation
766 ///
767 /// The default implementation simply calls `update()`. Override this
768 /// method for optimizers that need special gradient accumulation logic.
769 fn accumulate_grad(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
770 // Default implementation: just store the gradient for later use
771 self.update(parameter, grad)
772 }
773
774 /// Applies accumulated gradients after gradient accumulation.
775 ///
776 /// This method should be called after accumulating gradients from
777 /// multiple batches to apply the averaged update.
778 ///
779 /// # Arguments
780 ///
781 /// * `accumulation_steps` - The number of accumulation steps performed
782 ///
783 /// # Returns
784 ///
785 /// Returns `Ok(())` on success, or an error if application fails.
786 ///
787 /// # Default Implementation
788 ///
789 /// The default implementation is a no-op. Override this method for
790 /// optimizers that implement gradient accumulation.
791 fn apply_accumulated_grads(&mut self, accumulation_steps: usize) -> Result<()> {
792 // Default implementation: no-op, override if needed
793 let _ = accumulation_steps;
794 Ok(())
795 }
796}
797
798/// Weight initialization strategies for neural network parameters.
799///
800/// The `ParameterInit` trait provides various initialization methods that help
801/// ensure proper gradient flow and training stability. Different initialization
802/// strategies are optimal for different activation functions and architectures.
803///
804/// # Example
805///
806/// ```no_run
807/// use trustformers_core::traits::ParameterInit;
808/// use trustformers_core::tensor::Tensor;
809///
810/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
811/// let mut weight = Tensor::zeros(&[768, 768])?;
812///
813/// // Initialize with Xavier/Glorot uniform for tanh activations
814/// weight.xavier_uniform();
815///
816/// // Or use Kaiming/He initialization for ReLU activations
817/// weight.kaiming_normal("fan_in", "relu");
818/// # Ok(())
819/// # }
820/// ```
821pub trait ParameterInit {
822 /// Initializes the tensor with values from a normal distribution.
823 ///
824 /// # Arguments
825 ///
826 /// * `mean` - The mean of the normal distribution
827 /// * `std` - The standard deviation of the normal distribution
828 ///
829 /// # Example
830 ///
831 /// ```no_run
832 /// # use trustformers_core::tensor::Tensor;
833 /// # use trustformers_core::traits::ParameterInit;
834 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
835 /// let mut tensor = Tensor::zeros(&[100, 100])?;
836 /// tensor.normal(0.0, 0.02); // Common for transformer embeddings
837 /// # Ok(())
838 /// # }
839 /// ```
840 fn normal(&mut self, mean: f32, std: f32);
841
842 /// Initializes the tensor with values from a uniform distribution.
843 ///
844 /// # Arguments
845 ///
846 /// * `min` - The minimum value (inclusive)
847 /// * `max` - The maximum value (exclusive)
848 ///
849 /// # Example
850 ///
851 /// ```no_run
852 /// # use trustformers_core::tensor::Tensor;
853 /// # use trustformers_core::traits::ParameterInit;
854 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
855 /// let mut tensor = Tensor::zeros(&[100, 100])?;
856 /// tensor.uniform(-0.1, 0.1);
857 /// # Ok(())
858 /// # }
859 /// ```
860 fn uniform(&mut self, min: f32, max: f32);
861
862 /// Xavier/Glorot uniform initialization.
863 ///
864 /// Initializes weights to maintain variance across layers, optimal for
865 /// tanh and sigmoid activations. The range is [-x, x] where
866 /// x = sqrt(6 / (fan_in + fan_out)).
867 ///
868 /// # References
869 ///
870 /// Glorot & Bengio (2010): "Understanding the difficulty of training
871 /// deep feedforward neural networks"
872 fn xavier_uniform(&mut self);
873
874 /// Xavier/Glorot normal initialization.
875 ///
876 /// Similar to `xavier_uniform` but uses a normal distribution with
877 /// std = sqrt(2 / (fan_in + fan_out)).
878 fn xavier_normal(&mut self);
879
880 /// Kaiming/He uniform initialization.
881 ///
882 /// Designed for ReLU and similar activations. Maintains variance when
883 /// half of the neurons are zeroed out by ReLU.
884 ///
885 /// # Arguments
886 ///
887 /// * `mode` - Either "fan_in" or "fan_out", determines which dimension to use
888 /// * `nonlinearity` - The activation function ("relu", "leaky_relu", "linear")
889 ///
890 /// # Example
891 ///
892 /// ```no_run
893 /// # use trustformers_core::tensor::Tensor;
894 /// # use trustformers_core::traits::ParameterInit;
895 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
896 /// let mut conv_weight = Tensor::zeros(&[64, 32, 3, 3])?;
897 /// conv_weight.kaiming_uniform("fan_in", "relu");
898 /// # Ok(())
899 /// # }
900 /// ```
901 ///
902 /// # References
903 ///
904 /// He et al. (2015): "Delving Deep into Rectifiers: Surpassing
905 /// Human-Level Performance on ImageNet Classification"
906 fn kaiming_uniform(&mut self, mode: &str, nonlinearity: &str);
907
908 /// Kaiming/He normal initialization.
909 ///
910 /// Similar to `kaiming_uniform` but uses a normal distribution.
911 /// Generally preferred over uniform for deeper networks.
912 ///
913 /// # Arguments
914 ///
915 /// * `mode` - Either "fan_in" or "fan_out"
916 /// * `nonlinearity` - The activation function ("relu", "leaky_relu", "linear")
917 fn kaiming_normal(&mut self, mode: &str, nonlinearity: &str);
918}