1use crate::{
7 CompiledModule, ComputationGraph, JitCompiler, JitConfig, JitError, JitResult, Node, NodeId,
8 ScriptableModule,
9};
10use petgraph::visit::EdgeRef;
11use std::collections::HashMap;
12use torsh_core::{DType, Shape};
13
14pub struct ScriptCompiler {
16 jit_compiler: JitCompiler,
17 type_annotations: HashMap<String, TypeAnnotation>,
18}
19
20impl ScriptCompiler {
21 pub fn new(config: JitConfig) -> Self {
23 Self {
24 jit_compiler: JitCompiler::new(config),
25 type_annotations: HashMap::new(),
26 }
27 }
28
29 pub fn script<M: ScriptableModule>(&mut self, module: M) -> JitResult<CompiledModule> {
31 let graph = module.to_graph()?;
33
34 let annotated_graph = self.apply_type_annotations(graph)?;
36
37 self.jit_compiler.compile(annotated_graph)
39 }
40
41 pub fn add_type_annotation(&mut self, name: String, annotation: TypeAnnotation) {
43 self.type_annotations.insert(name, annotation);
44 }
45
46 fn apply_type_annotations(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
48 let node_ids: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
50 for node_id in node_ids {
51 if let Some(node) = graph.node(node_id) {
52 let node_name = node.name.clone();
53 if let Some(annotation) = self.type_annotations.get(&node_name) {
54 if let Some(node_mut) = graph.node_mut(node_id) {
55 match annotation {
56 TypeAnnotation::Tensor { dtype, shape } => {
57 node_mut.dtype = *dtype;
58 node_mut.output_shape = Shape::new(shape.clone());
59 }
60 TypeAnnotation::Scalar(dtype) => {
61 node_mut.dtype = *dtype;
62 node_mut.output_shape = Shape::new(vec![1]);
63 }
64 TypeAnnotation::List { element_type, size } => {
65 node_mut.attrs.insert(
67 "list_element_type".to_string(),
68 crate::graph::Attribute::String(format!("{:?}", element_type)),
69 );
70 node_mut.attrs.insert(
71 "list_size".to_string(),
72 crate::graph::Attribute::Int(*size as i64),
73 );
74 }
75 }
76 }
77 }
78 }
79 }
80
81 Ok(graph)
82 }
83}
84
85#[derive(Debug, Clone)]
87pub enum TypeAnnotation {
88 Tensor { dtype: DType, shape: Vec<usize> },
90 Scalar(DType),
92 List {
94 element_type: Box<TypeAnnotation>,
95 size: usize,
96 },
97}
98
99#[derive(Debug, Clone)]
101pub enum ScriptAst {
102 Function {
104 name: String,
105 params: Vec<Parameter>,
106 return_type: Option<TypeAnnotation>,
107 body: Box<ScriptAst>,
108 },
109 Let {
111 name: String,
112 type_ann: Option<TypeAnnotation>,
113 value: Box<ScriptAst>,
114 },
115 BinOp {
117 op: BinaryOp,
118 left: Box<ScriptAst>,
119 right: Box<ScriptAst>,
120 },
121 UnaryOp {
123 op: UnaryOp,
124 operand: Box<ScriptAst>,
125 },
126 Call { func: String, args: Vec<ScriptAst> },
128 If {
130 condition: Box<ScriptAst>,
131 then_branch: Box<ScriptAst>,
132 else_branch: Option<Box<ScriptAst>>,
133 },
134 For {
136 var: String,
137 iter: Box<ScriptAst>,
138 body: Box<ScriptAst>,
139 },
140 Block(Vec<ScriptAst>),
142 Var(String),
144 Literal(LiteralValue),
146 Return(Box<ScriptAst>),
148}
149
150#[derive(Debug, Clone)]
152pub struct Parameter {
153 pub name: String,
154 pub type_ann: TypeAnnotation,
155}
156
157#[derive(Debug, Clone)]
159pub enum BinaryOp {
160 Add,
161 Sub,
162 Mul,
163 Div,
164 Pow,
165 Eq,
166 Ne,
167 Lt,
168 Le,
169 Gt,
170 Ge,
171 And,
172 Or,
173}
174
175#[derive(Debug, Clone)]
177pub enum UnaryOp {
178 Neg,
179 Not,
180}
181
182#[derive(Debug, Clone)]
184pub enum LiteralValue {
185 Int(i64),
186 Float(f64),
187 Bool(bool),
188 String(String),
189}
190
191pub struct ScriptParser;
193
194impl ScriptParser {
195 pub fn parse(code: &str) -> JitResult<ScriptAst> {
197 let mut parser = PythonParser::new(code);
198 parser.parse()
199 }
200}
201
202pub struct PythonParser {
204 tokens: Vec<Token>,
205 current: usize,
206}
207
208#[derive(Debug, Clone, PartialEq)]
210pub enum Token {
211 Integer(i64),
213 Float(f64),
214 Boolean(bool),
215 String(String),
216
217 Identifier(String),
219
220 Def,
222 If,
223 Else,
224 For,
225 In,
226 Return,
227 True,
228 False,
229
230 Plus,
232 Minus,
233 Star,
234 Slash,
235 DoubleStar,
236 Equal,
237 EqualEqual,
238 NotEqual,
239 Less,
240 LessEqual,
241 Greater,
242 GreaterEqual,
243 And,
244 Or,
245 Not,
246
247 LeftParen,
249 RightParen,
250 LeftBracket,
251 RightBracket,
252 Comma,
253 Colon,
254 Arrow,
255
256 Newline,
258 Indent,
259 Dedent,
260 Eof,
261}
262
263impl PythonParser {
264 pub fn new(code: &str) -> Self {
266 let tokens = Self::tokenize(code);
267 Self { tokens, current: 0 }
268 }
269
270 fn tokenize(code: &str) -> Vec<Token> {
272 let mut tokens = Vec::new();
273 let mut chars = code.chars().peekable();
274
275 while let Some(&ch) = chars.peek() {
276 match ch {
277 ' ' | '\t' => {
278 chars.next();
279 }
280 '\n' => {
281 chars.next();
282 tokens.push(Token::Newline);
283 }
284 '(' => {
285 chars.next();
286 tokens.push(Token::LeftParen);
287 }
288 ')' => {
289 chars.next();
290 tokens.push(Token::RightParen);
291 }
292 '[' => {
293 chars.next();
294 tokens.push(Token::LeftBracket);
295 }
296 ']' => {
297 chars.next();
298 tokens.push(Token::RightBracket);
299 }
300 ',' => {
301 chars.next();
302 tokens.push(Token::Comma);
303 }
304 ':' => {
305 chars.next();
306 tokens.push(Token::Colon);
307 }
308 '+' => {
309 chars.next();
310 tokens.push(Token::Plus);
311 }
312 '-' => {
313 chars.next();
314 if chars.peek() == Some(&'>') {
315 chars.next();
316 tokens.push(Token::Arrow);
317 } else {
318 tokens.push(Token::Minus);
319 }
320 }
321 '*' => {
322 chars.next();
323 if chars.peek() == Some(&'*') {
324 chars.next();
325 tokens.push(Token::DoubleStar);
326 } else {
327 tokens.push(Token::Star);
328 }
329 }
330 '/' => {
331 chars.next();
332 tokens.push(Token::Slash);
333 }
334 '=' => {
335 chars.next();
336 if chars.peek() == Some(&'=') {
337 chars.next();
338 tokens.push(Token::EqualEqual);
339 } else {
340 tokens.push(Token::Equal);
341 }
342 }
343 '!' => {
344 chars.next();
345 if chars.peek() == Some(&'=') {
346 chars.next();
347 tokens.push(Token::NotEqual);
348 }
349 }
350 '<' => {
351 chars.next();
352 if chars.peek() == Some(&'=') {
353 chars.next();
354 tokens.push(Token::LessEqual);
355 } else {
356 tokens.push(Token::Less);
357 }
358 }
359 '>' => {
360 chars.next();
361 if chars.peek() == Some(&'=') {
362 chars.next();
363 tokens.push(Token::GreaterEqual);
364 } else {
365 tokens.push(Token::Greater);
366 }
367 }
368 '"' => {
369 chars.next();
370 let mut string_val = String::new();
371 while let Some(&ch) = chars.peek() {
372 if ch == '"' {
373 chars.next();
374 break;
375 }
376 string_val.push(ch);
377 chars.next();
378 }
379 tokens.push(Token::String(string_val));
380 }
381 c if c.is_ascii_digit() => {
382 let mut number = String::new();
383 let mut is_float = false;
384 while let Some(&ch) = chars.peek() {
385 if ch.is_ascii_digit() {
386 number.push(ch);
387 chars.next();
388 } else if ch == '.' && !is_float {
389 is_float = true;
390 number.push(ch);
391 chars.next();
392 } else {
393 break;
394 }
395 }
396
397 if is_float {
398 if let Ok(val) = number.parse::<f64>() {
399 tokens.push(Token::Float(val));
400 }
401 } else if let Ok(val) = number.parse::<i64>() {
402 tokens.push(Token::Integer(val));
403 }
404 }
405 c if c.is_ascii_alphabetic() || c == '_' => {
406 let mut ident = String::new();
407 while let Some(&ch) = chars.peek() {
408 if ch.is_ascii_alphanumeric() || ch == '_' {
409 ident.push(ch);
410 chars.next();
411 } else {
412 break;
413 }
414 }
415
416 let token = match ident.as_str() {
417 "def" => Token::Def,
418 "if" => Token::If,
419 "else" => Token::Else,
420 "for" => Token::For,
421 "in" => Token::In,
422 "return" => Token::Return,
423 "True" => Token::Boolean(true),
424 "False" => Token::Boolean(false),
425 "and" => Token::And,
426 "or" => Token::Or,
427 "not" => Token::Not,
428 _ => Token::Identifier(ident),
429 };
430 tokens.push(token);
431 }
432 _ => {
433 chars.next(); }
435 }
436 }
437
438 tokens.push(Token::Eof);
439 tokens
440 }
441
442 pub fn parse(&mut self) -> JitResult<ScriptAst> {
444 let mut statements = Vec::new();
445
446 while !self.is_at_end() {
447 if self.match_token(&Token::Newline) {
448 continue;
449 }
450 statements.push(self.parse_statement()?);
451 }
452
453 Ok(ScriptAst::Block(statements))
454 }
455
456 fn parse_statement(&mut self) -> JitResult<ScriptAst> {
458 if self.match_token(&Token::Def) {
459 self.parse_function()
460 } else if self.match_token(&Token::Return) {
461 let expr = self.parse_expression()?;
462 Ok(ScriptAst::Return(Box::new(expr)))
463 } else if self.match_token(&Token::If) {
464 self.parse_if()
465 } else if self.match_token(&Token::For) {
466 self.parse_for()
467 } else {
468 let expr = self.parse_expression()?;
470 if self.match_token(&Token::Equal) {
471 if let ScriptAst::Var(name) = expr {
472 let value = self.parse_expression()?;
473 Ok(ScriptAst::Let {
474 name,
475 type_ann: None,
476 value: Box::new(value),
477 })
478 } else {
479 Err(JitError::CompilationError(
480 "Invalid assignment target".to_string(),
481 ))
482 }
483 } else {
484 Ok(expr)
485 }
486 }
487 }
488
489 fn parse_function(&mut self) -> JitResult<ScriptAst> {
491 let name = if let Some(Token::Identifier(name)) = self.advance() {
492 name.clone()
493 } else {
494 return Err(JitError::CompilationError(
495 "Expected function name".to_string(),
496 ));
497 };
498
499 self.consume(&Token::LeftParen, "Expected '(' after function name")?;
500
501 let mut params = Vec::new();
502 while !self.check(&Token::RightParen) && !self.is_at_end() {
503 if let Some(Token::Identifier(param_name)) = self.advance() {
504 params.push(Parameter {
506 name: param_name.clone(),
507 type_ann: TypeAnnotation::Tensor {
508 dtype: DType::F32,
509 shape: vec![], },
511 });
512
513 if !self.check(&Token::RightParen) {
514 self.consume(&Token::Comma, "Expected ',' between parameters")?;
515 }
516 }
517 }
518
519 self.consume(&Token::RightParen, "Expected ')' after parameters")?;
520 self.consume(&Token::Colon, "Expected ':' after function signature")?;
521
522 let body = self.parse_block()?;
523
524 Ok(ScriptAst::Function {
525 name,
526 params,
527 return_type: None,
528 body: Box::new(body),
529 })
530 }
531
532 fn parse_if(&mut self) -> JitResult<ScriptAst> {
534 let condition = self.parse_expression()?;
535 self.consume(&Token::Colon, "Expected ':' after if condition")?;
536
537 let then_branch = self.parse_block()?;
538
539 let else_branch = if self.match_token(&Token::Else) {
540 self.consume(&Token::Colon, "Expected ':' after else")?;
541 Some(Box::new(self.parse_block()?))
542 } else {
543 None
544 };
545
546 Ok(ScriptAst::If {
547 condition: Box::new(condition),
548 then_branch: Box::new(then_branch),
549 else_branch,
550 })
551 }
552
553 fn parse_for(&mut self) -> JitResult<ScriptAst> {
555 let var = if let Some(Token::Identifier(name)) = self.advance() {
556 name.clone()
557 } else {
558 return Err(JitError::CompilationError(
559 "Expected variable name in for loop".to_string(),
560 ));
561 };
562
563 self.consume(&Token::In, "Expected 'in' in for loop")?;
564 let iter = self.parse_expression()?;
565 self.consume(&Token::Colon, "Expected ':' after for loop header")?;
566
567 let body = self.parse_block()?;
568
569 Ok(ScriptAst::For {
570 var,
571 iter: Box::new(iter),
572 body: Box::new(body),
573 })
574 }
575
576 fn parse_block(&mut self) -> JitResult<ScriptAst> {
578 let mut statements = Vec::new();
579
580 while !self.is_at_end() && !self.check(&Token::Else) && !self.check(&Token::Def) {
582 if self.match_token(&Token::Newline) {
583 continue;
584 }
585 statements.push(self.parse_statement()?);
586 break; }
588
589 Ok(ScriptAst::Block(statements))
590 }
591
592 fn parse_expression(&mut self) -> JitResult<ScriptAst> {
594 self.parse_or()
595 }
596
597 fn parse_or(&mut self) -> JitResult<ScriptAst> {
599 let mut expr = self.parse_and()?;
600
601 while self.match_token(&Token::Or) {
602 let right = self.parse_and()?;
603 expr = ScriptAst::BinOp {
604 op: BinaryOp::Or,
605 left: Box::new(expr),
606 right: Box::new(right),
607 };
608 }
609
610 Ok(expr)
611 }
612
613 fn parse_and(&mut self) -> JitResult<ScriptAst> {
615 let mut expr = self.parse_equality()?;
616
617 while self.match_token(&Token::And) {
618 let right = self.parse_equality()?;
619 expr = ScriptAst::BinOp {
620 op: BinaryOp::And,
621 left: Box::new(expr),
622 right: Box::new(right),
623 };
624 }
625
626 Ok(expr)
627 }
628
629 fn parse_equality(&mut self) -> JitResult<ScriptAst> {
631 let mut expr = self.parse_comparison()?;
632
633 while let Some(op) = self.match_equality_op() {
634 let right = self.parse_comparison()?;
635 expr = ScriptAst::BinOp {
636 op,
637 left: Box::new(expr),
638 right: Box::new(right),
639 };
640 }
641
642 Ok(expr)
643 }
644
645 fn parse_comparison(&mut self) -> JitResult<ScriptAst> {
647 let mut expr = self.parse_term()?;
648
649 while let Some(op) = self.match_comparison_op() {
650 let right = self.parse_term()?;
651 expr = ScriptAst::BinOp {
652 op,
653 left: Box::new(expr),
654 right: Box::new(right),
655 };
656 }
657
658 Ok(expr)
659 }
660
661 fn parse_term(&mut self) -> JitResult<ScriptAst> {
663 let mut expr = self.parse_factor()?;
664
665 while self.check(&Token::Plus) || self.check(&Token::Minus) {
666 let op = if self.match_token(&Token::Plus) {
667 BinaryOp::Add
668 } else {
669 self.advance();
670 BinaryOp::Sub
671 };
672
673 let right = self.parse_factor()?;
674 expr = ScriptAst::BinOp {
675 op,
676 left: Box::new(expr),
677 right: Box::new(right),
678 };
679 }
680
681 Ok(expr)
682 }
683
684 fn parse_factor(&mut self) -> JitResult<ScriptAst> {
686 let mut expr = self.parse_unary()?;
687
688 while self.check(&Token::Star)
689 || self.check(&Token::Slash)
690 || self.check(&Token::DoubleStar)
691 {
692 let op = if self.match_token(&Token::Star) {
693 BinaryOp::Mul
694 } else if self.match_token(&Token::Slash) {
695 BinaryOp::Div
696 } else {
697 self.advance();
698 BinaryOp::Pow
699 };
700
701 let right = self.parse_unary()?;
702 expr = ScriptAst::BinOp {
703 op,
704 left: Box::new(expr),
705 right: Box::new(right),
706 };
707 }
708
709 Ok(expr)
710 }
711
712 fn parse_unary(&mut self) -> JitResult<ScriptAst> {
714 if self.match_token(&Token::Not) {
715 let operand = self.parse_unary()?;
716 Ok(ScriptAst::UnaryOp {
717 op: UnaryOp::Not,
718 operand: Box::new(operand),
719 })
720 } else if self.match_token(&Token::Minus) {
721 let operand = self.parse_unary()?;
722 Ok(ScriptAst::UnaryOp {
723 op: UnaryOp::Neg,
724 operand: Box::new(operand),
725 })
726 } else {
727 self.parse_call()
728 }
729 }
730
731 fn parse_call(&mut self) -> JitResult<ScriptAst> {
733 let mut expr = self.parse_primary()?;
734
735 while self.match_token(&Token::LeftParen) {
736 let mut args = Vec::new();
737 while !self.check(&Token::RightParen) && !self.is_at_end() {
738 args.push(self.parse_expression()?);
739 if !self.check(&Token::RightParen) {
740 self.consume(&Token::Comma, "Expected ',' between arguments")?;
741 }
742 }
743 self.consume(&Token::RightParen, "Expected ')' after arguments")?;
744
745 if let ScriptAst::Var(func_name) = expr {
746 expr = ScriptAst::Call {
747 func: func_name,
748 args,
749 };
750 }
751 }
752
753 Ok(expr)
754 }
755
756 fn parse_primary(&mut self) -> JitResult<ScriptAst> {
758 if let Some(token) = self.advance() {
759 match token {
760 Token::Integer(val) => Ok(ScriptAst::Literal(LiteralValue::Int(*val))),
761 Token::Float(val) => Ok(ScriptAst::Literal(LiteralValue::Float(*val))),
762 Token::Boolean(val) => Ok(ScriptAst::Literal(LiteralValue::Bool(*val))),
763 Token::String(val) => Ok(ScriptAst::Literal(LiteralValue::String(val.clone()))),
764 Token::Identifier(name) => Ok(ScriptAst::Var(name.clone())),
765 Token::LeftParen => {
766 let expr = self.parse_expression()?;
767 self.consume(&Token::RightParen, "Expected ')' after expression")?;
768 Ok(expr)
769 }
770 _ => Err(JitError::CompilationError(
771 "Unexpected token in expression".to_string(),
772 )),
773 }
774 } else {
775 Err(JitError::CompilationError(
776 "Unexpected end of input".to_string(),
777 ))
778 }
779 }
780
781 fn match_equality_op(&mut self) -> Option<BinaryOp> {
783 if self.match_token(&Token::EqualEqual) {
784 Some(BinaryOp::Eq)
785 } else if self.match_token(&Token::NotEqual) {
786 Some(BinaryOp::Ne)
787 } else {
788 None
789 }
790 }
791
792 fn match_comparison_op(&mut self) -> Option<BinaryOp> {
794 if self.match_token(&Token::Greater) {
795 Some(BinaryOp::Gt)
796 } else if self.match_token(&Token::GreaterEqual) {
797 Some(BinaryOp::Ge)
798 } else if self.match_token(&Token::Less) {
799 Some(BinaryOp::Lt)
800 } else if self.match_token(&Token::LessEqual) {
801 Some(BinaryOp::Le)
802 } else {
803 None
804 }
805 }
806
807 fn match_token(&mut self, expected: &Token) -> bool {
809 if self.check(expected) {
810 self.advance();
811 true
812 } else {
813 false
814 }
815 }
816
817 fn check(&self, expected: &Token) -> bool {
818 if self.is_at_end() {
819 false
820 } else {
821 std::mem::discriminant(&self.tokens[self.current]) == std::mem::discriminant(expected)
822 }
823 }
824
825 fn advance(&mut self) -> Option<&Token> {
826 if !self.is_at_end() {
827 self.current += 1;
828 }
829 self.previous()
830 }
831
832 fn is_at_end(&self) -> bool {
833 self.current >= self.tokens.len()
834 || matches!(self.tokens.get(self.current), Some(Token::Eof))
835 }
836
837 fn previous(&self) -> Option<&Token> {
838 self.tokens.get(self.current.saturating_sub(1))
839 }
840
841 fn consume(&mut self, expected: &Token, message: &str) -> JitResult<()> {
842 if self.check(expected) {
843 self.advance();
844 Ok(())
845 } else {
846 Err(JitError::CompilationError(message.to_string()))
847 }
848 }
849}
850
851pub struct AstToGraphConverter {
853 graph: ComputationGraph,
854 var_map: HashMap<String, NodeId>,
855 next_id: usize,
856}
857
858impl Default for AstToGraphConverter {
859 fn default() -> Self {
860 Self::new()
861 }
862}
863
864impl AstToGraphConverter {
865 pub fn new() -> Self {
867 Self {
868 graph: ComputationGraph::new(),
869 var_map: HashMap::new(),
870 next_id: 0,
871 }
872 }
873
874 pub fn convert(&mut self, ast: ScriptAst) -> JitResult<ComputationGraph> {
876 self.convert_ast(ast)?;
877 Ok(self.graph.clone())
878 }
879
880 fn convert_ast(&mut self, ast: ScriptAst) -> JitResult<NodeId> {
882 match ast {
883 ScriptAst::BinOp { op, left, right } => {
884 let left_id = self.convert_ast(*left)?;
885 let right_id = self.convert_ast(*right)?;
886 self.create_binop_node(op, left_id, right_id)
887 }
888 ScriptAst::UnaryOp { op, operand } => {
889 let operand_id = self.convert_ast(*operand)?;
890 self.create_unaryop_node(op, operand_id)
891 }
892 ScriptAst::Call { func, args } => {
893 let arg_ids: Vec<_> = args
894 .into_iter()
895 .map(|arg| self.convert_ast(arg))
896 .collect::<JitResult<Vec<_>>>()?;
897 self.create_call_node(func, arg_ids)
898 }
899 ScriptAst::Var(name) => self
900 .var_map
901 .get(&name)
902 .copied()
903 .ok_or_else(|| JitError::GraphError(format!("Undefined variable: {}", name))),
904 ScriptAst::Literal(lit) => self.create_literal_node(lit),
905 ScriptAst::Let { name, value, .. } => {
906 let value_id = self.convert_ast(*value)?;
907 self.var_map.insert(name, value_id);
908 Ok(value_id)
909 }
910 ScriptAst::Block(stmts) => {
911 let mut last_id = None;
912 for stmt in stmts {
913 last_id = Some(self.convert_ast(stmt)?);
914 }
915 last_id.ok_or_else(|| JitError::GraphError("Empty block".to_string()))
916 }
917 _ => Err(JitError::GraphError("Unsupported AST node".to_string())),
918 }
919 }
920
921 fn create_binop_node(
923 &mut self,
924 op: BinaryOp,
925 left: NodeId,
926 right: NodeId,
927 ) -> JitResult<NodeId> {
928 use crate::graph::{Edge, Operation};
929 use torsh_core::DeviceType;
930
931 let operation = match op {
932 BinaryOp::Add => Operation::Add,
933 BinaryOp::Sub => Operation::Sub,
934 BinaryOp::Mul => Operation::Mul,
935 BinaryOp::Div => Operation::Div,
936 _ => return Err(JitError::UnsupportedOp(format!("{:?}", op))),
937 };
938
939 let mut node = Node::new(operation, format!("binop_{}", self.next_id));
940 node.device = DeviceType::Cpu;
941 node.inputs = vec![];
942 node.is_output = false;
943
944 let node_id = self.graph.add_node(node);
945 self.graph.add_edge(left, node_id, Edge::default());
946 self.graph.add_edge(right, node_id, Edge::default());
947 self.next_id += 1;
948 Ok(node_id)
949 }
950
951 fn create_unaryop_node(&mut self, op: UnaryOp, operand: NodeId) -> JitResult<NodeId> {
953 use crate::graph::{Edge, Operation};
954 use torsh_core::DeviceType;
955
956 let operation = match op {
957 UnaryOp::Neg => Operation::Neg,
958 _ => return Err(JitError::UnsupportedOp(format!("{:?}", op))),
959 };
960
961 let mut node = Node::new(operation, format!("unaryop_{}", self.next_id));
962 node.device = DeviceType::Cpu;
963 node.inputs = vec![];
964 node.is_output = false;
965
966 let node_id = self.graph.add_node(node);
967 self.graph.add_edge(operand, node_id, Edge::default());
968 self.next_id += 1;
969 Ok(node_id)
970 }
971
972 fn create_call_node(&mut self, func: String, args: Vec<NodeId>) -> JitResult<NodeId> {
974 use crate::graph::{Edge, Operation};
975 use torsh_core::DeviceType;
976
977 let operation = match func.as_str() {
978 "relu" => Operation::Relu,
979 "sigmoid" => Operation::Sigmoid,
980 "tanh" => Operation::Tanh,
981 "matmul" => Operation::MatMul,
982 _ => Operation::Custom(func),
983 };
984
985 let mut node = Node::new(operation, format!("call_{}", self.next_id));
986 node.device = DeviceType::Cpu;
987 node.inputs = vec![];
988 node.is_output = false;
989
990 let node_id = self.graph.add_node(node);
991 for (i, arg_id) in args.iter().enumerate() {
992 let edge = Edge {
993 src_output: 0,
994 dst_input: i,
995 };
996 self.graph.add_edge(*arg_id, node_id, edge);
997 }
998 self.next_id += 1;
999 Ok(node_id)
1000 }
1001
1002 fn create_literal_node(&mut self, lit: LiteralValue) -> JitResult<NodeId> {
1004 use crate::graph::{Attribute, ConstantInfo, ConstantValue, Operation};
1005 use torsh_core::DeviceType;
1006
1007 let (dtype, constant_value) = match lit {
1008 LiteralValue::Int(v) => (DType::I64, ConstantValue::IntScalar(v)),
1009 LiteralValue::Float(v) => (DType::F32, ConstantValue::Scalar(v)),
1010 LiteralValue::Bool(v) => (DType::Bool, ConstantValue::IntScalar(if v { 1 } else { 0 })),
1011 LiteralValue::String(v) => {
1012 let mut node = Node::new(
1014 Operation::Custom("string_literal".to_string()),
1015 format!("string_literal_{}", self.next_id),
1016 );
1017 node.device = DeviceType::Cpu;
1018 node.attrs.insert("value".to_string(), Attribute::String(v));
1019 node.inputs = vec![];
1020 node.is_output = false;
1021 let node_id = self.graph.add_node(node);
1022 self.next_id += 1;
1023 return Ok(node_id);
1024 }
1025 };
1026
1027 let mut node = Node::new(
1028 Operation::Constant(ConstantInfo {
1029 value: constant_value,
1030 }),
1031 format!("constant_{}", self.next_id),
1032 );
1033 node.device = DeviceType::Cpu;
1034 node.inputs = vec![];
1035 node.is_output = false;
1036
1037 let node_id = self.graph.add_node(node);
1038 self.next_id += 1;
1039 Ok(node_id)
1040 }
1041}
1042
1043pub fn export_torchscript(module: &CompiledModule, path: &str) -> JitResult<()> {
1045 use std::fs::File;
1046 use std::io::Write;
1047
1048 let ts_repr = TorchScriptModule {
1050 version: 1,
1051 graph: module.graph.clone(),
1052 constants: extract_constants_from_graph(&module.graph),
1053 metadata: create_metadata_from_module(module),
1054 };
1055
1056 let torchscript_ir = generate_torchscript_ir(&ts_repr)?;
1058
1059 let mut file = File::create(path)
1061 .map_err(|e| JitError::RuntimeError(format!("Failed to create file {}: {}", path, e)))?;
1062
1063 file.write_all(torchscript_ir.as_bytes())
1064 .map_err(|e| JitError::RuntimeError(format!("Failed to write file {}: {}", path, e)))?;
1065
1066 Ok(())
1067}
1068
1069pub fn import_torchscript(path: &str, config: JitConfig) -> JitResult<CompiledModule> {
1071 use std::fs::File;
1072 use std::io::Read;
1073
1074 let mut file = File::open(path)
1076 .map_err(|e| JitError::RuntimeError(format!("Failed to open file {}: {}", path, e)))?;
1077
1078 let mut contents = String::new();
1079 file.read_to_string(&mut contents)
1080 .map_err(|e| JitError::RuntimeError(format!("Failed to read file {}: {}", path, e)))?;
1081
1082 let ts_module = parse_torchscript_ir(&contents)?;
1084
1085 let mut jit_compiler = JitCompiler::new(config);
1087 let compiled_module = jit_compiler.compile(ts_module.graph)?;
1088
1089 Ok(compiled_module)
1090}
1091
1092#[derive(Debug, Clone)]
1094struct TorchScriptModule {
1095 version: u32,
1096 graph: ComputationGraph,
1097 constants: HashMap<String, Vec<f32>>,
1098 metadata: HashMap<String, String>,
1099}
1100
1101fn extract_constants_from_graph(graph: &ComputationGraph) -> HashMap<String, Vec<f32>> {
1103 use crate::graph::{ConstantValue, Operation};
1104
1105 let mut constants = HashMap::new();
1106
1107 for (node_id, node) in graph.nodes() {
1108 if let Operation::Constant(ref const_info) = node.op {
1109 let const_name = format!("const_{:?}", node_id);
1110 match &const_info.value {
1111 ConstantValue::Scalar(val) => {
1112 constants.insert(const_name, vec![*val as f32]);
1113 }
1114 ConstantValue::IntScalar(val) => {
1115 constants.insert(const_name, vec![*val as f32]);
1116 }
1117 ConstantValue::Tensor {
1118 shape: _,
1119 data,
1120 dtype: _,
1121 } => {
1122 constants.insert(const_name, data.iter().map(|&x| x as f32).collect());
1123 }
1124 ConstantValue::Bool(val) => {
1125 constants.insert(const_name, vec![if *val { 1.0 } else { 0.0 }]);
1126 }
1127 ConstantValue::Int(val) => {
1128 constants.insert(const_name, vec![*val as f32]);
1129 }
1130 ConstantValue::UInt(val) => {
1131 constants.insert(const_name, vec![*val as f32]);
1132 }
1133 ConstantValue::Float(val) => {
1134 constants.insert(const_name, vec![*val as f32]);
1135 }
1136 ConstantValue::String(_) => {
1137 constants.insert(const_name, vec![0.0]); }
1139 ConstantValue::FloatArray(arr) => {
1140 constants.insert(const_name, arr.clone());
1141 }
1142 ConstantValue::IntArray(arr) => {
1143 constants.insert(const_name, arr.iter().map(|&x| x as f32).collect());
1144 }
1145 ConstantValue::Array(arr) => {
1146 constants.insert(const_name, vec![arr.len() as f32]);
1148 }
1149 ConstantValue::Complex { real, imag: _ } => {
1150 constants.insert(const_name, vec![*real as f32]);
1151 }
1152 ConstantValue::None => {
1153 constants.insert(const_name, vec![0.0]);
1154 }
1155 ConstantValue::Undefined => {
1156 constants.insert(const_name, vec![0.0]);
1157 }
1158 }
1159 }
1160 }
1161
1162 constants
1163}
1164
1165fn create_metadata_from_module(module: &CompiledModule) -> HashMap<String, String> {
1167 let mut metadata = HashMap::new();
1168
1169 metadata.insert("producer".to_string(), "torsh-jit".to_string());
1170 metadata.insert("producer_version".to_string(), "0.1.0".to_string());
1171 metadata.insert("graph_name".to_string(), "main".to_string());
1172 metadata.insert(
1173 "node_count".to_string(),
1174 module.graph.node_count().to_string(),
1175 );
1176 metadata.insert(
1177 "edge_count".to_string(),
1178 module.graph.edge_count().to_string(),
1179 );
1180
1181 metadata
1182}
1183
1184fn generate_torchscript_ir(ts_module: &TorchScriptModule) -> JitResult<String> {
1186 use crate::graph::{ConstantValue, Operation};
1187
1188 let mut ir = String::new();
1189
1190 ir.push_str(&format!("graph():\n"));
1192
1193 for (name, values) in &ts_module.constants {
1195 ir.push_str(&format!(
1196 " %{} : Float({}) = prim::Constant[value={}]()\n",
1197 name,
1198 values.len(),
1199 format_tensor_values(values)
1200 ));
1201 }
1202
1203 let mut output_counter = 0;
1205 for (node_id, node) in ts_module.graph.nodes() {
1206 match &node.op {
1207 Operation::Add => {
1208 let inputs = get_node_inputs(&ts_module.graph, node_id);
1209 ir.push_str(&format!(
1210 " %{} : Float = aten::add({}, {})\n",
1211 output_counter, inputs[0], inputs[1]
1212 ));
1213 }
1214 Operation::Mul => {
1215 let inputs = get_node_inputs(&ts_module.graph, node_id);
1216 ir.push_str(&format!(
1217 " %{} : Float = aten::mul({}, {})\n",
1218 output_counter, inputs[0], inputs[1]
1219 ));
1220 }
1221 Operation::MatMul => {
1222 let inputs = get_node_inputs(&ts_module.graph, node_id);
1223 ir.push_str(&format!(
1224 " %{} : Float = aten::mm({}, {})\n",
1225 output_counter, inputs[0], inputs[1]
1226 ));
1227 }
1228 Operation::Relu => {
1229 let inputs = get_node_inputs(&ts_module.graph, node_id);
1230 ir.push_str(&format!(
1231 " %{} : Float = aten::relu({})\n",
1232 output_counter, inputs[0]
1233 ));
1234 }
1235 Operation::Sigmoid => {
1236 let inputs = get_node_inputs(&ts_module.graph, node_id);
1237 ir.push_str(&format!(
1238 " %{} : Float = aten::sigmoid({})\n",
1239 output_counter, inputs[0]
1240 ));
1241 }
1242 Operation::Constant(const_info) => match &const_info.value {
1243 ConstantValue::Scalar(val) => {
1244 ir.push_str(&format!(
1245 " %{} : Float = prim::Constant[value={}]()\n",
1246 output_counter, val
1247 ));
1248 }
1249 ConstantValue::IntScalar(val) => {
1250 ir.push_str(&format!(
1251 " %{} : int = prim::Constant[value={}]()\n",
1252 output_counter, val
1253 ));
1254 }
1255 ConstantValue::Tensor {
1256 shape: _,
1257 data,
1258 dtype: _,
1259 } => {
1260 let data_f32: Vec<f32> = data.iter().map(|&x| x as f32).collect();
1261 ir.push_str(&format!(
1262 " %{} : Float = prim::Constant[value={}]()\n",
1263 output_counter,
1264 format_tensor_values(&data_f32)
1265 ));
1266 }
1267 ConstantValue::Bool(val) => {
1268 ir.push_str(&format!(
1269 " %{} : bool = prim::Constant[value={}]()\n",
1270 output_counter, val
1271 ));
1272 }
1273 ConstantValue::Int(val) => {
1274 ir.push_str(&format!(
1275 " %{} : int = prim::Constant[value={}]()\n",
1276 output_counter, val
1277 ));
1278 }
1279 ConstantValue::UInt(val) => {
1280 ir.push_str(&format!(
1281 " %{} : int = prim::Constant[value={}]()\n",
1282 output_counter, val
1283 ));
1284 }
1285 ConstantValue::Float(val) => {
1286 ir.push_str(&format!(
1287 " %{} : Float = prim::Constant[value={}]()\n",
1288 output_counter, val
1289 ));
1290 }
1291 ConstantValue::String(val) => {
1292 ir.push_str(&format!(
1293 " %{} : str = prim::Constant[value=\"{}\"]()\n",
1294 output_counter, val
1295 ));
1296 }
1297 ConstantValue::FloatArray(arr) => {
1298 ir.push_str(&format!(
1299 " %{} : Float[] = prim::Constant[value={}]()\n",
1300 output_counter,
1301 format_tensor_values(arr)
1302 ));
1303 }
1304 ConstantValue::IntArray(arr) => {
1305 let arr_str = arr
1306 .iter()
1307 .map(|x| x.to_string())
1308 .collect::<Vec<_>>()
1309 .join(", ");
1310 ir.push_str(&format!(
1311 " %{} : int[] = prim::Constant[value=[{}]]()\n",
1312 output_counter, arr_str
1313 ));
1314 }
1315 ConstantValue::Array(_) => {
1316 ir.push_str(&format!(
1317 " %{} : Tensor = prim::Constant[value=<complex_array>]()\n",
1318 output_counter
1319 ));
1320 }
1321 ConstantValue::Complex { real, imag } => {
1322 ir.push_str(&format!(
1323 " %{} : complex = prim::Constant[value={}+{}i]()\n",
1324 output_counter, real, imag
1325 ));
1326 }
1327 ConstantValue::None => {
1328 ir.push_str(&format!(
1329 " %{} : NoneType = prim::Constant[value=None]()\n",
1330 output_counter
1331 ));
1332 }
1333 ConstantValue::Undefined => {
1334 ir.push_str(&format!(
1335 " %{} : Tensor = prim::Constant[value=<undefined>]()\n",
1336 output_counter
1337 ));
1338 }
1339 },
1340 Operation::Custom(name) => {
1341 let inputs = get_node_inputs(&ts_module.graph, node_id);
1342 let input_str = inputs.join(", ");
1343 ir.push_str(&format!(
1344 " %{} : Float = custom::{}({})\n",
1345 output_counter, name, input_str
1346 ));
1347 }
1348 _ => {
1349 let inputs = get_node_inputs(&ts_module.graph, node_id);
1351 let input_str = inputs.join(", ");
1352 ir.push_str(&format!(
1353 " %{} : Float = aten::{:?}({})\n",
1354 output_counter, node.op, input_str
1355 ));
1356 }
1357 }
1358 output_counter += 1;
1359 }
1360
1361 if output_counter > 0 {
1363 ir.push_str(&format!(" return (%{})\n", output_counter - 1));
1364 } else {
1365 ir.push_str(" return ()\n");
1366 }
1367
1368 Ok(ir)
1369}
1370
1371fn parse_torchscript_ir(ir: &str) -> JitResult<TorchScriptModule> {
1373 let mut graph = ComputationGraph::new();
1374 let mut constants = HashMap::new();
1375 let mut metadata = HashMap::new();
1376
1377 let lines: Vec<&str> = ir.lines().collect();
1379 let mut node_counter = 0;
1380
1381 for line in lines {
1382 let line = line.trim();
1383
1384 if line.starts_with('%') && line.contains("prim::Constant") {
1385 if let Some(value_start) = line.find("value=") {
1387 let value_part = &line[value_start + 6..];
1388 if let Some(value_end) = value_part.find(']') {
1389 let value_str = &value_part[..value_end];
1390 if let Ok(val) = value_str.parse::<f32>() {
1391 let const_name = format!("const_{}", node_counter);
1392 constants.insert(const_name, vec![val]);
1393
1394 add_constant_node_to_graph(&mut graph, val, node_counter);
1396 node_counter += 1;
1397 }
1398 }
1399 }
1400 } else if line.starts_with('%') && line.contains("aten::") {
1401 parse_aten_operation(&mut graph, line, node_counter)?;
1403 node_counter += 1;
1404 }
1405 }
1406
1407 metadata.insert("producer".to_string(), "torchscript".to_string());
1409 metadata.insert("version".to_string(), "1.0".to_string());
1410
1411 Ok(TorchScriptModule {
1412 version: 1,
1413 graph,
1414 constants,
1415 metadata,
1416 })
1417}
1418
1419fn format_tensor_values(values: &[f32]) -> String {
1421 if values.len() == 1 {
1422 values[0].to_string()
1423 } else {
1424 format!(
1425 "[{}]",
1426 values
1427 .iter()
1428 .map(|v| v.to_string())
1429 .collect::<Vec<_>>()
1430 .join(", ")
1431 )
1432 }
1433}
1434
1435fn get_node_inputs(graph: &ComputationGraph, node_id: NodeId) -> Vec<String> {
1437 let mut inputs = Vec::new();
1438
1439 for edge in graph.edges_directed(node_id, petgraph::Direction::Incoming) {
1440 let src_id = edge.source();
1441 inputs.push(format!("%{:?}", src_id));
1442 }
1443
1444 if inputs.is_empty() {
1446 inputs.push(format!("%input_{:?}", node_id));
1447 }
1448
1449 inputs
1450}
1451
1452fn add_constant_node_to_graph(graph: &mut ComputationGraph, value: f32, node_id: usize) {
1454 use crate::graph::{ConstantInfo, ConstantValue, Operation};
1455 use torsh_core::DeviceType;
1456
1457 let mut node = Node::new(
1458 Operation::Constant(ConstantInfo {
1459 value: ConstantValue::Scalar(value as f64),
1460 }),
1461 format!("const_{}", node_id),
1462 );
1463 node = node
1464 .with_output_shapes(vec![Some(Shape::new(vec![1]))])
1465 .with_dtypes(vec![DType::F32])
1466 .with_device(DeviceType::Cpu);
1467 node.inputs = vec![];
1468 node.is_output = false;
1469
1470 graph.add_node(node);
1471}
1472
1473fn parse_aten_operation(graph: &mut ComputationGraph, line: &str, node_id: usize) -> JitResult<()> {
1475 use crate::graph::Operation;
1476 use torsh_core::DeviceType;
1477
1478 let operation = if line.contains("aten::add") {
1479 Operation::Add
1480 } else if line.contains("aten::mul") {
1481 Operation::Mul
1482 } else if line.contains("aten::mm") {
1483 Operation::MatMul
1484 } else if line.contains("aten::relu") {
1485 Operation::Relu
1486 } else if line.contains("aten::sigmoid") {
1487 Operation::Sigmoid
1488 } else {
1489 if let Some(op_start) = line.find("aten::") {
1491 let op_part = &line[op_start + 6..];
1492 if let Some(op_end) = op_part.find('(') {
1493 let op_name = &op_part[..op_end];
1494 Operation::Custom(op_name.to_string())
1495 } else {
1496 Operation::Custom("unknown".to_string())
1497 }
1498 } else {
1499 Operation::Custom("unknown".to_string())
1500 }
1501 };
1502
1503 let mut node = Node::new(operation, format!("op_{}", node_id));
1504 node = node
1505 .with_output_shapes(vec![Some(Shape::new(vec![]))]) .with_dtypes(vec![DType::F32])
1507 .with_device(DeviceType::Cpu);
1508 node.inputs = vec![];
1509 node.is_output = false;
1510
1511 graph.add_node(node);
1512 Ok(())
1513}
1514
1515pub fn script<M: ScriptableModule>(module: M) -> JitResult<CompiledModule> {
1517 let config = JitConfig::default();
1518 let mut compiler = ScriptCompiler::new(config);
1519 compiler.script(module)
1520}
1521
1522#[cfg(test)]
1523mod tests {
1524 use super::*;
1525
1526 #[test]
1527 fn test_type_annotation() {
1528 let tensor_ann = TypeAnnotation::Tensor {
1529 dtype: DType::F32,
1530 shape: vec![10, 20],
1531 };
1532
1533 match tensor_ann {
1534 TypeAnnotation::Tensor { dtype, shape } => {
1535 assert_eq!(dtype, DType::F32);
1536 assert_eq!(shape, vec![10, 20]);
1537 }
1538 _ => panic!("Wrong type annotation"),
1539 }
1540 }
1541
1542 #[test]
1543 fn test_ast_to_graph_converter() {
1544 let mut converter = AstToGraphConverter::new();
1545
1546 let lit_ast = ScriptAst::Literal(LiteralValue::Float(3.14));
1548 let result = converter.convert(lit_ast);
1549 assert!(result.is_ok());
1550 }
1551
1552 #[test]
1553 fn test_script_compiler_creation() {
1554 let config = JitConfig::default();
1555 let compiler = ScriptCompiler::new(config);
1556 assert!(compiler.type_annotations.is_empty());
1557 }
1558}