Skip to main content

tenflowers_core/
layout.rs

1use crate::device::Device;
2/// Data layout optimization for tensors
3///
4/// This module provides utilities for handling different data layouts
5/// to optimize performance on different hardware (CPU vs GPU).
6use crate::{Result, Tensor, TensorError};
7use std::collections::HashMap;
8
9/// Supported data layouts for tensors
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum DataLayout {
12    /// Channels first: [N, C, H, W] - Optimal for GPU/CUDA
13    NCHW,
14    /// Channels last: [N, H, W, C] - Optimal for CPU/NEON
15    NHWC,
16    /// Channels first for 3D: [N, C, D, H, W]
17    NCDHW,
18    /// Channels last for 3D: [N, D, H, W, C]
19    NDHWC,
20    /// Channels first for 1D: [N, C, L]
21    NCL,
22    /// Channels last for 1D: [N, L, C]
23    NLC,
24    /// Auto-detect optimal layout based on device and operation
25    Auto,
26}
27
28impl DataLayout {
29    /// Get the number of dimensions for this layout
30    pub fn ndim(&self) -> usize {
31        match self {
32            DataLayout::NCL | DataLayout::NLC => 3,
33            DataLayout::NCHW | DataLayout::NHWC => 4,
34            DataLayout::NCDHW | DataLayout::NDHWC => 5,
35            DataLayout::Auto => 0, // Variable
36        }
37    }
38
39    /// Get channel axis for this layout
40    pub fn channel_axis(&self) -> usize {
41        match self {
42            DataLayout::NCHW | DataLayout::NCDHW | DataLayout::NCL => 1,
43            DataLayout::NHWC => 3,
44            DataLayout::NDHWC | DataLayout::NLC => 4,
45            DataLayout::Auto => panic!("Cannot get channel axis for Auto layout"),
46        }
47    }
48
49    /// Check if this is a channels-first layout
50    pub fn is_channels_first(&self) -> bool {
51        matches!(self, DataLayout::NCHW | DataLayout::NCDHW | DataLayout::NCL)
52    }
53
54    /// Get the permutation indices to convert from this layout to target layout
55    pub fn to_permutation(&self, target: DataLayout) -> Option<Vec<usize>> {
56        match (self, target) {
57            (DataLayout::NCHW, DataLayout::NHWC) => Some(vec![0, 2, 3, 1]), // [N,C,H,W] -> [N,H,W,C]
58            (DataLayout::NHWC, DataLayout::NCHW) => Some(vec![0, 3, 1, 2]), // [N,H,W,C] -> [N,C,H,W]
59            (DataLayout::NCDHW, DataLayout::NDHWC) => Some(vec![0, 2, 3, 4, 1]), // [N,C,D,H,W] -> [N,D,H,W,C]
60            (DataLayout::NDHWC, DataLayout::NCDHW) => Some(vec![0, 4, 1, 2, 3]), // [N,D,H,W,C] -> [N,C,D,H,W]
61            (DataLayout::NCL, DataLayout::NLC) => Some(vec![0, 2, 1]), // [N,C,L] -> [N,L,C]
62            (DataLayout::NLC, DataLayout::NCL) => Some(vec![0, 2, 1]), // [N,L,C] -> [N,C,L]
63            _ if self == &target => None,                              // No conversion needed
64            _ => None,                                                 // Unsupported conversion
65        }
66    }
67}
68
69/// Layout optimizer that chooses optimal data layouts based on device and operation type
70pub struct LayoutOptimizer {
71    /// Preferred layouts for different device types and operations
72    layout_preferences: HashMap<(Device, OperationType), DataLayout>,
73    /// Performance hints for layout conversions
74    conversion_costs: HashMap<(DataLayout, DataLayout), f32>,
75}
76
77/// Types of operations that may benefit from specific layouts
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum OperationType {
80    Convolution,
81    FullyConnected,
82    Pooling,
83    Normalization,
84    Activation,
85    ElementWise,
86    Reduction,
87}
88
89impl Default for LayoutOptimizer {
90    fn default() -> Self {
91        let mut layout_preferences = HashMap::new();
92        let mut conversion_costs = HashMap::new();
93
94        // GPU preferences (CUDA-style)
95        #[cfg(feature = "gpu")]
96        {
97            layout_preferences.insert(
98                (Device::Gpu(0), OperationType::Convolution),
99                DataLayout::NCHW,
100            );
101            layout_preferences.insert(
102                (Device::Gpu(0), OperationType::FullyConnected),
103                DataLayout::NCHW,
104            );
105            layout_preferences.insert((Device::Gpu(0), OperationType::Pooling), DataLayout::NCHW);
106            layout_preferences.insert(
107                (Device::Gpu(0), OperationType::Normalization),
108                DataLayout::NCHW,
109            );
110            layout_preferences.insert(
111                (Device::Gpu(0), OperationType::Activation),
112                DataLayout::NCHW,
113            );
114            layout_preferences.insert(
115                (Device::Gpu(0), OperationType::ElementWise),
116                DataLayout::NCHW,
117            );
118        }
119
120        // CPU preferences
121        layout_preferences.insert((Device::Cpu, OperationType::Convolution), DataLayout::NHWC);
122        layout_preferences.insert(
123            (Device::Cpu, OperationType::FullyConnected),
124            DataLayout::NHWC,
125        );
126        layout_preferences.insert((Device::Cpu, OperationType::Pooling), DataLayout::NHWC);
127        layout_preferences.insert(
128            (Device::Cpu, OperationType::Normalization),
129            DataLayout::NHWC,
130        );
131        layout_preferences.insert((Device::Cpu, OperationType::Activation), DataLayout::NHWC);
132        layout_preferences.insert((Device::Cpu, OperationType::ElementWise), DataLayout::NHWC);
133
134        // Conversion costs (relative units)
135        conversion_costs.insert((DataLayout::NCHW, DataLayout::NHWC), 1.0);
136        conversion_costs.insert((DataLayout::NHWC, DataLayout::NCHW), 1.0);
137        conversion_costs.insert((DataLayout::NCDHW, DataLayout::NDHWC), 1.5);
138        conversion_costs.insert((DataLayout::NDHWC, DataLayout::NCDHW), 1.5);
139        conversion_costs.insert((DataLayout::NCL, DataLayout::NLC), 0.5);
140        conversion_costs.insert((DataLayout::NLC, DataLayout::NCL), 0.5);
141
142        LayoutOptimizer {
143            layout_preferences,
144            conversion_costs,
145        }
146    }
147}
148
149impl LayoutOptimizer {
150    /// Get the preferred layout for a given device and operation type
151    pub fn preferred_layout(&self, device: &Device, op_type: OperationType) -> DataLayout {
152        self.layout_preferences
153            .get(&(*device, op_type))
154            .copied()
155            .unwrap_or(DataLayout::NCHW) // Default fallback
156    }
157
158    /// Get the cost of converting between two layouts
159    pub fn conversion_cost(&self, from: DataLayout, to: DataLayout) -> f32 {
160        if from == to {
161            return 0.0;
162        }
163        self.conversion_costs
164            .get(&(from, to))
165            .copied()
166            .unwrap_or(2.0) // Default high cost for unsupported conversions
167    }
168
169    /// Determine if a layout conversion is beneficial
170    pub fn should_convert(&self, from: DataLayout, to: DataLayout, operation_benefit: f32) -> bool {
171        let cost = self.conversion_cost(from, to);
172        operation_benefit > cost
173    }
174
175    /// Auto-select the best layout for a tensor given the target device and operation
176    pub fn auto_layout(
177        &self,
178        current_layout: DataLayout,
179        target_device: &Device,
180        op_type: OperationType,
181        operation_intensity: f32,
182    ) -> DataLayout {
183        let preferred = self.preferred_layout(target_device, op_type);
184
185        if self.should_convert(current_layout, preferred, operation_intensity) {
186            preferred
187        } else {
188            current_layout
189        }
190    }
191}
192
193/// Permute tensor dimensions according to given axes order
194fn permute_tensor<T>(input: &Tensor<T>, axes: &[usize]) -> Result<Tensor<T>>
195where
196    T: Clone
197        + Default
198        + scirs2_core::num_traits::Zero
199        + scirs2_core::num_traits::One
200        + Send
201        + Sync
202        + 'static
203        + bytemuck::Pod
204        + bytemuck::Zeroable,
205{
206    use crate::tensor::TensorStorage;
207
208    match &input.storage {
209        TensorStorage::Cpu(arr) => {
210            let permuted = arr.clone().permuted_axes(axes);
211
212            // Convert the permuted array to a vec and create new tensor
213            let new_shape: Vec<usize> = {
214                let old_shape = input.shape().dims();
215                axes.iter().map(|&i| old_shape[i]).collect()
216            };
217
218            let vec_data: Vec<T> = permuted.iter().cloned().collect();
219            Tensor::from_vec(vec_data, &new_shape)
220        }
221        #[cfg(feature = "gpu")]
222        TensorStorage::Gpu(gpu_buffer) => {
223            // GPU tensor permutation using compute shader
224            gpu_permute_tensor(gpu_buffer, input.shape().dims(), axes)
225        }
226    }
227}
228
229/// GPU tensor permutation using compute shader
230#[cfg(feature = "gpu")]
231fn gpu_permute_tensor<T>(
232    gpu_buffer: &crate::gpu::buffer::GpuBuffer<T>,
233    input_shape: &[usize],
234    axes: &[usize],
235) -> Result<Tensor<T>>
236where
237    T: Clone
238        + Default
239        + scirs2_core::num_traits::Zero
240        + scirs2_core::num_traits::One
241        + Send
242        + Sync
243        + 'static
244        + bytemuck::Pod
245        + bytemuck::Zeroable,
246{
247    use crate::gpu::ops::execute_tensor_permutation;
248
249    // Calculate output shape
250    let output_shape: Vec<usize> = axes.iter().map(|&i| input_shape[i]).collect();
251    let output_len = output_shape.iter().product();
252
253    // Execute GPU permutation
254    let result_buffer = execute_tensor_permutation(gpu_buffer, axes, input_shape, output_len)?;
255
256    Ok(Tensor::from_gpu_buffer(
257        result_buffer,
258        crate::Shape::new(output_shape),
259    ))
260}
261
262/// Convert tensor between different data layouts
263pub fn convert_layout<T>(
264    input: &Tensor<T>,
265    from_layout: DataLayout,
266    to_layout: DataLayout,
267) -> Result<Tensor<T>>
268where
269    T: Clone
270        + Default
271        + scirs2_core::num_traits::Zero
272        + scirs2_core::num_traits::One
273        + Send
274        + Sync
275        + 'static
276        + bytemuck::Pod
277        + bytemuck::Zeroable,
278{
279    if from_layout == to_layout {
280        return Ok(input.clone());
281    }
282
283    if let Some(perm) = from_layout.to_permutation(to_layout) {
284        permute_tensor(input, &perm)
285    } else {
286        Err(TensorError::unsupported_operation_simple(format!(
287            "Layout conversion from {from_layout:?} to {to_layout:?} not supported"
288        )))
289    }
290}
291
292/// Infer the likely data layout from tensor shape and context
293pub fn infer_layout(shape: &[usize], ndim_hint: Option<usize>) -> DataLayout {
294    let ndim = ndim_hint.unwrap_or(shape.len());
295
296    match ndim {
297        3 => {
298            // For 3D tensors, assume NCL if channel dimension is small
299            if shape.len() >= 3 && shape[1] <= 512 && shape[1] < shape[2] {
300                DataLayout::NCL
301            } else {
302                DataLayout::NLC
303            }
304        }
305        4 => {
306            // For 4D tensors, assume NCHW if channel dimension is small
307            if shape.len() >= 4 && shape[1] <= 2048 && shape[1] < shape[2] && shape[1] < shape[3] {
308                DataLayout::NCHW
309            } else {
310                DataLayout::NHWC
311            }
312        }
313        5 => {
314            // For 5D tensors, assume NCDHW if channel dimension is small
315            if shape.len() >= 5 && shape[1] <= 2048 && shape[1] < shape[2] {
316                DataLayout::NCDHW
317            } else {
318                DataLayout::NDHWC
319            }
320        }
321        _ => DataLayout::Auto,
322    }
323}
324
325/// Smart layout converter that minimizes conversions in a computation graph
326pub struct LayoutPlan {
327    conversions: Vec<(usize, DataLayout, DataLayout)>, // (tensor_id, from, to)
328    optimal_layouts: HashMap<usize, DataLayout>,
329}
330
331impl LayoutPlan {
332    /// Create an optimal layout plan for a sequence of operations
333    pub fn optimize(
334        tensor_layouts: &[(usize, DataLayout)],
335        operations: &[(OperationType, Vec<usize>, Device)], // (op_type, input_tensor_ids, device)
336        optimizer: &LayoutOptimizer,
337    ) -> Self {
338        let mut optimal_layouts = HashMap::new();
339        let mut conversions = Vec::new();
340
341        // Initialize with current layouts
342        for &(tensor_id, layout) in tensor_layouts {
343            optimal_layouts.insert(tensor_id, layout);
344        }
345
346        // Process operations and determine optimal layouts
347        for (op_type, input_ids, device) in operations {
348            for &tensor_id in input_ids {
349                if let Some(&current_layout) = optimal_layouts.get(&tensor_id) {
350                    let preferred = optimizer.preferred_layout(device, *op_type);
351
352                    // Simple heuristic: convert if operation intensity is high
353                    let operation_intensity = match op_type {
354                        OperationType::Convolution => 3.0,
355                        OperationType::FullyConnected => 2.0,
356                        OperationType::Pooling => 1.5,
357                        _ => 1.0,
358                    };
359
360                    if optimizer.should_convert(current_layout, preferred, operation_intensity) {
361                        conversions.push((tensor_id, current_layout, preferred));
362                        optimal_layouts.insert(tensor_id, preferred);
363                    }
364                }
365            }
366        }
367
368        LayoutPlan {
369            conversions,
370            optimal_layouts,
371        }
372    }
373
374    /// Get the planned conversions
375    pub fn conversions(&self) -> &[(usize, DataLayout, DataLayout)] {
376        &self.conversions
377    }
378
379    /// Get the optimal layout for a tensor
380    pub fn optimal_layout(&self, tensor_id: usize) -> Option<DataLayout> {
381        self.optimal_layouts.get(&tensor_id).copied()
382    }
383}
384
385/// Automatic layout optimization pass for computation graphs
386pub struct AutoLayoutOptimizer {
387    optimizer: LayoutOptimizer,
388    /// Track tensor layouts throughout the computation
389    tensor_layouts: HashMap<usize, DataLayout>,
390    /// Track conversion costs
391    total_conversion_cost: f32,
392}
393
394impl AutoLayoutOptimizer {
395    /// Create a new automatic layout optimizer
396    pub fn new() -> Self {
397        Self {
398            optimizer: LayoutOptimizer::default(),
399            tensor_layouts: HashMap::new(),
400            total_conversion_cost: 0.0,
401        }
402    }
403
404    /// Register a tensor with its initial layout
405    pub fn register_tensor(&mut self, tensor_id: usize, layout: DataLayout) {
406        self.tensor_layouts.insert(tensor_id, layout);
407    }
408
409    /// Optimize layout for a specific operation
410    pub fn optimize_for_operation<T>(
411        &mut self,
412        tensors: &mut [&mut Tensor<T>],
413        tensor_ids: &[usize],
414        op_type: OperationType,
415        device: &Device,
416    ) -> Result<()>
417    where
418        T: Clone
419            + Default
420            + scirs2_core::num_traits::Zero
421            + scirs2_core::num_traits::One
422            + Send
423            + Sync
424            + 'static
425            + bytemuck::Pod
426            + bytemuck::Zeroable,
427    {
428        let preferred_layout = self.optimizer.preferred_layout(device, op_type);
429
430        // Determine operation intensity for cost-benefit analysis
431        let operation_intensity = match op_type {
432            OperationType::Convolution => 3.0,
433            OperationType::FullyConnected => 2.0,
434            OperationType::Pooling => 1.5,
435            OperationType::Normalization => 1.2,
436            OperationType::Activation => 0.8,
437            OperationType::ElementWise => 0.5,
438            OperationType::Reduction => 1.0,
439        };
440
441        // Check if conversion is beneficial for each tensor
442        for (tensor, &tensor_id) in tensors.iter_mut().zip(tensor_ids.iter()) {
443            if let Some(&current_layout) = self.tensor_layouts.get(&tensor_id) {
444                if current_layout != preferred_layout {
445                    let conversion_cost = self
446                        .optimizer
447                        .conversion_cost(current_layout, preferred_layout);
448
449                    if operation_intensity > conversion_cost {
450                        // Convert the tensor
451                        let converted = convert_layout(tensor, current_layout, preferred_layout)?;
452                        **tensor = converted;
453
454                        // Update tracking
455                        self.tensor_layouts.insert(tensor_id, preferred_layout);
456                        self.total_conversion_cost += conversion_cost;
457                    }
458                }
459            }
460        }
461
462        Ok(())
463    }
464
465    /// Get the current layout of a tensor
466    pub fn get_layout(&self, tensor_id: usize) -> Option<DataLayout> {
467        self.tensor_layouts.get(&tensor_id).copied()
468    }
469
470    /// Get the total conversion cost incurred
471    pub fn total_cost(&self) -> f32 {
472        self.total_conversion_cost
473    }
474
475    /// Reset the optimizer state
476    pub fn reset(&mut self) {
477        self.tensor_layouts.clear();
478        self.total_conversion_cost = 0.0;
479    }
480}
481
482impl Default for AutoLayoutOptimizer {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488/// Layout optimization hint for specific operations
489#[derive(Debug, Clone)]
490pub struct LayoutHint {
491    pub operation: OperationType,
492    pub preferred_layout: DataLayout,
493    pub priority: f32,
494}
495
496impl LayoutHint {
497    /// Create a new layout hint
498    pub fn new(operation: OperationType, preferred_layout: DataLayout, priority: f32) -> Self {
499        Self {
500            operation,
501            preferred_layout,
502            priority,
503        }
504    }
505
506    /// High priority hint for convolution operations
507    pub fn convolution_hint(layout: DataLayout) -> Self {
508        Self::new(OperationType::Convolution, layout, 3.0)
509    }
510
511    /// Medium priority hint for fully connected operations
512    pub fn dense_hint(layout: DataLayout) -> Self {
513        Self::new(OperationType::FullyConnected, layout, 2.0)
514    }
515
516    /// Low priority hint for element-wise operations
517    pub fn elementwise_hint(layout: DataLayout) -> Self {
518        Self::new(OperationType::ElementWise, layout, 0.5)
519    }
520}
521
522/// Global layout optimization context
523pub struct LayoutContext {
524    optimizer: AutoLayoutOptimizer,
525    /// Hints for upcoming operations
526    hints: Vec<LayoutHint>,
527    /// Enable/disable automatic optimization
528    auto_optimize: bool,
529}
530
531impl LayoutContext {
532    /// Create a new layout context
533    pub fn new() -> Self {
534        Self {
535            optimizer: AutoLayoutOptimizer::new(),
536            hints: Vec::new(),
537            auto_optimize: true,
538        }
539    }
540
541    /// Add a layout hint for future operations
542    pub fn add_hint(&mut self, hint: LayoutHint) {
543        self.hints.push(hint);
544    }
545
546    /// Enable or disable automatic layout optimization
547    pub fn set_auto_optimize(&mut self, enable: bool) {
548        self.auto_optimize = enable;
549    }
550
551    /// Get the best layout for a tensor considering all hints
552    pub fn best_layout(
553        &self,
554        tensor_id: usize,
555        op_type: OperationType,
556        device: &Device,
557    ) -> DataLayout {
558        if !self.auto_optimize {
559            return self
560                .optimizer
561                .get_layout(tensor_id)
562                .unwrap_or(DataLayout::Auto);
563        }
564
565        // Consider hints first
566        let mut best_layout = self.optimizer.optimizer.preferred_layout(device, op_type);
567        let mut best_priority = 1.0;
568
569        for hint in &self.hints {
570            if hint.operation == op_type && hint.priority > best_priority {
571                best_layout = hint.preferred_layout;
572                best_priority = hint.priority;
573            }
574        }
575
576        best_layout
577    }
578
579    /// Clear all hints
580    pub fn clear_hints(&mut self) {
581        self.hints.clear();
582    }
583}
584
585impl Default for LayoutContext {
586    fn default() -> Self {
587        Self::new()
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    #[test]
596    fn test_layout_permutations() {
597        assert_eq!(
598            DataLayout::NCHW.to_permutation(DataLayout::NHWC),
599            Some(vec![0, 2, 3, 1])
600        );
601        assert_eq!(
602            DataLayout::NHWC.to_permutation(DataLayout::NCHW),
603            Some(vec![0, 3, 1, 2])
604        );
605    }
606
607    #[test]
608    fn test_layout_inference() {
609        // Typical image tensor: small channel dim
610        assert_eq!(infer_layout(&[32, 3, 224, 224], None), DataLayout::NCHW);
611
612        // Typical feature map: large channel dim
613        assert_eq!(infer_layout(&[32, 224, 224, 256], None), DataLayout::NHWC);
614    }
615
616    #[test]
617    fn test_layout_optimizer() {
618        let optimizer = LayoutOptimizer::default();
619
620        // GPU should prefer NCHW for convolution (if GPU feature is enabled)
621        #[cfg(feature = "gpu")]
622        assert_eq!(
623            optimizer.preferred_layout(&Device::Gpu(0), OperationType::Convolution),
624            DataLayout::NCHW
625        );
626
627        // CPU should prefer NHWC for convolution
628        assert_eq!(
629            optimizer.preferred_layout(&Device::Cpu, OperationType::Convolution),
630            DataLayout::NHWC
631        );
632    }
633
634    #[test]
635    fn test_conversion_costs() {
636        let optimizer = LayoutOptimizer::default();
637
638        assert_eq!(
639            optimizer.conversion_cost(DataLayout::NCHW, DataLayout::NCHW),
640            0.0
641        );
642        assert!(optimizer.conversion_cost(DataLayout::NCHW, DataLayout::NHWC) > 0.0);
643    }
644
645    #[test]
646    fn test_auto_layout_optimizer() {
647        let mut auto_optimizer = AutoLayoutOptimizer::new();
648
649        // Register a tensor with NCHW layout
650        auto_optimizer.register_tensor(0, DataLayout::NCHW);
651
652        // Check initial layout
653        assert_eq!(auto_optimizer.get_layout(0), Some(DataLayout::NCHW));
654
655        // Check that total cost starts at 0
656        assert_eq!(auto_optimizer.total_cost(), 0.0);
657    }
658
659    #[test]
660    fn test_layout_hints() {
661        let hint = LayoutHint::convolution_hint(DataLayout::NCHW);
662        assert_eq!(hint.operation, OperationType::Convolution);
663        assert_eq!(hint.preferred_layout, DataLayout::NCHW);
664        assert_eq!(hint.priority, 3.0);
665
666        let hint = LayoutHint::dense_hint(DataLayout::NHWC);
667        assert_eq!(hint.operation, OperationType::FullyConnected);
668        assert_eq!(hint.preferred_layout, DataLayout::NHWC);
669        assert_eq!(hint.priority, 2.0);
670    }
671
672    #[test]
673    fn test_layout_context() {
674        let mut context = LayoutContext::new();
675
676        // Add a convolution hint
677        context.add_hint(LayoutHint::convolution_hint(DataLayout::NCHW));
678
679        // Check that it returns the hinted layout
680        let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
681        assert_eq!(best_layout, DataLayout::NCHW);
682
683        // Clear hints
684        context.clear_hints();
685
686        // Should now return device-preferred layout
687        let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
688        assert_eq!(best_layout, DataLayout::NHWC); // CPU prefers NHWC
689    }
690
691    #[test]
692    fn test_layout_context_auto_optimize() {
693        let mut context = LayoutContext::new();
694
695        // Disable auto optimization
696        context.set_auto_optimize(false);
697
698        // Should return Auto layout when disabled
699        let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
700        assert_eq!(best_layout, DataLayout::Auto);
701
702        // Re-enable auto optimization
703        context.set_auto_optimize(true);
704
705        // Should return device-preferred layout when enabled
706        let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
707        assert_eq!(best_layout, DataLayout::NHWC);
708    }
709}