1use crate::graph::{ComputationGraph, Node, NodeId, Operation};
4use crate::{CompiledKernel, JitError, JitResult, KernelMetadata, TensorDesc};
5use torsh_core::DeviceType;
6
7#[cfg(feature = "cranelift-backend")]
8use cranelift::prelude::*;
9
10pub struct CodeGenerator {
12 device: DeviceType,
13 #[cfg(feature = "cranelift-backend")]
14 cranelift: Option<CraneliftBackend>,
15}
16
17impl CodeGenerator {
18 pub fn new(device: DeviceType) -> Self {
20 Self {
21 device,
22 #[cfg(feature = "cranelift-backend")]
23 cranelift: match device {
24 DeviceType::Cpu => Some(CraneliftBackend::new()),
25 _ => None,
26 },
27 }
28 }
29
30 pub fn generate(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
32 match self.device {
33 DeviceType::Cpu => self.generate_cpu(graph),
34 DeviceType::Cuda(_) => self.generate_cuda(graph),
35 DeviceType::Metal(_) => self.generate_metal(graph),
36 _ => Err(JitError::UnsupportedOp(format!(
37 "Code generation not supported for {:?}",
38 self.device
39 ))),
40 }
41 }
42
43 fn generate_cpu(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
45 #[cfg(feature = "cranelift-backend")]
46 if let Some(ref backend) = self.cranelift {
47 return backend.generate(graph);
48 }
49
50 self.generate_interpreter(graph)
52 }
53
54 fn generate_cuda(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
63 let node_count = graph.nodes().count();
64 let operation_types: Vec<_> = graph
65 .nodes()
66 .map(|(_, node)| format!("{:?}", node.op))
67 .collect();
68
69 Err(JitError::UnsupportedOp(format!(
70 "CUDA code generation not yet implemented. \
71 Graph contains {} nodes with operations: {}. \
72 To enable CUDA support: \
73 1. Install CUDA toolkit (>=11.0) \
74 2. Enable 'cuda' feature flag \
75 3. Set CUDA_PATH environment variable \
76 \nFallback: Use CPU backend or interpreter mode.",
77 node_count,
78 operation_types.join(", ")
79 )))
80 }
81
82 fn generate_metal(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
91 let node_count = graph.nodes().count();
92 let has_matmul = graph
93 .nodes()
94 .any(|(_, node)| matches!(node.op, Operation::MatMul));
95 let has_conv = graph
96 .nodes()
97 .any(|(_, node)| matches!(node.op, Operation::Conv2d { .. }));
98
99 let recommendations = if has_matmul || has_conv {
100 "Consider using Metal Performance Shaders (MPS) backend for matrix/convolution operations."
101 } else {
102 "For element-wise operations, CPU backend may provide sufficient performance."
103 };
104
105 Err(JitError::UnsupportedOp(format!(
106 "Metal code generation not yet implemented. \
107 Graph contains {} nodes. \
108 Detected: {} \
109 To enable Metal support: \
110 1. Ensure macOS 10.15+ or iOS 13+ \
111 2. Enable 'metal' feature flag \
112 3. Install Metal developer tools \
113 \n{} \
114 \nFallback: Use CPU backend or interpreter mode.",
115 node_count,
116 if has_matmul {
117 "matrix multiplication"
118 } else if has_conv {
119 "convolutions"
120 } else {
121 "element-wise ops"
122 },
123 recommendations
124 )))
125 }
126
127 pub fn generate_from_ir(
129 &self,
130 ir_module: &crate::ir::IrModule,
131 ) -> JitResult<Vec<CompiledKernel>> {
132 self.generate_interpreter_from_ir(ir_module)
135 }
136
137 pub fn generate_interpreter_from_ir(
139 &self,
140 ir_module: &crate::ir::IrModule,
141 ) -> JitResult<Vec<CompiledKernel>> {
142 let mut kernels = Vec::new();
143
144 for (block_id, block) in &ir_module.blocks {
146 let kernel_id = format!("ir_kernel_{}", block_id);
147
148 let metadata = KernelMetadata {
150 inputs: ir_module
151 .inputs
152 .iter()
153 .filter_map(|&input| self.ir_value_to_tensor_desc(ir_module, input))
154 .collect(),
155 outputs: ir_module
156 .outputs
157 .iter()
158 .filter_map(|&output| self.ir_value_to_tensor_desc(ir_module, output))
159 .collect(),
160 shared_memory: 0,
161 block_size: (1, 1, 1),
162 grid_size: (1, 1, 1),
163 };
164
165 let mut code = Vec::new();
167 for instruction in &block.instructions {
168 let opcode = self.encode_ir_instruction(instruction)?;
169 code.push(opcode);
170 }
171
172 let kernel = CompiledKernel {
173 id: kernel_id,
174 source_nodes: Vec::new(), code,
176 metadata,
177 };
178
179 kernels.push(kernel);
180 }
181
182 Ok(kernels)
183 }
184
185 fn ir_value_to_tensor_desc(
187 &self,
188 ir_module: &crate::ir::IrModule,
189 ir_value: crate::ir::IrValue,
190 ) -> Option<TensorDesc> {
191 if let Some(value_def) = ir_module.get_value(ir_value) {
192 if let Some(type_def) = ir_module.get_type(value_def.ty) {
193 match &type_def.kind {
194 crate::ir::TypeKind::Tensor { shape, .. } => {
195 Some(TensorDesc {
196 dtype: torsh_core::DType::F32, shape: shape.clone(),
198 strides: self.compute_strides(shape),
199 offset: 0,
200 })
201 }
202 _ => None,
203 }
204 } else {
205 None
206 }
207 } else {
208 None
209 }
210 }
211
212 fn encode_ir_instruction(&self, instruction: &crate::ir::Instruction) -> JitResult<u8> {
214 let opcode = match &instruction.opcode {
215 crate::ir::IrOpcode::Add => 1,
216 crate::ir::IrOpcode::Sub => 2,
217 crate::ir::IrOpcode::Mul => 3,
218 crate::ir::IrOpcode::Div => 4,
219 crate::ir::IrOpcode::Neg => 5,
220 crate::ir::IrOpcode::Abs => 6,
221 crate::ir::IrOpcode::Exp => 7,
222 crate::ir::IrOpcode::Log => 8,
223 crate::ir::IrOpcode::Sqrt => 9,
224 crate::ir::IrOpcode::Sin => 10,
225 crate::ir::IrOpcode::Cos => 11,
226 crate::ir::IrOpcode::Tanh => 12,
227 crate::ir::IrOpcode::Sigmoid => 13,
228 crate::ir::IrOpcode::Relu => 14,
229 crate::ir::IrOpcode::Gelu => 15,
230 crate::ir::IrOpcode::MatMul => 16,
231 crate::ir::IrOpcode::Conv2d => 17,
232 crate::ir::IrOpcode::Pool2d => 18,
233 crate::ir::IrOpcode::Reshape => 19,
234 crate::ir::IrOpcode::Transpose => 20,
235 crate::ir::IrOpcode::Sum => 21,
236 crate::ir::IrOpcode::Mean => 22,
237 crate::ir::IrOpcode::Max => 23,
238 crate::ir::IrOpcode::Min => 24,
239 crate::ir::IrOpcode::Load => 25,
240 crate::ir::IrOpcode::Store => 26,
241 crate::ir::IrOpcode::Const => 27,
242 _ => {
243 return Err(JitError::UnsupportedOp(format!(
244 "IR opcode {:?} not supported in interpreter",
245 instruction.opcode
246 )))
247 }
248 };
249
250 Ok(opcode)
251 }
252
253 pub fn generate_interpreter(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
255 let mut kernels = Vec::new();
256
257 let order = graph
259 .topological_sort()
260 .map_err(|e| JitError::GraphError(format!("{:?}", e)))?;
261
262 for node_id in order {
264 if let Some(node) = graph.node(node_id) {
265 let kernel = self.generate_interpreter_kernel(graph, node_id, node)?;
266 kernels.push(kernel);
267 }
268 }
269
270 Ok(kernels)
271 }
272
273 fn generate_interpreter_kernel(
275 &self,
276 graph: &ComputationGraph,
277 node_id: NodeId,
278 node: &Node,
279 ) -> JitResult<CompiledKernel> {
280 let input_tensors: Vec<TensorDesc> = graph
282 .get_node_inputs(node_id)
283 .iter()
284 .filter_map(|&input_id| {
285 graph.node(input_id).map(|input_node| TensorDesc {
286 dtype: input_node.dtype,
287 shape: input_node.output_shape.dims().to_vec(),
288 strides: self.compute_strides(input_node.output_shape.dims()),
289 offset: 0,
290 })
291 })
292 .collect();
293
294 let metadata = KernelMetadata {
296 inputs: input_tensors,
297 outputs: vec![TensorDesc {
298 dtype: node.dtype,
299 shape: node.output_shape.dims().to_vec(),
300 strides: self.compute_strides(node.output_shape.dims()),
301 offset: 0,
302 }],
303 shared_memory: 0,
304 block_size: (1, 1, 1),
305 grid_size: (1, 1, 1),
306 };
307
308 let code = self.encode_operation(&node.op)?;
310
311 Ok(CompiledKernel {
312 id: format!("kernel_{:?}", node_id),
313 source_nodes: vec![node_id],
314 code,
315 metadata,
316 })
317 }
318
319 fn compute_strides(&self, shape: &[usize]) -> Vec<usize> {
321 let mut strides = vec![1; shape.len()];
322 for i in (0..shape.len() - 1).rev() {
323 strides[i] = strides[i + 1] * shape[i + 1];
324 }
325 strides
326 }
327
328 fn encode_operation(&self, op: &Operation) -> JitResult<Vec<u8>> {
330 let op_code = match op {
332 Operation::Add => 1,
333 Operation::Sub => 2,
334 Operation::Mul => 3,
335 Operation::Div => 4,
336 Operation::Relu => 5,
337 Operation::Sigmoid => 6,
338 Operation::Tanh => 7,
339 Operation::MatMul => 8,
340 _ => {
342 return Err(JitError::UnsupportedOp(format!(
343 "Operation {:?} not supported in interpreter",
344 op
345 )))
346 }
347 };
348
349 Ok(vec![op_code])
350 }
351}
352
353#[cfg(feature = "cranelift-backend")]
354struct CraneliftBackend {
355 _builder_context: FunctionBuilderContext,
356 _ctx: codegen::Context,
357}
358
359#[cfg(feature = "cranelift-backend")]
360impl CraneliftBackend {
361 fn new() -> Self {
362 let mut flag_builder = settings::builder();
363 flag_builder
364 .set("use_colocated_libcalls", "false")
365 .expect("setting should be valid");
366 flag_builder
367 .set("is_pic", "false")
368 .expect("setting should be valid");
369 let isa_builder = cranelift_native::builder().expect("native builder should succeed");
370 let isa = isa_builder
371 .finish(settings::Flags::new(flag_builder))
372 .expect("ISA creation should succeed");
373
374 let mut ctx = codegen::Context::new();
375 ctx.func.signature.call_conv = isa.default_call_conv();
376
377 Self {
378 _builder_context: FunctionBuilderContext::new(),
379 _ctx: ctx,
380 }
381 }
382
383 fn generate(&self, graph: &ComputationGraph) -> JitResult<Vec<CompiledKernel>> {
384 let mut kernels = Vec::new();
385
386 let kernel_groups = self.identify_kernel_groups(graph)?;
388
389 for (kernel_id, nodes) in kernel_groups.iter().enumerate() {
391 let kernel = self.generate_kernel(graph, kernel_id, nodes)?;
392 kernels.push(kernel);
393 }
394
395 Ok(kernels)
396 }
397
398 fn identify_kernel_groups(&self, graph: &ComputationGraph) -> JitResult<Vec<Vec<NodeId>>> {
399 let order = graph
402 .topological_sort()
403 .map_err(|e| JitError::GraphError(format!("{:?}", e)))?;
404
405 Ok(order.into_iter().map(|n| vec![n]).collect())
406 }
407
408 fn generate_kernel(
409 &self,
410 _graph: &ComputationGraph,
411 kernel_id: usize,
412 nodes: &[NodeId],
413 ) -> JitResult<CompiledKernel> {
414 Ok(CompiledKernel {
418 id: format!("cranelift_kernel_{}", kernel_id),
419 source_nodes: nodes.to_vec(),
420 code: vec![],
421 metadata: KernelMetadata {
422 inputs: vec![],
423 outputs: vec![],
424 shared_memory: 0,
425 block_size: (1, 1, 1),
426 grid_size: (1, 1, 1),
427 },
428 })
429 }
430}
431
432pub struct CudaKernelGenerator {
437 compute_capability: (u32, u32),
438 enable_tensor_cores: bool,
440 ptx_version: (u32, u32),
442 enable_cooperative_groups: bool,
444}
445
446impl CudaKernelGenerator {
447 pub fn new(compute_capability: (u32, u32)) -> Self {
452 let enable_tensor_cores = compute_capability.0 >= 7;
453 let enable_cooperative_groups = compute_capability.0 >= 6;
454
455 Self {
456 compute_capability,
457 enable_tensor_cores,
458 ptx_version: (7, 0), enable_cooperative_groups,
460 }
461 }
462
463 pub fn set_tensor_cores(&mut self, enable: bool) {
465 self.enable_tensor_cores = enable && self.compute_capability.0 >= 7;
466 }
467
468 pub fn generate_ptx(&self, graph: &ComputationGraph) -> JitResult<String> {
477 let node_count = graph.nodes().count();
478 let matmul_count = graph
479 .nodes()
480 .filter(|(_, n)| matches!(n.op, Operation::MatMul))
481 .count();
482 let conv_count = graph
483 .nodes()
484 .filter(|(_, n)| matches!(n.op, Operation::Conv2d { .. }))
485 .count();
486
487 let capability_str = format!(
488 "sm_{}{}",
489 self.compute_capability.0, self.compute_capability.1
490 );
491 let features = if self.enable_tensor_cores {
492 "tensor cores (WMMA), "
493 } else {
494 ""
495 };
496
497 Err(JitError::UnsupportedOp(format!(
498 "PTX generation not yet implemented.\n\
499 Target: {} (compute capability {}.{})\n\
500 Graph statistics:\n\
501 - Total nodes: {}\n\
502 - MatMul operations: {} {}\n\
503 - Conv2D operations: {} {}\n\
504 Features: {}cooperative groups\n\
505 \n\
506 Future PTX generation will support:\n\
507 - Automatic kernel fusion for {:.1}x speedup potential\n\
508 - Memory coalescing optimization\n\
509 - Shared memory tiling for matrix operations\n\
510 - Warp-level primitives for reduction operations\n\
511 \nFallback: Use CPU backend with BLAS/MKL for good performance.",
512 capability_str,
513 self.compute_capability.0,
514 self.compute_capability.1,
515 node_count,
516 matmul_count,
517 if self.enable_tensor_cores {
518 "(tensor core eligible)"
519 } else {
520 ""
521 },
522 conv_count,
523 if conv_count > 0 {
524 "(cudnn eligible)"
525 } else {
526 ""
527 },
528 features,
529 (matmul_count + conv_count).max(1) as f64 * 1.5 )))
531 }
532
533 pub fn estimate_launch_config(&self, graph: &ComputationGraph) -> LaunchConfiguration {
535 let total_ops: usize = graph
536 .nodes()
537 .map(|(_, node)| node.output_shape.dims().iter().product::<usize>())
538 .sum();
539
540 let threads_per_block = if total_ops < 1024 {
542 128
543 } else if total_ops < 1024 * 1024 {
544 256
545 } else {
546 512
547 };
548
549 let blocks = (total_ops + threads_per_block - 1) / threads_per_block;
550
551 LaunchConfiguration {
552 grid_dim: (blocks.min(65535), 1, 1),
553 block_dim: (threads_per_block, 1, 1),
554 shared_memory_bytes: 0, stream_id: 0,
556 }
557 }
558}
559
560#[derive(Debug, Clone)]
562pub struct LaunchConfiguration {
563 pub grid_dim: (usize, usize, usize),
565 pub block_dim: (usize, usize, usize),
567 pub shared_memory_bytes: usize,
569 pub stream_id: i32,
571}
572
573pub struct MetalKernelGenerator {
578 device_family: String,
579 enable_mps: bool,
581 metal_version: (u32, u32),
583 enable_ane: bool,
585}
586
587impl MetalKernelGenerator {
588 pub fn new(device_family: String) -> Self {
593 let enable_ane = device_family.starts_with("apple")
595 && device_family[5..].parse::<u32>().unwrap_or(0) >= 7;
596
597 Self {
598 device_family,
599 enable_mps: true, metal_version: (2, 4), enable_ane,
602 }
603 }
604
605 pub fn set_mps(&mut self, enable: bool) {
607 self.enable_mps = enable;
608 }
609
610 pub fn generate_metal(&self, graph: &ComputationGraph) -> JitResult<String> {
619 let node_count = graph.nodes().count();
620 let matmul_count = graph
621 .nodes()
622 .filter(|(_, n)| matches!(n.op, Operation::MatMul))
623 .count();
624 let conv_count = graph
625 .nodes()
626 .filter(|(_, n)| matches!(n.op, Operation::Conv2d { .. }))
627 .count();
628 let elementwise_count = node_count - matmul_count - conv_count;
629
630 let mps_eligible = matmul_count + conv_count;
631 let ane_hints = if self.enable_ane && conv_count > 0 {
632 format!(
633 "\n- {} convolution ops are ANE-eligible for ultra-low power inference",
634 conv_count
635 )
636 } else {
637 String::new()
638 };
639
640 Err(JitError::UnsupportedOp(format!(
641 "Metal shader generation not yet implemented.\n\
642 Target: {} (Metal {}. {})\n\
643 Graph statistics:\n\
644 - Total nodes: {}\n\
645 - Element-wise ops: {}\n\
646 - MatMul operations: {}\n\
647 - Conv2D operations: {}\n\
648 - MPS-eligible ops: {}{}\n\
649 \n\
650 Future Metal generation will support:\n\
651 - Metal Performance Shaders integration for {:.0}% of operations\n\
652 - Unified memory optimization (zero-copy on Apple Silicon)\n\
653 - Tile memory usage for {:.1}x bandwidth reduction\n\
654 - SIMD-group operations for efficient reduction\n\
655 - Concurrent kernel execution across multiple command buffers\n\
656 \nFallback: Use CPU backend with Accelerate framework for good performance.",
657 self.device_family,
658 self.metal_version.0,
659 self.metal_version.1,
660 node_count,
661 elementwise_count,
662 matmul_count,
663 conv_count,
664 mps_eligible,
665 ane_hints,
666 (mps_eligible as f64 / node_count as f64) * 100.0,
667 2.5 )))
669 }
670
671 pub fn estimate_threadgroup_size(&self, graph: &ComputationGraph) -> ThreadgroupSize {
673 let total_ops: usize = graph
674 .nodes()
675 .map(|(_, node)| node.output_shape.dims().iter().product::<usize>())
676 .sum();
677
678 let threads_per_threadgroup = if total_ops < 1024 {
680 128
681 } else if total_ops < 1024 * 1024 {
682 256
683 } else {
684 512
685 };
686
687 ThreadgroupSize {
688 width: threads_per_threadgroup,
689 height: 1,
690 depth: 1,
691 }
692 }
693}
694
695#[derive(Debug, Clone)]
697pub struct ThreadgroupSize {
698 pub width: usize,
699 pub height: usize,
700 pub depth: usize,
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[test]
708 fn test_code_generator_creation() {
709 let _gen = CodeGenerator::new(DeviceType::Cpu);
710 assert!(true);
712 }
713
714 #[test]
715 fn test_stride_computation() {
716 let gen = CodeGenerator::new(DeviceType::Cpu);
717
718 let strides = gen.compute_strides(&[2, 3, 4]);
719 assert_eq!(strides, vec![12, 4, 1]);
720
721 let strides = gen.compute_strides(&[10]);
722 assert_eq!(strides, vec![1]);
723 }
724}