sklears_compose/workflow_language/
dsl_language.rs

1//! Domain-Specific Language (DSL) for Machine Learning Pipelines
2//!
3//! This module provides a text-based syntax for defining pipelines in a concise, readable format.
4//! Includes lexical analysis, parsing, and AST generation for the workflow DSL.
5//!
6//! Example DSL syntax:
7//! ```text
8//! pipeline "Customer Churn Prediction" {
9//!     version "1.0.0"
10//!     author "Data Science Team"
11//!
12//!     input features: Matrix<f64> [samples, features]
13//!     input labels: Array<f64> [samples]
14//!
15//!     step scaler: StandardScaler {
16//!         with_mean: true,
17//!         with_std: true
18//!     }
19//!
20//!     step model: RandomForestClassifier {
21//!         n_estimators: 100,
22//!         max_depth: 10,
23//!         random_state: 42
24//!     }
25//!
26//!     flow features -> scaler.X
27//!     flow scaler.X_scaled -> model.X
28//!     flow labels -> model.y
29//!
30//!     output predictions: model.predictions
31//!
32//!     execute {
33//!         mode: parallel,
34//!         workers: 4
35//!     }
36//! }
37//! ```
38
39use serde::{Deserialize, Serialize};
40use sklears_core::error::{Result as SklResult, SklearsError};
41use std::collections::{BTreeMap, VecDeque};
42
43use super::workflow_definitions::{
44    Connection, ConnectionType, DataType, ExecutionConfig, ExecutionMode, InputDefinition,
45    OutputDefinition, ParallelConfig, ParameterValue, StepDefinition, StepType, WorkflowDefinition,
46};
47
48/// Domain-Specific Language (DSL) for machine learning pipelines
49#[derive(Debug)]
50pub struct PipelineDSL {
51    /// DSL lexer
52    lexer: DslLexer,
53    /// DSL parser
54    parser: DslParser,
55}
56
57/// DSL Lexer for tokenizing pipeline definitions
58#[derive(Debug)]
59pub struct DslLexer {
60    /// Input text
61    input: String,
62    /// Cached input characters for efficient indexing
63    input_chars: Vec<char>,
64    /// Current position
65    position: usize,
66    /// Current line number
67    line: usize,
68    /// Current column number
69    column: usize,
70}
71
72/// DSL Parser for converting tokens to workflow definitions
73#[derive(Debug)]
74pub struct DslParser {
75    /// Token stream
76    tokens: VecDeque<Token>,
77    /// Current token index
78    current: usize,
79}
80
81/// Token types for the DSL
82#[derive(Debug, Clone, PartialEq)]
83pub enum Token {
84    // Keywords
85    /// Pipeline
86    Pipeline,
87    /// Version
88    Version,
89    /// Author
90    Author,
91    /// Description
92    Description,
93    /// Input
94    Input,
95    /// Output
96    Output,
97    /// Step
98    Step,
99    /// Flow
100    Flow,
101    /// Execute
102    Execute,
103
104    // Data types
105    /// Matrix
106    Matrix,
107    /// Array
108    Array,
109    /// Float32
110    Float32,
111    /// Float64
112    Float64,
113    /// Int32
114    Int32,
115    /// Int64
116    Int64,
117    /// Bool
118    Bool,
119    /// String
120    String,
121
122    // Literals
123    /// Identifier
124    Identifier(String),
125    /// StringLiteral
126    StringLiteral(String),
127    /// NumberLiteral
128    NumberLiteral(f64),
129    /// BooleanLiteral
130    BooleanLiteral(bool),
131
132    // Operators and punctuation
133    /// LeftBrace
134    LeftBrace,
135    /// RightBrace
136    RightBrace,
137    /// LeftBracket
138    LeftBracket,
139    /// RightBracket
140    RightBracket,
141    /// LeftParen
142    LeftParen,
143    /// RightParen
144    RightParen,
145    /// LeftAngle
146    LeftAngle,
147    /// RightAngle
148    RightAngle,
149    /// Comma
150    Comma,
151    /// Colon
152    Colon,
153    /// Semicolon
154    Semicolon,
155    /// Dot
156    Dot,
157    /// Arrow
158    Arrow,
159
160    // Special
161    /// Newline
162    Newline,
163    /// Eof
164    Eof,
165}
166
167/// DSL parsing error
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct DslError {
170    /// Error message
171    pub message: String,
172    /// Line number where error occurred
173    pub line: usize,
174    /// Column number where error occurred
175    pub column: usize,
176}
177
178impl PipelineDSL {
179    /// Create a new DSL processor
180    #[must_use]
181    pub fn new() -> Self {
182        Self {
183            lexer: DslLexer::new(),
184            parser: DslParser::new(),
185        }
186    }
187
188    /// Parse DSL text into a workflow definition
189    pub fn parse(&mut self, input: &str) -> SklResult<WorkflowDefinition> {
190        // Tokenize input
191        let tokens = self.lexer.tokenize(input)?;
192
193        // Parse tokens into workflow
194        self.parser.parse(tokens)
195    }
196
197    /// Generate DSL text from workflow definition
198    #[must_use]
199    pub fn generate(&self, workflow: &WorkflowDefinition) -> String {
200        let mut dsl = String::new();
201
202        // Pipeline header
203        dsl.push_str(&format!("pipeline \"{}\" {{\n", workflow.metadata.name));
204        dsl.push_str(&format!("    version \"{}\"\n", workflow.metadata.version));
205
206        if let Some(author) = &workflow.metadata.author {
207            dsl.push_str(&format!("    author \"{author}\"\n"));
208        }
209
210        if let Some(description) = &workflow.metadata.description {
211            dsl.push_str(&format!("    description \"{description}\"\n"));
212        }
213
214        dsl.push('\n');
215
216        // Inputs
217        for input in &workflow.inputs {
218            dsl.push_str(&format!(
219                "    input {}: {}\n",
220                input.name,
221                self.format_data_type(&input.data_type)
222            ));
223        }
224
225        if !workflow.inputs.is_empty() {
226            dsl.push('\n');
227        }
228
229        // Steps
230        for step in &workflow.steps {
231            dsl.push_str(&format!("    step {}: {} {{\n", step.id, step.algorithm));
232
233            for (param_name, param_value) in &step.parameters {
234                dsl.push_str(&format!(
235                    "        {}: {},\n",
236                    param_name,
237                    self.format_parameter_value(param_value)
238                ));
239            }
240
241            dsl.push_str("    }\n\n");
242        }
243
244        // Connections
245        for connection in &workflow.connections {
246            dsl.push_str(&format!(
247                "    flow {}.{} -> {}.{}\n",
248                connection.from_step,
249                connection.from_output,
250                connection.to_step,
251                connection.to_input
252            ));
253        }
254
255        if !workflow.connections.is_empty() {
256            dsl.push('\n');
257        }
258
259        // Outputs
260        for output in &workflow.outputs {
261            dsl.push_str(&format!("    output {}\n", output.name));
262        }
263
264        if !workflow.outputs.is_empty() {
265            dsl.push('\n');
266        }
267
268        // Execution configuration
269        if workflow.execution.mode != ExecutionMode::Sequential {
270            dsl.push_str("    execute {\n");
271            dsl.push_str(&format!("        mode: {:?},\n", workflow.execution.mode));
272
273            if let Some(parallel_config) = &workflow.execution.parallel {
274                dsl.push_str(&format!(
275                    "        workers: {}\n",
276                    parallel_config.num_workers
277                ));
278            }
279
280            dsl.push_str("    }\n");
281        }
282
283        dsl.push_str("}\n");
284        dsl
285    }
286
287    /// Format data type for DSL output
288    fn format_data_type(&self, data_type: &DataType) -> String {
289        match data_type {
290            DataType::Float32 => "f32".to_string(),
291            DataType::Float64 => "f64".to_string(),
292            DataType::Int32 => "i32".to_string(),
293            DataType::Int64 => "i64".to_string(),
294            DataType::Boolean => "bool".to_string(),
295            DataType::String => "String".to_string(),
296            DataType::Array(inner) => format!("Array<{}>", self.format_data_type(inner)),
297            DataType::Matrix(inner) => format!("Matrix<{}>", self.format_data_type(inner)),
298            DataType::Custom(name) => name.clone(),
299        }
300    }
301
302    /// Format parameter value for DSL output
303    fn format_parameter_value(&self, value: &ParameterValue) -> String {
304        match value {
305            ParameterValue::Float(f) => f.to_string(),
306            ParameterValue::Int(i) => i.to_string(),
307            ParameterValue::Bool(b) => b.to_string(),
308            ParameterValue::String(s) => format!("\"{s}\""),
309            ParameterValue::Array(arr) => {
310                let items: Vec<String> =
311                    arr.iter().map(|v| self.format_parameter_value(v)).collect();
312                format!("[{}]", items.join(", "))
313            }
314        }
315    }
316
317    /// Validate DSL syntax
318    pub fn validate_syntax(&mut self, input: &str) -> SklResult<Vec<DslError>> {
319        let mut errors = Vec::new();
320
321        // Attempt to tokenize
322        match self.lexer.tokenize(input) {
323            Ok(tokens) => {
324                // Attempt to parse
325                match self.parser.parse(tokens) {
326                    Ok(_) => {
327                        // Syntax is valid
328                    }
329                    Err(e) => {
330                        errors.push(DslError {
331                            message: e.to_string(),
332                            line: 1, // Parser should provide line/column info
333                            column: 1,
334                        });
335                    }
336                }
337            }
338            Err(e) => {
339                errors.push(DslError {
340                    message: e.to_string(),
341                    line: self.lexer.line,
342                    column: self.lexer.column,
343                });
344            }
345        }
346
347        Ok(errors)
348    }
349
350    /// Get syntax highlighting information
351    pub fn get_syntax_highlighting(&mut self, input: &str) -> Vec<SyntaxHighlight> {
352        let mut highlights = Vec::new();
353
354        if let Ok(tokens) = self.lexer.tokenize(input) {
355            let mut position = 0;
356
357            for token in tokens {
358                let (token_type, length) = match &token {
359                    Token::Pipeline
360                    | Token::Version
361                    | Token::Author
362                    | Token::Description
363                    | Token::Input
364                    | Token::Output
365                    | Token::Step
366                    | Token::Flow
367                    | Token::Execute => ("keyword", self.estimate_token_length(&token)),
368                    Token::Matrix
369                    | Token::Array
370                    | Token::Float32
371                    | Token::Float64
372                    | Token::Int32
373                    | Token::Int64
374                    | Token::Bool
375                    | Token::String => ("type", self.estimate_token_length(&token)),
376                    Token::StringLiteral(_) => ("string", self.estimate_token_length(&token)),
377                    Token::NumberLiteral(_) => ("number", self.estimate_token_length(&token)),
378                    Token::BooleanLiteral(_) => ("boolean", self.estimate_token_length(&token)),
379                    Token::Identifier(_) => ("identifier", self.estimate_token_length(&token)),
380                    _ => ("punctuation", self.estimate_token_length(&token)),
381                };
382
383                highlights.push(SyntaxHighlight {
384                    start: position,
385                    end: position + length,
386                    token_type: token_type.to_string(),
387                });
388
389                position += length;
390            }
391        }
392
393        highlights
394    }
395
396    /// Estimate token length for highlighting
397    fn estimate_token_length(&self, token: &Token) -> usize {
398        match token {
399            Token::Identifier(s) | Token::StringLiteral(s) => s.len(),
400            Token::NumberLiteral(n) => n.to_string().len(),
401            Token::BooleanLiteral(b) => b.to_string().len(),
402            Token::Pipeline => "pipeline".len(),
403            Token::Version => "version".len(),
404            Token::Author => "author".len(),
405            Token::Description => "description".len(),
406            Token::Input => "input".len(),
407            Token::Output => "output".len(),
408            Token::Step => "step".len(),
409            Token::Flow => "flow".len(),
410            Token::Execute => "execute".len(),
411            Token::Matrix => "Matrix".len(),
412            Token::Array => "Array".len(),
413            Token::Float32 => "f32".len(),
414            Token::Float64 => "f64".len(),
415            Token::Int32 => "i32".len(),
416            Token::Int64 => "i64".len(),
417            Token::Bool => "bool".len(),
418            Token::String => "String".len(),
419            Token::Arrow => "->".len(),
420            _ => 1,
421        }
422    }
423}
424
425/// Syntax highlighting information
426#[derive(Debug, Clone)]
427pub struct SyntaxHighlight {
428    /// Start position in text
429    pub start: usize,
430    /// End position in text
431    pub end: usize,
432    /// Token type for styling
433    pub token_type: String,
434}
435
436impl DslLexer {
437    /// Create a new lexer
438    #[must_use]
439    pub fn new() -> Self {
440        Self {
441            input: String::new(),
442            input_chars: Vec::new(),
443            position: 0,
444            line: 1,
445            column: 1,
446        }
447    }
448
449    /// Tokenize input text
450    pub fn tokenize(&mut self, input: &str) -> SklResult<VecDeque<Token>> {
451        self.input = input.to_string();
452        self.input_chars = self.input.chars().collect();
453        self.position = 0;
454        self.line = 1;
455        self.column = 1;
456
457        let mut tokens = VecDeque::new();
458
459        while !self.is_at_end() {
460            self.skip_whitespace();
461
462            if self.is_at_end() {
463                break;
464            }
465
466            let token = self.scan_token()?;
467            if token != Token::Newline {
468                // Skip newlines for now
469                tokens.push_back(token);
470            }
471        }
472
473        tokens.push_back(Token::Eof);
474        Ok(tokens)
475    }
476
477    /// Scan next token
478    fn scan_token(&mut self) -> SklResult<Token> {
479        let c = self.advance();
480
481        match c {
482            '{' => Ok(Token::LeftBrace),
483            '}' => Ok(Token::RightBrace),
484            '[' => Ok(Token::LeftBracket),
485            ']' => Ok(Token::RightBracket),
486            '(' => Ok(Token::LeftParen),
487            ')' => Ok(Token::RightParen),
488            '<' => Ok(Token::LeftAngle),
489            '>' => Ok(Token::RightAngle),
490            ',' => Ok(Token::Comma),
491            ':' => Ok(Token::Colon),
492            ';' => Ok(Token::Semicolon),
493            '.' => Ok(Token::Dot),
494            '\n' => {
495                self.line += 1;
496                self.column = 1;
497                Ok(Token::Newline)
498            }
499            '-' => {
500                if self.match_char('>') {
501                    Ok(Token::Arrow)
502                } else {
503                    self.scan_number()
504                }
505            }
506            '"' => self.scan_string(),
507            _ if c.is_ascii_digit() => {
508                self.position -= 1; // Back up to scan full number
509                self.column -= 1;
510                self.scan_number()
511            }
512            _ if c.is_ascii_alphabetic() || c == '_' => {
513                self.position -= 1; // Back up to scan full identifier
514                self.column -= 1;
515                self.scan_identifier()
516            }
517            _ => Err(SklearsError::InvalidInput(format!(
518                "Unexpected character '{}' at line {}, column {}",
519                c, self.line, self.column
520            ))),
521        }
522    }
523
524    /// Scan string literal
525    fn scan_string(&mut self) -> SklResult<Token> {
526        let mut value = String::new();
527
528        while !self.is_at_end() && self.peek() != '"' {
529            if self.peek() == '\n' {
530                self.line += 1;
531                self.column = 1;
532            }
533            value.push(self.advance());
534        }
535
536        if self.is_at_end() {
537            return Err(SklearsError::InvalidInput(format!(
538                "Unterminated string at line {}",
539                self.line
540            )));
541        }
542
543        // Consume closing quote
544        self.advance();
545
546        Ok(Token::StringLiteral(value))
547    }
548
549    /// Scan number literal
550    fn scan_number(&mut self) -> SklResult<Token> {
551        let mut value = String::new();
552
553        // Handle negative numbers
554        if self.peek() == '-' {
555            value.push(self.advance());
556        }
557
558        while !self.is_at_end() && (self.peek().is_ascii_digit() || self.peek() == '.') {
559            value.push(self.advance());
560        }
561
562        match value.parse::<f64>() {
563            Ok(number) => Ok(Token::NumberLiteral(number)),
564            Err(_) => Err(SklearsError::InvalidInput(format!(
565                "Invalid number '{}' at line {}, column {}",
566                value, self.line, self.column
567            ))),
568        }
569    }
570
571    /// Scan identifier or keyword
572    fn scan_identifier(&mut self) -> SklResult<Token> {
573        let mut value = String::new();
574
575        while !self.is_at_end() && (self.peek().is_ascii_alphanumeric() || self.peek() == '_') {
576            value.push(self.advance());
577        }
578
579        // Check for keywords
580        let token = match value.as_str() {
581            "pipeline" => Token::Pipeline,
582            "version" => Token::Version,
583            "author" => Token::Author,
584            "description" => Token::Description,
585            "input" => Token::Input,
586            "output" => Token::Output,
587            "step" => Token::Step,
588            "flow" => Token::Flow,
589            "execute" => Token::Execute,
590            "Matrix" => Token::Matrix,
591            "Array" => Token::Array,
592            "f32" => Token::Float32,
593            "f64" => Token::Float64,
594            "i32" => Token::Int32,
595            "i64" => Token::Int64,
596            "bool" => Token::Bool,
597            "String" => Token::String,
598            "true" => Token::BooleanLiteral(true),
599            "false" => Token::BooleanLiteral(false),
600            _ => Token::Identifier(value),
601        };
602
603        Ok(token)
604    }
605
606    /// Skip whitespace characters
607    fn skip_whitespace(&mut self) {
608        while !self.is_at_end() {
609            match self.peek() {
610                ' ' | '\r' | '\t' => {
611                    self.advance();
612                }
613                '/' if self.peek_next() == '/' => {
614                    // Line comment
615                    while !self.is_at_end() && self.peek() != '\n' {
616                        self.advance();
617                    }
618                }
619                _ => break,
620            }
621        }
622    }
623
624    /// Advance to next character
625    fn advance(&mut self) -> char {
626        if let Some(&c) = self.input_chars.get(self.position) {
627            self.position += 1;
628            self.column += 1;
629            c
630        } else {
631            '\0'
632        }
633    }
634
635    /// Peek at current character
636    fn peek(&self) -> char {
637        self.input_chars.get(self.position).copied().unwrap_or('\0')
638    }
639
640    /// Peek at next character
641    fn peek_next(&self) -> char {
642        self.input_chars
643            .get(self.position + 1)
644            .copied()
645            .unwrap_or('\0')
646    }
647
648    /// Check if character matches expected
649    fn match_char(&mut self, expected: char) -> bool {
650        if self.is_at_end() || self.peek() != expected {
651            false
652        } else {
653            self.advance();
654            true
655        }
656    }
657
658    /// Check if at end of input
659    fn is_at_end(&self) -> bool {
660        self.position >= self.input_chars.len()
661    }
662}
663
664impl DslParser {
665    /// Create a new parser
666    #[must_use]
667    pub fn new() -> Self {
668        Self {
669            tokens: VecDeque::new(),
670            current: 0,
671        }
672    }
673
674    /// Parse tokens into workflow definition
675    pub fn parse(&mut self, tokens: VecDeque<Token>) -> SklResult<WorkflowDefinition> {
676        self.tokens = tokens;
677        self.current = 0;
678
679        self.parse_pipeline()
680    }
681
682    /// Parse pipeline definition
683    fn parse_pipeline(&mut self) -> SklResult<WorkflowDefinition> {
684        self.consume(Token::Pipeline, "Expected 'pipeline'")?;
685
686        let name = if let Token::StringLiteral(name) = self.advance() {
687            name
688        } else {
689            return Err(SklearsError::InvalidInput(
690                "Expected pipeline name".to_string(),
691            ));
692        };
693
694        self.consume(Token::LeftBrace, "Expected '{' after pipeline name")?;
695
696        let mut workflow = WorkflowDefinition::default();
697        workflow.metadata.name = name;
698
699        // Parse pipeline body
700        while !self.check(&Token::RightBrace) && !self.is_at_end() {
701            match &self.peek() {
702                Token::Version => {
703                    self.advance();
704                    if let Token::StringLiteral(version) = self.advance() {
705                        workflow.metadata.version = version;
706                    }
707                }
708                Token::Author => {
709                    self.advance();
710                    if let Token::StringLiteral(author) = self.advance() {
711                        workflow.metadata.author = Some(author);
712                    }
713                }
714                Token::Description => {
715                    self.advance();
716                    if let Token::StringLiteral(description) = self.advance() {
717                        workflow.metadata.description = Some(description);
718                    }
719                }
720                Token::Input => {
721                    workflow.inputs.push(self.parse_input()?);
722                }
723                Token::Output => {
724                    workflow.outputs.push(self.parse_output()?);
725                }
726                Token::Step => {
727                    workflow.steps.push(self.parse_step()?);
728                }
729                Token::Flow => {
730                    workflow.connections.push(self.parse_flow()?);
731                }
732                Token::Execute => {
733                    workflow.execution = self.parse_execute()?;
734                }
735                _ => {
736                    return Err(SklearsError::InvalidInput(format!(
737                        "Unexpected token: {:?}",
738                        self.peek()
739                    )));
740                }
741            }
742        }
743
744        self.consume(Token::RightBrace, "Expected '}' after pipeline body")?;
745
746        Ok(workflow)
747    }
748
749    /// Parse input definition
750    fn parse_input(&mut self) -> SklResult<InputDefinition> {
751        self.consume(Token::Input, "Expected 'input'")?;
752
753        let name = if let Token::Identifier(name) = self.advance() {
754            name
755        } else {
756            return Err(SklearsError::InvalidInput(
757                "Expected input name".to_string(),
758            ));
759        };
760
761        self.consume(Token::Colon, "Expected ':' after input name")?;
762
763        let data_type = self.parse_data_type()?;
764
765        Ok(InputDefinition {
766            name,
767            data_type,
768            shape: None,
769            constraints: None,
770            description: None,
771        })
772    }
773
774    /// Parse output definition
775    fn parse_output(&mut self) -> SklResult<OutputDefinition> {
776        self.consume(Token::Output, "Expected 'output'")?;
777
778        let name = if let Token::Identifier(name) = self.advance() {
779            name
780        } else {
781            return Err(SklearsError::InvalidInput(
782                "Expected output name".to_string(),
783            ));
784        };
785
786        Ok(OutputDefinition {
787            name,
788            data_type: DataType::Float64, // Default type
789            shape: None,
790            description: None,
791        })
792    }
793
794    /// Parse step definition
795    fn parse_step(&mut self) -> SklResult<StepDefinition> {
796        self.consume(Token::Step, "Expected 'step'")?;
797
798        let id = if let Token::Identifier(id) = self.advance() {
799            id
800        } else {
801            return Err(SklearsError::InvalidInput(
802                "Expected step identifier".to_string(),
803            ));
804        };
805
806        self.consume(Token::Colon, "Expected ':' after step identifier")?;
807
808        let algorithm = if let Token::Identifier(algorithm) = self.advance() {
809            algorithm
810        } else {
811            return Err(SklearsError::InvalidInput(
812                "Expected algorithm name".to_string(),
813            ));
814        };
815
816        let mut parameters = BTreeMap::new();
817
818        if self.check(&Token::LeftBrace) {
819            self.advance(); // consume '{'
820
821            while !self.check(&Token::RightBrace) && !self.is_at_end() {
822                let param_name = if let Token::Identifier(name) = self.advance() {
823                    name
824                } else {
825                    return Err(SklearsError::InvalidInput(
826                        "Expected parameter name".to_string(),
827                    ));
828                };
829
830                self.consume(Token::Colon, "Expected ':' after parameter name")?;
831
832                let param_value = self.parse_parameter_value()?;
833                parameters.insert(param_name, param_value);
834
835                if self.check(&Token::Comma) {
836                    self.advance();
837                }
838            }
839
840            self.consume(Token::RightBrace, "Expected '}' after step parameters")?;
841        }
842
843        Ok(StepDefinition {
844            id,
845            step_type: StepType::Custom("DSL".to_string()),
846            algorithm,
847            parameters,
848            inputs: Vec::new(),
849            outputs: Vec::new(),
850            condition: None,
851            description: None,
852        })
853    }
854
855    /// Parse flow/connection definition
856    fn parse_flow(&mut self) -> SklResult<Connection> {
857        self.consume(Token::Flow, "Expected 'flow'")?;
858
859        // Parse source (step.output)
860        let from_step = if let Token::Identifier(step) = self.advance() {
861            step
862        } else {
863            return Err(SklearsError::InvalidInput(
864                "Expected source step".to_string(),
865            ));
866        };
867
868        self.consume(Token::Dot, "Expected '.' after source step")?;
869
870        let from_output = match self.advance() {
871            Token::Identifier(output) => output,
872            Token::Output => "output".to_string(),
873            Token::Input => "input".to_string(),
874            other => {
875                return Err(SklearsError::InvalidInput(format!(
876                    "Expected source output, got {:?}",
877                    other
878                )))
879            }
880        };
881
882        self.consume(Token::Arrow, "Expected '->' in flow")?;
883
884        // Parse target (step.input)
885        let to_step = if let Token::Identifier(step) = self.advance() {
886            step
887        } else {
888            return Err(SklearsError::InvalidInput(
889                "Expected target step".to_string(),
890            ));
891        };
892
893        self.consume(Token::Dot, "Expected '.' after target step")?;
894
895        let to_input = match self.advance() {
896            Token::Identifier(input) => input,
897            Token::Input => "input".to_string(),
898            other => {
899                return Err(SklearsError::InvalidInput(format!(
900                    "Expected target input, got {:?}",
901                    other
902                )))
903            }
904        };
905
906        Ok(Connection {
907            from_step,
908            from_output,
909            to_step,
910            to_input,
911            connection_type: ConnectionType::Direct,
912            transform: None,
913        })
914    }
915
916    /// Parse execution configuration
917    fn parse_execute(&mut self) -> SklResult<ExecutionConfig> {
918        self.consume(Token::Execute, "Expected 'execute'")?;
919        self.consume(Token::LeftBrace, "Expected '{' after 'execute'")?;
920
921        let mut config = ExecutionConfig {
922            mode: ExecutionMode::Sequential,
923            parallel: None,
924            resources: None,
925            caching: None,
926        };
927
928        while !self.check(&Token::RightBrace) && !self.is_at_end() {
929            let key = if let Token::Identifier(key) = self.advance() {
930                key
931            } else {
932                return Err(SklearsError::InvalidInput(
933                    "Expected configuration key".to_string(),
934                ));
935            };
936
937            self.consume(Token::Colon, "Expected ':' after configuration key")?;
938
939            match key.as_str() {
940                "mode" => {
941                    if let Token::Identifier(mode) = self.advance() {
942                        config.mode = match mode.as_str() {
943                            "parallel" => ExecutionMode::Parallel,
944                            "sequential" => ExecutionMode::Sequential,
945                            "distributed" => ExecutionMode::Distributed,
946                            "gpu" => ExecutionMode::GPU,
947                            "adaptive" => ExecutionMode::Adaptive,
948                            _ => ExecutionMode::Sequential,
949                        };
950                    }
951                }
952                "workers" => {
953                    if let Token::NumberLiteral(workers) = self.advance() {
954                        config.parallel = Some(ParallelConfig {
955                            num_workers: workers as usize,
956                            chunk_size: None,
957                            load_balancing: "round_robin".to_string(),
958                        });
959                    }
960                }
961                _ => {
962                    // Skip unknown keys
963                    self.advance();
964                }
965            }
966
967            if self.check(&Token::Comma) {
968                self.advance();
969            }
970        }
971
972        self.consume(
973            Token::RightBrace,
974            "Expected '}' after execute configuration",
975        )?;
976
977        Ok(config)
978    }
979
980    /// Parse data type
981    fn parse_data_type(&mut self) -> SklResult<DataType> {
982        match &self.advance() {
983            Token::Float32 => Ok(DataType::Float32),
984            Token::Float64 => Ok(DataType::Float64),
985            Token::Int32 => Ok(DataType::Int32),
986            Token::Int64 => Ok(DataType::Int64),
987            Token::Bool => Ok(DataType::Boolean),
988            Token::String => Ok(DataType::String),
989            Token::Array => {
990                self.consume(Token::LeftAngle, "Expected '<' after 'Array'")?;
991                let inner_type = self.parse_data_type()?;
992                self.consume(Token::RightAngle, "Expected '>' after array type")?;
993                Ok(DataType::Array(Box::new(inner_type)))
994            }
995            Token::Matrix => {
996                self.consume(Token::LeftAngle, "Expected '<' after 'Matrix'")?;
997                let inner_type = self.parse_data_type()?;
998                self.consume(Token::RightAngle, "Expected '>' after matrix type")?;
999                Ok(DataType::Matrix(Box::new(inner_type)))
1000            }
1001            Token::Identifier(name) => Ok(DataType::Custom(name.clone())),
1002            _ => Err(SklearsError::InvalidInput("Expected data type".to_string())),
1003        }
1004    }
1005
1006    /// Parse parameter value
1007    fn parse_parameter_value(&mut self) -> SklResult<ParameterValue> {
1008        match &self.advance() {
1009            Token::NumberLiteral(n) => {
1010                if n.fract() == 0.0 {
1011                    Ok(ParameterValue::Int(*n as i64))
1012                } else {
1013                    Ok(ParameterValue::Float(*n))
1014                }
1015            }
1016            Token::BooleanLiteral(b) => Ok(ParameterValue::Bool(*b)),
1017            Token::StringLiteral(s) => Ok(ParameterValue::String(s.clone())),
1018            _ => Err(SklearsError::InvalidInput(
1019                "Expected parameter value".to_string(),
1020            )),
1021        }
1022    }
1023
1024    /// Consume expected token
1025    fn consume(&mut self, expected: Token, message: &str) -> SklResult<Token> {
1026        if self.check(&expected) {
1027            Ok(self.advance())
1028        } else {
1029            Err(SklearsError::InvalidInput(format!(
1030                "{}, got {:?}",
1031                message,
1032                self.peek()
1033            )))
1034        }
1035    }
1036
1037    /// Check if current token matches expected
1038    fn check(&self, token_type: &Token) -> bool {
1039        if self.is_at_end() {
1040            false
1041        } else {
1042            std::mem::discriminant(&self.peek()) == std::mem::discriminant(token_type)
1043        }
1044    }
1045
1046    /// Advance to next token
1047    fn advance(&mut self) -> Token {
1048        if !self.is_at_end() {
1049            self.current += 1;
1050        }
1051        self.previous()
1052    }
1053
1054    /// Get current token
1055    fn peek(&self) -> Token {
1056        if let Some(token) = self.tokens.get(self.current) {
1057            token.clone()
1058        } else {
1059            Token::Eof
1060        }
1061    }
1062
1063    /// Get previous token
1064    fn previous(&self) -> Token {
1065        if self.current > 0 {
1066            if let Some(token) = self.tokens.get(self.current - 1) {
1067                token.clone()
1068            } else {
1069                Token::Eof
1070            }
1071        } else {
1072            Token::Eof
1073        }
1074    }
1075
1076    /// Check if at end of tokens
1077    fn is_at_end(&self) -> bool {
1078        matches!(self.peek(), Token::Eof)
1079    }
1080}
1081
1082impl std::fmt::Display for DslError {
1083    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1084        write!(
1085            f,
1086            "DSL Error at line {}, column {}: {}",
1087            self.line, self.column, self.message
1088        )
1089    }
1090}
1091
1092impl std::error::Error for DslError {}
1093
1094impl Default for PipelineDSL {
1095    fn default() -> Self {
1096        Self::new()
1097    }
1098}
1099
1100impl Default for DslLexer {
1101    fn default() -> Self {
1102        Self::new()
1103    }
1104}
1105
1106impl Default for DslParser {
1107    fn default() -> Self {
1108        Self::new()
1109    }
1110}
1111
1112/// Abstract Syntax Tree node for DSL parsing
1113#[derive(Debug, Clone, Serialize, Deserialize)]
1114pub enum AstNode {
1115    /// Pipeline node
1116    Pipeline {
1117        name: String,
1118        metadata: PipelineMetadata,
1119        children: Vec<AstNode>,
1120    },
1121    /// Step node
1122    Step {
1123        name: String,
1124        algorithm: String,
1125        parameters: Vec<AstNode>,
1126    },
1127    /// Connection node
1128    Connection {
1129        from: String,
1130        to: String,
1131        port_mapping: Vec<(String, String)>,
1132    },
1133    /// Parameter node
1134    Parameter { name: String, value: ParameterValue },
1135    /// Input definition node
1136    Input { name: String, data_type: DataType },
1137    /// Output definition node
1138    Output { name: String, data_type: DataType },
1139    /// Configuration node
1140    Config { key: String, value: String },
1141}
1142
1143/// Pipeline metadata for AST
1144#[derive(Debug, Clone, Serialize, Deserialize)]
1145pub struct PipelineMetadata {
1146    /// Pipeline version
1147    pub version: String,
1148    /// Pipeline description
1149    pub description: Option<String>,
1150    /// Pipeline author
1151    pub author: Option<String>,
1152}
1153
1154/// Auto completer for DSL editing
1155#[derive(Debug, Clone, Serialize, Deserialize)]
1156pub struct AutoCompleter {
1157    /// Available keywords
1158    pub keywords: Vec<String>,
1159    /// Available functions
1160    pub functions: Vec<String>,
1161    /// Available components
1162    pub components: Vec<String>,
1163    /// Context-sensitive suggestions
1164    pub context_suggestions: BTreeMap<String, Vec<String>>,
1165}
1166
1167/// DSL configuration settings
1168#[derive(Debug, Clone, Serialize, Deserialize)]
1169pub struct DslConfig {
1170    /// Enable syntax highlighting
1171    pub syntax_highlighting: bool,
1172    /// Enable auto completion
1173    pub auto_completion: bool,
1174    /// Enable real-time validation
1175    pub real_time_validation: bool,
1176    /// Indentation size
1177    pub indent_size: usize,
1178    /// Maximum line length
1179    pub max_line_length: usize,
1180    /// Comment style
1181    pub comment_style: CommentStyle,
1182}
1183
1184/// Comment styles for DSL
1185#[derive(Debug, Clone, Serialize, Deserialize)]
1186pub enum CommentStyle {
1187    /// Single line comments with //
1188    SingleLine,
1189    /// Multi-line comments with /* */
1190    MultiLine,
1191    /// Both styles supported
1192    Both,
1193}
1194
1195/// Lexical analysis error
1196#[derive(Debug, Clone, Serialize, Deserialize)]
1197pub enum LexError {
1198    /// Unexpected character
1199    UnexpectedCharacter(char, usize, usize),
1200    /// Unterminated string
1201    UnterminatedString(usize, usize),
1202    /// Invalid number format
1203    InvalidNumber(String, usize, usize),
1204    /// Invalid escape sequence
1205    InvalidEscape(String, usize, usize),
1206    /// EOF reached unexpectedly
1207    UnexpectedEof(usize, usize),
1208}
1209
1210/// Parse error types
1211#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
1212pub enum ParseError {
1213    /// Unexpected token
1214    #[error("Unexpected token '{0}' at line {1}, column {2}")]
1215    UnexpectedToken(String, usize, usize),
1216    /// Missing token
1217    #[error("Missing token '{0}' at line {1}, column {2}")]
1218    MissingToken(String, usize, usize),
1219    /// Invalid syntax
1220    #[error("Invalid syntax: {0} at line {1}, column {2}")]
1221    InvalidSyntax(String, usize, usize),
1222    /// Semantic error
1223    #[error("Semantic error: {0} at line {1}, column {2}")]
1224    SemanticError(String, usize, usize),
1225    /// Unknown identifier
1226    #[error("Unknown identifier '{0}' at line {1}, column {2}")]
1227    UnknownIdentifier(String, usize, usize),
1228    /// Type mismatch
1229    #[error("Type mismatch: expected {1}, found {0} at line {2}, column {3}")]
1230    TypeMismatch(String, String, usize, usize),
1231}
1232
1233/// Type alias for parse results
1234pub type ParseResult<T> = Result<T, ParseError>;
1235
1236/// Semantic analyzer for DSL validation
1237#[derive(Debug, Clone, Serialize, Deserialize)]
1238pub struct SemanticAnalyzer {
1239    /// Symbol table
1240    pub symbol_table: SymbolTable,
1241    /// Type checker
1242    pub type_checker: TypeChecker,
1243    /// Validation rules
1244    pub rules: Vec<SemanticRule>,
1245    /// Error collector
1246    pub errors: Vec<ParseError>,
1247}
1248
1249/// Semantic validation rules
1250#[derive(Debug, Clone, Serialize, Deserialize)]
1251pub struct SemanticRule {
1252    /// Rule name
1253    pub name: String,
1254    /// Rule description
1255    pub description: String,
1256    /// Rule severity
1257    pub severity: RuleSeverity,
1258    /// Rule checker function name
1259    pub checker: String,
1260}
1261
1262/// Rule severity levels
1263#[derive(Debug, Clone, Serialize, Deserialize)]
1264pub enum RuleSeverity {
1265    /// Error - compilation fails
1266    Error,
1267    /// Warning - compilation succeeds with warning
1268    Warning,
1269    /// Info - informational message
1270    Info,
1271}
1272
1273/// Symbol table for identifier tracking
1274#[derive(Debug, Clone, Serialize, Deserialize)]
1275pub struct SymbolTable {
1276    /// Defined symbols
1277    pub symbols: BTreeMap<String, Symbol>,
1278    /// Scope stack
1279    pub scopes: Vec<Scope>,
1280    /// Current scope level
1281    pub current_scope: usize,
1282}
1283
1284/// Symbol definition
1285#[derive(Debug, Clone, Serialize, Deserialize)]
1286pub struct Symbol {
1287    /// Symbol name
1288    pub name: String,
1289    /// Symbol type
1290    pub symbol_type: SymbolType,
1291    /// Data type
1292    pub data_type: DataType,
1293    /// Scope level
1294    pub scope: usize,
1295    /// Line number where defined
1296    pub line: usize,
1297    /// Column number where defined
1298    pub column: usize,
1299}
1300
1301/// Symbol types
1302#[derive(Debug, Clone, Serialize, Deserialize)]
1303pub enum SymbolType {
1304    /// Variable symbol
1305    Variable,
1306    /// Function symbol
1307    Function,
1308    /// Type symbol
1309    Type,
1310    /// Constant symbol
1311    Constant,
1312    /// Step symbol
1313    Step,
1314    /// Parameter symbol
1315    Parameter,
1316}
1317
1318/// Scope definition
1319#[derive(Debug, Clone, Serialize, Deserialize)]
1320pub struct Scope {
1321    /// Scope name
1322    pub name: String,
1323    /// Parent scope
1324    pub parent: Option<usize>,
1325    /// Symbols in this scope
1326    pub symbols: Vec<String>,
1327}
1328
1329/// Type alias for syntax highlighter
1330pub type SyntaxHighlighter = SyntaxHighlight;
1331
1332/// Token types for lexical analysis
1333#[derive(Debug, Clone, Serialize, Deserialize)]
1334pub enum TokenType {
1335    /// Keyword tokens
1336    Keyword(String),
1337    /// Identifier tokens
1338    Identifier(String),
1339    /// Number literal tokens
1340    Number(f64),
1341    /// String literal tokens
1342    String(String),
1343    /// Boolean literal tokens
1344    Boolean(bool),
1345    /// Operator tokens
1346    Operator(String),
1347    /// Punctuation tokens
1348    Punctuation(char),
1349    /// Comment tokens
1350    Comment(String),
1351    /// Whitespace tokens
1352    Whitespace(String),
1353    /// End of file token
1354    Eof,
1355}
1356
1357/// Type checker for semantic analysis
1358#[derive(Debug, Clone, Serialize, Deserialize)]
1359pub struct TypeChecker {
1360    /// Type rules
1361    pub rules: Vec<TypeRule>,
1362    /// Known types
1363    pub types: BTreeMap<String, TypeInfo>,
1364    /// Type coercion rules
1365    pub coercion_rules: Vec<CoercionRule>,
1366}
1367
1368/// Type checking rules
1369#[derive(Debug, Clone, Serialize, Deserialize)]
1370pub struct TypeRule {
1371    /// Rule name
1372    pub name: String,
1373    /// Source type
1374    pub source_type: DataType,
1375    /// Target type
1376    pub target_type: DataType,
1377    /// Rule checker
1378    pub checker: String,
1379}
1380
1381/// Type information
1382#[derive(Debug, Clone, Serialize, Deserialize)]
1383pub struct TypeInfo {
1384    /// Type name
1385    pub name: String,
1386    /// Base type
1387    pub base_type: DataType,
1388    /// Type constraints
1389    pub constraints: Vec<String>,
1390    /// Type metadata
1391    pub metadata: BTreeMap<String, String>,
1392}
1393
1394/// Type coercion rules
1395#[derive(Debug, Clone, Serialize, Deserialize)]
1396pub struct CoercionRule {
1397    /// From type
1398    pub from_type: DataType,
1399    /// To type
1400    pub to_type: DataType,
1401    /// Coercion cost
1402    pub cost: u32,
1403    /// Coercion function
1404    pub coercion_fn: String,
1405}
1406
1407impl Default for DslConfig {
1408    fn default() -> Self {
1409        Self {
1410            syntax_highlighting: true,
1411            auto_completion: true,
1412            real_time_validation: true,
1413            indent_size: 4,
1414            max_line_length: 100,
1415            comment_style: CommentStyle::Both,
1416        }
1417    }
1418}
1419
1420impl AutoCompleter {
1421    /// Create a new auto completer with default suggestions
1422    #[must_use]
1423    pub fn new() -> Self {
1424        let mut completer = Self {
1425            keywords: vec![
1426                "pipeline".to_string(),
1427                "step".to_string(),
1428                "connect".to_string(),
1429                "input".to_string(),
1430                "output".to_string(),
1431                "execute".to_string(),
1432                "version".to_string(),
1433            ],
1434            functions: vec![
1435                "transform".to_string(),
1436                "fit".to_string(),
1437                "predict".to_string(),
1438                "evaluate".to_string(),
1439            ],
1440            components: vec![
1441                "StandardScaler".to_string(),
1442                "LinearRegression".to_string(),
1443                "RandomForest".to_string(),
1444                "SVM".to_string(),
1445            ],
1446            context_suggestions: BTreeMap::new(),
1447        };
1448
1449        // Add context-specific suggestions
1450        completer
1451            .context_suggestions
1452            .insert("step".to_string(), completer.components.clone());
1453
1454        completer
1455    }
1456
1457    /// Get suggestions for a given context
1458    #[must_use]
1459    pub fn get_suggestions(&self, context: &str, prefix: &str) -> Vec<String> {
1460        let mut suggestions = Vec::new();
1461
1462        // Add keyword suggestions
1463        for keyword in &self.keywords {
1464            if keyword.starts_with(prefix) {
1465                suggestions.push(keyword.clone());
1466            }
1467        }
1468
1469        // Add context-specific suggestions
1470        if let Some(context_suggestions) = self.context_suggestions.get(context) {
1471            for suggestion in context_suggestions {
1472                if suggestion.starts_with(prefix) {
1473                    suggestions.push(suggestion.clone());
1474                }
1475            }
1476        }
1477
1478        suggestions.sort();
1479        suggestions.dedup();
1480        suggestions
1481    }
1482}
1483
1484impl Default for AutoCompleter {
1485    fn default() -> Self {
1486        Self::new()
1487    }
1488}
1489
1490impl SymbolTable {
1491    /// Create a new symbol table
1492    #[must_use]
1493    pub fn new() -> Self {
1494        Self {
1495            symbols: BTreeMap::new(),
1496            scopes: vec![Scope {
1497                name: "global".to_string(),
1498                parent: None,
1499                symbols: Vec::new(),
1500            }],
1501            current_scope: 0,
1502        }
1503    }
1504
1505    /// Add a symbol to the current scope
1506    pub fn add_symbol(&mut self, symbol: Symbol) {
1507        self.symbols.insert(symbol.name.clone(), symbol.clone());
1508        if let Some(scope) = self.scopes.get_mut(self.current_scope) {
1509            scope.symbols.push(symbol.name);
1510        }
1511    }
1512
1513    /// Look up a symbol
1514    #[must_use]
1515    pub fn lookup(&self, name: &str) -> Option<&Symbol> {
1516        self.symbols.get(name)
1517    }
1518}
1519
1520impl Default for SymbolTable {
1521    fn default() -> Self {
1522        Self::new()
1523    }
1524}
1525
1526impl TypeChecker {
1527    /// Create a new type checker
1528    #[must_use]
1529    pub fn new() -> Self {
1530        Self {
1531            rules: Vec::new(),
1532            types: BTreeMap::new(),
1533            coercion_rules: Vec::new(),
1534        }
1535    }
1536
1537    /// Check if two types are compatible
1538    #[must_use]
1539    pub fn are_compatible(&self, type1: &DataType, type2: &DataType) -> bool {
1540        type1 == type2 || self.can_coerce(type1, type2)
1541    }
1542
1543    /// Check if type can be coerced
1544    #[must_use]
1545    pub fn can_coerce(&self, from: &DataType, to: &DataType) -> bool {
1546        self.coercion_rules
1547            .iter()
1548            .any(|rule| rule.from_type == *from && rule.to_type == *to)
1549    }
1550}
1551
1552impl Default for TypeChecker {
1553    fn default() -> Self {
1554        Self::new()
1555    }
1556}
1557
1558#[allow(non_snake_case)]
1559#[cfg(test)]
1560mod tests {
1561    use super::*;
1562
1563    #[test]
1564    fn test_lexer_basic_tokens() {
1565        let mut lexer = DslLexer::new();
1566        let tokens = lexer.tokenize("pipeline { }").unwrap();
1567
1568        assert_eq!(tokens[0], Token::Pipeline);
1569        assert_eq!(tokens[1], Token::LeftBrace);
1570        assert_eq!(tokens[2], Token::RightBrace);
1571        assert_eq!(tokens[3], Token::Eof);
1572    }
1573
1574    #[test]
1575    fn test_lexer_string_literal() {
1576        let mut lexer = DslLexer::new();
1577        let tokens = lexer.tokenize("\"hello world\"").unwrap();
1578
1579        if let Token::StringLiteral(s) = &tokens[0] {
1580            assert_eq!(s, "hello world");
1581        } else {
1582            panic!("Expected string literal");
1583        }
1584    }
1585
1586    #[test]
1587    fn test_lexer_number_literal() {
1588        let mut lexer = DslLexer::new();
1589        let tokens = lexer.tokenize("42.5").unwrap();
1590
1591        if let Token::NumberLiteral(n) = &tokens[0] {
1592            assert_eq!(*n, 42.5);
1593        } else {
1594            panic!("Expected number literal");
1595        }
1596    }
1597
1598    #[test]
1599    fn test_parser_simple_pipeline() {
1600        let mut dsl = PipelineDSL::new();
1601        let input = r#"
1602            pipeline "Test Pipeline" {
1603                version "1.0.0"
1604                step scaler: StandardScaler {
1605                    with_mean: true
1606                }
1607            }
1608        "#;
1609
1610        let workflow = dsl.parse(input).unwrap();
1611        assert_eq!(workflow.metadata.name, "Test Pipeline");
1612        assert_eq!(workflow.metadata.version, "1.0.0");
1613        assert_eq!(workflow.steps.len(), 1);
1614        assert_eq!(workflow.steps[0].algorithm, "StandardScaler");
1615    }
1616
1617    #[test]
1618    fn test_dsl_generation() {
1619        let mut workflow = WorkflowDefinition::default();
1620        workflow.metadata.name = "Test Workflow".to_string();
1621        workflow.metadata.version = "1.0.0".to_string();
1622
1623        let step = StepDefinition::new("scaler", StepType::Transformer, "StandardScaler")
1624            .with_parameter("with_mean", ParameterValue::Bool(true));
1625        workflow.steps.push(step);
1626
1627        let dsl = PipelineDSL::new();
1628        let generated = dsl.generate(&workflow);
1629
1630        assert!(generated.contains("pipeline \"Test Workflow\""));
1631        assert!(generated.contains("version \"1.0.0\""));
1632        assert!(generated.contains("step scaler: StandardScaler"));
1633        assert!(generated.contains("with_mean: true"));
1634    }
1635
1636    #[test]
1637    fn test_syntax_validation() {
1638        let mut dsl = PipelineDSL::new();
1639
1640        // Valid syntax
1641        let valid_input = r#"
1642            pipeline "Valid" {
1643                version "1.0.0"
1644            }
1645        "#;
1646        let errors = dsl.validate_syntax(valid_input).unwrap();
1647        assert!(errors.is_empty());
1648
1649        // Invalid syntax
1650        let invalid_input = r#"
1651            pipeline "Invalid" {
1652                version // missing value
1653            }
1654        "#;
1655        let errors = dsl.validate_syntax(invalid_input).unwrap();
1656        assert!(!errors.is_empty());
1657    }
1658}