1use crate::device::Device;
2use crate::{Result, Tensor, TensorError};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum DataLayout {
12 NCHW,
14 NHWC,
16 NCDHW,
18 NDHWC,
20 NCL,
22 NLC,
24 Auto,
26}
27
28impl DataLayout {
29 pub fn ndim(&self) -> usize {
31 match self {
32 DataLayout::NCL | DataLayout::NLC => 3,
33 DataLayout::NCHW | DataLayout::NHWC => 4,
34 DataLayout::NCDHW | DataLayout::NDHWC => 5,
35 DataLayout::Auto => 0, }
37 }
38
39 pub fn channel_axis(&self) -> usize {
41 match self {
42 DataLayout::NCHW | DataLayout::NCDHW | DataLayout::NCL => 1,
43 DataLayout::NHWC => 3,
44 DataLayout::NDHWC | DataLayout::NLC => 4,
45 DataLayout::Auto => panic!("Cannot get channel axis for Auto layout"),
46 }
47 }
48
49 pub fn is_channels_first(&self) -> bool {
51 matches!(self, DataLayout::NCHW | DataLayout::NCDHW | DataLayout::NCL)
52 }
53
54 pub fn to_permutation(&self, target: DataLayout) -> Option<Vec<usize>> {
56 match (self, target) {
57 (DataLayout::NCHW, DataLayout::NHWC) => Some(vec![0, 2, 3, 1]), (DataLayout::NHWC, DataLayout::NCHW) => Some(vec![0, 3, 1, 2]), (DataLayout::NCDHW, DataLayout::NDHWC) => Some(vec![0, 2, 3, 4, 1]), (DataLayout::NDHWC, DataLayout::NCDHW) => Some(vec![0, 4, 1, 2, 3]), (DataLayout::NCL, DataLayout::NLC) => Some(vec![0, 2, 1]), (DataLayout::NLC, DataLayout::NCL) => Some(vec![0, 2, 1]), _ if self == &target => None, _ => None, }
66 }
67}
68
69pub struct LayoutOptimizer {
71 layout_preferences: HashMap<(Device, OperationType), DataLayout>,
73 conversion_costs: HashMap<(DataLayout, DataLayout), f32>,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum OperationType {
80 Convolution,
81 FullyConnected,
82 Pooling,
83 Normalization,
84 Activation,
85 ElementWise,
86 Reduction,
87}
88
89impl Default for LayoutOptimizer {
90 fn default() -> Self {
91 let mut layout_preferences = HashMap::new();
92 let mut conversion_costs = HashMap::new();
93
94 #[cfg(feature = "gpu")]
96 {
97 layout_preferences.insert(
98 (Device::Gpu(0), OperationType::Convolution),
99 DataLayout::NCHW,
100 );
101 layout_preferences.insert(
102 (Device::Gpu(0), OperationType::FullyConnected),
103 DataLayout::NCHW,
104 );
105 layout_preferences.insert((Device::Gpu(0), OperationType::Pooling), DataLayout::NCHW);
106 layout_preferences.insert(
107 (Device::Gpu(0), OperationType::Normalization),
108 DataLayout::NCHW,
109 );
110 layout_preferences.insert(
111 (Device::Gpu(0), OperationType::Activation),
112 DataLayout::NCHW,
113 );
114 layout_preferences.insert(
115 (Device::Gpu(0), OperationType::ElementWise),
116 DataLayout::NCHW,
117 );
118 }
119
120 layout_preferences.insert((Device::Cpu, OperationType::Convolution), DataLayout::NHWC);
122 layout_preferences.insert(
123 (Device::Cpu, OperationType::FullyConnected),
124 DataLayout::NHWC,
125 );
126 layout_preferences.insert((Device::Cpu, OperationType::Pooling), DataLayout::NHWC);
127 layout_preferences.insert(
128 (Device::Cpu, OperationType::Normalization),
129 DataLayout::NHWC,
130 );
131 layout_preferences.insert((Device::Cpu, OperationType::Activation), DataLayout::NHWC);
132 layout_preferences.insert((Device::Cpu, OperationType::ElementWise), DataLayout::NHWC);
133
134 conversion_costs.insert((DataLayout::NCHW, DataLayout::NHWC), 1.0);
136 conversion_costs.insert((DataLayout::NHWC, DataLayout::NCHW), 1.0);
137 conversion_costs.insert((DataLayout::NCDHW, DataLayout::NDHWC), 1.5);
138 conversion_costs.insert((DataLayout::NDHWC, DataLayout::NCDHW), 1.5);
139 conversion_costs.insert((DataLayout::NCL, DataLayout::NLC), 0.5);
140 conversion_costs.insert((DataLayout::NLC, DataLayout::NCL), 0.5);
141
142 LayoutOptimizer {
143 layout_preferences,
144 conversion_costs,
145 }
146 }
147}
148
149impl LayoutOptimizer {
150 pub fn preferred_layout(&self, device: &Device, op_type: OperationType) -> DataLayout {
152 self.layout_preferences
153 .get(&(*device, op_type))
154 .copied()
155 .unwrap_or(DataLayout::NCHW) }
157
158 pub fn conversion_cost(&self, from: DataLayout, to: DataLayout) -> f32 {
160 if from == to {
161 return 0.0;
162 }
163 self.conversion_costs
164 .get(&(from, to))
165 .copied()
166 .unwrap_or(2.0) }
168
169 pub fn should_convert(&self, from: DataLayout, to: DataLayout, operation_benefit: f32) -> bool {
171 let cost = self.conversion_cost(from, to);
172 operation_benefit > cost
173 }
174
175 pub fn auto_layout(
177 &self,
178 current_layout: DataLayout,
179 target_device: &Device,
180 op_type: OperationType,
181 operation_intensity: f32,
182 ) -> DataLayout {
183 let preferred = self.preferred_layout(target_device, op_type);
184
185 if self.should_convert(current_layout, preferred, operation_intensity) {
186 preferred
187 } else {
188 current_layout
189 }
190 }
191}
192
193fn permute_tensor<T>(input: &Tensor<T>, axes: &[usize]) -> Result<Tensor<T>>
195where
196 T: Clone
197 + Default
198 + scirs2_core::num_traits::Zero
199 + scirs2_core::num_traits::One
200 + Send
201 + Sync
202 + 'static
203 + bytemuck::Pod
204 + bytemuck::Zeroable,
205{
206 use crate::tensor::TensorStorage;
207
208 match &input.storage {
209 TensorStorage::Cpu(arr) => {
210 let permuted = arr.clone().permuted_axes(axes);
211
212 let new_shape: Vec<usize> = {
214 let old_shape = input.shape().dims();
215 axes.iter().map(|&i| old_shape[i]).collect()
216 };
217
218 let vec_data: Vec<T> = permuted.iter().cloned().collect();
219 Tensor::from_vec(vec_data, &new_shape)
220 }
221 #[cfg(feature = "gpu")]
222 TensorStorage::Gpu(gpu_buffer) => {
223 gpu_permute_tensor(gpu_buffer, input.shape().dims(), axes)
225 }
226 }
227}
228
229#[cfg(feature = "gpu")]
231fn gpu_permute_tensor<T>(
232 gpu_buffer: &crate::gpu::buffer::GpuBuffer<T>,
233 input_shape: &[usize],
234 axes: &[usize],
235) -> Result<Tensor<T>>
236where
237 T: Clone
238 + Default
239 + scirs2_core::num_traits::Zero
240 + scirs2_core::num_traits::One
241 + Send
242 + Sync
243 + 'static
244 + bytemuck::Pod
245 + bytemuck::Zeroable,
246{
247 use crate::gpu::ops::execute_tensor_permutation;
248
249 let output_shape: Vec<usize> = axes.iter().map(|&i| input_shape[i]).collect();
251 let output_len = output_shape.iter().product();
252
253 let result_buffer = execute_tensor_permutation(gpu_buffer, axes, input_shape, output_len)?;
255
256 Ok(Tensor::from_gpu_buffer(
257 result_buffer,
258 crate::Shape::new(output_shape),
259 ))
260}
261
262pub fn convert_layout<T>(
264 input: &Tensor<T>,
265 from_layout: DataLayout,
266 to_layout: DataLayout,
267) -> Result<Tensor<T>>
268where
269 T: Clone
270 + Default
271 + scirs2_core::num_traits::Zero
272 + scirs2_core::num_traits::One
273 + Send
274 + Sync
275 + 'static
276 + bytemuck::Pod
277 + bytemuck::Zeroable,
278{
279 if from_layout == to_layout {
280 return Ok(input.clone());
281 }
282
283 if let Some(perm) = from_layout.to_permutation(to_layout) {
284 permute_tensor(input, &perm)
285 } else {
286 Err(TensorError::unsupported_operation_simple(format!(
287 "Layout conversion from {from_layout:?} to {to_layout:?} not supported"
288 )))
289 }
290}
291
292pub fn infer_layout(shape: &[usize], ndim_hint: Option<usize>) -> DataLayout {
294 let ndim = ndim_hint.unwrap_or(shape.len());
295
296 match ndim {
297 3 => {
298 if shape.len() >= 3 && shape[1] <= 512 && shape[1] < shape[2] {
300 DataLayout::NCL
301 } else {
302 DataLayout::NLC
303 }
304 }
305 4 => {
306 if shape.len() >= 4 && shape[1] <= 2048 && shape[1] < shape[2] && shape[1] < shape[3] {
308 DataLayout::NCHW
309 } else {
310 DataLayout::NHWC
311 }
312 }
313 5 => {
314 if shape.len() >= 5 && shape[1] <= 2048 && shape[1] < shape[2] {
316 DataLayout::NCDHW
317 } else {
318 DataLayout::NDHWC
319 }
320 }
321 _ => DataLayout::Auto,
322 }
323}
324
325pub struct LayoutPlan {
327 conversions: Vec<(usize, DataLayout, DataLayout)>, optimal_layouts: HashMap<usize, DataLayout>,
329}
330
331impl LayoutPlan {
332 pub fn optimize(
334 tensor_layouts: &[(usize, DataLayout)],
335 operations: &[(OperationType, Vec<usize>, Device)], optimizer: &LayoutOptimizer,
337 ) -> Self {
338 let mut optimal_layouts = HashMap::new();
339 let mut conversions = Vec::new();
340
341 for &(tensor_id, layout) in tensor_layouts {
343 optimal_layouts.insert(tensor_id, layout);
344 }
345
346 for (op_type, input_ids, device) in operations {
348 for &tensor_id in input_ids {
349 if let Some(¤t_layout) = optimal_layouts.get(&tensor_id) {
350 let preferred = optimizer.preferred_layout(device, *op_type);
351
352 let operation_intensity = match op_type {
354 OperationType::Convolution => 3.0,
355 OperationType::FullyConnected => 2.0,
356 OperationType::Pooling => 1.5,
357 _ => 1.0,
358 };
359
360 if optimizer.should_convert(current_layout, preferred, operation_intensity) {
361 conversions.push((tensor_id, current_layout, preferred));
362 optimal_layouts.insert(tensor_id, preferred);
363 }
364 }
365 }
366 }
367
368 LayoutPlan {
369 conversions,
370 optimal_layouts,
371 }
372 }
373
374 pub fn conversions(&self) -> &[(usize, DataLayout, DataLayout)] {
376 &self.conversions
377 }
378
379 pub fn optimal_layout(&self, tensor_id: usize) -> Option<DataLayout> {
381 self.optimal_layouts.get(&tensor_id).copied()
382 }
383}
384
385pub struct AutoLayoutOptimizer {
387 optimizer: LayoutOptimizer,
388 tensor_layouts: HashMap<usize, DataLayout>,
390 total_conversion_cost: f32,
392}
393
394impl AutoLayoutOptimizer {
395 pub fn new() -> Self {
397 Self {
398 optimizer: LayoutOptimizer::default(),
399 tensor_layouts: HashMap::new(),
400 total_conversion_cost: 0.0,
401 }
402 }
403
404 pub fn register_tensor(&mut self, tensor_id: usize, layout: DataLayout) {
406 self.tensor_layouts.insert(tensor_id, layout);
407 }
408
409 pub fn optimize_for_operation<T>(
411 &mut self,
412 tensors: &mut [&mut Tensor<T>],
413 tensor_ids: &[usize],
414 op_type: OperationType,
415 device: &Device,
416 ) -> Result<()>
417 where
418 T: Clone
419 + Default
420 + scirs2_core::num_traits::Zero
421 + scirs2_core::num_traits::One
422 + Send
423 + Sync
424 + 'static
425 + bytemuck::Pod
426 + bytemuck::Zeroable,
427 {
428 let preferred_layout = self.optimizer.preferred_layout(device, op_type);
429
430 let operation_intensity = match op_type {
432 OperationType::Convolution => 3.0,
433 OperationType::FullyConnected => 2.0,
434 OperationType::Pooling => 1.5,
435 OperationType::Normalization => 1.2,
436 OperationType::Activation => 0.8,
437 OperationType::ElementWise => 0.5,
438 OperationType::Reduction => 1.0,
439 };
440
441 for (tensor, &tensor_id) in tensors.iter_mut().zip(tensor_ids.iter()) {
443 if let Some(¤t_layout) = self.tensor_layouts.get(&tensor_id) {
444 if current_layout != preferred_layout {
445 let conversion_cost = self
446 .optimizer
447 .conversion_cost(current_layout, preferred_layout);
448
449 if operation_intensity > conversion_cost {
450 let converted = convert_layout(tensor, current_layout, preferred_layout)?;
452 **tensor = converted;
453
454 self.tensor_layouts.insert(tensor_id, preferred_layout);
456 self.total_conversion_cost += conversion_cost;
457 }
458 }
459 }
460 }
461
462 Ok(())
463 }
464
465 pub fn get_layout(&self, tensor_id: usize) -> Option<DataLayout> {
467 self.tensor_layouts.get(&tensor_id).copied()
468 }
469
470 pub fn total_cost(&self) -> f32 {
472 self.total_conversion_cost
473 }
474
475 pub fn reset(&mut self) {
477 self.tensor_layouts.clear();
478 self.total_conversion_cost = 0.0;
479 }
480}
481
482impl Default for AutoLayoutOptimizer {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488#[derive(Debug, Clone)]
490pub struct LayoutHint {
491 pub operation: OperationType,
492 pub preferred_layout: DataLayout,
493 pub priority: f32,
494}
495
496impl LayoutHint {
497 pub fn new(operation: OperationType, preferred_layout: DataLayout, priority: f32) -> Self {
499 Self {
500 operation,
501 preferred_layout,
502 priority,
503 }
504 }
505
506 pub fn convolution_hint(layout: DataLayout) -> Self {
508 Self::new(OperationType::Convolution, layout, 3.0)
509 }
510
511 pub fn dense_hint(layout: DataLayout) -> Self {
513 Self::new(OperationType::FullyConnected, layout, 2.0)
514 }
515
516 pub fn elementwise_hint(layout: DataLayout) -> Self {
518 Self::new(OperationType::ElementWise, layout, 0.5)
519 }
520}
521
522pub struct LayoutContext {
524 optimizer: AutoLayoutOptimizer,
525 hints: Vec<LayoutHint>,
527 auto_optimize: bool,
529}
530
531impl LayoutContext {
532 pub fn new() -> Self {
534 Self {
535 optimizer: AutoLayoutOptimizer::new(),
536 hints: Vec::new(),
537 auto_optimize: true,
538 }
539 }
540
541 pub fn add_hint(&mut self, hint: LayoutHint) {
543 self.hints.push(hint);
544 }
545
546 pub fn set_auto_optimize(&mut self, enable: bool) {
548 self.auto_optimize = enable;
549 }
550
551 pub fn best_layout(
553 &self,
554 tensor_id: usize,
555 op_type: OperationType,
556 device: &Device,
557 ) -> DataLayout {
558 if !self.auto_optimize {
559 return self
560 .optimizer
561 .get_layout(tensor_id)
562 .unwrap_or(DataLayout::Auto);
563 }
564
565 let mut best_layout = self.optimizer.optimizer.preferred_layout(device, op_type);
567 let mut best_priority = 1.0;
568
569 for hint in &self.hints {
570 if hint.operation == op_type && hint.priority > best_priority {
571 best_layout = hint.preferred_layout;
572 best_priority = hint.priority;
573 }
574 }
575
576 best_layout
577 }
578
579 pub fn clear_hints(&mut self) {
581 self.hints.clear();
582 }
583}
584
585impl Default for LayoutContext {
586 fn default() -> Self {
587 Self::new()
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_layout_permutations() {
597 assert_eq!(
598 DataLayout::NCHW.to_permutation(DataLayout::NHWC),
599 Some(vec![0, 2, 3, 1])
600 );
601 assert_eq!(
602 DataLayout::NHWC.to_permutation(DataLayout::NCHW),
603 Some(vec![0, 3, 1, 2])
604 );
605 }
606
607 #[test]
608 fn test_layout_inference() {
609 assert_eq!(infer_layout(&[32, 3, 224, 224], None), DataLayout::NCHW);
611
612 assert_eq!(infer_layout(&[32, 224, 224, 256], None), DataLayout::NHWC);
614 }
615
616 #[test]
617 fn test_layout_optimizer() {
618 let optimizer = LayoutOptimizer::default();
619
620 #[cfg(feature = "gpu")]
622 assert_eq!(
623 optimizer.preferred_layout(&Device::Gpu(0), OperationType::Convolution),
624 DataLayout::NCHW
625 );
626
627 assert_eq!(
629 optimizer.preferred_layout(&Device::Cpu, OperationType::Convolution),
630 DataLayout::NHWC
631 );
632 }
633
634 #[test]
635 fn test_conversion_costs() {
636 let optimizer = LayoutOptimizer::default();
637
638 assert_eq!(
639 optimizer.conversion_cost(DataLayout::NCHW, DataLayout::NCHW),
640 0.0
641 );
642 assert!(optimizer.conversion_cost(DataLayout::NCHW, DataLayout::NHWC) > 0.0);
643 }
644
645 #[test]
646 fn test_auto_layout_optimizer() {
647 let mut auto_optimizer = AutoLayoutOptimizer::new();
648
649 auto_optimizer.register_tensor(0, DataLayout::NCHW);
651
652 assert_eq!(auto_optimizer.get_layout(0), Some(DataLayout::NCHW));
654
655 assert_eq!(auto_optimizer.total_cost(), 0.0);
657 }
658
659 #[test]
660 fn test_layout_hints() {
661 let hint = LayoutHint::convolution_hint(DataLayout::NCHW);
662 assert_eq!(hint.operation, OperationType::Convolution);
663 assert_eq!(hint.preferred_layout, DataLayout::NCHW);
664 assert_eq!(hint.priority, 3.0);
665
666 let hint = LayoutHint::dense_hint(DataLayout::NHWC);
667 assert_eq!(hint.operation, OperationType::FullyConnected);
668 assert_eq!(hint.preferred_layout, DataLayout::NHWC);
669 assert_eq!(hint.priority, 2.0);
670 }
671
672 #[test]
673 fn test_layout_context() {
674 let mut context = LayoutContext::new();
675
676 context.add_hint(LayoutHint::convolution_hint(DataLayout::NCHW));
678
679 let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
681 assert_eq!(best_layout, DataLayout::NCHW);
682
683 context.clear_hints();
685
686 let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
688 assert_eq!(best_layout, DataLayout::NHWC); }
690
691 #[test]
692 fn test_layout_context_auto_optimize() {
693 let mut context = LayoutContext::new();
694
695 context.set_auto_optimize(false);
697
698 let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
700 assert_eq!(best_layout, DataLayout::Auto);
701
702 context.set_auto_optimize(true);
704
705 let best_layout = context.best_layout(0, OperationType::Convolution, &Device::Cpu);
707 assert_eq!(best_layout, DataLayout::NHWC);
708 }
709}