Skip to main content

torsh_jit/
script.rs

1//! Script mode for JIT compilation
2//!
3//! This module implements TorchScript-style script mode compilation,
4//! allowing models to be exported and optimized without tracing.
5
6use crate::{
7    CompiledModule, ComputationGraph, JitCompiler, JitConfig, JitError, JitResult, Node, NodeId,
8    ScriptableModule,
9};
10use petgraph::visit::EdgeRef;
11use std::collections::HashMap;
12use torsh_core::{DType, Shape};
13
14/// Script compiler for converting modules to JIT-compiled form
15pub struct ScriptCompiler {
16    jit_compiler: JitCompiler,
17    type_annotations: HashMap<String, TypeAnnotation>,
18}
19
20impl ScriptCompiler {
21    /// Create a new script compiler
22    pub fn new(config: JitConfig) -> Self {
23        Self {
24            jit_compiler: JitCompiler::new(config),
25            type_annotations: HashMap::new(),
26        }
27    }
28
29    /// Script a module into a compiled module
30    pub fn script<M: ScriptableModule>(&mut self, module: M) -> JitResult<CompiledModule> {
31        // Convert module to computation graph
32        let graph = module.to_graph()?;
33
34        // Apply type annotations if available
35        let annotated_graph = self.apply_type_annotations(graph)?;
36
37        // Compile the graph
38        self.jit_compiler.compile(annotated_graph)
39    }
40
41    /// Add type annotation for a parameter or variable
42    pub fn add_type_annotation(&mut self, name: String, annotation: TypeAnnotation) {
43        self.type_annotations.insert(name, annotation);
44    }
45
46    /// Apply type annotations to the graph
47    fn apply_type_annotations(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
48        // Apply annotations to nodes by name
49        let node_ids: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
50        for node_id in node_ids {
51            if let Some(node) = graph.node(node_id) {
52                let node_name = node.name.clone();
53                if let Some(annotation) = self.type_annotations.get(&node_name) {
54                    if let Some(node_mut) = graph.node_mut(node_id) {
55                        match annotation {
56                            TypeAnnotation::Tensor { dtype, shape } => {
57                                node_mut.dtype = *dtype;
58                                node_mut.output_shape = Shape::new(shape.clone());
59                            }
60                            TypeAnnotation::Scalar(dtype) => {
61                                node_mut.dtype = *dtype;
62                                node_mut.output_shape = Shape::new(vec![1]);
63                            }
64                            TypeAnnotation::List { element_type, size } => {
65                                // Handle list types via attributes
66                                node_mut.attrs.insert(
67                                    "list_element_type".to_string(),
68                                    crate::graph::Attribute::String(format!("{:?}", element_type)),
69                                );
70                                node_mut.attrs.insert(
71                                    "list_size".to_string(),
72                                    crate::graph::Attribute::Int(*size as i64),
73                                );
74                            }
75                        }
76                    }
77                }
78            }
79        }
80
81        Ok(graph)
82    }
83}
84
85/// Type annotation for script mode
86#[derive(Debug, Clone)]
87pub enum TypeAnnotation {
88    /// Tensor with specific dtype and shape
89    Tensor { dtype: DType, shape: Vec<usize> },
90    /// Scalar value
91    Scalar(DType),
92    /// List of elements
93    List {
94        element_type: Box<TypeAnnotation>,
95        size: usize,
96    },
97}
98
99/// Script AST node for parsing script code
100#[derive(Debug, Clone)]
101pub enum ScriptAst {
102    /// Function definition
103    Function {
104        name: String,
105        params: Vec<Parameter>,
106        return_type: Option<TypeAnnotation>,
107        body: Box<ScriptAst>,
108    },
109    /// Variable declaration
110    Let {
111        name: String,
112        type_ann: Option<TypeAnnotation>,
113        value: Box<ScriptAst>,
114    },
115    /// Binary operation
116    BinOp {
117        op: BinaryOp,
118        left: Box<ScriptAst>,
119        right: Box<ScriptAst>,
120    },
121    /// Unary operation
122    UnaryOp {
123        op: UnaryOp,
124        operand: Box<ScriptAst>,
125    },
126    /// Function call
127    Call { func: String, args: Vec<ScriptAst> },
128    /// Conditional
129    If {
130        condition: Box<ScriptAst>,
131        then_branch: Box<ScriptAst>,
132        else_branch: Option<Box<ScriptAst>>,
133    },
134    /// Loop
135    For {
136        var: String,
137        iter: Box<ScriptAst>,
138        body: Box<ScriptAst>,
139    },
140    /// Block of statements
141    Block(Vec<ScriptAst>),
142    /// Variable reference
143    Var(String),
144    /// Literal value
145    Literal(LiteralValue),
146    /// Return statement
147    Return(Box<ScriptAst>),
148}
149
150/// Function parameter
151#[derive(Debug, Clone)]
152pub struct Parameter {
153    pub name: String,
154    pub type_ann: TypeAnnotation,
155}
156
157/// Binary operators
158#[derive(Debug, Clone)]
159pub enum BinaryOp {
160    Add,
161    Sub,
162    Mul,
163    Div,
164    Pow,
165    Eq,
166    Ne,
167    Lt,
168    Le,
169    Gt,
170    Ge,
171    And,
172    Or,
173}
174
175/// Unary operators
176#[derive(Debug, Clone)]
177pub enum UnaryOp {
178    Neg,
179    Not,
180}
181
182/// Literal values
183#[derive(Debug, Clone)]
184pub enum LiteralValue {
185    Int(i64),
186    Float(f64),
187    Bool(bool),
188    String(String),
189}
190
191/// Script parser
192pub struct ScriptParser;
193
194impl ScriptParser {
195    /// Parse script code into AST
196    pub fn parse(code: &str) -> JitResult<ScriptAst> {
197        let mut parser = PythonParser::new(code);
198        parser.parse()
199    }
200}
201
202/// Python subset parser for JIT compilation
203pub struct PythonParser {
204    tokens: Vec<Token>,
205    current: usize,
206}
207
208/// Token types for Python parsing
209#[derive(Debug, Clone, PartialEq)]
210pub enum Token {
211    // Literals
212    Integer(i64),
213    Float(f64),
214    Boolean(bool),
215    String(String),
216
217    // Identifiers
218    Identifier(String),
219
220    // Keywords
221    Def,
222    If,
223    Else,
224    For,
225    In,
226    Return,
227    True,
228    False,
229
230    // Operators
231    Plus,
232    Minus,
233    Star,
234    Slash,
235    DoubleStar,
236    Equal,
237    EqualEqual,
238    NotEqual,
239    Less,
240    LessEqual,
241    Greater,
242    GreaterEqual,
243    And,
244    Or,
245    Not,
246
247    // Punctuation
248    LeftParen,
249    RightParen,
250    LeftBracket,
251    RightBracket,
252    Comma,
253    Colon,
254    Arrow,
255
256    // Special
257    Newline,
258    Indent,
259    Dedent,
260    Eof,
261}
262
263impl PythonParser {
264    /// Create a new Python parser
265    pub fn new(code: &str) -> Self {
266        let tokens = Self::tokenize(code);
267        Self { tokens, current: 0 }
268    }
269
270    /// Simple tokenizer for Python subset
271    fn tokenize(code: &str) -> Vec<Token> {
272        let mut tokens = Vec::new();
273        let mut chars = code.chars().peekable();
274
275        while let Some(&ch) = chars.peek() {
276            match ch {
277                ' ' | '\t' => {
278                    chars.next();
279                }
280                '\n' => {
281                    chars.next();
282                    tokens.push(Token::Newline);
283                }
284                '(' => {
285                    chars.next();
286                    tokens.push(Token::LeftParen);
287                }
288                ')' => {
289                    chars.next();
290                    tokens.push(Token::RightParen);
291                }
292                '[' => {
293                    chars.next();
294                    tokens.push(Token::LeftBracket);
295                }
296                ']' => {
297                    chars.next();
298                    tokens.push(Token::RightBracket);
299                }
300                ',' => {
301                    chars.next();
302                    tokens.push(Token::Comma);
303                }
304                ':' => {
305                    chars.next();
306                    tokens.push(Token::Colon);
307                }
308                '+' => {
309                    chars.next();
310                    tokens.push(Token::Plus);
311                }
312                '-' => {
313                    chars.next();
314                    if chars.peek() == Some(&'>') {
315                        chars.next();
316                        tokens.push(Token::Arrow);
317                    } else {
318                        tokens.push(Token::Minus);
319                    }
320                }
321                '*' => {
322                    chars.next();
323                    if chars.peek() == Some(&'*') {
324                        chars.next();
325                        tokens.push(Token::DoubleStar);
326                    } else {
327                        tokens.push(Token::Star);
328                    }
329                }
330                '/' => {
331                    chars.next();
332                    tokens.push(Token::Slash);
333                }
334                '=' => {
335                    chars.next();
336                    if chars.peek() == Some(&'=') {
337                        chars.next();
338                        tokens.push(Token::EqualEqual);
339                    } else {
340                        tokens.push(Token::Equal);
341                    }
342                }
343                '!' => {
344                    chars.next();
345                    if chars.peek() == Some(&'=') {
346                        chars.next();
347                        tokens.push(Token::NotEqual);
348                    }
349                }
350                '<' => {
351                    chars.next();
352                    if chars.peek() == Some(&'=') {
353                        chars.next();
354                        tokens.push(Token::LessEqual);
355                    } else {
356                        tokens.push(Token::Less);
357                    }
358                }
359                '>' => {
360                    chars.next();
361                    if chars.peek() == Some(&'=') {
362                        chars.next();
363                        tokens.push(Token::GreaterEqual);
364                    } else {
365                        tokens.push(Token::Greater);
366                    }
367                }
368                '"' => {
369                    chars.next();
370                    let mut string_val = String::new();
371                    while let Some(&ch) = chars.peek() {
372                        if ch == '"' {
373                            chars.next();
374                            break;
375                        }
376                        string_val.push(ch);
377                        chars.next();
378                    }
379                    tokens.push(Token::String(string_val));
380                }
381                c if c.is_ascii_digit() => {
382                    let mut number = String::new();
383                    let mut is_float = false;
384                    while let Some(&ch) = chars.peek() {
385                        if ch.is_ascii_digit() {
386                            number.push(ch);
387                            chars.next();
388                        } else if ch == '.' && !is_float {
389                            is_float = true;
390                            number.push(ch);
391                            chars.next();
392                        } else {
393                            break;
394                        }
395                    }
396
397                    if is_float {
398                        if let Ok(val) = number.parse::<f64>() {
399                            tokens.push(Token::Float(val));
400                        }
401                    } else if let Ok(val) = number.parse::<i64>() {
402                        tokens.push(Token::Integer(val));
403                    }
404                }
405                c if c.is_ascii_alphabetic() || c == '_' => {
406                    let mut ident = String::new();
407                    while let Some(&ch) = chars.peek() {
408                        if ch.is_ascii_alphanumeric() || ch == '_' {
409                            ident.push(ch);
410                            chars.next();
411                        } else {
412                            break;
413                        }
414                    }
415
416                    let token = match ident.as_str() {
417                        "def" => Token::Def,
418                        "if" => Token::If,
419                        "else" => Token::Else,
420                        "for" => Token::For,
421                        "in" => Token::In,
422                        "return" => Token::Return,
423                        "True" => Token::Boolean(true),
424                        "False" => Token::Boolean(false),
425                        "and" => Token::And,
426                        "or" => Token::Or,
427                        "not" => Token::Not,
428                        _ => Token::Identifier(ident),
429                    };
430                    tokens.push(token);
431                }
432                _ => {
433                    chars.next(); // Skip unknown characters
434                }
435            }
436        }
437
438        tokens.push(Token::Eof);
439        tokens
440    }
441
442    /// Parse tokens into AST
443    pub fn parse(&mut self) -> JitResult<ScriptAst> {
444        let mut statements = Vec::new();
445
446        while !self.is_at_end() {
447            if self.match_token(&Token::Newline) {
448                continue;
449            }
450            statements.push(self.parse_statement()?);
451        }
452
453        Ok(ScriptAst::Block(statements))
454    }
455
456    /// Parse a statement
457    fn parse_statement(&mut self) -> JitResult<ScriptAst> {
458        if self.match_token(&Token::Def) {
459            self.parse_function()
460        } else if self.match_token(&Token::Return) {
461            let expr = self.parse_expression()?;
462            Ok(ScriptAst::Return(Box::new(expr)))
463        } else if self.match_token(&Token::If) {
464            self.parse_if()
465        } else if self.match_token(&Token::For) {
466            self.parse_for()
467        } else {
468            // Assignment or expression statement
469            let expr = self.parse_expression()?;
470            if self.match_token(&Token::Equal) {
471                if let ScriptAst::Var(name) = expr {
472                    let value = self.parse_expression()?;
473                    Ok(ScriptAst::Let {
474                        name,
475                        type_ann: None,
476                        value: Box::new(value),
477                    })
478                } else {
479                    Err(JitError::CompilationError(
480                        "Invalid assignment target".to_string(),
481                    ))
482                }
483            } else {
484                Ok(expr)
485            }
486        }
487    }
488
489    /// Parse function definition
490    fn parse_function(&mut self) -> JitResult<ScriptAst> {
491        let name = if let Some(Token::Identifier(name)) = self.advance() {
492            name.clone()
493        } else {
494            return Err(JitError::CompilationError(
495                "Expected function name".to_string(),
496            ));
497        };
498
499        self.consume(&Token::LeftParen, "Expected '(' after function name")?;
500
501        let mut params = Vec::new();
502        while !self.check(&Token::RightParen) && !self.is_at_end() {
503            if let Some(Token::Identifier(param_name)) = self.advance() {
504                // For now, default all parameters to float tensors
505                params.push(Parameter {
506                    name: param_name.clone(),
507                    type_ann: TypeAnnotation::Tensor {
508                        dtype: DType::F32,
509                        shape: vec![], // Will be inferred
510                    },
511                });
512
513                if !self.check(&Token::RightParen) {
514                    self.consume(&Token::Comma, "Expected ',' between parameters")?;
515                }
516            }
517        }
518
519        self.consume(&Token::RightParen, "Expected ')' after parameters")?;
520        self.consume(&Token::Colon, "Expected ':' after function signature")?;
521
522        let body = self.parse_block()?;
523
524        Ok(ScriptAst::Function {
525            name,
526            params,
527            return_type: None,
528            body: Box::new(body),
529        })
530    }
531
532    /// Parse if statement
533    fn parse_if(&mut self) -> JitResult<ScriptAst> {
534        let condition = self.parse_expression()?;
535        self.consume(&Token::Colon, "Expected ':' after if condition")?;
536
537        let then_branch = self.parse_block()?;
538
539        let else_branch = if self.match_token(&Token::Else) {
540            self.consume(&Token::Colon, "Expected ':' after else")?;
541            Some(Box::new(self.parse_block()?))
542        } else {
543            None
544        };
545
546        Ok(ScriptAst::If {
547            condition: Box::new(condition),
548            then_branch: Box::new(then_branch),
549            else_branch,
550        })
551    }
552
553    /// Parse for loop
554    fn parse_for(&mut self) -> JitResult<ScriptAst> {
555        let var = if let Some(Token::Identifier(name)) = self.advance() {
556            name.clone()
557        } else {
558            return Err(JitError::CompilationError(
559                "Expected variable name in for loop".to_string(),
560            ));
561        };
562
563        self.consume(&Token::In, "Expected 'in' in for loop")?;
564        let iter = self.parse_expression()?;
565        self.consume(&Token::Colon, "Expected ':' after for loop header")?;
566
567        let body = self.parse_block()?;
568
569        Ok(ScriptAst::For {
570            var,
571            iter: Box::new(iter),
572            body: Box::new(body),
573        })
574    }
575
576    /// Parse a block of statements
577    fn parse_block(&mut self) -> JitResult<ScriptAst> {
578        let mut statements = Vec::new();
579
580        // Simple block parsing - in a real implementation this would handle indentation
581        while !self.is_at_end() && !self.check(&Token::Else) && !self.check(&Token::Def) {
582            if self.match_token(&Token::Newline) {
583                continue;
584            }
585            statements.push(self.parse_statement()?);
586            break; // For simplicity, just parse one statement per block
587        }
588
589        Ok(ScriptAst::Block(statements))
590    }
591
592    /// Parse expression
593    fn parse_expression(&mut self) -> JitResult<ScriptAst> {
594        self.parse_or()
595    }
596
597    /// Parse logical OR
598    fn parse_or(&mut self) -> JitResult<ScriptAst> {
599        let mut expr = self.parse_and()?;
600
601        while self.match_token(&Token::Or) {
602            let right = self.parse_and()?;
603            expr = ScriptAst::BinOp {
604                op: BinaryOp::Or,
605                left: Box::new(expr),
606                right: Box::new(right),
607            };
608        }
609
610        Ok(expr)
611    }
612
613    /// Parse logical AND
614    fn parse_and(&mut self) -> JitResult<ScriptAst> {
615        let mut expr = self.parse_equality()?;
616
617        while self.match_token(&Token::And) {
618            let right = self.parse_equality()?;
619            expr = ScriptAst::BinOp {
620                op: BinaryOp::And,
621                left: Box::new(expr),
622                right: Box::new(right),
623            };
624        }
625
626        Ok(expr)
627    }
628
629    /// Parse equality operations
630    fn parse_equality(&mut self) -> JitResult<ScriptAst> {
631        let mut expr = self.parse_comparison()?;
632
633        while let Some(op) = self.match_equality_op() {
634            let right = self.parse_comparison()?;
635            expr = ScriptAst::BinOp {
636                op,
637                left: Box::new(expr),
638                right: Box::new(right),
639            };
640        }
641
642        Ok(expr)
643    }
644
645    /// Parse comparison operations
646    fn parse_comparison(&mut self) -> JitResult<ScriptAst> {
647        let mut expr = self.parse_term()?;
648
649        while let Some(op) = self.match_comparison_op() {
650            let right = self.parse_term()?;
651            expr = ScriptAst::BinOp {
652                op,
653                left: Box::new(expr),
654                right: Box::new(right),
655            };
656        }
657
658        Ok(expr)
659    }
660
661    /// Parse addition and subtraction
662    fn parse_term(&mut self) -> JitResult<ScriptAst> {
663        let mut expr = self.parse_factor()?;
664
665        while self.check(&Token::Plus) || self.check(&Token::Minus) {
666            let op = if self.match_token(&Token::Plus) {
667                BinaryOp::Add
668            } else {
669                self.advance();
670                BinaryOp::Sub
671            };
672
673            let right = self.parse_factor()?;
674            expr = ScriptAst::BinOp {
675                op,
676                left: Box::new(expr),
677                right: Box::new(right),
678            };
679        }
680
681        Ok(expr)
682    }
683
684    /// Parse multiplication, division, and power
685    fn parse_factor(&mut self) -> JitResult<ScriptAst> {
686        let mut expr = self.parse_unary()?;
687
688        while self.check(&Token::Star)
689            || self.check(&Token::Slash)
690            || self.check(&Token::DoubleStar)
691        {
692            let op = if self.match_token(&Token::Star) {
693                BinaryOp::Mul
694            } else if self.match_token(&Token::Slash) {
695                BinaryOp::Div
696            } else {
697                self.advance();
698                BinaryOp::Pow
699            };
700
701            let right = self.parse_unary()?;
702            expr = ScriptAst::BinOp {
703                op,
704                left: Box::new(expr),
705                right: Box::new(right),
706            };
707        }
708
709        Ok(expr)
710    }
711
712    /// Parse unary operations
713    fn parse_unary(&mut self) -> JitResult<ScriptAst> {
714        if self.match_token(&Token::Not) {
715            let operand = self.parse_unary()?;
716            Ok(ScriptAst::UnaryOp {
717                op: UnaryOp::Not,
718                operand: Box::new(operand),
719            })
720        } else if self.match_token(&Token::Minus) {
721            let operand = self.parse_unary()?;
722            Ok(ScriptAst::UnaryOp {
723                op: UnaryOp::Neg,
724                operand: Box::new(operand),
725            })
726        } else {
727            self.parse_call()
728        }
729    }
730
731    /// Parse function calls
732    fn parse_call(&mut self) -> JitResult<ScriptAst> {
733        let mut expr = self.parse_primary()?;
734
735        while self.match_token(&Token::LeftParen) {
736            let mut args = Vec::new();
737            while !self.check(&Token::RightParen) && !self.is_at_end() {
738                args.push(self.parse_expression()?);
739                if !self.check(&Token::RightParen) {
740                    self.consume(&Token::Comma, "Expected ',' between arguments")?;
741                }
742            }
743            self.consume(&Token::RightParen, "Expected ')' after arguments")?;
744
745            if let ScriptAst::Var(func_name) = expr {
746                expr = ScriptAst::Call {
747                    func: func_name,
748                    args,
749                };
750            }
751        }
752
753        Ok(expr)
754    }
755
756    /// Parse primary expressions
757    fn parse_primary(&mut self) -> JitResult<ScriptAst> {
758        if let Some(token) = self.advance() {
759            match token {
760                Token::Integer(val) => Ok(ScriptAst::Literal(LiteralValue::Int(*val))),
761                Token::Float(val) => Ok(ScriptAst::Literal(LiteralValue::Float(*val))),
762                Token::Boolean(val) => Ok(ScriptAst::Literal(LiteralValue::Bool(*val))),
763                Token::String(val) => Ok(ScriptAst::Literal(LiteralValue::String(val.clone()))),
764                Token::Identifier(name) => Ok(ScriptAst::Var(name.clone())),
765                Token::LeftParen => {
766                    let expr = self.parse_expression()?;
767                    self.consume(&Token::RightParen, "Expected ')' after expression")?;
768                    Ok(expr)
769                }
770                _ => Err(JitError::CompilationError(
771                    "Unexpected token in expression".to_string(),
772                )),
773            }
774        } else {
775            Err(JitError::CompilationError(
776                "Unexpected end of input".to_string(),
777            ))
778        }
779    }
780
781    /// Match equality operators
782    fn match_equality_op(&mut self) -> Option<BinaryOp> {
783        if self.match_token(&Token::EqualEqual) {
784            Some(BinaryOp::Eq)
785        } else if self.match_token(&Token::NotEqual) {
786            Some(BinaryOp::Ne)
787        } else {
788            None
789        }
790    }
791
792    /// Match comparison operators
793    fn match_comparison_op(&mut self) -> Option<BinaryOp> {
794        if self.match_token(&Token::Greater) {
795            Some(BinaryOp::Gt)
796        } else if self.match_token(&Token::GreaterEqual) {
797            Some(BinaryOp::Ge)
798        } else if self.match_token(&Token::Less) {
799            Some(BinaryOp::Lt)
800        } else if self.match_token(&Token::LessEqual) {
801            Some(BinaryOp::Le)
802        } else {
803            None
804        }
805    }
806
807    /// Helper methods for parsing
808    fn match_token(&mut self, expected: &Token) -> bool {
809        if self.check(expected) {
810            self.advance();
811            true
812        } else {
813            false
814        }
815    }
816
817    fn check(&self, expected: &Token) -> bool {
818        if self.is_at_end() {
819            false
820        } else {
821            std::mem::discriminant(&self.tokens[self.current]) == std::mem::discriminant(expected)
822        }
823    }
824
825    fn advance(&mut self) -> Option<&Token> {
826        if !self.is_at_end() {
827            self.current += 1;
828        }
829        self.previous()
830    }
831
832    fn is_at_end(&self) -> bool {
833        self.current >= self.tokens.len()
834            || matches!(self.tokens.get(self.current), Some(Token::Eof))
835    }
836
837    fn previous(&self) -> Option<&Token> {
838        self.tokens.get(self.current.saturating_sub(1))
839    }
840
841    fn consume(&mut self, expected: &Token, message: &str) -> JitResult<()> {
842        if self.check(expected) {
843            self.advance();
844            Ok(())
845        } else {
846            Err(JitError::CompilationError(message.to_string()))
847        }
848    }
849}
850
851/// Convert script AST to computation graph
852pub struct AstToGraphConverter {
853    graph: ComputationGraph,
854    var_map: HashMap<String, NodeId>,
855    next_id: usize,
856}
857
858impl Default for AstToGraphConverter {
859    fn default() -> Self {
860        Self::new()
861    }
862}
863
864impl AstToGraphConverter {
865    /// Create a new converter
866    pub fn new() -> Self {
867        Self {
868            graph: ComputationGraph::new(),
869            var_map: HashMap::new(),
870            next_id: 0,
871        }
872    }
873
874    /// Convert AST to computation graph
875    pub fn convert(&mut self, ast: ScriptAst) -> JitResult<ComputationGraph> {
876        self.convert_ast(ast)?;
877        Ok(self.graph.clone())
878    }
879
880    /// Convert an AST node
881    fn convert_ast(&mut self, ast: ScriptAst) -> JitResult<NodeId> {
882        match ast {
883            ScriptAst::BinOp { op, left, right } => {
884                let left_id = self.convert_ast(*left)?;
885                let right_id = self.convert_ast(*right)?;
886                self.create_binop_node(op, left_id, right_id)
887            }
888            ScriptAst::UnaryOp { op, operand } => {
889                let operand_id = self.convert_ast(*operand)?;
890                self.create_unaryop_node(op, operand_id)
891            }
892            ScriptAst::Call { func, args } => {
893                let arg_ids: Vec<_> = args
894                    .into_iter()
895                    .map(|arg| self.convert_ast(arg))
896                    .collect::<JitResult<Vec<_>>>()?;
897                self.create_call_node(func, arg_ids)
898            }
899            ScriptAst::Var(name) => self
900                .var_map
901                .get(&name)
902                .copied()
903                .ok_or_else(|| JitError::GraphError(format!("Undefined variable: {}", name))),
904            ScriptAst::Literal(lit) => self.create_literal_node(lit),
905            ScriptAst::Let { name, value, .. } => {
906                let value_id = self.convert_ast(*value)?;
907                self.var_map.insert(name, value_id);
908                Ok(value_id)
909            }
910            ScriptAst::Block(stmts) => {
911                let mut last_id = None;
912                for stmt in stmts {
913                    last_id = Some(self.convert_ast(stmt)?);
914                }
915                last_id.ok_or_else(|| JitError::GraphError("Empty block".to_string()))
916            }
917            _ => Err(JitError::GraphError("Unsupported AST node".to_string())),
918        }
919    }
920
921    /// Create a binary operation node
922    fn create_binop_node(
923        &mut self,
924        op: BinaryOp,
925        left: NodeId,
926        right: NodeId,
927    ) -> JitResult<NodeId> {
928        use crate::graph::{Edge, Operation};
929        use torsh_core::DeviceType;
930
931        let operation = match op {
932            BinaryOp::Add => Operation::Add,
933            BinaryOp::Sub => Operation::Sub,
934            BinaryOp::Mul => Operation::Mul,
935            BinaryOp::Div => Operation::Div,
936            _ => return Err(JitError::UnsupportedOp(format!("{:?}", op))),
937        };
938
939        let mut node = Node::new(operation, format!("binop_{}", self.next_id));
940        node.device = DeviceType::Cpu;
941        node.inputs = vec![];
942        node.is_output = false;
943
944        let node_id = self.graph.add_node(node);
945        self.graph.add_edge(left, node_id, Edge::default());
946        self.graph.add_edge(right, node_id, Edge::default());
947        self.next_id += 1;
948        Ok(node_id)
949    }
950
951    /// Create a unary operation node
952    fn create_unaryop_node(&mut self, op: UnaryOp, operand: NodeId) -> JitResult<NodeId> {
953        use crate::graph::{Edge, Operation};
954        use torsh_core::DeviceType;
955
956        let operation = match op {
957            UnaryOp::Neg => Operation::Neg,
958            _ => return Err(JitError::UnsupportedOp(format!("{:?}", op))),
959        };
960
961        let mut node = Node::new(operation, format!("unaryop_{}", self.next_id));
962        node.device = DeviceType::Cpu;
963        node.inputs = vec![];
964        node.is_output = false;
965
966        let node_id = self.graph.add_node(node);
967        self.graph.add_edge(operand, node_id, Edge::default());
968        self.next_id += 1;
969        Ok(node_id)
970    }
971
972    /// Create a function call node
973    fn create_call_node(&mut self, func: String, args: Vec<NodeId>) -> JitResult<NodeId> {
974        use crate::graph::{Edge, Operation};
975        use torsh_core::DeviceType;
976
977        let operation = match func.as_str() {
978            "relu" => Operation::Relu,
979            "sigmoid" => Operation::Sigmoid,
980            "tanh" => Operation::Tanh,
981            "matmul" => Operation::MatMul,
982            _ => Operation::Custom(func),
983        };
984
985        let mut node = Node::new(operation, format!("call_{}", self.next_id));
986        node.device = DeviceType::Cpu;
987        node.inputs = vec![];
988        node.is_output = false;
989
990        let node_id = self.graph.add_node(node);
991        for (i, arg_id) in args.iter().enumerate() {
992            let edge = Edge {
993                src_output: 0,
994                dst_input: i,
995            };
996            self.graph.add_edge(*arg_id, node_id, edge);
997        }
998        self.next_id += 1;
999        Ok(node_id)
1000    }
1001
1002    /// Create a literal node
1003    fn create_literal_node(&mut self, lit: LiteralValue) -> JitResult<NodeId> {
1004        use crate::graph::{Attribute, ConstantInfo, ConstantValue, Operation};
1005        use torsh_core::DeviceType;
1006
1007        let (dtype, constant_value) = match lit {
1008            LiteralValue::Int(v) => (DType::I64, ConstantValue::IntScalar(v)),
1009            LiteralValue::Float(v) => (DType::F32, ConstantValue::Scalar(v)),
1010            LiteralValue::Bool(v) => (DType::Bool, ConstantValue::IntScalar(if v { 1 } else { 0 })),
1011            LiteralValue::String(v) => {
1012                // String literals need special handling
1013                let mut node = Node::new(
1014                    Operation::Custom("string_literal".to_string()),
1015                    format!("string_literal_{}", self.next_id),
1016                );
1017                node.device = DeviceType::Cpu;
1018                node.attrs.insert("value".to_string(), Attribute::String(v));
1019                node.inputs = vec![];
1020                node.is_output = false;
1021                let node_id = self.graph.add_node(node);
1022                self.next_id += 1;
1023                return Ok(node_id);
1024            }
1025        };
1026
1027        let mut node = Node::new(
1028            Operation::Constant(ConstantInfo {
1029                value: constant_value,
1030            }),
1031            format!("constant_{}", self.next_id),
1032        );
1033        node.device = DeviceType::Cpu;
1034        node.inputs = vec![];
1035        node.is_output = false;
1036
1037        let node_id = self.graph.add_node(node);
1038        self.next_id += 1;
1039        Ok(node_id)
1040    }
1041}
1042
1043/// Export a compiled module to TorchScript format
1044pub fn export_torchscript(module: &CompiledModule, path: &str) -> JitResult<()> {
1045    use std::fs::File;
1046    use std::io::Write;
1047
1048    // Create TorchScript representation
1049    let ts_repr = TorchScriptModule {
1050        version: 1,
1051        graph: module.graph.clone(),
1052        constants: extract_constants_from_graph(&module.graph),
1053        metadata: create_metadata_from_module(module),
1054    };
1055
1056    // Convert to TorchScript IR format
1057    let torchscript_ir = generate_torchscript_ir(&ts_repr)?;
1058
1059    // Write to file
1060    let mut file = File::create(path)
1061        .map_err(|e| JitError::RuntimeError(format!("Failed to create file {}: {}", path, e)))?;
1062
1063    file.write_all(torchscript_ir.as_bytes())
1064        .map_err(|e| JitError::RuntimeError(format!("Failed to write file {}: {}", path, e)))?;
1065
1066    Ok(())
1067}
1068
1069/// Import a module from TorchScript format
1070pub fn import_torchscript(path: &str, config: JitConfig) -> JitResult<CompiledModule> {
1071    use std::fs::File;
1072    use std::io::Read;
1073
1074    // Read file
1075    let mut file = File::open(path)
1076        .map_err(|e| JitError::RuntimeError(format!("Failed to open file {}: {}", path, e)))?;
1077
1078    let mut contents = String::new();
1079    file.read_to_string(&mut contents)
1080        .map_err(|e| JitError::RuntimeError(format!("Failed to read file {}: {}", path, e)))?;
1081
1082    // Parse TorchScript IR
1083    let ts_module = parse_torchscript_ir(&contents)?;
1084
1085    // Convert to our internal representation
1086    let mut jit_compiler = JitCompiler::new(config);
1087    let compiled_module = jit_compiler.compile(ts_module.graph)?;
1088
1089    Ok(compiled_module)
1090}
1091
1092/// TorchScript module representation for serialization
1093#[derive(Debug, Clone)]
1094struct TorchScriptModule {
1095    version: u32,
1096    graph: ComputationGraph,
1097    constants: HashMap<String, Vec<f32>>,
1098    metadata: HashMap<String, String>,
1099}
1100
1101/// Extract constants from computation graph
1102fn extract_constants_from_graph(graph: &ComputationGraph) -> HashMap<String, Vec<f32>> {
1103    use crate::graph::{ConstantValue, Operation};
1104
1105    let mut constants = HashMap::new();
1106
1107    for (node_id, node) in graph.nodes() {
1108        if let Operation::Constant(ref const_info) = node.op {
1109            let const_name = format!("const_{:?}", node_id);
1110            match &const_info.value {
1111                ConstantValue::Scalar(val) => {
1112                    constants.insert(const_name, vec![*val as f32]);
1113                }
1114                ConstantValue::IntScalar(val) => {
1115                    constants.insert(const_name, vec![*val as f32]);
1116                }
1117                ConstantValue::Tensor {
1118                    shape: _,
1119                    data,
1120                    dtype: _,
1121                } => {
1122                    constants.insert(const_name, data.iter().map(|&x| x as f32).collect());
1123                }
1124                ConstantValue::Bool(val) => {
1125                    constants.insert(const_name, vec![if *val { 1.0 } else { 0.0 }]);
1126                }
1127                ConstantValue::Int(val) => {
1128                    constants.insert(const_name, vec![*val as f32]);
1129                }
1130                ConstantValue::UInt(val) => {
1131                    constants.insert(const_name, vec![*val as f32]);
1132                }
1133                ConstantValue::Float(val) => {
1134                    constants.insert(const_name, vec![*val as f32]);
1135                }
1136                ConstantValue::String(_) => {
1137                    constants.insert(const_name, vec![0.0]); // String as placeholder
1138                }
1139                ConstantValue::FloatArray(arr) => {
1140                    constants.insert(const_name, arr.clone());
1141                }
1142                ConstantValue::IntArray(arr) => {
1143                    constants.insert(const_name, arr.iter().map(|&x| x as f32).collect());
1144                }
1145                ConstantValue::Array(arr) => {
1146                    // Convert array of values to f32 - simplified
1147                    constants.insert(const_name, vec![arr.len() as f32]);
1148                }
1149                ConstantValue::Complex { real, imag: _ } => {
1150                    constants.insert(const_name, vec![*real as f32]);
1151                }
1152                ConstantValue::None => {
1153                    constants.insert(const_name, vec![0.0]);
1154                }
1155                ConstantValue::Undefined => {
1156                    constants.insert(const_name, vec![0.0]);
1157                }
1158            }
1159        }
1160    }
1161
1162    constants
1163}
1164
1165/// Create metadata from compiled module
1166fn create_metadata_from_module(module: &CompiledModule) -> HashMap<String, String> {
1167    let mut metadata = HashMap::new();
1168
1169    metadata.insert("producer".to_string(), "torsh-jit".to_string());
1170    metadata.insert("producer_version".to_string(), "0.1.0".to_string());
1171    metadata.insert("graph_name".to_string(), "main".to_string());
1172    metadata.insert(
1173        "node_count".to_string(),
1174        module.graph.node_count().to_string(),
1175    );
1176    metadata.insert(
1177        "edge_count".to_string(),
1178        module.graph.edge_count().to_string(),
1179    );
1180
1181    metadata
1182}
1183
1184/// Generate TorchScript IR from TorchScript module
1185fn generate_torchscript_ir(ts_module: &TorchScriptModule) -> JitResult<String> {
1186    use crate::graph::{ConstantValue, Operation};
1187
1188    let mut ir = String::new();
1189
1190    // Header
1191    ir.push_str(&format!("graph():\n"));
1192
1193    // Constants section
1194    for (name, values) in &ts_module.constants {
1195        ir.push_str(&format!(
1196            "  %{} : Float({}) = prim::Constant[value={}]()\n",
1197            name,
1198            values.len(),
1199            format_tensor_values(values)
1200        ));
1201    }
1202
1203    // Nodes section
1204    let mut output_counter = 0;
1205    for (node_id, node) in ts_module.graph.nodes() {
1206        match &node.op {
1207            Operation::Add => {
1208                let inputs = get_node_inputs(&ts_module.graph, node_id);
1209                ir.push_str(&format!(
1210                    "  %{} : Float = aten::add({}, {})\n",
1211                    output_counter, inputs[0], inputs[1]
1212                ));
1213            }
1214            Operation::Mul => {
1215                let inputs = get_node_inputs(&ts_module.graph, node_id);
1216                ir.push_str(&format!(
1217                    "  %{} : Float = aten::mul({}, {})\n",
1218                    output_counter, inputs[0], inputs[1]
1219                ));
1220            }
1221            Operation::MatMul => {
1222                let inputs = get_node_inputs(&ts_module.graph, node_id);
1223                ir.push_str(&format!(
1224                    "  %{} : Float = aten::mm({}, {})\n",
1225                    output_counter, inputs[0], inputs[1]
1226                ));
1227            }
1228            Operation::Relu => {
1229                let inputs = get_node_inputs(&ts_module.graph, node_id);
1230                ir.push_str(&format!(
1231                    "  %{} : Float = aten::relu({})\n",
1232                    output_counter, inputs[0]
1233                ));
1234            }
1235            Operation::Sigmoid => {
1236                let inputs = get_node_inputs(&ts_module.graph, node_id);
1237                ir.push_str(&format!(
1238                    "  %{} : Float = aten::sigmoid({})\n",
1239                    output_counter, inputs[0]
1240                ));
1241            }
1242            Operation::Constant(const_info) => match &const_info.value {
1243                ConstantValue::Scalar(val) => {
1244                    ir.push_str(&format!(
1245                        "  %{} : Float = prim::Constant[value={}]()\n",
1246                        output_counter, val
1247                    ));
1248                }
1249                ConstantValue::IntScalar(val) => {
1250                    ir.push_str(&format!(
1251                        "  %{} : int = prim::Constant[value={}]()\n",
1252                        output_counter, val
1253                    ));
1254                }
1255                ConstantValue::Tensor {
1256                    shape: _,
1257                    data,
1258                    dtype: _,
1259                } => {
1260                    let data_f32: Vec<f32> = data.iter().map(|&x| x as f32).collect();
1261                    ir.push_str(&format!(
1262                        "  %{} : Float = prim::Constant[value={}]()\n",
1263                        output_counter,
1264                        format_tensor_values(&data_f32)
1265                    ));
1266                }
1267                ConstantValue::Bool(val) => {
1268                    ir.push_str(&format!(
1269                        "  %{} : bool = prim::Constant[value={}]()\n",
1270                        output_counter, val
1271                    ));
1272                }
1273                ConstantValue::Int(val) => {
1274                    ir.push_str(&format!(
1275                        "  %{} : int = prim::Constant[value={}]()\n",
1276                        output_counter, val
1277                    ));
1278                }
1279                ConstantValue::UInt(val) => {
1280                    ir.push_str(&format!(
1281                        "  %{} : int = prim::Constant[value={}]()\n",
1282                        output_counter, val
1283                    ));
1284                }
1285                ConstantValue::Float(val) => {
1286                    ir.push_str(&format!(
1287                        "  %{} : Float = prim::Constant[value={}]()\n",
1288                        output_counter, val
1289                    ));
1290                }
1291                ConstantValue::String(val) => {
1292                    ir.push_str(&format!(
1293                        "  %{} : str = prim::Constant[value=\"{}\"]()\n",
1294                        output_counter, val
1295                    ));
1296                }
1297                ConstantValue::FloatArray(arr) => {
1298                    ir.push_str(&format!(
1299                        "  %{} : Float[] = prim::Constant[value={}]()\n",
1300                        output_counter,
1301                        format_tensor_values(arr)
1302                    ));
1303                }
1304                ConstantValue::IntArray(arr) => {
1305                    let arr_str = arr
1306                        .iter()
1307                        .map(|x| x.to_string())
1308                        .collect::<Vec<_>>()
1309                        .join(", ");
1310                    ir.push_str(&format!(
1311                        "  %{} : int[] = prim::Constant[value=[{}]]()\n",
1312                        output_counter, arr_str
1313                    ));
1314                }
1315                ConstantValue::Array(_) => {
1316                    ir.push_str(&format!(
1317                        "  %{} : Tensor = prim::Constant[value=<complex_array>]()\n",
1318                        output_counter
1319                    ));
1320                }
1321                ConstantValue::Complex { real, imag } => {
1322                    ir.push_str(&format!(
1323                        "  %{} : complex = prim::Constant[value={}+{}i]()\n",
1324                        output_counter, real, imag
1325                    ));
1326                }
1327                ConstantValue::None => {
1328                    ir.push_str(&format!(
1329                        "  %{} : NoneType = prim::Constant[value=None]()\n",
1330                        output_counter
1331                    ));
1332                }
1333                ConstantValue::Undefined => {
1334                    ir.push_str(&format!(
1335                        "  %{} : Tensor = prim::Constant[value=<undefined>]()\n",
1336                        output_counter
1337                    ));
1338                }
1339            },
1340            Operation::Custom(name) => {
1341                let inputs = get_node_inputs(&ts_module.graph, node_id);
1342                let input_str = inputs.join(", ");
1343                ir.push_str(&format!(
1344                    "  %{} : Float = custom::{}({})\n",
1345                    output_counter, name, input_str
1346                ));
1347            }
1348            _ => {
1349                // Generic operation handling
1350                let inputs = get_node_inputs(&ts_module.graph, node_id);
1351                let input_str = inputs.join(", ");
1352                ir.push_str(&format!(
1353                    "  %{} : Float = aten::{:?}({})\n",
1354                    output_counter, node.op, input_str
1355                ));
1356            }
1357        }
1358        output_counter += 1;
1359    }
1360
1361    // Return the last output
1362    if output_counter > 0 {
1363        ir.push_str(&format!("  return (%{})\n", output_counter - 1));
1364    } else {
1365        ir.push_str("  return ()\n");
1366    }
1367
1368    Ok(ir)
1369}
1370
1371/// Parse TorchScript IR into TorchScript module
1372fn parse_torchscript_ir(ir: &str) -> JitResult<TorchScriptModule> {
1373    let mut graph = ComputationGraph::new();
1374    let mut constants = HashMap::new();
1375    let mut metadata = HashMap::new();
1376
1377    // Simple line-based parser for TorchScript IR
1378    let lines: Vec<&str> = ir.lines().collect();
1379    let mut node_counter = 0;
1380
1381    for line in lines {
1382        let line = line.trim();
1383
1384        if line.starts_with('%') && line.contains("prim::Constant") {
1385            // Parse constant
1386            if let Some(value_start) = line.find("value=") {
1387                let value_part = &line[value_start + 6..];
1388                if let Some(value_end) = value_part.find(']') {
1389                    let value_str = &value_part[..value_end];
1390                    if let Ok(val) = value_str.parse::<f32>() {
1391                        let const_name = format!("const_{}", node_counter);
1392                        constants.insert(const_name, vec![val]);
1393
1394                        // Add constant node to graph
1395                        add_constant_node_to_graph(&mut graph, val, node_counter);
1396                        node_counter += 1;
1397                    }
1398                }
1399            }
1400        } else if line.starts_with('%') && line.contains("aten::") {
1401            // Parse operation
1402            parse_aten_operation(&mut graph, line, node_counter)?;
1403            node_counter += 1;
1404        }
1405    }
1406
1407    // Add default metadata
1408    metadata.insert("producer".to_string(), "torchscript".to_string());
1409    metadata.insert("version".to_string(), "1.0".to_string());
1410
1411    Ok(TorchScriptModule {
1412        version: 1,
1413        graph,
1414        constants,
1415        metadata,
1416    })
1417}
1418
1419/// Helper function to format tensor values for TorchScript IR
1420fn format_tensor_values(values: &[f32]) -> String {
1421    if values.len() == 1 {
1422        values[0].to_string()
1423    } else {
1424        format!(
1425            "[{}]",
1426            values
1427                .iter()
1428                .map(|v| v.to_string())
1429                .collect::<Vec<_>>()
1430                .join(", ")
1431        )
1432    }
1433}
1434
1435/// Helper function to get node inputs as strings
1436fn get_node_inputs(graph: &ComputationGraph, node_id: NodeId) -> Vec<String> {
1437    let mut inputs = Vec::new();
1438
1439    for edge in graph.edges_directed(node_id, petgraph::Direction::Incoming) {
1440        let src_id = edge.source();
1441        inputs.push(format!("%{:?}", src_id));
1442    }
1443
1444    // If no inputs, assume it's an input node
1445    if inputs.is_empty() {
1446        inputs.push(format!("%input_{:?}", node_id));
1447    }
1448
1449    inputs
1450}
1451
1452/// Add constant node to computation graph
1453fn add_constant_node_to_graph(graph: &mut ComputationGraph, value: f32, node_id: usize) {
1454    use crate::graph::{ConstantInfo, ConstantValue, Operation};
1455    use torsh_core::DeviceType;
1456
1457    let mut node = Node::new(
1458        Operation::Constant(ConstantInfo {
1459            value: ConstantValue::Scalar(value as f64),
1460        }),
1461        format!("const_{}", node_id),
1462    );
1463    node = node
1464        .with_output_shapes(vec![Some(Shape::new(vec![1]))])
1465        .with_dtypes(vec![DType::F32])
1466        .with_device(DeviceType::Cpu);
1467    node.inputs = vec![];
1468    node.is_output = false;
1469
1470    graph.add_node(node);
1471}
1472
1473/// Parse aten operation from TorchScript IR line
1474fn parse_aten_operation(graph: &mut ComputationGraph, line: &str, node_id: usize) -> JitResult<()> {
1475    use crate::graph::Operation;
1476    use torsh_core::DeviceType;
1477
1478    let operation = if line.contains("aten::add") {
1479        Operation::Add
1480    } else if line.contains("aten::mul") {
1481        Operation::Mul
1482    } else if line.contains("aten::mm") {
1483        Operation::MatMul
1484    } else if line.contains("aten::relu") {
1485        Operation::Relu
1486    } else if line.contains("aten::sigmoid") {
1487        Operation::Sigmoid
1488    } else {
1489        // Extract operation name
1490        if let Some(op_start) = line.find("aten::") {
1491            let op_part = &line[op_start + 6..];
1492            if let Some(op_end) = op_part.find('(') {
1493                let op_name = &op_part[..op_end];
1494                Operation::Custom(op_name.to_string())
1495            } else {
1496                Operation::Custom("unknown".to_string())
1497            }
1498        } else {
1499            Operation::Custom("unknown".to_string())
1500        }
1501    };
1502
1503    let mut node = Node::new(operation, format!("op_{}", node_id));
1504    node = node
1505        .with_output_shapes(vec![Some(Shape::new(vec![]))]) // Will be inferred
1506        .with_dtypes(vec![DType::F32])
1507        .with_device(DeviceType::Cpu);
1508    node.inputs = vec![];
1509    node.is_output = false;
1510
1511    graph.add_node(node);
1512    Ok(())
1513}
1514
1515/// Implementation of script function
1516pub fn script<M: ScriptableModule>(module: M) -> JitResult<CompiledModule> {
1517    let config = JitConfig::default();
1518    let mut compiler = ScriptCompiler::new(config);
1519    compiler.script(module)
1520}
1521
1522#[cfg(test)]
1523mod tests {
1524    use super::*;
1525
1526    #[test]
1527    fn test_type_annotation() {
1528        let tensor_ann = TypeAnnotation::Tensor {
1529            dtype: DType::F32,
1530            shape: vec![10, 20],
1531        };
1532
1533        match tensor_ann {
1534            TypeAnnotation::Tensor { dtype, shape } => {
1535                assert_eq!(dtype, DType::F32);
1536                assert_eq!(shape, vec![10, 20]);
1537            }
1538            _ => panic!("Wrong type annotation"),
1539        }
1540    }
1541
1542    #[test]
1543    fn test_ast_to_graph_converter() {
1544        let mut converter = AstToGraphConverter::new();
1545
1546        // Test literal conversion
1547        let lit_ast = ScriptAst::Literal(LiteralValue::Float(3.14));
1548        let result = converter.convert(lit_ast);
1549        assert!(result.is_ok());
1550    }
1551
1552    #[test]
1553    fn test_script_compiler_creation() {
1554        let config = JitConfig::default();
1555        let compiler = ScriptCompiler::new(config);
1556        assert!(compiler.type_annotations.is_empty());
1557    }
1558}