Skip to main content

torsh_jit/
codegen.rs

1//! Code generation backend for JIT compilation
2
3use crate::graph::{ComputationGraph, Node, NodeId, Operation};
4use crate::{CompiledKernel, JitError, JitResult, KernelMetadata, TensorDesc};
5use torsh_core::DeviceType;
6
7#[cfg(feature = "cranelift-backend")]
8use cranelift::prelude::*;
9
10/// Code generator for different backends
11pub struct CodeGenerator {
12    device: DeviceType,
13    #[cfg(feature = "cranelift-backend")]
14    cranelift: Option<CraneliftBackend>,
15}
16
17impl CodeGenerator {
18    /// Create a new code generator for the target device
19    pub fn new(device: DeviceType) -> Self {
20        Self {
21            device,
22            #[cfg(feature = "cranelift-backend")]
23            cranelift: match device {
24                DeviceType::Cpu => Some(CraneliftBackend::new()),
25                _ => None,
26            },
27        }
28    }
29
30    /// Generate code for the computation graph
31    pub fn generate(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
32        match self.device {
33            DeviceType::Cpu => self.generate_cpu(graph),
34            DeviceType::Cuda(_) => self.generate_cuda(graph),
35            DeviceType::Metal(_) => self.generate_metal(graph),
36            _ => Err(JitError::UnsupportedOp(format!(
37                "Code generation not supported for {:?}",
38                self.device
39            ))),
40        }
41    }
42
43    /// Generate CPU code
44    fn generate_cpu(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
45        #[cfg(feature = "cranelift-backend")]
46        if let Some(ref backend) = self.cranelift {
47            return backend.generate(graph);
48        }
49
50        // Fallback to interpreter mode
51        self.generate_interpreter(graph)
52    }
53
54    /// Generate CUDA code
55    ///
56    /// Future implementation will support:
57    /// - PTX code generation for NVIDIA GPUs
58    /// - Kernel fusion for memory bandwidth optimization
59    /// - Tensor core utilization for matrix operations
60    /// - Automatic memory coalescing
61    /// - Multi-stream execution support
62    fn generate_cuda(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
63        let node_count = graph.nodes().count();
64        let operation_types: Vec<_> = graph
65            .nodes()
66            .map(|(_, node)| format!("{:?}", node.op))
67            .collect();
68
69        Err(JitError::UnsupportedOp(format!(
70            "CUDA code generation not yet implemented. \
71             Graph contains {} nodes with operations: {}. \
72             To enable CUDA support: \
73             1. Install CUDA toolkit (>=11.0) \
74             2. Enable 'cuda' feature flag \
75             3. Set CUDA_PATH environment variable \
76             \nFallback: Use CPU backend or interpreter mode.",
77            node_count,
78            operation_types.join(", ")
79        )))
80    }
81
82    /// Generate Metal code
83    ///
84    /// Future implementation will support:
85    /// - Metal Shading Language (MSL) generation
86    /// - Metal Performance Shaders (MPS) integration
87    /// - Unified memory architecture optimization
88    /// - Apple Neural Engine (ANE) acceleration
89    /// - Multi-GPU support for Mac Pro
90    fn generate_metal(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
91        let node_count = graph.nodes().count();
92        let has_matmul = graph
93            .nodes()
94            .any(|(_, node)| matches!(node.op, Operation::MatMul));
95        let has_conv = graph
96            .nodes()
97            .any(|(_, node)| matches!(node.op, Operation::Conv2d { .. }));
98
99        let recommendations = if has_matmul || has_conv {
100            "Consider using Metal Performance Shaders (MPS) backend for matrix/convolution operations."
101        } else {
102            "For element-wise operations, CPU backend may provide sufficient performance."
103        };
104
105        Err(JitError::UnsupportedOp(format!(
106            "Metal code generation not yet implemented. \
107             Graph contains {} nodes. \
108             Detected: {} \
109             To enable Metal support: \
110             1. Ensure macOS 10.15+ or iOS 13+ \
111             2. Enable 'metal' feature flag \
112             3. Install Metal developer tools \
113             \n{} \
114             \nFallback: Use CPU backend or interpreter mode.",
115            node_count,
116            if has_matmul {
117                "matrix multiplication"
118            } else if has_conv {
119                "convolutions"
120            } else {
121                "element-wise ops"
122            },
123            recommendations
124        )))
125    }
126
127    /// Generate code from IR module  
128    pub fn generate_from_ir(
129        &self,
130        ir_module: &crate::ir::IrModule,
131    ) -> JitResult<Vec<CompiledKernel>> {
132        // For now, convert IR back to graph-like representation and use existing logic
133        // In a real implementation, this would generate code directly from IR
134        self.generate_interpreter_from_ir(ir_module)
135    }
136
137    /// Generate interpreter kernels from IR
138    pub fn generate_interpreter_from_ir(
139        &self,
140        ir_module: &crate::ir::IrModule,
141    ) -> JitResult<Vec<CompiledKernel>> {
142        let mut kernels = Vec::new();
143
144        // For each basic block, create a kernel
145        for (block_id, block) in &ir_module.blocks {
146            let kernel_id = format!("ir_kernel_{}", block_id);
147
148            // Create simple metadata
149            let metadata = KernelMetadata {
150                inputs: ir_module
151                    .inputs
152                    .iter()
153                    .filter_map(|&input| self.ir_value_to_tensor_desc(ir_module, input))
154                    .collect(),
155                outputs: ir_module
156                    .outputs
157                    .iter()
158                    .filter_map(|&output| self.ir_value_to_tensor_desc(ir_module, output))
159                    .collect(),
160                shared_memory: 0,
161                block_size: (1, 1, 1),
162                grid_size: (1, 1, 1),
163            };
164
165            // Encode the instructions
166            let mut code = Vec::new();
167            for instruction in &block.instructions {
168                let opcode = self.encode_ir_instruction(instruction)?;
169                code.push(opcode);
170            }
171
172            let kernel = CompiledKernel {
173                id: kernel_id,
174                source_nodes: Vec::new(), // Would need mapping from IR to original nodes
175                code,
176                metadata,
177            };
178
179            kernels.push(kernel);
180        }
181
182        Ok(kernels)
183    }
184
185    /// Convert IR value to tensor descriptor
186    fn ir_value_to_tensor_desc(
187        &self,
188        ir_module: &crate::ir::IrModule,
189        ir_value: crate::ir::IrValue,
190    ) -> Option<TensorDesc> {
191        if let Some(value_def) = ir_module.get_value(ir_value) {
192            if let Some(type_def) = ir_module.get_type(value_def.ty) {
193                match &type_def.kind {
194                    crate::ir::TypeKind::Tensor { shape, .. } => {
195                        Some(TensorDesc {
196                            dtype: torsh_core::DType::F32, // Simplified
197                            shape: shape.clone(),
198                            strides: self.compute_strides(shape),
199                            offset: 0,
200                        })
201                    }
202                    _ => None,
203                }
204            } else {
205                None
206            }
207        } else {
208            None
209        }
210    }
211
212    /// Encode an IR instruction
213    fn encode_ir_instruction(&self, instruction: &crate::ir::Instruction) -> JitResult<u8> {
214        let opcode = match &instruction.opcode {
215            crate::ir::IrOpcode::Add => 1,
216            crate::ir::IrOpcode::Sub => 2,
217            crate::ir::IrOpcode::Mul => 3,
218            crate::ir::IrOpcode::Div => 4,
219            crate::ir::IrOpcode::Neg => 5,
220            crate::ir::IrOpcode::Abs => 6,
221            crate::ir::IrOpcode::Exp => 7,
222            crate::ir::IrOpcode::Log => 8,
223            crate::ir::IrOpcode::Sqrt => 9,
224            crate::ir::IrOpcode::Sin => 10,
225            crate::ir::IrOpcode::Cos => 11,
226            crate::ir::IrOpcode::Tanh => 12,
227            crate::ir::IrOpcode::Sigmoid => 13,
228            crate::ir::IrOpcode::Relu => 14,
229            crate::ir::IrOpcode::Gelu => 15,
230            crate::ir::IrOpcode::MatMul => 16,
231            crate::ir::IrOpcode::Conv2d => 17,
232            crate::ir::IrOpcode::Pool2d => 18,
233            crate::ir::IrOpcode::Reshape => 19,
234            crate::ir::IrOpcode::Transpose => 20,
235            crate::ir::IrOpcode::Sum => 21,
236            crate::ir::IrOpcode::Mean => 22,
237            crate::ir::IrOpcode::Max => 23,
238            crate::ir::IrOpcode::Min => 24,
239            crate::ir::IrOpcode::Load => 25,
240            crate::ir::IrOpcode::Store => 26,
241            crate::ir::IrOpcode::Const => 27,
242            _ => {
243                return Err(JitError::UnsupportedOp(format!(
244                    "IR opcode {:?} not supported in interpreter",
245                    instruction.opcode
246                )))
247            }
248        };
249
250        Ok(opcode)
251    }
252
253    /// Generate interpreter-based kernels (fallback)
254    pub fn generate_interpreter(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
255        let mut kernels = Vec::new();
256
257        // Get topological order
258        let order = graph
259            .topological_sort()
260            .map_err(|e| JitError::GraphError(format!("{:?}", e)))?;
261
262        // Generate a kernel for each node (simple approach)
263        for node_id in order {
264            if let Some(node) = graph.node(node_id) {
265                let kernel = self.generate_interpreter_kernel(graph, node_id, node)?;
266                kernels.push(kernel);
267            }
268        }
269
270        Ok(kernels)
271    }
272
273    /// Generate an interpreter kernel for a single node
274    fn generate_interpreter_kernel(
275        &self,
276        graph: &ComputationGraph,
277        node_id: NodeId,
278        node: &Node,
279    ) -> JitResult<CompiledKernel> {
280        // Populate inputs from graph
281        let input_tensors: Vec<TensorDesc> = graph
282            .get_node_inputs(node_id)
283            .iter()
284            .filter_map(|&input_id| {
285                graph.node(input_id).map(|input_node| TensorDesc {
286                    dtype: input_node.dtype,
287                    shape: input_node.output_shape.dims().to_vec(),
288                    strides: self.compute_strides(input_node.output_shape.dims()),
289                    offset: 0,
290                })
291            })
292            .collect();
293
294        // Generate metadata
295        let metadata = KernelMetadata {
296            inputs: input_tensors,
297            outputs: vec![TensorDesc {
298                dtype: node.dtype,
299                shape: node.output_shape.dims().to_vec(),
300                strides: self.compute_strides(node.output_shape.dims()),
301                offset: 0,
302            }],
303            shared_memory: 0,
304            block_size: (1, 1, 1),
305            grid_size: (1, 1, 1),
306        };
307
308        // Encode operation as "code"
309        let code = self.encode_operation(&node.op)?;
310
311        Ok(CompiledKernel {
312            id: format!("kernel_{:?}", node_id),
313            source_nodes: vec![node_id],
314            code,
315            metadata,
316        })
317    }
318
319    /// Compute strides for a shape
320    fn compute_strides(&self, shape: &[usize]) -> Vec<usize> {
321        let mut strides = vec![1; shape.len()];
322        for i in (0..shape.len() - 1).rev() {
323            strides[i] = strides[i + 1] * shape[i + 1];
324        }
325        strides
326    }
327
328    /// Encode an operation for interpreter execution
329    fn encode_operation(&self, op: &Operation) -> JitResult<Vec<u8>> {
330        // Simple encoding scheme for interpreter
331        let op_code = match op {
332            Operation::Add => 1,
333            Operation::Sub => 2,
334            Operation::Mul => 3,
335            Operation::Div => 4,
336            Operation::Relu => 5,
337            Operation::Sigmoid => 6,
338            Operation::Tanh => 7,
339            Operation::MatMul => 8,
340            // ... more operations
341            _ => {
342                return Err(JitError::UnsupportedOp(format!(
343                    "Operation {:?} not supported in interpreter",
344                    op
345                )))
346            }
347        };
348
349        Ok(vec![op_code])
350    }
351}
352
353#[cfg(feature = "cranelift-backend")]
354struct CraneliftBackend {
355    _builder_context: FunctionBuilderContext,
356    _ctx: codegen::Context,
357}
358
359#[cfg(feature = "cranelift-backend")]
360impl CraneliftBackend {
361    fn new() -> Self {
362        let mut flag_builder = settings::builder();
363        flag_builder
364            .set("use_colocated_libcalls", "false")
365            .expect("setting should be valid");
366        flag_builder
367            .set("is_pic", "false")
368            .expect("setting should be valid");
369        let isa_builder = cranelift_native::builder().expect("native builder should succeed");
370        let isa = isa_builder
371            .finish(settings::Flags::new(flag_builder))
372            .expect("ISA creation should succeed");
373
374        let mut ctx = codegen::Context::new();
375        ctx.func.signature.call_conv = isa.default_call_conv();
376
377        Self {
378            _builder_context: FunctionBuilderContext::new(),
379            _ctx: ctx,
380        }
381    }
382
383    fn generate(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
384        let mut kernels = Vec::new();
385
386        // Group nodes into kernels based on fusion information
387        let kernel_groups = self.identify_kernel_groups(graph)?;
388
389        // Generate code for each kernel
390        for (kernel_id, nodes) in kernel_groups.iter().enumerate() {
391            let kernel = self.generate_kernel(graph, kernel_id, nodes)?;
392            kernels.push(kernel);
393        }
394
395        Ok(kernels)
396    }
397
398    fn identify_kernel_groups(&self, graph: &ComputationGraph) -> JitResult<Vec<Vec<NodeId>>> {
399        // For now, each node is its own kernel
400        // In a real implementation, this would use fusion information
401        let order = graph
402            .topological_sort()
403            .map_err(|e| JitError::GraphError(format!("{:?}", e)))?;
404
405        Ok(order.into_iter().map(|n| vec![n]).collect())
406    }
407
408    fn generate_kernel(
409        &self,
410        _graph: &ComputationGraph,
411        kernel_id: usize,
412        nodes: &[NodeId],
413    ) -> JitResult<CompiledKernel> {
414        // TODO: Implement actual Cranelift code generation
415
416        // For now, return a placeholder
417        Ok(CompiledKernel {
418            id: format!("cranelift_kernel_{}", kernel_id),
419            source_nodes: nodes.to_vec(),
420            code: vec![],
421            metadata: KernelMetadata {
422                inputs: vec![],
423                outputs: vec![],
424                shared_memory: 0,
425                block_size: (1, 1, 1),
426                grid_size: (1, 1, 1),
427            },
428        })
429    }
430}
431
432/// CUDA kernel generator
433///
434/// Generates PTX (Parallel Thread Execution) code for NVIDIA GPUs.
435/// Supports compute capabilities from 5.0 (Maxwell) to 9.0 (Hopper).
436pub struct CudaKernelGenerator {
437    compute_capability: (u32, u32),
438    /// Enable tensor core usage for matrix operations (compute capability >= 7.0)
439    enable_tensor_cores: bool,
440    /// Target PTX ISA version
441    ptx_version: (u32, u32),
442    /// Enable cooperative groups
443    enable_cooperative_groups: bool,
444}
445
446impl CudaKernelGenerator {
447    /// Create a new CUDA kernel generator
448    ///
449    /// # Arguments
450    /// * `compute_capability` - GPU compute capability (e.g., (7, 5) for sm_75)
451    pub fn new(compute_capability: (u32, u32)) -> Self {
452        let enable_tensor_cores = compute_capability.0 >= 7;
453        let enable_cooperative_groups = compute_capability.0 >= 6;
454
455        Self {
456            compute_capability,
457            enable_tensor_cores,
458            ptx_version: (7, 0), // Default to PTX 7.0
459            enable_cooperative_groups,
460        }
461    }
462
463    /// Enable or disable tensor core usage
464    pub fn set_tensor_cores(&mut self, enable: bool) {
465        self.enable_tensor_cores = enable && self.compute_capability.0 >= 7;
466    }
467
468    /// Generate PTX assembly code for the computation graph
469    ///
470    /// Future implementation will:
471    /// - Analyze graph for optimal thread block configuration
472    /// - Generate fused kernels for element-wise operation chains
473    /// - Emit specialized tensor core instructions (WMMA) for matrix ops
474    /// - Apply memory coalescing patterns
475    /// - Generate multi-kernel launches for large graphs
476    pub fn generate_ptx(&self, graph: &ComputationGraph) -> JitResult<String> {
477        let node_count = graph.nodes().count();
478        let matmul_count = graph
479            .nodes()
480            .filter(|(_, n)| matches!(n.op, Operation::MatMul))
481            .count();
482        let conv_count = graph
483            .nodes()
484            .filter(|(_, n)| matches!(n.op, Operation::Conv2d { .. }))
485            .count();
486
487        let capability_str = format!(
488            "sm_{}{}",
489            self.compute_capability.0, self.compute_capability.1
490        );
491        let features = if self.enable_tensor_cores {
492            "tensor cores (WMMA), "
493        } else {
494            ""
495        };
496
497        Err(JitError::UnsupportedOp(format!(
498            "PTX generation not yet implemented.\n\
499             Target: {} (compute capability {}.{})\n\
500             Graph statistics:\n\
501             - Total nodes: {}\n\
502             - MatMul operations: {} {}\n\
503             - Conv2D operations: {} {}\n\
504             Features: {}cooperative groups\n\
505             \n\
506             Future PTX generation will support:\n\
507             - Automatic kernel fusion for {:.1}x speedup potential\n\
508             - Memory coalescing optimization\n\
509             - Shared memory tiling for matrix operations\n\
510             - Warp-level primitives for reduction operations\n\
511             \nFallback: Use CPU backend with BLAS/MKL for good performance.",
512            capability_str,
513            self.compute_capability.0,
514            self.compute_capability.1,
515            node_count,
516            matmul_count,
517            if self.enable_tensor_cores {
518                "(tensor core eligible)"
519            } else {
520                ""
521            },
522            conv_count,
523            if conv_count > 0 {
524                "(cudnn eligible)"
525            } else {
526                ""
527            },
528            features,
529            (matmul_count + conv_count).max(1) as f64 * 1.5 // Estimated fusion speedup
530        )))
531    }
532
533    /// Estimate kernel launch configuration for a graph
534    pub fn estimate_launch_config(&self, graph: &ComputationGraph) -> LaunchConfiguration {
535        let total_ops: usize = graph
536            .nodes()
537            .map(|(_, node)| node.output_shape.dims().iter().product::<usize>())
538            .sum();
539
540        // Simple heuristic for block size
541        let threads_per_block = if total_ops < 1024 {
542            128
543        } else if total_ops < 1024 * 1024 {
544            256
545        } else {
546            512
547        };
548
549        let blocks = (total_ops + threads_per_block - 1) / threads_per_block;
550
551        LaunchConfiguration {
552            grid_dim: (blocks.min(65535), 1, 1),
553            block_dim: (threads_per_block, 1, 1),
554            shared_memory_bytes: 0, // Would be calculated based on kernel
555            stream_id: 0,
556        }
557    }
558}
559
560/// CUDA kernel launch configuration
561#[derive(Debug, Clone)]
562pub struct LaunchConfiguration {
563    /// Grid dimensions (number of blocks)
564    pub grid_dim: (usize, usize, usize),
565    /// Block dimensions (threads per block)
566    pub block_dim: (usize, usize, usize),
567    /// Shared memory per block in bytes
568    pub shared_memory_bytes: usize,
569    /// CUDA stream ID
570    pub stream_id: i32,
571}
572
573/// Metal kernel generator
574///
575/// Generates Metal Shading Language (MSL) code for Apple GPUs.
576/// Supports macOS 10.15+, iOS 13+, and Apple Silicon.
577pub struct MetalKernelGenerator {
578    device_family: String,
579    /// Enable Metal Performance Shaders (MPS) integration
580    enable_mps: bool,
581    /// Metal language version
582    metal_version: (u32, u32),
583    /// Target Apple Neural Engine (ANE) when available
584    enable_ane: bool,
585}
586
587impl MetalKernelGenerator {
588    /// Create a new Metal kernel generator
589    ///
590    /// # Arguments
591    /// * `device_family` - Metal GPU family (e.g., "apple7" for M1)
592    pub fn new(device_family: String) -> Self {
593        // Detect if ANE is available (A11+ or Apple Silicon)
594        let enable_ane = device_family.starts_with("apple")
595            && device_family[5..].parse::<u32>().unwrap_or(0) >= 7;
596
597        Self {
598            device_family,
599            enable_mps: true,      // MPS available on all modern devices
600            metal_version: (2, 4), // Metal 2.4 for macOS 12+
601            enable_ane,
602        }
603    }
604
605    /// Enable or disable Metal Performance Shaders integration
606    pub fn set_mps(&mut self, enable: bool) {
607        self.enable_mps = enable;
608    }
609
610    /// Generate Metal Shading Language code for the computation graph
611    ///
612    /// Future implementation will:
613    /// - Generate optimized MSL kernels for each operation
614    /// - Integrate with Metal Performance Shaders for standard ops
615    /// - Utilize tile memory for data reuse
616    /// - Emit SIMD-group operations for reduction
617    /// - Generate ANE-compatible operations when possible
618    pub fn generate_metal(&self, graph: &ComputationGraph) -> JitResult<String> {
619        let node_count = graph.nodes().count();
620        let matmul_count = graph
621            .nodes()
622            .filter(|(_, n)| matches!(n.op, Operation::MatMul))
623            .count();
624        let conv_count = graph
625            .nodes()
626            .filter(|(_, n)| matches!(n.op, Operation::Conv2d { .. }))
627            .count();
628        let elementwise_count = node_count - matmul_count - conv_count;
629
630        let mps_eligible = matmul_count + conv_count;
631        let ane_hints = if self.enable_ane && conv_count > 0 {
632            format!(
633                "\n- {} convolution ops are ANE-eligible for ultra-low power inference",
634                conv_count
635            )
636        } else {
637            String::new()
638        };
639
640        Err(JitError::UnsupportedOp(format!(
641            "Metal shader generation not yet implemented.\n\
642             Target: {} (Metal {}. {})\n\
643             Graph statistics:\n\
644             - Total nodes: {}\n\
645             - Element-wise ops: {}\n\
646             - MatMul operations: {}\n\
647             - Conv2D operations: {}\n\
648             - MPS-eligible ops: {}{}\n\
649             \n\
650             Future Metal generation will support:\n\
651             - Metal Performance Shaders integration for {:.0}% of operations\n\
652             - Unified memory optimization (zero-copy on Apple Silicon)\n\
653             - Tile memory usage for {:.1}x bandwidth reduction\n\
654             - SIMD-group operations for efficient reduction\n\
655             - Concurrent kernel execution across multiple command buffers\n\
656             \nFallback: Use CPU backend with Accelerate framework for good performance.",
657            self.device_family,
658            self.metal_version.0,
659            self.metal_version.1,
660            node_count,
661            elementwise_count,
662            matmul_count,
663            conv_count,
664            mps_eligible,
665            ane_hints,
666            (mps_eligible as f64 / node_count as f64) * 100.0,
667            2.5 // Estimated bandwidth reduction from tile memory
668        )))
669    }
670
671    /// Estimate threadgroup size for a graph
672    pub fn estimate_threadgroup_size(&self, graph: &ComputationGraph) -> ThreadgroupSize {
673        let total_ops: usize = graph
674            .nodes()
675            .map(|(_, node)| node.output_shape.dims().iter().product::<usize>())
676            .sum();
677
678        // Metal recommends threadgroup sizes in multiples of SIMD width (32)
679        let threads_per_threadgroup = if total_ops < 1024 {
680            128
681        } else if total_ops < 1024 * 1024 {
682            256
683        } else {
684            512
685        };
686
687        ThreadgroupSize {
688            width: threads_per_threadgroup,
689            height: 1,
690            depth: 1,
691        }
692    }
693}
694
695/// Metal threadgroup size configuration
696#[derive(Debug, Clone)]
697pub struct ThreadgroupSize {
698    pub width: usize,
699    pub height: usize,
700    pub depth: usize,
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706
707    #[test]
708    fn test_code_generator_creation() {
709        let _gen = CodeGenerator::new(DeviceType::Cpu);
710        // Basic creation test
711        assert!(true);
712    }
713
714    #[test]
715    fn test_stride_computation() {
716        let gen = CodeGenerator::new(DeviceType::Cpu);
717
718        let strides = gen.compute_strides(&[2, 3, 4]);
719        assert_eq!(strides, vec![12, 4, 1]);
720
721        let strides = gen.compute_strides(&[10]);
722        assert_eq!(strides, vec![1]);
723    }
724}