Skip to main content

trustformers_core/layers/
linear.rs

1//! Linear (fully connected) layer implementation.
2//!
3//! This module provides the `Linear` layer, which performs affine transformations
4//! of the form: `y = xW^T + b`, where `W` is the weight matrix and `b` is the
5//! optional bias vector.
6
7use crate::device::Device;
8use crate::errors::{Result, TrustformersError};
9#[cfg(all(target_os = "macos", feature = "metal"))]
10use crate::gpu_ops::dispatch_matmul;
11use crate::tensor::Tensor;
12use crate::traits::Layer;
13use scirs2_core::ndarray::{Array2, Ix2, IxDyn};
14#[cfg(not(target_os = "macos"))]
15use scirs2_core::simd_ops::SimdUnifiedOps;
16
17/// Direct BLAS GEMM using OxiBLAS for maximum performance
18#[cfg(target_os = "macos")]
19#[inline]
20fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
21    use oxiblas_blas::level3::gemm;
22    use oxiblas_matrix::{MatMut, MatRef};
23
24    // Create matrix views from slices (row-major layout)
25    let a_mat = MatRef::new(a.as_ptr(), m, k, k);
26    let b_mat = MatRef::new(b.as_ptr(), k, n, n);
27    let c_mat = MatMut::new(c.as_mut_ptr(), m, n, n);
28
29    // GEMM: C = 1.0 * A * B + 0.0 * C
30    gemm(1.0, a_mat, b_mat, 0.0, c_mat);
31}
32
33/// Direct BLAS GEMM using OxiBLAS for f64
34#[cfg(target_os = "macos")]
35#[inline]
36fn blas_dgemm(a: &[f64], b: &[f64], c: &mut [f64], m: usize, k: usize, n: usize) {
37    use oxiblas_blas::level3::gemm;
38    use oxiblas_matrix::{MatMut, MatRef};
39
40    // Create matrix views from slices (row-major layout)
41    let a_mat = MatRef::new(a.as_ptr(), m, k, k);
42    let b_mat = MatRef::new(b.as_ptr(), k, n, n);
43    let c_mat = MatMut::new(c.as_mut_ptr(), m, n, n);
44
45    // GEMM: C = 1.0 * A * B + 0.0 * C
46    gemm(1.0, a_mat, b_mat, 0.0, c_mat);
47}
48
49/// Fallback for non-macOS: use scirs2-core SIMD GEMM for f32
50#[cfg(not(target_os = "macos"))]
51#[inline]
52fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
53    let a_arr = Array2::from_shape_vec((m, k), a.to_vec())
54        .expect("matrix dimensions must match slice length");
55    let b_arr = Array2::from_shape_vec((k, n), b.to_vec())
56        .expect("matrix dimensions must match slice length");
57    let mut c_arr = Array2::from_shape_vec((m, n), c.to_vec())
58        .expect("matrix dimensions must match slice length");
59    f32::simd_gemm(1.0, &a_arr.view(), &b_arr.view(), 0.0, &mut c_arr);
60    c.copy_from_slice(c_arr.as_slice().expect("Array2 must have contiguous slice"));
61}
62
63/// Fallback for non-macOS: use scirs2-core SIMD GEMM for f64
64#[cfg(not(target_os = "macos"))]
65#[inline]
66fn blas_dgemm(a: &[f64], b: &[f64], c: &mut [f64], m: usize, k: usize, n: usize) {
67    let a_arr = Array2::from_shape_vec((m, k), a.to_vec())
68        .expect("matrix dimensions must match slice length");
69    let b_arr = Array2::from_shape_vec((k, n), b.to_vec())
70        .expect("matrix dimensions must match slice length");
71    let mut c_arr = Array2::from_shape_vec((m, n), c.to_vec())
72        .expect("matrix dimensions must match slice length");
73    f64::simd_gemm(1.0, &a_arr.view(), &b_arr.view(), 0.0, &mut c_arr);
74    c.copy_from_slice(c_arr.as_slice().expect("Array2 must have contiguous slice"));
75}
76
77/// A linear transformation layer (fully connected layer).
78///
79/// The `Linear` layer applies a linear transformation to the incoming data:
80/// `y = xW^T + b`. This is one of the most fundamental building blocks in
81/// neural networks.
82///
83/// # Parameters
84///
85/// - `weight`: Learnable weight matrix of shape `[out_features, in_features]`
86/// - `bias`: Optional learnable bias vector of shape `[out_features]`
87///
88/// # Input/Output Shapes
89///
90/// - Input: `[..., in_features]` - Can be 2D or 3D
91/// - Output: `[..., out_features]` - Same number of dimensions as input
92///
93/// # Example
94///
95/// ```no_run
96/// use trustformers_core::layers::Linear;
97/// use trustformers_core::tensor::Tensor;
98/// use trustformers_core::traits::Layer;
99///
100/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
101/// // Create a linear layer: 768 → 3072
102/// let linear = Linear::new(768, 3072, true);
103///
104/// // Apply to 2D input: [seq_len, in_features]
105/// let input_2d = Tensor::randn(&[128, 768])?;
106/// let output_2d = linear.forward(input_2d)?;  // Shape: [128, 3072]
107///
108/// // Apply to 3D input: [batch, seq_len, in_features]
109/// let input_3d = Tensor::randn(&[4, 128, 768])?;
110/// let output_3d = linear.forward(input_3d)?;  // Shape: [4, 128, 3072]
111/// # Ok(())
112/// # }
113/// ```
114#[derive(Debug, Clone)]
115pub struct Linear {
116    weight: Tensor,
117    bias: Option<Tensor>,
118    device: Device,
119    #[cfg(all(target_os = "macos", feature = "metal"))]
120    weight_buffer_id: std::sync::Arc<std::sync::RwLock<Option<crate::gpu_ops::BufferId>>>,
121    #[cfg(feature = "cuda")]
122    weight_buffer_id_cuda:
123        std::sync::Arc<std::sync::RwLock<Option<crate::gpu_ops::cuda::BufferId>>>,
124}
125
126impl Linear {
127    /// Creates a new linear layer.
128    ///
129    /// # Arguments
130    ///
131    /// * `in_features` - Size of each input sample
132    /// * `out_features` - Size of each output sample
133    /// * `bias` - Whether to include a learnable bias
134    ///
135    /// # Returns
136    ///
137    /// A new `Linear` layer with randomly initialized weights using a normal
138    /// distribution, and bias initialized to zeros if enabled.
139    ///
140    /// # Example
141    ///
142    /// ```no_run
143    /// use trustformers_core::layers::Linear;
144    ///
145    /// // Linear layer without bias
146    /// let linear1 = Linear::new(512, 1024, false);
147    ///
148    /// // Linear layer with bias
149    /// let linear2 = Linear::new(512, 1024, true);
150    /// ```
151    pub fn new(in_features: usize, out_features: usize, bias: bool) -> Self {
152        Self::new_with_device(in_features, out_features, bias, Device::CPU)
153    }
154
155    /// Creates a new linear layer with specified device.
156    ///
157    /// # Arguments
158    ///
159    /// * `in_features` - Size of each input sample
160    /// * `out_features` - Size of each output sample
161    /// * `bias` - Whether to include a learnable bias
162    /// * `device` - Device to use for computations (CPU, Metal, CUDA, etc.)
163    ///
164    /// # Returns
165    ///
166    /// A new `Linear` layer with randomly initialized weights.
167    ///
168    /// # Example
169    ///
170    /// ```no_run
171    /// use trustformers_core::layers::Linear;
172    /// use trustformers_core::Device;
173    ///
174    /// // Create a linear layer on Metal GPU
175    /// let linear = Linear::new_with_device(768, 3072, true, Device::metal_if_available(0));
176    /// ```
177    pub fn new_with_device(
178        in_features: usize,
179        out_features: usize,
180        bias: bool,
181        device: Device,
182    ) -> Self {
183        let weight =
184            Tensor::randn(&[out_features, in_features]).expect("Failed to create random tensor");
185        let bias = if bias {
186            Some(Tensor::zeros(&[out_features]).expect("Failed to create zero tensor"))
187        } else {
188            None
189        };
190
191        Self {
192            weight,
193            bias,
194            device,
195            #[cfg(all(target_os = "macos", feature = "metal"))]
196            weight_buffer_id: std::sync::Arc::new(std::sync::RwLock::new(None)),
197            #[cfg(feature = "cuda")]
198            weight_buffer_id_cuda: std::sync::Arc::new(std::sync::RwLock::new(None)),
199        }
200    }
201
202    /// Returns the device this layer uses for computations.
203    pub fn device(&self) -> Device {
204        self.device
205    }
206
207    /// Move this layer to a different device.
208    ///
209    /// # Arguments
210    ///
211    /// * `device` - Target device
212    ///
213    /// # Returns
214    ///
215    /// Self with updated device.
216    pub fn to_device(mut self, device: Device) -> Self {
217        self.device = device;
218        // Clear cached buffer when changing device
219        #[cfg(all(target_os = "macos", feature = "metal"))]
220        {
221            if let Ok(mut buffer_id) = self.weight_buffer_id.write() {
222                *buffer_id = None;
223            }
224        }
225        #[cfg(feature = "cuda")]
226        {
227            if let Ok(mut buffer_id) = self.weight_buffer_id_cuda.write() {
228                *buffer_id = None;
229            }
230        }
231        self
232    }
233
234    /// Sets the weight matrix for this layer.
235    ///
236    /// # Arguments
237    ///
238    /// * `weight` - The new weight tensor, must have shape `[out_features, in_features]`
239    ///
240    /// # Returns
241    ///
242    /// `Ok(())` if successful.
243    ///
244    /// # Note
245    ///
246    /// This method is typically used when loading pretrained weights.
247    pub fn set_weight(&mut self, weight: Tensor) -> Result<()> {
248        self.weight = weight;
249        // Clear cached buffer when weights are updated
250        #[cfg(all(target_os = "macos", feature = "metal"))]
251        {
252            if let Ok(mut buffer_id) = self.weight_buffer_id.write() {
253                *buffer_id = None;
254            }
255        }
256        #[cfg(feature = "cuda")]
257        {
258            if let Ok(mut buffer_id) = self.weight_buffer_id_cuda.write() {
259                *buffer_id = None;
260            }
261        }
262        Ok(())
263    }
264
265    /// Sets the bias vector for this layer.
266    ///
267    /// # Arguments
268    ///
269    /// * `bias` - The new bias tensor, must have shape `[out_features]`
270    ///
271    /// # Returns
272    ///
273    /// `Ok(())` if successful.
274    ///
275    /// # Note
276    ///
277    /// This will enable bias even if the layer was created without bias.
278    pub fn set_bias(&mut self, bias: Tensor) -> Result<()> {
279        self.bias = Some(bias);
280        Ok(())
281    }
282
283    /// Returns a reference to the weight matrix.
284    ///
285    /// # Returns
286    ///
287    /// A reference to the weight tensor of shape `[out_features, in_features]`.
288    pub fn weight(&self) -> &Tensor {
289        &self.weight
290    }
291
292    /// Returns a reference to the bias vector if present.
293    ///
294    /// # Returns
295    ///
296    /// `Some(&bias)` if bias is enabled, `None` otherwise.
297    pub fn bias(&self) -> Option<&Tensor> {
298        self.bias.as_ref()
299    }
300
301    /// Returns the total number of learnable parameters in this layer.
302    ///
303    /// # Returns
304    ///
305    /// The total parameter count including weights and bias (if present).
306    pub fn parameter_count(&self) -> usize {
307        let weight_count = self.weight.len();
308        let bias_count = self.bias.as_ref().map_or(0, |b| b.len());
309        weight_count + bias_count
310    }
311
312    /// Initialize persistent GPU weight buffer for Metal device
313    /// This is called automatically on first forward pass with Metal device
314    #[cfg(all(target_os = "macos", feature = "metal"))]
315    fn ensure_weight_buffer_cached(&self) -> Result<()> {
316        use crate::gpu_ops::metal::get_metal_backend;
317
318        // Check if already cached (read lock is cheaper)
319        if let Ok(buffer_id) = self.weight_buffer_id.read() {
320            if buffer_id.is_some() {
321                return Ok(()); // Already cached
322            }
323        }
324
325        // Only cache if using Metal
326        if matches!(self.device, Device::Metal(_)) {
327            // Get write lock to cache the buffer
328            let mut buffer_id_guard = self.weight_buffer_id.write().map_err(|_| {
329                TrustformersError::hardware_error(
330                    "Failed to acquire write lock on buffer cache",
331                    "ensure_weight_buffer_cached",
332                )
333            })?;
334
335            // Double-check after acquiring write lock (another thread might have cached it)
336            if buffer_id_guard.is_some() {
337                return Ok(());
338            }
339
340            // Get weight data as f32 slice
341            // CRITICAL FIX: Cache the TRANSPOSED weight, not the original!
342            // The Metal shader expects weight in [in_features, out_features] layout
343            // but self.weight is stored as [out_features, in_features]
344            let weight_t = self.weight.transpose(0, 1)?;
345            match &weight_t {
346                Tensor::F32(arr) => {
347                    if arr.ndim() != 2 {
348                        return Err(TrustformersError::shape_error(
349                            "Weight tensor must be 2D for Metal caching".to_string(),
350                        ));
351                    }
352
353                    // Convert to contiguous vec for GPU upload
354                    // Using as_standard_layout() ensures proper row-major order
355                    let contiguous_arr = arr.as_standard_layout();
356                    let weight_data: Vec<f32> = contiguous_arr.iter().copied().collect();
357
358                    // Get Metal backend and cache the buffer
359                    let backend = get_metal_backend()?;
360                    let new_buffer_id = backend.create_persistent_buffer(&weight_data)?;
361                    *buffer_id_guard = Some(new_buffer_id);
362                },
363                _ => {
364                    return Err(TrustformersError::tensor_op_error(
365                        "Only F32 tensors supported for Metal caching",
366                        "ensure_weight_buffer_cached",
367                    ));
368                },
369            }
370        }
371        Ok(())
372    }
373
374    /// Pre-cache layer weights on GPU for zero-transfer pipeline
375    /// This uploads weights to GPU memory in advance to avoid transfers during forward pass
376    #[cfg(all(target_os = "macos", feature = "metal"))]
377    pub fn weights_to_gpu(&mut self, device: &crate::device::Device) -> Result<()> {
378        use crate::device::Device;
379
380        if !matches!(device, Device::Metal(_)) {
381            return Ok(()); // Nothing to do for non-Metal devices
382        }
383
384        // Update device setting
385        self.device = *device;
386
387        // Pre-cache weight buffer on GPU (keeps weight as F32 CPU tensor)
388        // The caching mechanism handles the GPU upload internally
389        self.ensure_weight_buffer_cached()?;
390
391        // Upload bias to GPU if present (for GPU bias addition kernel)
392        if let Some(ref bias) = self.bias {
393            self.bias = Some(bias.to_device_enum(device)?);
394        }
395
396        Ok(())
397    }
398
399    /// Initialize persistent GPU weight buffer for CUDA device
400    /// This is called automatically on first forward pass with CUDA device
401    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
402    fn ensure_weight_buffer_cached_cuda(&self) -> Result<()> {
403        use crate::gpu_ops::cuda::get_cuda_backend;
404
405        // Check if already cached (read lock is cheaper)
406        if let Ok(buffer_id) = self.weight_buffer_id_cuda.read() {
407            if buffer_id.is_some() {
408                return Ok(()); // Already cached
409            }
410        }
411
412        // Only cache if using CUDA
413        if matches!(self.device, Device::CUDA(_)) {
414            // Get write lock to cache the buffer
415            let mut buffer_id_guard = self.weight_buffer_id_cuda.write().map_err(|_| {
416                TrustformersError::hardware_error(
417                    "Failed to acquire write lock on CUDA buffer cache",
418                    "ensure_weight_buffer_cached_cuda",
419                )
420            })?;
421
422            // Double-check after acquiring write lock (another thread might have cached it)
423            if buffer_id_guard.is_some() {
424                return Ok(());
425            }
426
427            // Get weight data as f32 slice
428            // CRITICAL FIX: Cache the TRANSPOSED weight, not the original!
429            // The CUDA kernel expects weight in [in_features, out_features] layout
430            // but self.weight is stored as [out_features, in_features]
431            let weight_t = self.weight.transpose(0, 1)?;
432            match &weight_t {
433                Tensor::F32(arr) => {
434                    if arr.ndim() != 2 {
435                        return Err(TrustformersError::shape_error(
436                            "Weight tensor must be 2D for CUDA caching".to_string(),
437                        ));
438                    }
439
440                    // Convert to contiguous vec for GPU upload
441                    // Using as_standard_layout() ensures proper row-major order
442                    let contiguous_arr = arr.as_standard_layout();
443                    let weight_data: Vec<f32> = contiguous_arr.iter().copied().collect();
444
445                    // Get CUDA backend and cache the buffer
446                    let device_id = if let Device::CUDA(id) = self.device {
447                        id
448                    } else {
449                        0 // Default to device 0
450                    };
451                    let backend = get_cuda_backend(device_id)?;
452                    let new_buffer_id = backend.create_persistent_buffer(&weight_data)?;
453                    *buffer_id_guard = Some(new_buffer_id);
454                },
455                _ => {
456                    return Err(TrustformersError::tensor_op_error(
457                        "Only F32 tensors supported for CUDA caching",
458                        "ensure_weight_buffer_cached_cuda",
459                    ));
460                },
461            }
462        }
463        Ok(())
464    }
465
466    /// Pre-cache layer weights on GPU for zero-transfer pipeline (CUDA)
467    /// This uploads weights to GPU memory in advance to avoid transfers during forward pass
468    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
469    pub fn weights_to_gpu_cuda(&mut self, device: &crate::device::Device) -> Result<()> {
470        use crate::device::Device;
471
472        if !matches!(device, Device::CUDA(_)) {
473            return Ok(()); // Nothing to do for non-CUDA devices
474        }
475
476        // Update device setting
477        self.device = *device;
478
479        // Pre-cache weight buffer on GPU (keeps weight as F32 CPU tensor)
480        // The caching mechanism handles the GPU upload internally
481        self.ensure_weight_buffer_cached_cuda()?;
482
483        // Upload bias to GPU if present (for GPU bias addition kernel)
484        if let Some(ref bias) = self.bias {
485            self.bias = Some(bias.to_device_enum(device)?);
486        }
487
488        Ok(())
489    }
490}
491
492impl Layer for Linear {
493    type Input = Tensor;
494    type Output = Tensor;
495
496    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
497        // =====================================================================
498        // GPU-TO-GPU PATH: Tensor::Metal (ZERO CPU TRANSFERS!)
499        // =====================================================================
500        #[cfg(all(target_os = "macos", feature = "metal"))]
501        if let Tensor::Metal(ref input_metal) = input {
502            use crate::gpu_ops::metal::get_metal_backend;
503            use crate::tensor::MetalTensorData;
504
505            // eprintln!("🎯 Linear::forward - GPU-to-GPU path triggered (Tensor::Metal input)");
506
507            // Ensure weight buffer is cached on GPU
508            self.ensure_weight_buffer_cached()?;
509
510            // Get cached weight buffer ID
511            let weight_buffer_id = {
512                let buffer_id_guard = self.weight_buffer_id.read().map_err(|_| {
513                    TrustformersError::hardware_error(
514                        "Failed to acquire read lock on buffer cache",
515                        "Linear::forward",
516                    )
517                })?;
518
519                if let Some(id) = *buffer_id_guard {
520                    id
521                } else {
522                    // Weight not cached - fallback to CPU
523                    let cpu_input = input.to_device_enum(&crate::device::Device::CPU)?;
524                    return self.forward(cpu_input);
525                }
526            };
527
528            // Get Metal backend
529            let backend = get_metal_backend()?;
530
531            // Extract input shape and calculate matmul dimensions
532            let shape = &input_metal.shape;
533            let weight_shape = self.weight.shape();
534            let in_features = shape[shape.len() - 1];
535
536            // Check shape compatibility
537            if in_features != weight_shape[1] {
538                return Err(TrustformersError::shape_error(format!(
539                    "Linear layer input features {} doesn't match weight shape {:?}",
540                    in_features, weight_shape
541                )));
542            }
543
544            // Flatten input to [batch, in_features] for matmul
545            // Works for both 2D [seq_len, in_features] and 3D [batch, seq_len, in_features]
546            let batch_dims: usize = shape[..shape.len() - 1].iter().product();
547            let m = batch_dims; // number of rows in output
548            let k = in_features; // shared dimension
549            let n = self.weight.shape()[0]; // out_features
550
551            // Perform GPU-to-GPU matmul using MPS (100-500x faster!)
552            // Try MPS first, fallback to naive kernel if MPS unavailable
553            let output_buffer_id = backend
554                .matmul_gpu_to_gpu_mps(&input_metal.buffer_id, &weight_buffer_id, m, k, n)
555                .or_else(|_e| {
556                    // eprintln!(
557                    //     "⚠️  MPS matmul failed: {:?}, falling back to naive Metal kernel",
558                    //     e
559                    // );
560                    // Fallback to naive Metal kernel if MPS fails
561                    backend.matmul_gpu_to_gpu(&input_metal.buffer_id, &weight_buffer_id, m, k, n)
562                })?;
563
564            // Calculate output shape (preserve batch dimensions, change last dim)
565            let mut output_shape = shape[..shape.len() - 1].to_vec();
566            output_shape.push(n);
567
568            // Create output Metal tensor
569            let mut output = Tensor::Metal(MetalTensorData {
570                buffer_id: output_buffer_id,
571                shape: output_shape.clone(),
572                dtype: input_metal.dtype,
573            });
574
575            // Handle bias if present
576            if let Some(ref bias) = self.bias {
577                // eprintln!("🔍 Linear: Has bias, checking type...");
578                // Try GPU-to-GPU bias addition if bias is on GPU
579                match bias {
580                    #[cfg(all(target_os = "macos", feature = "metal"))]
581                    Tensor::Metal(bias_data) => {
582                        // eprintln!("🔍 Linear: Bias is Metal, using GPU-to-GPU bias addition");
583                        // Both output and bias are Metal tensors - use GPU kernel!
584                        if let Tensor::Metal(output_data) = &output {
585                            // eprintln!("🔍 Linear: Output is Metal, calling add_bias_gpu_to_gpu");
586                            let output_buffer_id = backend.add_bias_gpu_to_gpu(
587                                &output_data.buffer_id,
588                                &bias_data.buffer_id,
589                                batch_dims,
590                                n,
591                            )?;
592                            // eprintln!(
593                            //     "🔍 Linear: add_bias_gpu_to_gpu succeeded, returning Metal tensor"
594                            // );
595
596                            return Ok(Tensor::Metal(MetalTensorData {
597                                buffer_id: output_buffer_id,
598                                shape: output_shape.clone(),
599                                dtype: output_data.dtype,
600                            }));
601                        }
602                        // eprintln!("🔍 Linear: Output is NOT Metal, falling back to CPU");
603                    },
604                    _ => {
605                        // eprintln!(
606                        //     "🔍 Linear: Bias is NOT Metal (type={:?}), falling back to CPU",
607                        //     std::mem::discriminant(bias)
608                        // );
609                    },
610                }
611
612                // Fallback: CPU bias addition
613                // eprintln!("🔍 Linear: Using CPU bias fallback");
614                output = output.to_device_enum(&crate::device::Device::CPU)?;
615                // eprintln!("🔍 Linear: Converted output to CPU");
616                output = output.add(bias)?;
617                // eprintln!("🔍 Linear: Added bias on CPU");
618
619                // Convert back to Metal tensor if needed
620                if matches!(self.device, crate::device::Device::Metal(_)) {
621                    // eprintln!("🔍 Linear: Converting back to Metal device");
622                    output = output.to_device_enum(&self.device)?;
623                    // eprintln!(
624                    //     "🔍 Linear: Converted back to Metal, type={:?}",
625                    //     std::mem::discriminant(&output)
626                    // );
627                }
628            } else {
629                // eprintln!("🔍 Linear: No bias");
630            }
631
632            // eprintln!(
633            //     "🔍 Linear: Returning output, type={:?}",
634            //     std::mem::discriminant(&output)
635            // );
636            return Ok(output);
637        }
638
639        // =====================================================================
640        // GPU-TO-GPU PATH: Tensor::CUDA (ZERO CPU TRANSFERS!)
641        // =====================================================================
642        #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
643        if let Tensor::CUDA(ref input_cuda) = input {
644            use crate::gpu_ops::cuda::get_cuda_backend;
645            use crate::tensor::CudaTensorData;
646
647            // Ensure weight buffer is cached on GPU
648            self.ensure_weight_buffer_cached_cuda()?;
649
650            // Get cached weight buffer ID
651            let weight_buffer_id = {
652                let buffer_id_guard = self.weight_buffer_id_cuda.read().map_err(|_| {
653                    TrustformersError::hardware_error(
654                        "Failed to acquire read lock on CUDA buffer cache",
655                        "Linear::forward",
656                    )
657                })?;
658
659                if let Some(id) = *buffer_id_guard {
660                    id
661                } else {
662                    // Weight not cached - fallback to CPU
663                    let cpu_input = input.to_device_enum(&crate::device::Device::CPU)?;
664                    return self.forward(cpu_input);
665                }
666            };
667
668            // Get CUDA backend
669            let device_id = if let Device::CUDA(id) = self.device {
670                id
671            } else {
672                0 // Default to device 0
673            };
674            let backend = get_cuda_backend(device_id)?;
675
676            // Extract input shape and calculate matmul dimensions
677            let shape = &input_cuda.shape;
678            let weight_shape = self.weight.shape();
679            let in_features = shape[shape.len() - 1];
680
681            // Check shape compatibility
682            if in_features != weight_shape[1] {
683                return Err(TrustformersError::shape_error(format!(
684                    "Linear layer input features {} doesn't match weight shape {:?}",
685                    in_features, weight_shape
686                )));
687            }
688
689            // Flatten input to [batch, in_features] for matmul
690            // Works for both 2D [seq_len, in_features] and 3D [batch, seq_len, in_features]
691            let batch_dims: usize = shape[..shape.len() - 1].iter().product();
692            let m = batch_dims; // number of rows in output
693            let k = in_features; // shared dimension
694            let n = self.weight.shape()[0]; // out_features
695
696            // Perform GPU-to-GPU matmul (ZERO CPU TRANSFERS!)
697            let output_buffer_id =
698                backend.matmul_gpu_to_gpu(&input_cuda.buffer_id, &weight_buffer_id, m, k, n)?;
699
700            // Calculate output shape (preserve batch dimensions, change last dim)
701            let mut output_shape = shape[..shape.len() - 1].to_vec();
702            output_shape.push(n);
703
704            // Create output CUDA tensor
705            let mut output = Tensor::CUDA(CudaTensorData {
706                buffer_id: output_buffer_id,
707                shape: output_shape.clone(),
708                dtype: input_cuda.dtype,
709            });
710
711            // Handle bias if present
712            if let Some(ref bias) = self.bias {
713                // Try GPU-to-GPU bias addition if bias is on GPU
714                match bias {
715                    #[cfg(feature = "cuda")]
716                    Tensor::CUDA(bias_data) => {
717                        // Both output and bias are CUDA tensors - use GPU kernel!
718                        if let Tensor::CUDA(output_data) = &output {
719                            let output_buffer_id = backend.add_bias_gpu_to_gpu(
720                                &output_data.buffer_id,
721                                &bias_data.buffer_id,
722                                batch_dims,
723                                n,
724                            )?;
725
726                            return Ok(Tensor::CUDA(CudaTensorData {
727                                buffer_id: output_buffer_id,
728                                shape: output_shape.clone(),
729                                dtype: output_data.dtype,
730                            }));
731                        }
732                    },
733                    _ => {},
734                }
735
736                // Fallback: CPU bias addition
737                output = output.to_device_enum(&crate::device::Device::CPU)?;
738                output = output.add(bias)?;
739
740                // Convert back to CUDA tensor if needed
741                if matches!(self.device, crate::device::Device::CUDA(_)) {
742                    output = output.to_device_enum(&self.device)?;
743                }
744            }
745
746            return Ok(output);
747        }
748
749        // =====================================================================
750        // CPU/F32 PATH (existing implementation)
751        // =====================================================================
752        // Handle different input shapes for matmul
753        let input_shape = input.shape();
754        let weight_t = self.weight.transpose(0, 1)?;
755
756        let output = if input_shape.len() == 2 {
757            // Standard 2D input: [seq_len, hidden_size] x [hidden_size, out_features]
758
759            // Try to use cached Metal buffer if available (ZERO-COPY OPTIMIZATION)
760            #[cfg(all(target_os = "macos", feature = "metal"))]
761            if matches!(self.device, Device::Metal(_)) {
762                // Ensure buffer is cached
763                self.ensure_weight_buffer_cached()?;
764
765                // Try to use cached buffer
766                if let Ok(buffer_id_guard) = self.weight_buffer_id.read() {
767                    if let Some(buffer_id) = *buffer_id_guard {
768                        // We have a cached buffer! Use it for ZERO-COPY matmul
769                        use crate::gpu_ops::metal::get_metal_backend;
770
771                        if let (Tensor::F32(inp), Tensor::F32(w_t)) = (&input, &weight_t) {
772                            if inp.ndim() == 2 && w_t.ndim() == 2 {
773                                let inp_shape = inp.shape();
774                                let w_shape = w_t.shape();
775                                let m = inp_shape[0];
776                                let k = inp_shape[1];
777                                let k2 = w_shape[0];
778                                let n = w_shape[1];
779
780                                if k == k2 {
781                                    // Get Metal backend
782                                    if let Ok(backend) = get_metal_backend() {
783                                        // Convert input to contiguous
784                                        let input_data: Vec<f32> = inp.iter().copied().collect();
785
786                                        // Call matmul with CACHED weight buffer (no weight transfer!)
787                                        if let Ok(result) = backend.matmul_with_cached_weight(
788                                            &input_data,
789                                            &buffer_id,
790                                            m,
791                                            k,
792                                            n,
793                                        ) {
794                                            let result_arr =
795                                                scirs2_core::ndarray::Array2::from_shape_vec(
796                                                    (m, n),
797                                                    result,
798                                                )
799                                                .map_err(|e| {
800                                                    TrustformersError::shape_error(format!(
801                                                        "Result reshape failed: {}",
802                                                        e
803                                                    ))
804                                                })?;
805
806                                            // Add bias if present
807                                            let mut output = Tensor::F32(result_arr.into_dyn());
808                                            if let Some(ref bias) = self.bias {
809                                                output = output.add(bias)?;
810                                            }
811                                            return Ok(output);
812                                        }
813                                    }
814                                }
815                            }
816                        }
817                    }
818                }
819            }
820
821            // Fallback: Use standard dispatch (for non-Metal or if cached path failed)
822            #[cfg(all(target_os = "macos", feature = "metal"))]
823            {
824                if self.device.is_gpu() {
825                    dispatch_matmul(&input, &weight_t, &self.device)?
826                } else {
827                    input.matmul(&weight_t)?
828                }
829            }
830            #[cfg(not(all(target_os = "macos", feature = "metal")))]
831            {
832                input.matmul(&weight_t)?
833            }
834        } else if input_shape.len() == 3 {
835            // Batched 3D input: [batch, seq_len, hidden_size] x [hidden_size, out_features]
836            // Handle manually since tensor.matmul doesn't support 3D x 2D
837            match (&input, &weight_t) {
838                (Tensor::F32(inp), Tensor::F32(w)) => {
839                    let batch = input_shape[0];
840                    let seq_len = input_shape[1];
841                    let hidden = input_shape[2];
842                    let out_features = w.shape()[1];
843
844                    // Try to use cached Metal buffer for 3D case (CRITICAL OPTIMIZATION)
845                    #[cfg(all(target_os = "macos", feature = "metal"))]
846                    if matches!(self.device, Device::Metal(_)) {
847                        // Ensure buffer is cached
848                        self.ensure_weight_buffer_cached()?;
849
850                        if let Ok(buffer_id_guard) = self.weight_buffer_id.read() {
851                            if let Some(buffer_id) = *buffer_id_guard {
852                                // Reshape 3D input to 2D for matmul
853                                let m = batch * seq_len;
854                                let k = hidden;
855                                let n = out_features;
856
857                                // Get Metal backend
858                                use crate::gpu_ops::metal::get_metal_backend;
859                                if let Ok(backend) = get_metal_backend() {
860                                    // Convert input to contiguous 2D data
861                                    let input_data: Vec<f32> = inp.iter().copied().collect();
862
863                                    // Call matmul with CACHED weight buffer (no weight transfer!)
864                                    if let Ok(result) = backend.matmul_with_cached_weight(
865                                        &input_data,
866                                        &buffer_id,
867                                        m,
868                                        k,
869                                        n,
870                                    ) {
871                                        // Reshape result from 2D [m, n] back to 3D [batch, seq_len, out_features]
872                                        let result_arr =
873                                            scirs2_core::ndarray::Array2::from_shape_vec(
874                                                (m, n),
875                                                result,
876                                            )
877                                            .map_err(
878                                                |e| {
879                                                    TrustformersError::shape_error(format!(
880                                                        "Result reshape failed: {}",
881                                                        e
882                                                    ))
883                                                },
884                                            )?;
885
886                                        let result_3d = result_arr
887                                            .into_shape_with_order(IxDyn(&[
888                                                batch,
889                                                seq_len,
890                                                out_features,
891                                            ]))
892                                            .map_err(|e| {
893                                                TrustformersError::shape_error(format!(
894                                                    "3D reshape failed: {}",
895                                                    e
896                                                ))
897                                            })?;
898
899                                        // Add bias if present
900                                        let mut output = Tensor::F32(result_3d);
901                                        if let Some(ref bias) = self.bias {
902                                            output = output.add(bias)?;
903                                        }
904                                        return Ok(output);
905                                    }
906                                }
907                            }
908                        }
909                    }
910
911                    // Fallback: CPU path with ndarray
912
913                    // Ensure contiguous layout before reshaping input to 2D for dot product
914                    let inp_contiguous = inp.to_owned();
915                    let inp_2d = inp_contiguous
916                        .into_shape_with_order([batch * seq_len, hidden])
917                        .map_err(|e| {
918                            TrustformersError::shape_error(format!(
919                                "Failed to reshape input: {}",
920                                e
921                            ))
922                        })?;
923
924                    // Ensure contiguous layout for weight and convert to 2D for GEMM
925                    let w_contiguous = w.to_owned();
926                    let w_2d = w_contiguous.into_dimensionality::<Ix2>().map_err(|e| {
927                        TrustformersError::shape_error(format!(
928                            "Failed to convert weight to 2D: {}",
929                            e
930                        ))
931                    })?;
932
933                    // Use direct BLAS for maximum performance (Accelerate on macOS)
934                    let m = inp_2d.nrows();
935                    let n = w_2d.ncols();
936                    let k = inp_2d.ncols();
937                    const MIN_SIZE_FOR_BLAS: usize = 32;
938                    let out_2d = if m < MIN_SIZE_FOR_BLAS
939                        || n < MIN_SIZE_FOR_BLAS
940                        || k < MIN_SIZE_FOR_BLAS
941                    {
942                        inp_2d.dot(&w_2d)
943                    } else {
944                        // Use direct BLAS GEMM for 10-50x speedup
945                        let inp_slice = inp_2d.as_slice().unwrap_or(&[]);
946                        let w_slice = w_2d.as_slice().unwrap_or(&[]);
947                        if !inp_slice.is_empty() && !w_slice.is_empty() {
948                            let mut result_vec = vec![0.0f32; m * n];
949                            blas_sgemm(inp_slice, w_slice, &mut result_vec, m, k, n);
950                            Array2::from_shape_vec((m, n), result_vec)
951                                .expect("BLAS result shape must match m x n")
952                        } else {
953                            // Fallback to ndarray dot if slices aren't contiguous
954                            inp_2d.dot(&w_2d)
955                        }
956                    };
957
958                    // Reshape back to 3D
959                    let out_3d = out_2d
960                        .into_shape_with_order(IxDyn(&[batch, seq_len, out_features]))
961                        .map_err(|e| {
962                            TrustformersError::shape_error(format!(
963                                "Failed to reshape output: {}",
964                                e
965                            ))
966                        })?;
967
968                    Tensor::F32(out_3d)
969                },
970                (Tensor::F64(inp), Tensor::F64(w)) => {
971                    let batch = input_shape[0];
972                    let seq_len = input_shape[1];
973                    let hidden = input_shape[2];
974                    let out_features = w.shape()[1];
975
976                    // Ensure contiguous layout before reshaping
977                    let inp_contiguous = inp.to_owned();
978                    let inp_2d = inp_contiguous
979                        .into_shape_with_order([batch * seq_len, hidden])
980                        .map_err(|e| {
981                            TrustformersError::shape_error(format!(
982                                "Failed to reshape input: {}",
983                                e
984                            ))
985                        })?;
986
987                    // Ensure contiguous layout for weight and convert to 2D for GEMM
988                    let w_contiguous = w.to_owned();
989                    let w_2d = w_contiguous.into_dimensionality::<Ix2>().map_err(|e| {
990                        TrustformersError::shape_error(format!(
991                            "Failed to convert weight to 2D: {}",
992                            e
993                        ))
994                    })?;
995
996                    // Use direct BLAS for maximum performance (Accelerate on macOS)
997                    let m = inp_2d.nrows();
998                    let n = w_2d.ncols();
999                    let k = inp_2d.ncols();
1000                    const MIN_SIZE_FOR_BLAS: usize = 32;
1001                    let out_2d = if m < MIN_SIZE_FOR_BLAS
1002                        || n < MIN_SIZE_FOR_BLAS
1003                        || k < MIN_SIZE_FOR_BLAS
1004                    {
1005                        inp_2d.dot(&w_2d)
1006                    } else {
1007                        // Use direct BLAS GEMM for 10-50x speedup
1008                        let inp_slice = inp_2d.as_slice().unwrap_or(&[]);
1009                        let w_slice = w_2d.as_slice().unwrap_or(&[]);
1010                        if !inp_slice.is_empty() && !w_slice.is_empty() {
1011                            let mut result_vec = vec![0.0f64; m * n];
1012                            blas_dgemm(inp_slice, w_slice, &mut result_vec, m, k, n);
1013                            Array2::from_shape_vec((m, n), result_vec)
1014                                .expect("BLAS result shape must match m x n")
1015                        } else {
1016                            // Fallback to ndarray dot if slices aren't contiguous
1017                            inp_2d.dot(&w_2d)
1018                        }
1019                    };
1020
1021                    let out_3d = out_2d
1022                        .into_shape_with_order(IxDyn(&[batch, seq_len, out_features]))
1023                        .map_err(|e| {
1024                            TrustformersError::shape_error(format!(
1025                                "Failed to reshape output: {}",
1026                                e
1027                            ))
1028                        })?;
1029
1030                    Tensor::F64(out_3d)
1031                },
1032                _ => {
1033                    return Err(TrustformersError::tensor_op_error(
1034                        "Unsupported tensor types for 3D linear layer",
1035                        "Linear::forward",
1036                    ))
1037                },
1038            }
1039        } else {
1040            return Err(TrustformersError::tensor_op_error(
1041                &format!(
1042                    "Linear layer doesn't support input with {} dimensions",
1043                    input_shape.len()
1044                ),
1045                "Linear::forward",
1046            ));
1047        };
1048
1049        if let Some(ref bias) = self.bias {
1050            // Handle broadcasting for bias addition
1051            match (&output, bias) {
1052                (Tensor::F32(out_arr), Tensor::F32(bias_arr)) => {
1053                    // Broadcast bias to match output shape
1054                    let result = out_arr + bias_arr;
1055                    Ok(Tensor::F32(result))
1056                },
1057                _ => output.add(bias),
1058            }
1059        } else {
1060            Ok(output)
1061        }
1062    }
1063}