Skip to main content

torsh_jit/
ir.rs

1//! Intermediate Representation (IR) for JIT compilation
2//!
3//! This module provides a lower-level IR that sits between the high-level
4//! computation graph and the target-specific code generation.
5
6use crate::graph::NodeId;
7use indexmap::IndexMap;
8use std::collections::HashMap;
9use torsh_core::{DType, Shape};
10
11/// Intermediate Representation of a computation
12#[derive(Debug, Clone)]
13pub struct IrModule {
14    /// Module name
15    pub name: String,
16
17    /// Input values
18    pub inputs: Vec<IrValue>,
19
20    /// Output values
21    pub outputs: Vec<IrValue>,
22
23    /// Basic blocks containing instructions
24    pub blocks: IndexMap<BlockId, BasicBlock>,
25
26    /// Entry block
27    pub entry_block: BlockId,
28
29    /// Value definitions
30    pub values: IndexMap<IrValue, ValueDef>,
31
32    /// Type definitions
33    pub types: IndexMap<IrType, TypeDef>,
34}
35
36/// Basic block identifier
37pub type BlockId = u32;
38
39/// IR value identifier
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub struct IrValue(pub u32);
42
43/// IR type identifier
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub struct IrType(pub u32);
46
47/// Basic block containing a sequence of instructions
48#[derive(Debug, Clone)]
49pub struct BasicBlock {
50    /// Block identifier
51    pub id: BlockId,
52
53    /// Block parameters (phi nodes)
54    pub params: Vec<IrValue>,
55
56    /// Instructions in execution order
57    pub instructions: Vec<Instruction>,
58
59    /// Block terminator
60    pub terminator: Option<Terminator>,
61}
62
63/// Single instruction in IR
64#[derive(Debug, Clone)]
65pub struct Instruction {
66    /// Result value (if any)
67    pub result: Option<IrValue>,
68
69    /// Operation to perform
70    pub opcode: IrOpcode,
71
72    /// Input operands
73    pub operands: Vec<IrValue>,
74
75    /// Additional attributes
76    pub attrs: HashMap<String, IrAttribute>,
77}
78
79/// IR operation codes
80#[derive(Debug, Clone, PartialEq, Hash)]
81pub enum IrOpcode {
82    // Arithmetic operations
83    Add,
84    Sub,
85    Mul,
86    Div,
87    Rem,
88    Neg,
89    Abs,
90
91    // Mathematical functions
92    Exp,
93    Log,
94    Sqrt,
95    Sin,
96    Cos,
97    Tanh,
98    Sigmoid,
99
100    // Logical operations
101    And,
102    Or,
103    Not,
104    Xor,
105
106    // Comparison operations
107    Eq,
108    Ne,
109    Lt,
110    Le,
111    Gt,
112    Ge,
113
114    // Memory operations
115    Load,
116    Store,
117    Alloca,
118
119    // Tensor operations
120    Reshape,
121    Transpose,
122    Slice,
123    Concat,
124    Split,
125
126    // Matrix operations
127    MatMul,
128    Conv2d,
129    Pool2d,
130
131    // Reduction operations
132    Sum,
133    Mean,
134    Max,
135    Min,
136    Argmax,
137    Argmin,
138
139    // Activation functions
140    Relu,
141    Gelu,
142    Softmax,
143
144    // Control flow
145    Br,
146    CondBr,
147    Call,
148    Return,
149
150    // Constants
151    Const,
152
153    // Type conversions
154    Cast,
155    Bitcast,
156
157    // Intrinsics
158    Intrinsic(String),
159
160    // No-operation
161    Nop,
162}
163
164/// Block terminator (control flow)
165#[derive(Debug, Clone)]
166pub enum Terminator {
167    /// Unconditional branch
168    Branch { target: BlockId },
169
170    /// Conditional branch
171    CondBranch {
172        condition: IrValue,
173        then_block: BlockId,
174        else_block: BlockId,
175    },
176
177    /// Return from function
178    Return { value: Option<IrValue> },
179
180    /// Unreachable code
181    Unreachable,
182}
183
184/// Value definition
185#[derive(Debug, Clone)]
186pub struct ValueDef {
187    /// Value type
188    pub ty: IrType,
189
190    /// Source location (for debugging)
191    pub source_node: Option<NodeId>,
192
193    /// Value kind
194    pub kind: ValueKind,
195}
196
197/// Kind of value
198#[derive(Debug, Clone)]
199pub enum ValueKind {
200    /// Function parameter
201    Parameter { index: usize },
202
203    /// Instruction result
204    Instruction { block: BlockId, index: usize },
205
206    /// Constant value
207    Constant { data: ConstantData },
208
209    /// Undefined value
210    Undef,
211}
212
213/// Constant data
214#[derive(Debug, Clone)]
215pub enum ConstantData {
216    /// Scalar integer
217    Int(i64),
218
219    /// Scalar float
220    Float(f64),
221
222    /// Boolean
223    Bool(bool),
224
225    /// String
226    String(String),
227
228    /// Array of constants
229    Array(Vec<ConstantData>),
230
231    /// Tensor data
232    Tensor { shape: Vec<usize>, data: Vec<f32> },
233}
234
235/// Type definition
236#[derive(Debug, Clone)]
237pub struct TypeDef {
238    /// Type kind
239    pub kind: TypeKind,
240
241    /// Size in bytes (if known)
242    pub size: Option<usize>,
243
244    /// Alignment requirements
245    pub align: Option<usize>,
246}
247
248/// Kind of type
249#[derive(Debug, Clone, PartialEq, Eq, Hash)]
250pub enum TypeKind {
251    /// Void type
252    Void,
253
254    /// Boolean type
255    Bool,
256
257    /// Integer types
258    I8,
259    I16,
260    I32,
261    I64,
262
263    /// Unsigned integer types
264    U8,
265    U16,
266    U32,
267    U64,
268
269    /// Floating point types
270    F16,
271    F32,
272    F64,
273
274    /// Complex types
275    C64,
276    C128,
277
278    /// Pointer type
279    Pointer {
280        pointee: IrType,
281    },
282
283    /// Array type
284    Array {
285        element: IrType,
286        length: usize,
287    },
288
289    /// Tensor type
290    Tensor {
291        element: IrType,
292        shape: Vec<usize>,
293    },
294
295    /// Function type
296    Function {
297        params: Vec<IrType>,
298        return_type: Option<IrType>,
299    },
300
301    /// Struct type
302    Struct {
303        fields: Vec<IrType>,
304    },
305}
306
307/// IR attribute value
308#[derive(Debug, Clone)]
309pub enum IrAttribute {
310    Int(i64),
311    Float(f64),
312    String(String),
313    Bool(bool),
314    Array(Vec<IrAttribute>),
315}
316
317impl IrModule {
318    /// Create a new empty IR module
319    pub fn new(name: String) -> Self {
320        Self {
321            name,
322            inputs: Vec::new(),
323            outputs: Vec::new(),
324            blocks: IndexMap::new(),
325            entry_block: 0,
326            values: IndexMap::new(),
327            types: IndexMap::new(),
328        }
329    }
330
331    /// Add a new basic block
332    pub fn add_block(&mut self) -> BlockId {
333        let id = self.blocks.len() as BlockId;
334        let block = BasicBlock {
335            id,
336            params: Vec::new(),
337            instructions: Vec::new(),
338            terminator: None,
339        };
340        self.blocks.insert(id, block);
341        id
342    }
343
344    /// Add a new value
345    pub fn add_value(&mut self, def: ValueDef) -> IrValue {
346        let id = IrValue(self.values.len() as u32);
347        self.values.insert(id, def);
348        id
349    }
350
351    /// Add a value (for external access)
352    pub fn add_value_external(&mut self, def: ValueDef) -> IrValue {
353        self.add_value(def)
354    }
355
356    /// Add a new type
357    pub fn add_type(&mut self, def: TypeDef) -> IrType {
358        // Check if type already exists
359        for (&existing_id, existing_def) in &self.types {
360            if existing_def.kind == def.kind {
361                return existing_id;
362            }
363        }
364
365        let id = IrType(self.types.len() as u32);
366        self.types.insert(id, def);
367        id
368    }
369
370    /// Get a block by ID
371    pub fn get_block(&self, id: BlockId) -> Option<&BasicBlock> {
372        self.blocks.get(&id)
373    }
374
375    /// Get a mutable block by ID
376    pub fn get_block_mut(&mut self, id: BlockId) -> Option<&mut BasicBlock> {
377        self.blocks.get_mut(&id)
378    }
379
380    /// Get value definition
381    pub fn get_value(&self, value: IrValue) -> Option<&ValueDef> {
382        self.values.get(&value)
383    }
384
385    /// Get type definition
386    pub fn get_type(&self, ty: IrType) -> Option<&TypeDef> {
387        self.types.get(&ty)
388    }
389
390    /// Validate the IR module
391    pub fn validate(&self) -> Result<(), String> {
392        // Check entry block exists
393        if !self.blocks.contains_key(&self.entry_block) {
394            return Err("Entry block does not exist".to_string());
395        }
396
397        // Validate each block
398        for (id, block) in &self.blocks {
399            if *id != block.id {
400                return Err(format!("Block ID mismatch: {} != {}", id, block.id));
401            }
402
403            // Validate instructions
404            for (i, instr) in block.instructions.iter().enumerate() {
405                self.validate_instruction(instr, *id, i)?;
406            }
407
408            // Validate terminator
409            if let Some(ref term) = block.terminator {
410                self.validate_terminator(term)?;
411            }
412        }
413
414        // Validate values
415        for (value, def) in &self.values {
416            self.validate_value(*value, def)?;
417        }
418
419        Ok(())
420    }
421
422    fn validate_instruction(
423        &self,
424        instr: &Instruction,
425        block_id: BlockId,
426        _index: usize,
427    ) -> Result<(), String> {
428        // Check operands exist
429        for &operand in &instr.operands {
430            if !self.values.contains_key(&operand) {
431                return Err(format!(
432                    "Operand {:?} not found in block {}",
433                    operand, block_id
434                ));
435            }
436        }
437
438        // Check result type consistency
439        if let Some(result) = instr.result {
440            if !self.values.contains_key(&result) {
441                return Err(format!(
442                    "Result {:?} not found in block {}",
443                    result, block_id
444                ));
445            }
446        }
447
448        Ok(())
449    }
450
451    fn validate_terminator(&self, term: &Terminator) -> Result<(), String> {
452        match term {
453            Terminator::Branch { target } => {
454                if !self.blocks.contains_key(target) {
455                    return Err(format!("Branch target block {} does not exist", target));
456                }
457            }
458            Terminator::CondBranch {
459                condition,
460                then_block,
461                else_block,
462            } => {
463                // Validate condition value exists
464                if !self.values.contains_key(condition) {
465                    return Err(format!("Condition value {:?} does not exist", condition));
466                }
467                // Validate target blocks exist
468                if !self.blocks.contains_key(then_block) {
469                    return Err(format!("Then block {} does not exist", then_block));
470                }
471                if !self.blocks.contains_key(else_block) {
472                    return Err(format!("Else block {} does not exist", else_block));
473                }
474            }
475            Terminator::Return { value } => {
476                if let Some(val) = value {
477                    if !self.values.contains_key(val) {
478                        return Err(format!("Return value {:?} does not exist", val));
479                    }
480                }
481            }
482            Terminator::Unreachable => {
483                // Nothing to validate for unreachable
484            }
485        }
486        Ok(())
487    }
488
489    fn validate_value(&self, _value: IrValue, def: &ValueDef) -> Result<(), String> {
490        // Check type exists
491        if !self.types.contains_key(&def.ty) {
492            return Err(format!("Type {:?} not found", def.ty));
493        }
494
495        Ok(())
496    }
497
498    /// Get a function-like interface for debugging compatibility
499    /// For now, treat the entry block as the main "function"
500    pub fn get_function(&self, _name: &str) -> Option<&BasicBlock> {
501        self.blocks.get(&self.entry_block)
502    }
503
504    /// Inline small functions (placeholder implementation)
505    pub fn inline_small_functions(&mut self) -> crate::JitResult<()> {
506        // Placeholder implementation for function inlining
507        // This would analyze the module and inline small functions
508        Ok(())
509    }
510
511    /// Get all functions in the module (returns an iterator over blocks as functions)
512    pub fn functions_mut(&mut self) -> impl Iterator<Item = &mut BasicBlock> {
513        self.blocks.values_mut()
514    }
515
516    /// Get all instructions in the module
517    pub fn instructions(&self) -> impl Iterator<Item = &Instruction> {
518        self.blocks
519            .values()
520            .flat_map(|block| block.instructions.iter())
521    }
522
523    /// Get all instructions in the module (mutable)
524    pub fn instructions_mut(&mut self) -> impl Iterator<Item = &mut Instruction> {
525        self.blocks
526            .values_mut()
527            .flat_map(|block| block.instructions.iter_mut())
528    }
529
530    /// Retain instructions that satisfy the predicate
531    pub fn retain_instructions<F>(&mut self, mut predicate: F)
532    where
533        F: FnMut(usize, &Instruction) -> bool,
534    {
535        let mut global_idx = 0;
536        for block in self.blocks.values_mut() {
537            block.instructions.retain(|instruction| {
538                let keep = predicate(global_idx, instruction);
539                global_idx += 1;
540                keep
541            });
542        }
543    }
544
545    /// Remove unused functions (placeholder implementation)
546    pub fn remove_unused_functions(&mut self) -> crate::JitResult<()> {
547        // Placeholder - in a real implementation, this would analyze function usage
548        // and remove unused function blocks
549        Ok(())
550    }
551}
552
553impl BasicBlock {
554    /// Get the instructions in this block
555    pub fn instructions(&self) -> &Vec<Instruction> {
556        &self.instructions
557    }
558}
559
560impl Instruction {
561    /// Check if this instruction produces a value
562    pub fn produces_value(&self) -> bool {
563        self.result.is_some()
564    }
565
566    /// Get the operands of this instruction
567    pub fn operands(&self) -> &Vec<IrValue> {
568        &self.operands
569    }
570}
571
572/// IR builder for constructing IR modules
573pub struct IrBuilder {
574    pub module: IrModule,
575    current_block: Option<BlockId>,
576    #[allow(dead_code)]
577    value_counter: u32,
578    type_cache: HashMap<TypeKind, IrType>,
579}
580
581impl IrBuilder {
582    /// Create a new IR builder
583    pub fn new(module_name: String) -> Self {
584        Self {
585            module: IrModule::new(module_name),
586            current_block: None,
587            value_counter: 0,
588            type_cache: HashMap::new(),
589        }
590    }
591
592    /// Set current block for instruction insertion
593    pub fn set_current_block(&mut self, block: BlockId) {
594        self.current_block = Some(block);
595    }
596
597    /// Add a new basic block and set it as current
598    pub fn add_block(&mut self) -> BlockId {
599        let id = self.module.add_block();
600        self.current_block = Some(id);
601        id
602    }
603
604    /// Get or create a type
605    pub fn get_type(&mut self, kind: TypeKind) -> IrType {
606        if let Some(&existing) = self.type_cache.get(&kind) {
607            return existing;
608        }
609
610        let size = match &kind {
611            TypeKind::Void => Some(0),
612            TypeKind::Bool | TypeKind::I8 | TypeKind::U8 => Some(1),
613            TypeKind::I16 | TypeKind::U16 | TypeKind::F16 => Some(2),
614            TypeKind::I32 | TypeKind::U32 | TypeKind::F32 => Some(4),
615            TypeKind::I64 | TypeKind::U64 | TypeKind::F64 | TypeKind::C64 => Some(8),
616            TypeKind::C128 => Some(16),
617            _ => None,
618        };
619
620        let ty = self.module.add_type(TypeDef {
621            kind: kind.clone(),
622            size,
623            align: size,
624        });
625
626        self.type_cache.insert(kind, ty);
627        ty
628    }
629
630    /// Create a constant value
631    pub fn const_int(&mut self, value: i64, ty: IrType) -> IrValue {
632        let val_def = ValueDef {
633            ty,
634            source_node: None,
635            kind: ValueKind::Constant {
636                data: ConstantData::Int(value),
637            },
638        };
639        self.module.add_value(val_def)
640    }
641
642    /// Create a constant float
643    pub fn const_float(&mut self, value: f64, ty: IrType) -> IrValue {
644        let val_def = ValueDef {
645            ty,
646            source_node: None,
647            kind: ValueKind::Constant {
648                data: ConstantData::Float(value),
649            },
650        };
651        self.module.add_value(val_def)
652    }
653
654    /// Add an instruction to the current block
655    pub fn add_instruction(
656        &mut self,
657        opcode: IrOpcode,
658        operands: Vec<IrValue>,
659        result_type: Option<IrType>,
660    ) -> Option<IrValue> {
661        let current_block = self.current_block.expect("No current block set");
662
663        let result = if let Some(ty) = result_type {
664            let val_def = ValueDef {
665                ty,
666                source_node: None,
667                kind: ValueKind::Instruction {
668                    block: current_block,
669                    index: 0, // Will be updated
670                },
671            };
672            Some(self.module.add_value(val_def))
673        } else {
674            None
675        };
676
677        let instr = Instruction {
678            result,
679            opcode,
680            operands,
681            attrs: HashMap::new(),
682        };
683
684        if let Some(block) = self.module.get_block_mut(current_block) {
685            block.instructions.push(instr);
686        }
687
688        result
689    }
690
691    /// Set block terminator
692    pub fn set_terminator(&mut self, terminator: Terminator) {
693        if let Some(current_block) = self.current_block {
694            if let Some(block) = self.module.get_block_mut(current_block) {
695                block.terminator = Some(terminator);
696            }
697        }
698    }
699
700    /// Build the final IR module
701    pub fn build(self) -> IrModule {
702        self.module
703    }
704}
705
706/// Convert a torsh DType to IR type
707pub fn dtype_to_ir_type(dtype: DType) -> TypeKind {
708    match dtype {
709        DType::Bool => TypeKind::Bool,
710        DType::I8 => TypeKind::I8,
711        DType::I16 => TypeKind::I16,
712        DType::I32 => TypeKind::I32,
713        DType::I64 => TypeKind::I64,
714        DType::U8 => TypeKind::U8,
715        DType::U32 => TypeKind::U32,
716        DType::U64 => TypeKind::U64,
717        DType::F16 => TypeKind::F16,
718        DType::BF16 => TypeKind::F16, // Approximate
719        DType::F32 => TypeKind::F32,
720        DType::F64 => TypeKind::F64,
721        DType::C64 => TypeKind::C64,
722        DType::C128 => TypeKind::C128,
723        DType::QInt8 => TypeKind::I8,   // Quantized as regular int8
724        DType::QUInt8 => TypeKind::U8,  // Quantized as regular uint8
725        DType::QInt32 => TypeKind::I32, // Quantized as regular int32
726    }
727}
728
729/// Convert a shape and dtype to tensor type
730pub fn shape_dtype_to_tensor_type(shape: &Shape, dtype: DType) -> TypeKind {
731    let _element_type_kind = dtype_to_ir_type(dtype);
732    TypeKind::Tensor {
733        element: IrType(0), // Placeholder, will be resolved by builder
734        shape: shape.dims().to_vec(),
735    }
736}
737
738// Compatibility type aliases for advanced features
739pub type IrFunction = IrModule;
740pub type IrInstruction = Instruction;
741
742// Additional compatibility types that might be needed
743pub type InterproceduralResult<T> = Result<T, String>;
744pub type AnalysisResult<T> = Result<T, String>;
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    #[test]
751    fn test_ir_module_creation() {
752        let module = IrModule::new("test".to_string());
753        assert_eq!(module.name, "test");
754        assert!(module.blocks.is_empty());
755        assert!(module.values.is_empty());
756    }
757
758    #[test]
759    fn test_ir_builder() {
760        let mut builder = IrBuilder::new("test_module".to_string());
761
762        // Create types
763        let i32_type = builder.get_type(TypeKind::I32);
764        let f32_type = builder.get_type(TypeKind::F32);
765
766        // Create a block
767        let block = builder.add_block();
768        assert_eq!(block, 0);
769
770        // Create constants
771        let const1 = builder.const_int(42, i32_type);
772        let const2 = builder.const_float(3.14, f32_type);
773
774        // Add instruction
775        let result = builder.add_instruction(IrOpcode::Add, vec![const1, const2], Some(f32_type));
776        assert!(result.is_some());
777
778        // Build module
779        let module = builder.build();
780        assert_eq!(module.name, "test_module");
781        assert!(!module.blocks.is_empty());
782        assert!(!module.values.is_empty());
783    }
784
785    #[test]
786    fn test_module_validation() {
787        let mut builder = IrBuilder::new("valid_module".to_string());
788        let _block = builder.add_block();
789        builder.set_terminator(Terminator::Return { value: None });
790
791        let module = builder.build();
792        assert!(module.validate().is_ok());
793    }
794}