1use crate::graph::NodeId;
7use indexmap::IndexMap;
8use std::collections::HashMap;
9use torsh_core::{DType, Shape};
10
11#[derive(Debug, Clone)]
13pub struct IrModule {
14 pub name: String,
16
17 pub inputs: Vec<IrValue>,
19
20 pub outputs: Vec<IrValue>,
22
23 pub blocks: IndexMap<BlockId, BasicBlock>,
25
26 pub entry_block: BlockId,
28
29 pub values: IndexMap<IrValue, ValueDef>,
31
32 pub types: IndexMap<IrType, TypeDef>,
34}
35
36pub type BlockId = u32;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub struct IrValue(pub u32);
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub struct IrType(pub u32);
46
47#[derive(Debug, Clone)]
49pub struct BasicBlock {
50 pub id: BlockId,
52
53 pub params: Vec<IrValue>,
55
56 pub instructions: Vec<Instruction>,
58
59 pub terminator: Option<Terminator>,
61}
62
63#[derive(Debug, Clone)]
65pub struct Instruction {
66 pub result: Option<IrValue>,
68
69 pub opcode: IrOpcode,
71
72 pub operands: Vec<IrValue>,
74
75 pub attrs: HashMap<String, IrAttribute>,
77}
78
79#[derive(Debug, Clone, PartialEq, Hash)]
81pub enum IrOpcode {
82 Add,
84 Sub,
85 Mul,
86 Div,
87 Rem,
88 Neg,
89 Abs,
90
91 Exp,
93 Log,
94 Sqrt,
95 Sin,
96 Cos,
97 Tanh,
98 Sigmoid,
99
100 And,
102 Or,
103 Not,
104 Xor,
105
106 Eq,
108 Ne,
109 Lt,
110 Le,
111 Gt,
112 Ge,
113
114 Load,
116 Store,
117 Alloca,
118
119 Reshape,
121 Transpose,
122 Slice,
123 Concat,
124 Split,
125
126 MatMul,
128 Conv2d,
129 Pool2d,
130
131 Sum,
133 Mean,
134 Max,
135 Min,
136 Argmax,
137 Argmin,
138
139 Relu,
141 Gelu,
142 Softmax,
143
144 Br,
146 CondBr,
147 Call,
148 Return,
149
150 Const,
152
153 Cast,
155 Bitcast,
156
157 Intrinsic(String),
159
160 Nop,
162}
163
164#[derive(Debug, Clone)]
166pub enum Terminator {
167 Branch { target: BlockId },
169
170 CondBranch {
172 condition: IrValue,
173 then_block: BlockId,
174 else_block: BlockId,
175 },
176
177 Return { value: Option<IrValue> },
179
180 Unreachable,
182}
183
184#[derive(Debug, Clone)]
186pub struct ValueDef {
187 pub ty: IrType,
189
190 pub source_node: Option<NodeId>,
192
193 pub kind: ValueKind,
195}
196
197#[derive(Debug, Clone)]
199pub enum ValueKind {
200 Parameter { index: usize },
202
203 Instruction { block: BlockId, index: usize },
205
206 Constant { data: ConstantData },
208
209 Undef,
211}
212
213#[derive(Debug, Clone)]
215pub enum ConstantData {
216 Int(i64),
218
219 Float(f64),
221
222 Bool(bool),
224
225 String(String),
227
228 Array(Vec<ConstantData>),
230
231 Tensor { shape: Vec<usize>, data: Vec<f32> },
233}
234
235#[derive(Debug, Clone)]
237pub struct TypeDef {
238 pub kind: TypeKind,
240
241 pub size: Option<usize>,
243
244 pub align: Option<usize>,
246}
247
248#[derive(Debug, Clone, PartialEq, Eq, Hash)]
250pub enum TypeKind {
251 Void,
253
254 Bool,
256
257 I8,
259 I16,
260 I32,
261 I64,
262
263 U8,
265 U16,
266 U32,
267 U64,
268
269 F16,
271 F32,
272 F64,
273
274 C64,
276 C128,
277
278 Pointer {
280 pointee: IrType,
281 },
282
283 Array {
285 element: IrType,
286 length: usize,
287 },
288
289 Tensor {
291 element: IrType,
292 shape: Vec<usize>,
293 },
294
295 Function {
297 params: Vec<IrType>,
298 return_type: Option<IrType>,
299 },
300
301 Struct {
303 fields: Vec<IrType>,
304 },
305}
306
307#[derive(Debug, Clone)]
309pub enum IrAttribute {
310 Int(i64),
311 Float(f64),
312 String(String),
313 Bool(bool),
314 Array(Vec<IrAttribute>),
315}
316
317impl IrModule {
318 pub fn new(name: String) -> Self {
320 Self {
321 name,
322 inputs: Vec::new(),
323 outputs: Vec::new(),
324 blocks: IndexMap::new(),
325 entry_block: 0,
326 values: IndexMap::new(),
327 types: IndexMap::new(),
328 }
329 }
330
331 pub fn add_block(&mut self) -> BlockId {
333 let id = self.blocks.len() as BlockId;
334 let block = BasicBlock {
335 id,
336 params: Vec::new(),
337 instructions: Vec::new(),
338 terminator: None,
339 };
340 self.blocks.insert(id, block);
341 id
342 }
343
344 pub fn add_value(&mut self, def: ValueDef) -> IrValue {
346 let id = IrValue(self.values.len() as u32);
347 self.values.insert(id, def);
348 id
349 }
350
351 pub fn add_value_external(&mut self, def: ValueDef) -> IrValue {
353 self.add_value(def)
354 }
355
356 pub fn add_type(&mut self, def: TypeDef) -> IrType {
358 for (&existing_id, existing_def) in &self.types {
360 if existing_def.kind == def.kind {
361 return existing_id;
362 }
363 }
364
365 let id = IrType(self.types.len() as u32);
366 self.types.insert(id, def);
367 id
368 }
369
370 pub fn get_block(&self, id: BlockId) -> Option<&BasicBlock> {
372 self.blocks.get(&id)
373 }
374
375 pub fn get_block_mut(&mut self, id: BlockId) -> Option<&mut BasicBlock> {
377 self.blocks.get_mut(&id)
378 }
379
380 pub fn get_value(&self, value: IrValue) -> Option<&ValueDef> {
382 self.values.get(&value)
383 }
384
385 pub fn get_type(&self, ty: IrType) -> Option<&TypeDef> {
387 self.types.get(&ty)
388 }
389
390 pub fn validate(&self) -> Result<(), String> {
392 if !self.blocks.contains_key(&self.entry_block) {
394 return Err("Entry block does not exist".to_string());
395 }
396
397 for (id, block) in &self.blocks {
399 if *id != block.id {
400 return Err(format!("Block ID mismatch: {} != {}", id, block.id));
401 }
402
403 for (i, instr) in block.instructions.iter().enumerate() {
405 self.validate_instruction(instr, *id, i)?;
406 }
407
408 if let Some(ref term) = block.terminator {
410 self.validate_terminator(term)?;
411 }
412 }
413
414 for (value, def) in &self.values {
416 self.validate_value(*value, def)?;
417 }
418
419 Ok(())
420 }
421
422 fn validate_instruction(
423 &self,
424 instr: &Instruction,
425 block_id: BlockId,
426 _index: usize,
427 ) -> Result<(), String> {
428 for &operand in &instr.operands {
430 if !self.values.contains_key(&operand) {
431 return Err(format!(
432 "Operand {:?} not found in block {}",
433 operand, block_id
434 ));
435 }
436 }
437
438 if let Some(result) = instr.result {
440 if !self.values.contains_key(&result) {
441 return Err(format!(
442 "Result {:?} not found in block {}",
443 result, block_id
444 ));
445 }
446 }
447
448 Ok(())
449 }
450
451 fn validate_terminator(&self, term: &Terminator) -> Result<(), String> {
452 match term {
453 Terminator::Branch { target } => {
454 if !self.blocks.contains_key(target) {
455 return Err(format!("Branch target block {} does not exist", target));
456 }
457 }
458 Terminator::CondBranch {
459 condition,
460 then_block,
461 else_block,
462 } => {
463 if !self.values.contains_key(condition) {
465 return Err(format!("Condition value {:?} does not exist", condition));
466 }
467 if !self.blocks.contains_key(then_block) {
469 return Err(format!("Then block {} does not exist", then_block));
470 }
471 if !self.blocks.contains_key(else_block) {
472 return Err(format!("Else block {} does not exist", else_block));
473 }
474 }
475 Terminator::Return { value } => {
476 if let Some(val) = value {
477 if !self.values.contains_key(val) {
478 return Err(format!("Return value {:?} does not exist", val));
479 }
480 }
481 }
482 Terminator::Unreachable => {
483 }
485 }
486 Ok(())
487 }
488
489 fn validate_value(&self, _value: IrValue, def: &ValueDef) -> Result<(), String> {
490 if !self.types.contains_key(&def.ty) {
492 return Err(format!("Type {:?} not found", def.ty));
493 }
494
495 Ok(())
496 }
497
498 pub fn get_function(&self, _name: &str) -> Option<&BasicBlock> {
501 self.blocks.get(&self.entry_block)
502 }
503
504 pub fn inline_small_functions(&mut self) -> crate::JitResult<()> {
506 Ok(())
509 }
510
511 pub fn functions_mut(&mut self) -> impl Iterator<Item = &mut BasicBlock> {
513 self.blocks.values_mut()
514 }
515
516 pub fn instructions(&self) -> impl Iterator<Item = &Instruction> {
518 self.blocks
519 .values()
520 .flat_map(|block| block.instructions.iter())
521 }
522
523 pub fn instructions_mut(&mut self) -> impl Iterator<Item = &mut Instruction> {
525 self.blocks
526 .values_mut()
527 .flat_map(|block| block.instructions.iter_mut())
528 }
529
530 pub fn retain_instructions<F>(&mut self, mut predicate: F)
532 where
533 F: FnMut(usize, &Instruction) -> bool,
534 {
535 let mut global_idx = 0;
536 for block in self.blocks.values_mut() {
537 block.instructions.retain(|instruction| {
538 let keep = predicate(global_idx, instruction);
539 global_idx += 1;
540 keep
541 });
542 }
543 }
544
545 pub fn remove_unused_functions(&mut self) -> crate::JitResult<()> {
547 Ok(())
550 }
551}
552
553impl BasicBlock {
554 pub fn instructions(&self) -> &Vec<Instruction> {
556 &self.instructions
557 }
558}
559
560impl Instruction {
561 pub fn produces_value(&self) -> bool {
563 self.result.is_some()
564 }
565
566 pub fn operands(&self) -> &Vec<IrValue> {
568 &self.operands
569 }
570}
571
572pub struct IrBuilder {
574 pub module: IrModule,
575 current_block: Option<BlockId>,
576 #[allow(dead_code)]
577 value_counter: u32,
578 type_cache: HashMap<TypeKind, IrType>,
579}
580
581impl IrBuilder {
582 pub fn new(module_name: String) -> Self {
584 Self {
585 module: IrModule::new(module_name),
586 current_block: None,
587 value_counter: 0,
588 type_cache: HashMap::new(),
589 }
590 }
591
592 pub fn set_current_block(&mut self, block: BlockId) {
594 self.current_block = Some(block);
595 }
596
597 pub fn add_block(&mut self) -> BlockId {
599 let id = self.module.add_block();
600 self.current_block = Some(id);
601 id
602 }
603
604 pub fn get_type(&mut self, kind: TypeKind) -> IrType {
606 if let Some(&existing) = self.type_cache.get(&kind) {
607 return existing;
608 }
609
610 let size = match &kind {
611 TypeKind::Void => Some(0),
612 TypeKind::Bool | TypeKind::I8 | TypeKind::U8 => Some(1),
613 TypeKind::I16 | TypeKind::U16 | TypeKind::F16 => Some(2),
614 TypeKind::I32 | TypeKind::U32 | TypeKind::F32 => Some(4),
615 TypeKind::I64 | TypeKind::U64 | TypeKind::F64 | TypeKind::C64 => Some(8),
616 TypeKind::C128 => Some(16),
617 _ => None,
618 };
619
620 let ty = self.module.add_type(TypeDef {
621 kind: kind.clone(),
622 size,
623 align: size,
624 });
625
626 self.type_cache.insert(kind, ty);
627 ty
628 }
629
630 pub fn const_int(&mut self, value: i64, ty: IrType) -> IrValue {
632 let val_def = ValueDef {
633 ty,
634 source_node: None,
635 kind: ValueKind::Constant {
636 data: ConstantData::Int(value),
637 },
638 };
639 self.module.add_value(val_def)
640 }
641
642 pub fn const_float(&mut self, value: f64, ty: IrType) -> IrValue {
644 let val_def = ValueDef {
645 ty,
646 source_node: None,
647 kind: ValueKind::Constant {
648 data: ConstantData::Float(value),
649 },
650 };
651 self.module.add_value(val_def)
652 }
653
654 pub fn add_instruction(
656 &mut self,
657 opcode: IrOpcode,
658 operands: Vec<IrValue>,
659 result_type: Option<IrType>,
660 ) -> Option<IrValue> {
661 let current_block = self.current_block.expect("No current block set");
662
663 let result = if let Some(ty) = result_type {
664 let val_def = ValueDef {
665 ty,
666 source_node: None,
667 kind: ValueKind::Instruction {
668 block: current_block,
669 index: 0, },
671 };
672 Some(self.module.add_value(val_def))
673 } else {
674 None
675 };
676
677 let instr = Instruction {
678 result,
679 opcode,
680 operands,
681 attrs: HashMap::new(),
682 };
683
684 if let Some(block) = self.module.get_block_mut(current_block) {
685 block.instructions.push(instr);
686 }
687
688 result
689 }
690
691 pub fn set_terminator(&mut self, terminator: Terminator) {
693 if let Some(current_block) = self.current_block {
694 if let Some(block) = self.module.get_block_mut(current_block) {
695 block.terminator = Some(terminator);
696 }
697 }
698 }
699
700 pub fn build(self) -> IrModule {
702 self.module
703 }
704}
705
706pub fn dtype_to_ir_type(dtype: DType) -> TypeKind {
708 match dtype {
709 DType::Bool => TypeKind::Bool,
710 DType::I8 => TypeKind::I8,
711 DType::I16 => TypeKind::I16,
712 DType::I32 => TypeKind::I32,
713 DType::I64 => TypeKind::I64,
714 DType::U8 => TypeKind::U8,
715 DType::U32 => TypeKind::U32,
716 DType::U64 => TypeKind::U64,
717 DType::F16 => TypeKind::F16,
718 DType::BF16 => TypeKind::F16, DType::F32 => TypeKind::F32,
720 DType::F64 => TypeKind::F64,
721 DType::C64 => TypeKind::C64,
722 DType::C128 => TypeKind::C128,
723 DType::QInt8 => TypeKind::I8, DType::QUInt8 => TypeKind::U8, DType::QInt32 => TypeKind::I32, }
727}
728
729pub fn shape_dtype_to_tensor_type(shape: &Shape, dtype: DType) -> TypeKind {
731 let _element_type_kind = dtype_to_ir_type(dtype);
732 TypeKind::Tensor {
733 element: IrType(0), shape: shape.dims().to_vec(),
735 }
736}
737
738pub type IrFunction = IrModule;
740pub type IrInstruction = Instruction;
741
742pub type InterproceduralResult<T> = Result<T, String>;
744pub type AnalysisResult<T> = Result<T, String>;
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749
750 #[test]
751 fn test_ir_module_creation() {
752 let module = IrModule::new("test".to_string());
753 assert_eq!(module.name, "test");
754 assert!(module.blocks.is_empty());
755 assert!(module.values.is_empty());
756 }
757
758 #[test]
759 fn test_ir_builder() {
760 let mut builder = IrBuilder::new("test_module".to_string());
761
762 let i32_type = builder.get_type(TypeKind::I32);
764 let f32_type = builder.get_type(TypeKind::F32);
765
766 let block = builder.add_block();
768 assert_eq!(block, 0);
769
770 let const1 = builder.const_int(42, i32_type);
772 let const2 = builder.const_float(3.14, f32_type);
773
774 let result = builder.add_instruction(IrOpcode::Add, vec![const1, const2], Some(f32_type));
776 assert!(result.is_some());
777
778 let module = builder.build();
780 assert_eq!(module.name, "test_module");
781 assert!(!module.blocks.is_empty());
782 assert!(!module.values.is_empty());
783 }
784
785 #[test]
786 fn test_module_validation() {
787 let mut builder = IrBuilder::new("valid_module".to_string());
788 let _block = builder.add_block();
789 builder.set_terminator(Terminator::Return { value: None });
790
791 let module = builder.build();
792 assert!(module.validate().is_ok());
793 }
794}