trustformers_core/layers/linear.rs
1//! Linear (fully connected) layer implementation.
2//!
3//! This module provides the `Linear` layer, which performs affine transformations
4//! of the form: `y = xW^T + b`, where `W` is the weight matrix and `b` is the
5//! optional bias vector.
6
7use crate::device::Device;
8use crate::errors::{Result, TrustformersError};
9#[cfg(all(target_os = "macos", feature = "metal"))]
10use crate::gpu_ops::dispatch_matmul;
11use crate::tensor::Tensor;
12use crate::traits::Layer;
13use scirs2_core::ndarray::{Array2, Ix2, IxDyn};
14#[cfg(not(target_os = "macos"))]
15use scirs2_core::simd_ops::SimdUnifiedOps;
16
17/// Direct BLAS GEMM using OxiBLAS for maximum performance
18#[cfg(target_os = "macos")]
19#[inline]
20fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
21 use oxiblas_blas::level3::gemm;
22 use oxiblas_matrix::{MatMut, MatRef};
23
24 // Create matrix views from slices (row-major layout)
25 let a_mat = MatRef::new(a.as_ptr(), m, k, k);
26 let b_mat = MatRef::new(b.as_ptr(), k, n, n);
27 let c_mat = MatMut::new(c.as_mut_ptr(), m, n, n);
28
29 // GEMM: C = 1.0 * A * B + 0.0 * C
30 gemm(1.0, a_mat, b_mat, 0.0, c_mat);
31}
32
33/// Direct BLAS GEMM using OxiBLAS for f64
34#[cfg(target_os = "macos")]
35#[inline]
36fn blas_dgemm(a: &[f64], b: &[f64], c: &mut [f64], m: usize, k: usize, n: usize) {
37 use oxiblas_blas::level3::gemm;
38 use oxiblas_matrix::{MatMut, MatRef};
39
40 // Create matrix views from slices (row-major layout)
41 let a_mat = MatRef::new(a.as_ptr(), m, k, k);
42 let b_mat = MatRef::new(b.as_ptr(), k, n, n);
43 let c_mat = MatMut::new(c.as_mut_ptr(), m, n, n);
44
45 // GEMM: C = 1.0 * A * B + 0.0 * C
46 gemm(1.0, a_mat, b_mat, 0.0, c_mat);
47}
48
49/// Fallback for non-macOS: use scirs2-core SIMD GEMM for f32
50#[cfg(not(target_os = "macos"))]
51#[inline]
52fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
53 let a_arr = Array2::from_shape_vec((m, k), a.to_vec())
54 .expect("matrix dimensions must match slice length");
55 let b_arr = Array2::from_shape_vec((k, n), b.to_vec())
56 .expect("matrix dimensions must match slice length");
57 let mut c_arr = Array2::from_shape_vec((m, n), c.to_vec())
58 .expect("matrix dimensions must match slice length");
59 f32::simd_gemm(1.0, &a_arr.view(), &b_arr.view(), 0.0, &mut c_arr);
60 c.copy_from_slice(c_arr.as_slice().expect("Array2 must have contiguous slice"));
61}
62
63/// Fallback for non-macOS: use scirs2-core SIMD GEMM for f64
64#[cfg(not(target_os = "macos"))]
65#[inline]
66fn blas_dgemm(a: &[f64], b: &[f64], c: &mut [f64], m: usize, k: usize, n: usize) {
67 let a_arr = Array2::from_shape_vec((m, k), a.to_vec())
68 .expect("matrix dimensions must match slice length");
69 let b_arr = Array2::from_shape_vec((k, n), b.to_vec())
70 .expect("matrix dimensions must match slice length");
71 let mut c_arr = Array2::from_shape_vec((m, n), c.to_vec())
72 .expect("matrix dimensions must match slice length");
73 f64::simd_gemm(1.0, &a_arr.view(), &b_arr.view(), 0.0, &mut c_arr);
74 c.copy_from_slice(c_arr.as_slice().expect("Array2 must have contiguous slice"));
75}
76
77/// A linear transformation layer (fully connected layer).
78///
79/// The `Linear` layer applies a linear transformation to the incoming data:
80/// `y = xW^T + b`. This is one of the most fundamental building blocks in
81/// neural networks.
82///
83/// # Parameters
84///
85/// - `weight`: Learnable weight matrix of shape `[out_features, in_features]`
86/// - `bias`: Optional learnable bias vector of shape `[out_features]`
87///
88/// # Input/Output Shapes
89///
90/// - Input: `[..., in_features]` - Can be 2D or 3D
91/// - Output: `[..., out_features]` - Same number of dimensions as input
92///
93/// # Example
94///
95/// ```no_run
96/// use trustformers_core::layers::Linear;
97/// use trustformers_core::tensor::Tensor;
98/// use trustformers_core::traits::Layer;
99///
100/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
101/// // Create a linear layer: 768 → 3072
102/// let linear = Linear::new(768, 3072, true);
103///
104/// // Apply to 2D input: [seq_len, in_features]
105/// let input_2d = Tensor::randn(&[128, 768])?;
106/// let output_2d = linear.forward(input_2d)?; // Shape: [128, 3072]
107///
108/// // Apply to 3D input: [batch, seq_len, in_features]
109/// let input_3d = Tensor::randn(&[4, 128, 768])?;
110/// let output_3d = linear.forward(input_3d)?; // Shape: [4, 128, 3072]
111/// # Ok(())
112/// # }
113/// ```
114#[derive(Debug, Clone)]
115pub struct Linear {
116 weight: Tensor,
117 bias: Option<Tensor>,
118 device: Device,
119 #[cfg(all(target_os = "macos", feature = "metal"))]
120 weight_buffer_id: std::sync::Arc<std::sync::RwLock<Option<crate::gpu_ops::BufferId>>>,
121 #[cfg(feature = "cuda")]
122 weight_buffer_id_cuda:
123 std::sync::Arc<std::sync::RwLock<Option<crate::gpu_ops::cuda::BufferId>>>,
124}
125
126impl Linear {
127 /// Creates a new linear layer.
128 ///
129 /// # Arguments
130 ///
131 /// * `in_features` - Size of each input sample
132 /// * `out_features` - Size of each output sample
133 /// * `bias` - Whether to include a learnable bias
134 ///
135 /// # Returns
136 ///
137 /// A new `Linear` layer with randomly initialized weights using a normal
138 /// distribution, and bias initialized to zeros if enabled.
139 ///
140 /// # Example
141 ///
142 /// ```no_run
143 /// use trustformers_core::layers::Linear;
144 ///
145 /// // Linear layer without bias
146 /// let linear1 = Linear::new(512, 1024, false);
147 ///
148 /// // Linear layer with bias
149 /// let linear2 = Linear::new(512, 1024, true);
150 /// ```
151 pub fn new(in_features: usize, out_features: usize, bias: bool) -> Self {
152 Self::new_with_device(in_features, out_features, bias, Device::CPU)
153 }
154
155 /// Creates a new linear layer with specified device.
156 ///
157 /// # Arguments
158 ///
159 /// * `in_features` - Size of each input sample
160 /// * `out_features` - Size of each output sample
161 /// * `bias` - Whether to include a learnable bias
162 /// * `device` - Device to use for computations (CPU, Metal, CUDA, etc.)
163 ///
164 /// # Returns
165 ///
166 /// A new `Linear` layer with randomly initialized weights.
167 ///
168 /// # Example
169 ///
170 /// ```no_run
171 /// use trustformers_core::layers::Linear;
172 /// use trustformers_core::Device;
173 ///
174 /// // Create a linear layer on Metal GPU
175 /// let linear = Linear::new_with_device(768, 3072, true, Device::metal_if_available(0));
176 /// ```
177 pub fn new_with_device(
178 in_features: usize,
179 out_features: usize,
180 bias: bool,
181 device: Device,
182 ) -> Self {
183 let weight =
184 Tensor::randn(&[out_features, in_features]).expect("Failed to create random tensor");
185 let bias = if bias {
186 Some(Tensor::zeros(&[out_features]).expect("Failed to create zero tensor"))
187 } else {
188 None
189 };
190
191 Self {
192 weight,
193 bias,
194 device,
195 #[cfg(all(target_os = "macos", feature = "metal"))]
196 weight_buffer_id: std::sync::Arc::new(std::sync::RwLock::new(None)),
197 #[cfg(feature = "cuda")]
198 weight_buffer_id_cuda: std::sync::Arc::new(std::sync::RwLock::new(None)),
199 }
200 }
201
202 /// Returns the device this layer uses for computations.
203 pub fn device(&self) -> Device {
204 self.device
205 }
206
207 /// Move this layer to a different device.
208 ///
209 /// # Arguments
210 ///
211 /// * `device` - Target device
212 ///
213 /// # Returns
214 ///
215 /// Self with updated device.
216 pub fn to_device(mut self, device: Device) -> Self {
217 self.device = device;
218 // Clear cached buffer when changing device
219 #[cfg(all(target_os = "macos", feature = "metal"))]
220 {
221 if let Ok(mut buffer_id) = self.weight_buffer_id.write() {
222 *buffer_id = None;
223 }
224 }
225 #[cfg(feature = "cuda")]
226 {
227 if let Ok(mut buffer_id) = self.weight_buffer_id_cuda.write() {
228 *buffer_id = None;
229 }
230 }
231 self
232 }
233
234 /// Sets the weight matrix for this layer.
235 ///
236 /// # Arguments
237 ///
238 /// * `weight` - The new weight tensor, must have shape `[out_features, in_features]`
239 ///
240 /// # Returns
241 ///
242 /// `Ok(())` if successful.
243 ///
244 /// # Note
245 ///
246 /// This method is typically used when loading pretrained weights.
247 pub fn set_weight(&mut self, weight: Tensor) -> Result<()> {
248 self.weight = weight;
249 // Clear cached buffer when weights are updated
250 #[cfg(all(target_os = "macos", feature = "metal"))]
251 {
252 if let Ok(mut buffer_id) = self.weight_buffer_id.write() {
253 *buffer_id = None;
254 }
255 }
256 #[cfg(feature = "cuda")]
257 {
258 if let Ok(mut buffer_id) = self.weight_buffer_id_cuda.write() {
259 *buffer_id = None;
260 }
261 }
262 Ok(())
263 }
264
265 /// Sets the bias vector for this layer.
266 ///
267 /// # Arguments
268 ///
269 /// * `bias` - The new bias tensor, must have shape `[out_features]`
270 ///
271 /// # Returns
272 ///
273 /// `Ok(())` if successful.
274 ///
275 /// # Note
276 ///
277 /// This will enable bias even if the layer was created without bias.
278 pub fn set_bias(&mut self, bias: Tensor) -> Result<()> {
279 self.bias = Some(bias);
280 Ok(())
281 }
282
283 /// Returns a reference to the weight matrix.
284 ///
285 /// # Returns
286 ///
287 /// A reference to the weight tensor of shape `[out_features, in_features]`.
288 pub fn weight(&self) -> &Tensor {
289 &self.weight
290 }
291
292 /// Returns a reference to the bias vector if present.
293 ///
294 /// # Returns
295 ///
296 /// `Some(&bias)` if bias is enabled, `None` otherwise.
297 pub fn bias(&self) -> Option<&Tensor> {
298 self.bias.as_ref()
299 }
300
301 /// Returns the total number of learnable parameters in this layer.
302 ///
303 /// # Returns
304 ///
305 /// The total parameter count including weights and bias (if present).
306 pub fn parameter_count(&self) -> usize {
307 let weight_count = self.weight.len();
308 let bias_count = self.bias.as_ref().map_or(0, |b| b.len());
309 weight_count + bias_count
310 }
311
312 /// Initialize persistent GPU weight buffer for Metal device
313 /// This is called automatically on first forward pass with Metal device
314 #[cfg(all(target_os = "macos", feature = "metal"))]
315 fn ensure_weight_buffer_cached(&self) -> Result<()> {
316 use crate::gpu_ops::metal::get_metal_backend;
317
318 // Check if already cached (read lock is cheaper)
319 if let Ok(buffer_id) = self.weight_buffer_id.read() {
320 if buffer_id.is_some() {
321 return Ok(()); // Already cached
322 }
323 }
324
325 // Only cache if using Metal
326 if matches!(self.device, Device::Metal(_)) {
327 // Get write lock to cache the buffer
328 let mut buffer_id_guard = self.weight_buffer_id.write().map_err(|_| {
329 TrustformersError::hardware_error(
330 "Failed to acquire write lock on buffer cache",
331 "ensure_weight_buffer_cached",
332 )
333 })?;
334
335 // Double-check after acquiring write lock (another thread might have cached it)
336 if buffer_id_guard.is_some() {
337 return Ok(());
338 }
339
340 // Get weight data as f32 slice
341 // CRITICAL FIX: Cache the TRANSPOSED weight, not the original!
342 // The Metal shader expects weight in [in_features, out_features] layout
343 // but self.weight is stored as [out_features, in_features]
344 let weight_t = self.weight.transpose(0, 1)?;
345 match &weight_t {
346 Tensor::F32(arr) => {
347 if arr.ndim() != 2 {
348 return Err(TrustformersError::shape_error(
349 "Weight tensor must be 2D for Metal caching".to_string(),
350 ));
351 }
352
353 // Convert to contiguous vec for GPU upload
354 // Using as_standard_layout() ensures proper row-major order
355 let contiguous_arr = arr.as_standard_layout();
356 let weight_data: Vec<f32> = contiguous_arr.iter().copied().collect();
357
358 // Get Metal backend and cache the buffer
359 let backend = get_metal_backend()?;
360 let new_buffer_id = backend.create_persistent_buffer(&weight_data)?;
361 *buffer_id_guard = Some(new_buffer_id);
362 },
363 _ => {
364 return Err(TrustformersError::tensor_op_error(
365 "Only F32 tensors supported for Metal caching",
366 "ensure_weight_buffer_cached",
367 ));
368 },
369 }
370 }
371 Ok(())
372 }
373
374 /// Pre-cache layer weights on GPU for zero-transfer pipeline
375 /// This uploads weights to GPU memory in advance to avoid transfers during forward pass
376 #[cfg(all(target_os = "macos", feature = "metal"))]
377 pub fn weights_to_gpu(&mut self, device: &crate::device::Device) -> Result<()> {
378 use crate::device::Device;
379
380 if !matches!(device, Device::Metal(_)) {
381 return Ok(()); // Nothing to do for non-Metal devices
382 }
383
384 // Update device setting
385 self.device = *device;
386
387 // Pre-cache weight buffer on GPU (keeps weight as F32 CPU tensor)
388 // The caching mechanism handles the GPU upload internally
389 self.ensure_weight_buffer_cached()?;
390
391 // Upload bias to GPU if present (for GPU bias addition kernel)
392 if let Some(ref bias) = self.bias {
393 self.bias = Some(bias.to_device_enum(device)?);
394 }
395
396 Ok(())
397 }
398
399 /// Initialize persistent GPU weight buffer for CUDA device
400 /// This is called automatically on first forward pass with CUDA device
401 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
402 fn ensure_weight_buffer_cached_cuda(&self) -> Result<()> {
403 use crate::gpu_ops::cuda::get_cuda_backend;
404
405 // Check if already cached (read lock is cheaper)
406 if let Ok(buffer_id) = self.weight_buffer_id_cuda.read() {
407 if buffer_id.is_some() {
408 return Ok(()); // Already cached
409 }
410 }
411
412 // Only cache if using CUDA
413 if matches!(self.device, Device::CUDA(_)) {
414 // Get write lock to cache the buffer
415 let mut buffer_id_guard = self.weight_buffer_id_cuda.write().map_err(|_| {
416 TrustformersError::hardware_error(
417 "Failed to acquire write lock on CUDA buffer cache",
418 "ensure_weight_buffer_cached_cuda",
419 )
420 })?;
421
422 // Double-check after acquiring write lock (another thread might have cached it)
423 if buffer_id_guard.is_some() {
424 return Ok(());
425 }
426
427 // Get weight data as f32 slice
428 // CRITICAL FIX: Cache the TRANSPOSED weight, not the original!
429 // The CUDA kernel expects weight in [in_features, out_features] layout
430 // but self.weight is stored as [out_features, in_features]
431 let weight_t = self.weight.transpose(0, 1)?;
432 match &weight_t {
433 Tensor::F32(arr) => {
434 if arr.ndim() != 2 {
435 return Err(TrustformersError::shape_error(
436 "Weight tensor must be 2D for CUDA caching".to_string(),
437 ));
438 }
439
440 // Convert to contiguous vec for GPU upload
441 // Using as_standard_layout() ensures proper row-major order
442 let contiguous_arr = arr.as_standard_layout();
443 let weight_data: Vec<f32> = contiguous_arr.iter().copied().collect();
444
445 // Get CUDA backend and cache the buffer
446 let device_id = if let Device::CUDA(id) = self.device {
447 id
448 } else {
449 0 // Default to device 0
450 };
451 let backend = get_cuda_backend(device_id)?;
452 let new_buffer_id = backend.create_persistent_buffer(&weight_data)?;
453 *buffer_id_guard = Some(new_buffer_id);
454 },
455 _ => {
456 return Err(TrustformersError::tensor_op_error(
457 "Only F32 tensors supported for CUDA caching",
458 "ensure_weight_buffer_cached_cuda",
459 ));
460 },
461 }
462 }
463 Ok(())
464 }
465
466 /// Pre-cache layer weights on GPU for zero-transfer pipeline (CUDA)
467 /// This uploads weights to GPU memory in advance to avoid transfers during forward pass
468 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
469 pub fn weights_to_gpu_cuda(&mut self, device: &crate::device::Device) -> Result<()> {
470 use crate::device::Device;
471
472 if !matches!(device, Device::CUDA(_)) {
473 return Ok(()); // Nothing to do for non-CUDA devices
474 }
475
476 // Update device setting
477 self.device = *device;
478
479 // Pre-cache weight buffer on GPU (keeps weight as F32 CPU tensor)
480 // The caching mechanism handles the GPU upload internally
481 self.ensure_weight_buffer_cached_cuda()?;
482
483 // Upload bias to GPU if present (for GPU bias addition kernel)
484 if let Some(ref bias) = self.bias {
485 self.bias = Some(bias.to_device_enum(device)?);
486 }
487
488 Ok(())
489 }
490}
491
492impl Layer for Linear {
493 type Input = Tensor;
494 type Output = Tensor;
495
496 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
497 // =====================================================================
498 // GPU-TO-GPU PATH: Tensor::Metal (ZERO CPU TRANSFERS!)
499 // =====================================================================
500 #[cfg(all(target_os = "macos", feature = "metal"))]
501 if let Tensor::Metal(ref input_metal) = input {
502 use crate::gpu_ops::metal::get_metal_backend;
503 use crate::tensor::MetalTensorData;
504
505 // eprintln!("🎯 Linear::forward - GPU-to-GPU path triggered (Tensor::Metal input)");
506
507 // Ensure weight buffer is cached on GPU
508 self.ensure_weight_buffer_cached()?;
509
510 // Get cached weight buffer ID
511 let weight_buffer_id = {
512 let buffer_id_guard = self.weight_buffer_id.read().map_err(|_| {
513 TrustformersError::hardware_error(
514 "Failed to acquire read lock on buffer cache",
515 "Linear::forward",
516 )
517 })?;
518
519 if let Some(id) = *buffer_id_guard {
520 id
521 } else {
522 // Weight not cached - fallback to CPU
523 let cpu_input = input.to_device_enum(&crate::device::Device::CPU)?;
524 return self.forward(cpu_input);
525 }
526 };
527
528 // Get Metal backend
529 let backend = get_metal_backend()?;
530
531 // Extract input shape and calculate matmul dimensions
532 let shape = &input_metal.shape;
533 let weight_shape = self.weight.shape();
534 let in_features = shape[shape.len() - 1];
535
536 // Check shape compatibility
537 if in_features != weight_shape[1] {
538 return Err(TrustformersError::shape_error(format!(
539 "Linear layer input features {} doesn't match weight shape {:?}",
540 in_features, weight_shape
541 )));
542 }
543
544 // Flatten input to [batch, in_features] for matmul
545 // Works for both 2D [seq_len, in_features] and 3D [batch, seq_len, in_features]
546 let batch_dims: usize = shape[..shape.len() - 1].iter().product();
547 let m = batch_dims; // number of rows in output
548 let k = in_features; // shared dimension
549 let n = self.weight.shape()[0]; // out_features
550
551 // Perform GPU-to-GPU matmul using MPS (100-500x faster!)
552 // Try MPS first, fallback to naive kernel if MPS unavailable
553 let output_buffer_id = backend
554 .matmul_gpu_to_gpu_mps(&input_metal.buffer_id, &weight_buffer_id, m, k, n)
555 .or_else(|_e| {
556 // eprintln!(
557 // "⚠️ MPS matmul failed: {:?}, falling back to naive Metal kernel",
558 // e
559 // );
560 // Fallback to naive Metal kernel if MPS fails
561 backend.matmul_gpu_to_gpu(&input_metal.buffer_id, &weight_buffer_id, m, k, n)
562 })?;
563
564 // Calculate output shape (preserve batch dimensions, change last dim)
565 let mut output_shape = shape[..shape.len() - 1].to_vec();
566 output_shape.push(n);
567
568 // Create output Metal tensor
569 let mut output = Tensor::Metal(MetalTensorData {
570 buffer_id: output_buffer_id,
571 shape: output_shape.clone(),
572 dtype: input_metal.dtype,
573 });
574
575 // Handle bias if present
576 if let Some(ref bias) = self.bias {
577 // eprintln!("🔍 Linear: Has bias, checking type...");
578 // Try GPU-to-GPU bias addition if bias is on GPU
579 match bias {
580 #[cfg(all(target_os = "macos", feature = "metal"))]
581 Tensor::Metal(bias_data) => {
582 // eprintln!("🔍 Linear: Bias is Metal, using GPU-to-GPU bias addition");
583 // Both output and bias are Metal tensors - use GPU kernel!
584 if let Tensor::Metal(output_data) = &output {
585 // eprintln!("🔍 Linear: Output is Metal, calling add_bias_gpu_to_gpu");
586 let output_buffer_id = backend.add_bias_gpu_to_gpu(
587 &output_data.buffer_id,
588 &bias_data.buffer_id,
589 batch_dims,
590 n,
591 )?;
592 // eprintln!(
593 // "🔍 Linear: add_bias_gpu_to_gpu succeeded, returning Metal tensor"
594 // );
595
596 return Ok(Tensor::Metal(MetalTensorData {
597 buffer_id: output_buffer_id,
598 shape: output_shape.clone(),
599 dtype: output_data.dtype,
600 }));
601 }
602 // eprintln!("🔍 Linear: Output is NOT Metal, falling back to CPU");
603 },
604 _ => {
605 // eprintln!(
606 // "🔍 Linear: Bias is NOT Metal (type={:?}), falling back to CPU",
607 // std::mem::discriminant(bias)
608 // );
609 },
610 }
611
612 // Fallback: CPU bias addition
613 // eprintln!("🔍 Linear: Using CPU bias fallback");
614 output = output.to_device_enum(&crate::device::Device::CPU)?;
615 // eprintln!("🔍 Linear: Converted output to CPU");
616 output = output.add(bias)?;
617 // eprintln!("🔍 Linear: Added bias on CPU");
618
619 // Convert back to Metal tensor if needed
620 if matches!(self.device, crate::device::Device::Metal(_)) {
621 // eprintln!("🔍 Linear: Converting back to Metal device");
622 output = output.to_device_enum(&self.device)?;
623 // eprintln!(
624 // "🔍 Linear: Converted back to Metal, type={:?}",
625 // std::mem::discriminant(&output)
626 // );
627 }
628 } else {
629 // eprintln!("🔍 Linear: No bias");
630 }
631
632 // eprintln!(
633 // "🔍 Linear: Returning output, type={:?}",
634 // std::mem::discriminant(&output)
635 // );
636 return Ok(output);
637 }
638
639 // =====================================================================
640 // GPU-TO-GPU PATH: Tensor::CUDA (ZERO CPU TRANSFERS!)
641 // =====================================================================
642 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
643 if let Tensor::CUDA(ref input_cuda) = input {
644 use crate::gpu_ops::cuda::get_cuda_backend;
645 use crate::tensor::CudaTensorData;
646
647 // Ensure weight buffer is cached on GPU
648 self.ensure_weight_buffer_cached_cuda()?;
649
650 // Get cached weight buffer ID
651 let weight_buffer_id = {
652 let buffer_id_guard = self.weight_buffer_id_cuda.read().map_err(|_| {
653 TrustformersError::hardware_error(
654 "Failed to acquire read lock on CUDA buffer cache",
655 "Linear::forward",
656 )
657 })?;
658
659 if let Some(id) = *buffer_id_guard {
660 id
661 } else {
662 // Weight not cached - fallback to CPU
663 let cpu_input = input.to_device_enum(&crate::device::Device::CPU)?;
664 return self.forward(cpu_input);
665 }
666 };
667
668 // Get CUDA backend
669 let device_id = if let Device::CUDA(id) = self.device {
670 id
671 } else {
672 0 // Default to device 0
673 };
674 let backend = get_cuda_backend(device_id)?;
675
676 // Extract input shape and calculate matmul dimensions
677 let shape = &input_cuda.shape;
678 let weight_shape = self.weight.shape();
679 let in_features = shape[shape.len() - 1];
680
681 // Check shape compatibility
682 if in_features != weight_shape[1] {
683 return Err(TrustformersError::shape_error(format!(
684 "Linear layer input features {} doesn't match weight shape {:?}",
685 in_features, weight_shape
686 )));
687 }
688
689 // Flatten input to [batch, in_features] for matmul
690 // Works for both 2D [seq_len, in_features] and 3D [batch, seq_len, in_features]
691 let batch_dims: usize = shape[..shape.len() - 1].iter().product();
692 let m = batch_dims; // number of rows in output
693 let k = in_features; // shared dimension
694 let n = self.weight.shape()[0]; // out_features
695
696 // Perform GPU-to-GPU matmul (ZERO CPU TRANSFERS!)
697 let output_buffer_id =
698 backend.matmul_gpu_to_gpu(&input_cuda.buffer_id, &weight_buffer_id, m, k, n)?;
699
700 // Calculate output shape (preserve batch dimensions, change last dim)
701 let mut output_shape = shape[..shape.len() - 1].to_vec();
702 output_shape.push(n);
703
704 // Create output CUDA tensor
705 let mut output = Tensor::CUDA(CudaTensorData {
706 buffer_id: output_buffer_id,
707 shape: output_shape.clone(),
708 dtype: input_cuda.dtype,
709 });
710
711 // Handle bias if present
712 if let Some(ref bias) = self.bias {
713 // Try GPU-to-GPU bias addition if bias is on GPU
714 match bias {
715 #[cfg(feature = "cuda")]
716 Tensor::CUDA(bias_data) => {
717 // Both output and bias are CUDA tensors - use GPU kernel!
718 if let Tensor::CUDA(output_data) = &output {
719 let output_buffer_id = backend.add_bias_gpu_to_gpu(
720 &output_data.buffer_id,
721 &bias_data.buffer_id,
722 batch_dims,
723 n,
724 )?;
725
726 return Ok(Tensor::CUDA(CudaTensorData {
727 buffer_id: output_buffer_id,
728 shape: output_shape.clone(),
729 dtype: output_data.dtype,
730 }));
731 }
732 },
733 _ => {},
734 }
735
736 // Fallback: CPU bias addition
737 output = output.to_device_enum(&crate::device::Device::CPU)?;
738 output = output.add(bias)?;
739
740 // Convert back to CUDA tensor if needed
741 if matches!(self.device, crate::device::Device::CUDA(_)) {
742 output = output.to_device_enum(&self.device)?;
743 }
744 }
745
746 return Ok(output);
747 }
748
749 // =====================================================================
750 // CPU/F32 PATH (existing implementation)
751 // =====================================================================
752 // Handle different input shapes for matmul
753 let input_shape = input.shape();
754 let weight_t = self.weight.transpose(0, 1)?;
755
756 let output = if input_shape.len() == 2 {
757 // Standard 2D input: [seq_len, hidden_size] x [hidden_size, out_features]
758
759 // Try to use cached Metal buffer if available (ZERO-COPY OPTIMIZATION)
760 #[cfg(all(target_os = "macos", feature = "metal"))]
761 if matches!(self.device, Device::Metal(_)) {
762 // Ensure buffer is cached
763 self.ensure_weight_buffer_cached()?;
764
765 // Try to use cached buffer
766 if let Ok(buffer_id_guard) = self.weight_buffer_id.read() {
767 if let Some(buffer_id) = *buffer_id_guard {
768 // We have a cached buffer! Use it for ZERO-COPY matmul
769 use crate::gpu_ops::metal::get_metal_backend;
770
771 if let (Tensor::F32(inp), Tensor::F32(w_t)) = (&input, &weight_t) {
772 if inp.ndim() == 2 && w_t.ndim() == 2 {
773 let inp_shape = inp.shape();
774 let w_shape = w_t.shape();
775 let m = inp_shape[0];
776 let k = inp_shape[1];
777 let k2 = w_shape[0];
778 let n = w_shape[1];
779
780 if k == k2 {
781 // Get Metal backend
782 if let Ok(backend) = get_metal_backend() {
783 // Convert input to contiguous
784 let input_data: Vec<f32> = inp.iter().copied().collect();
785
786 // Call matmul with CACHED weight buffer (no weight transfer!)
787 if let Ok(result) = backend.matmul_with_cached_weight(
788 &input_data,
789 &buffer_id,
790 m,
791 k,
792 n,
793 ) {
794 let result_arr =
795 scirs2_core::ndarray::Array2::from_shape_vec(
796 (m, n),
797 result,
798 )
799 .map_err(|e| {
800 TrustformersError::shape_error(format!(
801 "Result reshape failed: {}",
802 e
803 ))
804 })?;
805
806 // Add bias if present
807 let mut output = Tensor::F32(result_arr.into_dyn());
808 if let Some(ref bias) = self.bias {
809 output = output.add(bias)?;
810 }
811 return Ok(output);
812 }
813 }
814 }
815 }
816 }
817 }
818 }
819 }
820
821 // Fallback: Use standard dispatch (for non-Metal or if cached path failed)
822 #[cfg(all(target_os = "macos", feature = "metal"))]
823 {
824 if self.device.is_gpu() {
825 dispatch_matmul(&input, &weight_t, &self.device)?
826 } else {
827 input.matmul(&weight_t)?
828 }
829 }
830 #[cfg(not(all(target_os = "macos", feature = "metal")))]
831 {
832 input.matmul(&weight_t)?
833 }
834 } else if input_shape.len() == 3 {
835 // Batched 3D input: [batch, seq_len, hidden_size] x [hidden_size, out_features]
836 // Handle manually since tensor.matmul doesn't support 3D x 2D
837 match (&input, &weight_t) {
838 (Tensor::F32(inp), Tensor::F32(w)) => {
839 let batch = input_shape[0];
840 let seq_len = input_shape[1];
841 let hidden = input_shape[2];
842 let out_features = w.shape()[1];
843
844 // Try to use cached Metal buffer for 3D case (CRITICAL OPTIMIZATION)
845 #[cfg(all(target_os = "macos", feature = "metal"))]
846 if matches!(self.device, Device::Metal(_)) {
847 // Ensure buffer is cached
848 self.ensure_weight_buffer_cached()?;
849
850 if let Ok(buffer_id_guard) = self.weight_buffer_id.read() {
851 if let Some(buffer_id) = *buffer_id_guard {
852 // Reshape 3D input to 2D for matmul
853 let m = batch * seq_len;
854 let k = hidden;
855 let n = out_features;
856
857 // Get Metal backend
858 use crate::gpu_ops::metal::get_metal_backend;
859 if let Ok(backend) = get_metal_backend() {
860 // Convert input to contiguous 2D data
861 let input_data: Vec<f32> = inp.iter().copied().collect();
862
863 // Call matmul with CACHED weight buffer (no weight transfer!)
864 if let Ok(result) = backend.matmul_with_cached_weight(
865 &input_data,
866 &buffer_id,
867 m,
868 k,
869 n,
870 ) {
871 // Reshape result from 2D [m, n] back to 3D [batch, seq_len, out_features]
872 let result_arr =
873 scirs2_core::ndarray::Array2::from_shape_vec(
874 (m, n),
875 result,
876 )
877 .map_err(
878 |e| {
879 TrustformersError::shape_error(format!(
880 "Result reshape failed: {}",
881 e
882 ))
883 },
884 )?;
885
886 let result_3d = result_arr
887 .into_shape_with_order(IxDyn(&[
888 batch,
889 seq_len,
890 out_features,
891 ]))
892 .map_err(|e| {
893 TrustformersError::shape_error(format!(
894 "3D reshape failed: {}",
895 e
896 ))
897 })?;
898
899 // Add bias if present
900 let mut output = Tensor::F32(result_3d);
901 if let Some(ref bias) = self.bias {
902 output = output.add(bias)?;
903 }
904 return Ok(output);
905 }
906 }
907 }
908 }
909 }
910
911 // Fallback: CPU path with ndarray
912
913 // Ensure contiguous layout before reshaping input to 2D for dot product
914 let inp_contiguous = inp.to_owned();
915 let inp_2d = inp_contiguous
916 .into_shape_with_order([batch * seq_len, hidden])
917 .map_err(|e| {
918 TrustformersError::shape_error(format!(
919 "Failed to reshape input: {}",
920 e
921 ))
922 })?;
923
924 // Ensure contiguous layout for weight and convert to 2D for GEMM
925 let w_contiguous = w.to_owned();
926 let w_2d = w_contiguous.into_dimensionality::<Ix2>().map_err(|e| {
927 TrustformersError::shape_error(format!(
928 "Failed to convert weight to 2D: {}",
929 e
930 ))
931 })?;
932
933 // Use direct BLAS for maximum performance (Accelerate on macOS)
934 let m = inp_2d.nrows();
935 let n = w_2d.ncols();
936 let k = inp_2d.ncols();
937 const MIN_SIZE_FOR_BLAS: usize = 32;
938 let out_2d = if m < MIN_SIZE_FOR_BLAS
939 || n < MIN_SIZE_FOR_BLAS
940 || k < MIN_SIZE_FOR_BLAS
941 {
942 inp_2d.dot(&w_2d)
943 } else {
944 // Use direct BLAS GEMM for 10-50x speedup
945 let inp_slice = inp_2d.as_slice().unwrap_or(&[]);
946 let w_slice = w_2d.as_slice().unwrap_or(&[]);
947 if !inp_slice.is_empty() && !w_slice.is_empty() {
948 let mut result_vec = vec![0.0f32; m * n];
949 blas_sgemm(inp_slice, w_slice, &mut result_vec, m, k, n);
950 Array2::from_shape_vec((m, n), result_vec)
951 .expect("BLAS result shape must match m x n")
952 } else {
953 // Fallback to ndarray dot if slices aren't contiguous
954 inp_2d.dot(&w_2d)
955 }
956 };
957
958 // Reshape back to 3D
959 let out_3d = out_2d
960 .into_shape_with_order(IxDyn(&[batch, seq_len, out_features]))
961 .map_err(|e| {
962 TrustformersError::shape_error(format!(
963 "Failed to reshape output: {}",
964 e
965 ))
966 })?;
967
968 Tensor::F32(out_3d)
969 },
970 (Tensor::F64(inp), Tensor::F64(w)) => {
971 let batch = input_shape[0];
972 let seq_len = input_shape[1];
973 let hidden = input_shape[2];
974 let out_features = w.shape()[1];
975
976 // Ensure contiguous layout before reshaping
977 let inp_contiguous = inp.to_owned();
978 let inp_2d = inp_contiguous
979 .into_shape_with_order([batch * seq_len, hidden])
980 .map_err(|e| {
981 TrustformersError::shape_error(format!(
982 "Failed to reshape input: {}",
983 e
984 ))
985 })?;
986
987 // Ensure contiguous layout for weight and convert to 2D for GEMM
988 let w_contiguous = w.to_owned();
989 let w_2d = w_contiguous.into_dimensionality::<Ix2>().map_err(|e| {
990 TrustformersError::shape_error(format!(
991 "Failed to convert weight to 2D: {}",
992 e
993 ))
994 })?;
995
996 // Use direct BLAS for maximum performance (Accelerate on macOS)
997 let m = inp_2d.nrows();
998 let n = w_2d.ncols();
999 let k = inp_2d.ncols();
1000 const MIN_SIZE_FOR_BLAS: usize = 32;
1001 let out_2d = if m < MIN_SIZE_FOR_BLAS
1002 || n < MIN_SIZE_FOR_BLAS
1003 || k < MIN_SIZE_FOR_BLAS
1004 {
1005 inp_2d.dot(&w_2d)
1006 } else {
1007 // Use direct BLAS GEMM for 10-50x speedup
1008 let inp_slice = inp_2d.as_slice().unwrap_or(&[]);
1009 let w_slice = w_2d.as_slice().unwrap_or(&[]);
1010 if !inp_slice.is_empty() && !w_slice.is_empty() {
1011 let mut result_vec = vec![0.0f64; m * n];
1012 blas_dgemm(inp_slice, w_slice, &mut result_vec, m, k, n);
1013 Array2::from_shape_vec((m, n), result_vec)
1014 .expect("BLAS result shape must match m x n")
1015 } else {
1016 // Fallback to ndarray dot if slices aren't contiguous
1017 inp_2d.dot(&w_2d)
1018 }
1019 };
1020
1021 let out_3d = out_2d
1022 .into_shape_with_order(IxDyn(&[batch, seq_len, out_features]))
1023 .map_err(|e| {
1024 TrustformersError::shape_error(format!(
1025 "Failed to reshape output: {}",
1026 e
1027 ))
1028 })?;
1029
1030 Tensor::F64(out_3d)
1031 },
1032 _ => {
1033 return Err(TrustformersError::tensor_op_error(
1034 "Unsupported tensor types for 3D linear layer",
1035 "Linear::forward",
1036 ))
1037 },
1038 }
1039 } else {
1040 return Err(TrustformersError::tensor_op_error(
1041 &format!(
1042 "Linear layer doesn't support input with {} dimensions",
1043 input_shape.len()
1044 ),
1045 "Linear::forward",
1046 ));
1047 };
1048
1049 if let Some(ref bias) = self.bias {
1050 // Handle broadcasting for bias addition
1051 match (&output, bias) {
1052 (Tensor::F32(out_arr), Tensor::F32(bias_arr)) => {
1053 // Broadcast bias to match output shape
1054 let result = out_arr + bias_arr;
1055 Ok(Tensor::F32(result))
1056 },
1057 _ => output.add(bias),
1058 }
1059 } else {
1060 Ok(output)
1061 }
1062 }
1063}