1use serde::{Deserialize, Serialize};
40use sklears_core::error::{Result as SklResult, SklearsError};
41use std::collections::{BTreeMap, VecDeque};
42
43use super::workflow_definitions::{
44 Connection, ConnectionType, DataType, ExecutionConfig, ExecutionMode, InputDefinition,
45 OutputDefinition, ParallelConfig, ParameterValue, StepDefinition, StepType, WorkflowDefinition,
46};
47
48#[derive(Debug)]
50pub struct PipelineDSL {
51 lexer: DslLexer,
53 parser: DslParser,
55}
56
57#[derive(Debug)]
59pub struct DslLexer {
60 input: String,
62 input_chars: Vec<char>,
64 position: usize,
66 line: usize,
68 column: usize,
70}
71
72#[derive(Debug)]
74pub struct DslParser {
75 tokens: VecDeque<Token>,
77 current: usize,
79}
80
81#[derive(Debug, Clone, PartialEq)]
83pub enum Token {
84 Pipeline,
87 Version,
89 Author,
91 Description,
93 Input,
95 Output,
97 Step,
99 Flow,
101 Execute,
103
104 Matrix,
107 Array,
109 Float32,
111 Float64,
113 Int32,
115 Int64,
117 Bool,
119 String,
121
122 Identifier(String),
125 StringLiteral(String),
127 NumberLiteral(f64),
129 BooleanLiteral(bool),
131
132 LeftBrace,
135 RightBrace,
137 LeftBracket,
139 RightBracket,
141 LeftParen,
143 RightParen,
145 LeftAngle,
147 RightAngle,
149 Comma,
151 Colon,
153 Semicolon,
155 Dot,
157 Arrow,
159
160 Newline,
163 Eof,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct DslError {
170 pub message: String,
172 pub line: usize,
174 pub column: usize,
176}
177
178impl PipelineDSL {
179 #[must_use]
181 pub fn new() -> Self {
182 Self {
183 lexer: DslLexer::new(),
184 parser: DslParser::new(),
185 }
186 }
187
188 pub fn parse(&mut self, input: &str) -> SklResult<WorkflowDefinition> {
190 let tokens = self.lexer.tokenize(input)?;
192
193 self.parser.parse(tokens)
195 }
196
197 #[must_use]
199 pub fn generate(&self, workflow: &WorkflowDefinition) -> String {
200 let mut dsl = String::new();
201
202 dsl.push_str(&format!("pipeline \"{}\" {{\n", workflow.metadata.name));
204 dsl.push_str(&format!(" version \"{}\"\n", workflow.metadata.version));
205
206 if let Some(author) = &workflow.metadata.author {
207 dsl.push_str(&format!(" author \"{author}\"\n"));
208 }
209
210 if let Some(description) = &workflow.metadata.description {
211 dsl.push_str(&format!(" description \"{description}\"\n"));
212 }
213
214 dsl.push('\n');
215
216 for input in &workflow.inputs {
218 dsl.push_str(&format!(
219 " input {}: {}\n",
220 input.name,
221 self.format_data_type(&input.data_type)
222 ));
223 }
224
225 if !workflow.inputs.is_empty() {
226 dsl.push('\n');
227 }
228
229 for step in &workflow.steps {
231 dsl.push_str(&format!(" step {}: {} {{\n", step.id, step.algorithm));
232
233 for (param_name, param_value) in &step.parameters {
234 dsl.push_str(&format!(
235 " {}: {},\n",
236 param_name,
237 self.format_parameter_value(param_value)
238 ));
239 }
240
241 dsl.push_str(" }\n\n");
242 }
243
244 for connection in &workflow.connections {
246 dsl.push_str(&format!(
247 " flow {}.{} -> {}.{}\n",
248 connection.from_step,
249 connection.from_output,
250 connection.to_step,
251 connection.to_input
252 ));
253 }
254
255 if !workflow.connections.is_empty() {
256 dsl.push('\n');
257 }
258
259 for output in &workflow.outputs {
261 dsl.push_str(&format!(" output {}\n", output.name));
262 }
263
264 if !workflow.outputs.is_empty() {
265 dsl.push('\n');
266 }
267
268 if workflow.execution.mode != ExecutionMode::Sequential {
270 dsl.push_str(" execute {\n");
271 dsl.push_str(&format!(" mode: {:?},\n", workflow.execution.mode));
272
273 if let Some(parallel_config) = &workflow.execution.parallel {
274 dsl.push_str(&format!(
275 " workers: {}\n",
276 parallel_config.num_workers
277 ));
278 }
279
280 dsl.push_str(" }\n");
281 }
282
283 dsl.push_str("}\n");
284 dsl
285 }
286
287 fn format_data_type(&self, data_type: &DataType) -> String {
289 match data_type {
290 DataType::Float32 => "f32".to_string(),
291 DataType::Float64 => "f64".to_string(),
292 DataType::Int32 => "i32".to_string(),
293 DataType::Int64 => "i64".to_string(),
294 DataType::Boolean => "bool".to_string(),
295 DataType::String => "String".to_string(),
296 DataType::Array(inner) => format!("Array<{}>", self.format_data_type(inner)),
297 DataType::Matrix(inner) => format!("Matrix<{}>", self.format_data_type(inner)),
298 DataType::Custom(name) => name.clone(),
299 }
300 }
301
302 fn format_parameter_value(&self, value: &ParameterValue) -> String {
304 match value {
305 ParameterValue::Float(f) => f.to_string(),
306 ParameterValue::Int(i) => i.to_string(),
307 ParameterValue::Bool(b) => b.to_string(),
308 ParameterValue::String(s) => format!("\"{s}\""),
309 ParameterValue::Array(arr) => {
310 let items: Vec<String> =
311 arr.iter().map(|v| self.format_parameter_value(v)).collect();
312 format!("[{}]", items.join(", "))
313 }
314 }
315 }
316
317 pub fn validate_syntax(&mut self, input: &str) -> SklResult<Vec<DslError>> {
319 let mut errors = Vec::new();
320
321 match self.lexer.tokenize(input) {
323 Ok(tokens) => {
324 match self.parser.parse(tokens) {
326 Ok(_) => {
327 }
329 Err(e) => {
330 errors.push(DslError {
331 message: e.to_string(),
332 line: 1, column: 1,
334 });
335 }
336 }
337 }
338 Err(e) => {
339 errors.push(DslError {
340 message: e.to_string(),
341 line: self.lexer.line,
342 column: self.lexer.column,
343 });
344 }
345 }
346
347 Ok(errors)
348 }
349
350 pub fn get_syntax_highlighting(&mut self, input: &str) -> Vec<SyntaxHighlight> {
352 let mut highlights = Vec::new();
353
354 if let Ok(tokens) = self.lexer.tokenize(input) {
355 let mut position = 0;
356
357 for token in tokens {
358 let (token_type, length) = match &token {
359 Token::Pipeline
360 | Token::Version
361 | Token::Author
362 | Token::Description
363 | Token::Input
364 | Token::Output
365 | Token::Step
366 | Token::Flow
367 | Token::Execute => ("keyword", self.estimate_token_length(&token)),
368 Token::Matrix
369 | Token::Array
370 | Token::Float32
371 | Token::Float64
372 | Token::Int32
373 | Token::Int64
374 | Token::Bool
375 | Token::String => ("type", self.estimate_token_length(&token)),
376 Token::StringLiteral(_) => ("string", self.estimate_token_length(&token)),
377 Token::NumberLiteral(_) => ("number", self.estimate_token_length(&token)),
378 Token::BooleanLiteral(_) => ("boolean", self.estimate_token_length(&token)),
379 Token::Identifier(_) => ("identifier", self.estimate_token_length(&token)),
380 _ => ("punctuation", self.estimate_token_length(&token)),
381 };
382
383 highlights.push(SyntaxHighlight {
384 start: position,
385 end: position + length,
386 token_type: token_type.to_string(),
387 });
388
389 position += length;
390 }
391 }
392
393 highlights
394 }
395
396 fn estimate_token_length(&self, token: &Token) -> usize {
398 match token {
399 Token::Identifier(s) | Token::StringLiteral(s) => s.len(),
400 Token::NumberLiteral(n) => n.to_string().len(),
401 Token::BooleanLiteral(b) => b.to_string().len(),
402 Token::Pipeline => "pipeline".len(),
403 Token::Version => "version".len(),
404 Token::Author => "author".len(),
405 Token::Description => "description".len(),
406 Token::Input => "input".len(),
407 Token::Output => "output".len(),
408 Token::Step => "step".len(),
409 Token::Flow => "flow".len(),
410 Token::Execute => "execute".len(),
411 Token::Matrix => "Matrix".len(),
412 Token::Array => "Array".len(),
413 Token::Float32 => "f32".len(),
414 Token::Float64 => "f64".len(),
415 Token::Int32 => "i32".len(),
416 Token::Int64 => "i64".len(),
417 Token::Bool => "bool".len(),
418 Token::String => "String".len(),
419 Token::Arrow => "->".len(),
420 _ => 1,
421 }
422 }
423}
424
425#[derive(Debug, Clone)]
427pub struct SyntaxHighlight {
428 pub start: usize,
430 pub end: usize,
432 pub token_type: String,
434}
435
436impl DslLexer {
437 #[must_use]
439 pub fn new() -> Self {
440 Self {
441 input: String::new(),
442 input_chars: Vec::new(),
443 position: 0,
444 line: 1,
445 column: 1,
446 }
447 }
448
449 pub fn tokenize(&mut self, input: &str) -> SklResult<VecDeque<Token>> {
451 self.input = input.to_string();
452 self.input_chars = self.input.chars().collect();
453 self.position = 0;
454 self.line = 1;
455 self.column = 1;
456
457 let mut tokens = VecDeque::new();
458
459 while !self.is_at_end() {
460 self.skip_whitespace();
461
462 if self.is_at_end() {
463 break;
464 }
465
466 let token = self.scan_token()?;
467 if token != Token::Newline {
468 tokens.push_back(token);
470 }
471 }
472
473 tokens.push_back(Token::Eof);
474 Ok(tokens)
475 }
476
477 fn scan_token(&mut self) -> SklResult<Token> {
479 let c = self.advance();
480
481 match c {
482 '{' => Ok(Token::LeftBrace),
483 '}' => Ok(Token::RightBrace),
484 '[' => Ok(Token::LeftBracket),
485 ']' => Ok(Token::RightBracket),
486 '(' => Ok(Token::LeftParen),
487 ')' => Ok(Token::RightParen),
488 '<' => Ok(Token::LeftAngle),
489 '>' => Ok(Token::RightAngle),
490 ',' => Ok(Token::Comma),
491 ':' => Ok(Token::Colon),
492 ';' => Ok(Token::Semicolon),
493 '.' => Ok(Token::Dot),
494 '\n' => {
495 self.line += 1;
496 self.column = 1;
497 Ok(Token::Newline)
498 }
499 '-' => {
500 if self.match_char('>') {
501 Ok(Token::Arrow)
502 } else {
503 self.scan_number()
504 }
505 }
506 '"' => self.scan_string(),
507 _ if c.is_ascii_digit() => {
508 self.position -= 1; self.column -= 1;
510 self.scan_number()
511 }
512 _ if c.is_ascii_alphabetic() || c == '_' => {
513 self.position -= 1; self.column -= 1;
515 self.scan_identifier()
516 }
517 _ => Err(SklearsError::InvalidInput(format!(
518 "Unexpected character '{}' at line {}, column {}",
519 c, self.line, self.column
520 ))),
521 }
522 }
523
524 fn scan_string(&mut self) -> SklResult<Token> {
526 let mut value = String::new();
527
528 while !self.is_at_end() && self.peek() != '"' {
529 if self.peek() == '\n' {
530 self.line += 1;
531 self.column = 1;
532 }
533 value.push(self.advance());
534 }
535
536 if self.is_at_end() {
537 return Err(SklearsError::InvalidInput(format!(
538 "Unterminated string at line {}",
539 self.line
540 )));
541 }
542
543 self.advance();
545
546 Ok(Token::StringLiteral(value))
547 }
548
549 fn scan_number(&mut self) -> SklResult<Token> {
551 let mut value = String::new();
552
553 if self.peek() == '-' {
555 value.push(self.advance());
556 }
557
558 while !self.is_at_end() && (self.peek().is_ascii_digit() || self.peek() == '.') {
559 value.push(self.advance());
560 }
561
562 match value.parse::<f64>() {
563 Ok(number) => Ok(Token::NumberLiteral(number)),
564 Err(_) => Err(SklearsError::InvalidInput(format!(
565 "Invalid number '{}' at line {}, column {}",
566 value, self.line, self.column
567 ))),
568 }
569 }
570
571 fn scan_identifier(&mut self) -> SklResult<Token> {
573 let mut value = String::new();
574
575 while !self.is_at_end() && (self.peek().is_ascii_alphanumeric() || self.peek() == '_') {
576 value.push(self.advance());
577 }
578
579 let token = match value.as_str() {
581 "pipeline" => Token::Pipeline,
582 "version" => Token::Version,
583 "author" => Token::Author,
584 "description" => Token::Description,
585 "input" => Token::Input,
586 "output" => Token::Output,
587 "step" => Token::Step,
588 "flow" => Token::Flow,
589 "execute" => Token::Execute,
590 "Matrix" => Token::Matrix,
591 "Array" => Token::Array,
592 "f32" => Token::Float32,
593 "f64" => Token::Float64,
594 "i32" => Token::Int32,
595 "i64" => Token::Int64,
596 "bool" => Token::Bool,
597 "String" => Token::String,
598 "true" => Token::BooleanLiteral(true),
599 "false" => Token::BooleanLiteral(false),
600 _ => Token::Identifier(value),
601 };
602
603 Ok(token)
604 }
605
606 fn skip_whitespace(&mut self) {
608 while !self.is_at_end() {
609 match self.peek() {
610 ' ' | '\r' | '\t' => {
611 self.advance();
612 }
613 '/' if self.peek_next() == '/' => {
614 while !self.is_at_end() && self.peek() != '\n' {
616 self.advance();
617 }
618 }
619 _ => break,
620 }
621 }
622 }
623
624 fn advance(&mut self) -> char {
626 if let Some(&c) = self.input_chars.get(self.position) {
627 self.position += 1;
628 self.column += 1;
629 c
630 } else {
631 '\0'
632 }
633 }
634
635 fn peek(&self) -> char {
637 self.input_chars.get(self.position).copied().unwrap_or('\0')
638 }
639
640 fn peek_next(&self) -> char {
642 self.input_chars
643 .get(self.position + 1)
644 .copied()
645 .unwrap_or('\0')
646 }
647
648 fn match_char(&mut self, expected: char) -> bool {
650 if self.is_at_end() || self.peek() != expected {
651 false
652 } else {
653 self.advance();
654 true
655 }
656 }
657
658 fn is_at_end(&self) -> bool {
660 self.position >= self.input_chars.len()
661 }
662}
663
664impl DslParser {
665 #[must_use]
667 pub fn new() -> Self {
668 Self {
669 tokens: VecDeque::new(),
670 current: 0,
671 }
672 }
673
674 pub fn parse(&mut self, tokens: VecDeque<Token>) -> SklResult<WorkflowDefinition> {
676 self.tokens = tokens;
677 self.current = 0;
678
679 self.parse_pipeline()
680 }
681
682 fn parse_pipeline(&mut self) -> SklResult<WorkflowDefinition> {
684 self.consume(Token::Pipeline, "Expected 'pipeline'")?;
685
686 let name = if let Token::StringLiteral(name) = self.advance() {
687 name
688 } else {
689 return Err(SklearsError::InvalidInput(
690 "Expected pipeline name".to_string(),
691 ));
692 };
693
694 self.consume(Token::LeftBrace, "Expected '{' after pipeline name")?;
695
696 let mut workflow = WorkflowDefinition::default();
697 workflow.metadata.name = name;
698
699 while !self.check(&Token::RightBrace) && !self.is_at_end() {
701 match &self.peek() {
702 Token::Version => {
703 self.advance();
704 if let Token::StringLiteral(version) = self.advance() {
705 workflow.metadata.version = version;
706 }
707 }
708 Token::Author => {
709 self.advance();
710 if let Token::StringLiteral(author) = self.advance() {
711 workflow.metadata.author = Some(author);
712 }
713 }
714 Token::Description => {
715 self.advance();
716 if let Token::StringLiteral(description) = self.advance() {
717 workflow.metadata.description = Some(description);
718 }
719 }
720 Token::Input => {
721 workflow.inputs.push(self.parse_input()?);
722 }
723 Token::Output => {
724 workflow.outputs.push(self.parse_output()?);
725 }
726 Token::Step => {
727 workflow.steps.push(self.parse_step()?);
728 }
729 Token::Flow => {
730 workflow.connections.push(self.parse_flow()?);
731 }
732 Token::Execute => {
733 workflow.execution = self.parse_execute()?;
734 }
735 _ => {
736 return Err(SklearsError::InvalidInput(format!(
737 "Unexpected token: {:?}",
738 self.peek()
739 )));
740 }
741 }
742 }
743
744 self.consume(Token::RightBrace, "Expected '}' after pipeline body")?;
745
746 Ok(workflow)
747 }
748
749 fn parse_input(&mut self) -> SklResult<InputDefinition> {
751 self.consume(Token::Input, "Expected 'input'")?;
752
753 let name = if let Token::Identifier(name) = self.advance() {
754 name
755 } else {
756 return Err(SklearsError::InvalidInput(
757 "Expected input name".to_string(),
758 ));
759 };
760
761 self.consume(Token::Colon, "Expected ':' after input name")?;
762
763 let data_type = self.parse_data_type()?;
764
765 Ok(InputDefinition {
766 name,
767 data_type,
768 shape: None,
769 constraints: None,
770 description: None,
771 })
772 }
773
774 fn parse_output(&mut self) -> SklResult<OutputDefinition> {
776 self.consume(Token::Output, "Expected 'output'")?;
777
778 let name = if let Token::Identifier(name) = self.advance() {
779 name
780 } else {
781 return Err(SklearsError::InvalidInput(
782 "Expected output name".to_string(),
783 ));
784 };
785
786 Ok(OutputDefinition {
787 name,
788 data_type: DataType::Float64, shape: None,
790 description: None,
791 })
792 }
793
794 fn parse_step(&mut self) -> SklResult<StepDefinition> {
796 self.consume(Token::Step, "Expected 'step'")?;
797
798 let id = if let Token::Identifier(id) = self.advance() {
799 id
800 } else {
801 return Err(SklearsError::InvalidInput(
802 "Expected step identifier".to_string(),
803 ));
804 };
805
806 self.consume(Token::Colon, "Expected ':' after step identifier")?;
807
808 let algorithm = if let Token::Identifier(algorithm) = self.advance() {
809 algorithm
810 } else {
811 return Err(SklearsError::InvalidInput(
812 "Expected algorithm name".to_string(),
813 ));
814 };
815
816 let mut parameters = BTreeMap::new();
817
818 if self.check(&Token::LeftBrace) {
819 self.advance(); while !self.check(&Token::RightBrace) && !self.is_at_end() {
822 let param_name = if let Token::Identifier(name) = self.advance() {
823 name
824 } else {
825 return Err(SklearsError::InvalidInput(
826 "Expected parameter name".to_string(),
827 ));
828 };
829
830 self.consume(Token::Colon, "Expected ':' after parameter name")?;
831
832 let param_value = self.parse_parameter_value()?;
833 parameters.insert(param_name, param_value);
834
835 if self.check(&Token::Comma) {
836 self.advance();
837 }
838 }
839
840 self.consume(Token::RightBrace, "Expected '}' after step parameters")?;
841 }
842
843 Ok(StepDefinition {
844 id,
845 step_type: StepType::Custom("DSL".to_string()),
846 algorithm,
847 parameters,
848 inputs: Vec::new(),
849 outputs: Vec::new(),
850 condition: None,
851 description: None,
852 })
853 }
854
855 fn parse_flow(&mut self) -> SklResult<Connection> {
857 self.consume(Token::Flow, "Expected 'flow'")?;
858
859 let from_step = if let Token::Identifier(step) = self.advance() {
861 step
862 } else {
863 return Err(SklearsError::InvalidInput(
864 "Expected source step".to_string(),
865 ));
866 };
867
868 self.consume(Token::Dot, "Expected '.' after source step")?;
869
870 let from_output = match self.advance() {
871 Token::Identifier(output) => output,
872 Token::Output => "output".to_string(),
873 Token::Input => "input".to_string(),
874 other => {
875 return Err(SklearsError::InvalidInput(format!(
876 "Expected source output, got {:?}",
877 other
878 )))
879 }
880 };
881
882 self.consume(Token::Arrow, "Expected '->' in flow")?;
883
884 let to_step = if let Token::Identifier(step) = self.advance() {
886 step
887 } else {
888 return Err(SklearsError::InvalidInput(
889 "Expected target step".to_string(),
890 ));
891 };
892
893 self.consume(Token::Dot, "Expected '.' after target step")?;
894
895 let to_input = match self.advance() {
896 Token::Identifier(input) => input,
897 Token::Input => "input".to_string(),
898 other => {
899 return Err(SklearsError::InvalidInput(format!(
900 "Expected target input, got {:?}",
901 other
902 )))
903 }
904 };
905
906 Ok(Connection {
907 from_step,
908 from_output,
909 to_step,
910 to_input,
911 connection_type: ConnectionType::Direct,
912 transform: None,
913 })
914 }
915
916 fn parse_execute(&mut self) -> SklResult<ExecutionConfig> {
918 self.consume(Token::Execute, "Expected 'execute'")?;
919 self.consume(Token::LeftBrace, "Expected '{' after 'execute'")?;
920
921 let mut config = ExecutionConfig {
922 mode: ExecutionMode::Sequential,
923 parallel: None,
924 resources: None,
925 caching: None,
926 };
927
928 while !self.check(&Token::RightBrace) && !self.is_at_end() {
929 let key = if let Token::Identifier(key) = self.advance() {
930 key
931 } else {
932 return Err(SklearsError::InvalidInput(
933 "Expected configuration key".to_string(),
934 ));
935 };
936
937 self.consume(Token::Colon, "Expected ':' after configuration key")?;
938
939 match key.as_str() {
940 "mode" => {
941 if let Token::Identifier(mode) = self.advance() {
942 config.mode = match mode.as_str() {
943 "parallel" => ExecutionMode::Parallel,
944 "sequential" => ExecutionMode::Sequential,
945 "distributed" => ExecutionMode::Distributed,
946 "gpu" => ExecutionMode::GPU,
947 "adaptive" => ExecutionMode::Adaptive,
948 _ => ExecutionMode::Sequential,
949 };
950 }
951 }
952 "workers" => {
953 if let Token::NumberLiteral(workers) = self.advance() {
954 config.parallel = Some(ParallelConfig {
955 num_workers: workers as usize,
956 chunk_size: None,
957 load_balancing: "round_robin".to_string(),
958 });
959 }
960 }
961 _ => {
962 self.advance();
964 }
965 }
966
967 if self.check(&Token::Comma) {
968 self.advance();
969 }
970 }
971
972 self.consume(
973 Token::RightBrace,
974 "Expected '}' after execute configuration",
975 )?;
976
977 Ok(config)
978 }
979
980 fn parse_data_type(&mut self) -> SklResult<DataType> {
982 match &self.advance() {
983 Token::Float32 => Ok(DataType::Float32),
984 Token::Float64 => Ok(DataType::Float64),
985 Token::Int32 => Ok(DataType::Int32),
986 Token::Int64 => Ok(DataType::Int64),
987 Token::Bool => Ok(DataType::Boolean),
988 Token::String => Ok(DataType::String),
989 Token::Array => {
990 self.consume(Token::LeftAngle, "Expected '<' after 'Array'")?;
991 let inner_type = self.parse_data_type()?;
992 self.consume(Token::RightAngle, "Expected '>' after array type")?;
993 Ok(DataType::Array(Box::new(inner_type)))
994 }
995 Token::Matrix => {
996 self.consume(Token::LeftAngle, "Expected '<' after 'Matrix'")?;
997 let inner_type = self.parse_data_type()?;
998 self.consume(Token::RightAngle, "Expected '>' after matrix type")?;
999 Ok(DataType::Matrix(Box::new(inner_type)))
1000 }
1001 Token::Identifier(name) => Ok(DataType::Custom(name.clone())),
1002 _ => Err(SklearsError::InvalidInput("Expected data type".to_string())),
1003 }
1004 }
1005
1006 fn parse_parameter_value(&mut self) -> SklResult<ParameterValue> {
1008 match &self.advance() {
1009 Token::NumberLiteral(n) => {
1010 if n.fract() == 0.0 {
1011 Ok(ParameterValue::Int(*n as i64))
1012 } else {
1013 Ok(ParameterValue::Float(*n))
1014 }
1015 }
1016 Token::BooleanLiteral(b) => Ok(ParameterValue::Bool(*b)),
1017 Token::StringLiteral(s) => Ok(ParameterValue::String(s.clone())),
1018 _ => Err(SklearsError::InvalidInput(
1019 "Expected parameter value".to_string(),
1020 )),
1021 }
1022 }
1023
1024 fn consume(&mut self, expected: Token, message: &str) -> SklResult<Token> {
1026 if self.check(&expected) {
1027 Ok(self.advance())
1028 } else {
1029 Err(SklearsError::InvalidInput(format!(
1030 "{}, got {:?}",
1031 message,
1032 self.peek()
1033 )))
1034 }
1035 }
1036
1037 fn check(&self, token_type: &Token) -> bool {
1039 if self.is_at_end() {
1040 false
1041 } else {
1042 std::mem::discriminant(&self.peek()) == std::mem::discriminant(token_type)
1043 }
1044 }
1045
1046 fn advance(&mut self) -> Token {
1048 if !self.is_at_end() {
1049 self.current += 1;
1050 }
1051 self.previous()
1052 }
1053
1054 fn peek(&self) -> Token {
1056 if let Some(token) = self.tokens.get(self.current) {
1057 token.clone()
1058 } else {
1059 Token::Eof
1060 }
1061 }
1062
1063 fn previous(&self) -> Token {
1065 if self.current > 0 {
1066 if let Some(token) = self.tokens.get(self.current - 1) {
1067 token.clone()
1068 } else {
1069 Token::Eof
1070 }
1071 } else {
1072 Token::Eof
1073 }
1074 }
1075
1076 fn is_at_end(&self) -> bool {
1078 matches!(self.peek(), Token::Eof)
1079 }
1080}
1081
1082impl std::fmt::Display for DslError {
1083 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1084 write!(
1085 f,
1086 "DSL Error at line {}, column {}: {}",
1087 self.line, self.column, self.message
1088 )
1089 }
1090}
1091
1092impl std::error::Error for DslError {}
1093
1094impl Default for PipelineDSL {
1095 fn default() -> Self {
1096 Self::new()
1097 }
1098}
1099
1100impl Default for DslLexer {
1101 fn default() -> Self {
1102 Self::new()
1103 }
1104}
1105
1106impl Default for DslParser {
1107 fn default() -> Self {
1108 Self::new()
1109 }
1110}
1111
1112#[derive(Debug, Clone, Serialize, Deserialize)]
1114pub enum AstNode {
1115 Pipeline {
1117 name: String,
1118 metadata: PipelineMetadata,
1119 children: Vec<AstNode>,
1120 },
1121 Step {
1123 name: String,
1124 algorithm: String,
1125 parameters: Vec<AstNode>,
1126 },
1127 Connection {
1129 from: String,
1130 to: String,
1131 port_mapping: Vec<(String, String)>,
1132 },
1133 Parameter { name: String, value: ParameterValue },
1135 Input { name: String, data_type: DataType },
1137 Output { name: String, data_type: DataType },
1139 Config { key: String, value: String },
1141}
1142
1143#[derive(Debug, Clone, Serialize, Deserialize)]
1145pub struct PipelineMetadata {
1146 pub version: String,
1148 pub description: Option<String>,
1150 pub author: Option<String>,
1152}
1153
1154#[derive(Debug, Clone, Serialize, Deserialize)]
1156pub struct AutoCompleter {
1157 pub keywords: Vec<String>,
1159 pub functions: Vec<String>,
1161 pub components: Vec<String>,
1163 pub context_suggestions: BTreeMap<String, Vec<String>>,
1165}
1166
1167#[derive(Debug, Clone, Serialize, Deserialize)]
1169pub struct DslConfig {
1170 pub syntax_highlighting: bool,
1172 pub auto_completion: bool,
1174 pub real_time_validation: bool,
1176 pub indent_size: usize,
1178 pub max_line_length: usize,
1180 pub comment_style: CommentStyle,
1182}
1183
1184#[derive(Debug, Clone, Serialize, Deserialize)]
1186pub enum CommentStyle {
1187 SingleLine,
1189 MultiLine,
1191 Both,
1193}
1194
1195#[derive(Debug, Clone, Serialize, Deserialize)]
1197pub enum LexError {
1198 UnexpectedCharacter(char, usize, usize),
1200 UnterminatedString(usize, usize),
1202 InvalidNumber(String, usize, usize),
1204 InvalidEscape(String, usize, usize),
1206 UnexpectedEof(usize, usize),
1208}
1209
1210#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
1212pub enum ParseError {
1213 #[error("Unexpected token '{0}' at line {1}, column {2}")]
1215 UnexpectedToken(String, usize, usize),
1216 #[error("Missing token '{0}' at line {1}, column {2}")]
1218 MissingToken(String, usize, usize),
1219 #[error("Invalid syntax: {0} at line {1}, column {2}")]
1221 InvalidSyntax(String, usize, usize),
1222 #[error("Semantic error: {0} at line {1}, column {2}")]
1224 SemanticError(String, usize, usize),
1225 #[error("Unknown identifier '{0}' at line {1}, column {2}")]
1227 UnknownIdentifier(String, usize, usize),
1228 #[error("Type mismatch: expected {1}, found {0} at line {2}, column {3}")]
1230 TypeMismatch(String, String, usize, usize),
1231}
1232
1233pub type ParseResult<T> = Result<T, ParseError>;
1235
1236#[derive(Debug, Clone, Serialize, Deserialize)]
1238pub struct SemanticAnalyzer {
1239 pub symbol_table: SymbolTable,
1241 pub type_checker: TypeChecker,
1243 pub rules: Vec<SemanticRule>,
1245 pub errors: Vec<ParseError>,
1247}
1248
1249#[derive(Debug, Clone, Serialize, Deserialize)]
1251pub struct SemanticRule {
1252 pub name: String,
1254 pub description: String,
1256 pub severity: RuleSeverity,
1258 pub checker: String,
1260}
1261
1262#[derive(Debug, Clone, Serialize, Deserialize)]
1264pub enum RuleSeverity {
1265 Error,
1267 Warning,
1269 Info,
1271}
1272
1273#[derive(Debug, Clone, Serialize, Deserialize)]
1275pub struct SymbolTable {
1276 pub symbols: BTreeMap<String, Symbol>,
1278 pub scopes: Vec<Scope>,
1280 pub current_scope: usize,
1282}
1283
1284#[derive(Debug, Clone, Serialize, Deserialize)]
1286pub struct Symbol {
1287 pub name: String,
1289 pub symbol_type: SymbolType,
1291 pub data_type: DataType,
1293 pub scope: usize,
1295 pub line: usize,
1297 pub column: usize,
1299}
1300
1301#[derive(Debug, Clone, Serialize, Deserialize)]
1303pub enum SymbolType {
1304 Variable,
1306 Function,
1308 Type,
1310 Constant,
1312 Step,
1314 Parameter,
1316}
1317
1318#[derive(Debug, Clone, Serialize, Deserialize)]
1320pub struct Scope {
1321 pub name: String,
1323 pub parent: Option<usize>,
1325 pub symbols: Vec<String>,
1327}
1328
1329pub type SyntaxHighlighter = SyntaxHighlight;
1331
1332#[derive(Debug, Clone, Serialize, Deserialize)]
1334pub enum TokenType {
1335 Keyword(String),
1337 Identifier(String),
1339 Number(f64),
1341 String(String),
1343 Boolean(bool),
1345 Operator(String),
1347 Punctuation(char),
1349 Comment(String),
1351 Whitespace(String),
1353 Eof,
1355}
1356
1357#[derive(Debug, Clone, Serialize, Deserialize)]
1359pub struct TypeChecker {
1360 pub rules: Vec<TypeRule>,
1362 pub types: BTreeMap<String, TypeInfo>,
1364 pub coercion_rules: Vec<CoercionRule>,
1366}
1367
1368#[derive(Debug, Clone, Serialize, Deserialize)]
1370pub struct TypeRule {
1371 pub name: String,
1373 pub source_type: DataType,
1375 pub target_type: DataType,
1377 pub checker: String,
1379}
1380
1381#[derive(Debug, Clone, Serialize, Deserialize)]
1383pub struct TypeInfo {
1384 pub name: String,
1386 pub base_type: DataType,
1388 pub constraints: Vec<String>,
1390 pub metadata: BTreeMap<String, String>,
1392}
1393
1394#[derive(Debug, Clone, Serialize, Deserialize)]
1396pub struct CoercionRule {
1397 pub from_type: DataType,
1399 pub to_type: DataType,
1401 pub cost: u32,
1403 pub coercion_fn: String,
1405}
1406
1407impl Default for DslConfig {
1408 fn default() -> Self {
1409 Self {
1410 syntax_highlighting: true,
1411 auto_completion: true,
1412 real_time_validation: true,
1413 indent_size: 4,
1414 max_line_length: 100,
1415 comment_style: CommentStyle::Both,
1416 }
1417 }
1418}
1419
1420impl AutoCompleter {
1421 #[must_use]
1423 pub fn new() -> Self {
1424 let mut completer = Self {
1425 keywords: vec![
1426 "pipeline".to_string(),
1427 "step".to_string(),
1428 "connect".to_string(),
1429 "input".to_string(),
1430 "output".to_string(),
1431 "execute".to_string(),
1432 "version".to_string(),
1433 ],
1434 functions: vec![
1435 "transform".to_string(),
1436 "fit".to_string(),
1437 "predict".to_string(),
1438 "evaluate".to_string(),
1439 ],
1440 components: vec![
1441 "StandardScaler".to_string(),
1442 "LinearRegression".to_string(),
1443 "RandomForest".to_string(),
1444 "SVM".to_string(),
1445 ],
1446 context_suggestions: BTreeMap::new(),
1447 };
1448
1449 completer
1451 .context_suggestions
1452 .insert("step".to_string(), completer.components.clone());
1453
1454 completer
1455 }
1456
1457 #[must_use]
1459 pub fn get_suggestions(&self, context: &str, prefix: &str) -> Vec<String> {
1460 let mut suggestions = Vec::new();
1461
1462 for keyword in &self.keywords {
1464 if keyword.starts_with(prefix) {
1465 suggestions.push(keyword.clone());
1466 }
1467 }
1468
1469 if let Some(context_suggestions) = self.context_suggestions.get(context) {
1471 for suggestion in context_suggestions {
1472 if suggestion.starts_with(prefix) {
1473 suggestions.push(suggestion.clone());
1474 }
1475 }
1476 }
1477
1478 suggestions.sort();
1479 suggestions.dedup();
1480 suggestions
1481 }
1482}
1483
1484impl Default for AutoCompleter {
1485 fn default() -> Self {
1486 Self::new()
1487 }
1488}
1489
1490impl SymbolTable {
1491 #[must_use]
1493 pub fn new() -> Self {
1494 Self {
1495 symbols: BTreeMap::new(),
1496 scopes: vec![Scope {
1497 name: "global".to_string(),
1498 parent: None,
1499 symbols: Vec::new(),
1500 }],
1501 current_scope: 0,
1502 }
1503 }
1504
1505 pub fn add_symbol(&mut self, symbol: Symbol) {
1507 self.symbols.insert(symbol.name.clone(), symbol.clone());
1508 if let Some(scope) = self.scopes.get_mut(self.current_scope) {
1509 scope.symbols.push(symbol.name);
1510 }
1511 }
1512
1513 #[must_use]
1515 pub fn lookup(&self, name: &str) -> Option<&Symbol> {
1516 self.symbols.get(name)
1517 }
1518}
1519
1520impl Default for SymbolTable {
1521 fn default() -> Self {
1522 Self::new()
1523 }
1524}
1525
1526impl TypeChecker {
1527 #[must_use]
1529 pub fn new() -> Self {
1530 Self {
1531 rules: Vec::new(),
1532 types: BTreeMap::new(),
1533 coercion_rules: Vec::new(),
1534 }
1535 }
1536
1537 #[must_use]
1539 pub fn are_compatible(&self, type1: &DataType, type2: &DataType) -> bool {
1540 type1 == type2 || self.can_coerce(type1, type2)
1541 }
1542
1543 #[must_use]
1545 pub fn can_coerce(&self, from: &DataType, to: &DataType) -> bool {
1546 self.coercion_rules
1547 .iter()
1548 .any(|rule| rule.from_type == *from && rule.to_type == *to)
1549 }
1550}
1551
1552impl Default for TypeChecker {
1553 fn default() -> Self {
1554 Self::new()
1555 }
1556}
1557
1558#[allow(non_snake_case)]
1559#[cfg(test)]
1560mod tests {
1561 use super::*;
1562
1563 #[test]
1564 fn test_lexer_basic_tokens() {
1565 let mut lexer = DslLexer::new();
1566 let tokens = lexer.tokenize("pipeline { }").unwrap();
1567
1568 assert_eq!(tokens[0], Token::Pipeline);
1569 assert_eq!(tokens[1], Token::LeftBrace);
1570 assert_eq!(tokens[2], Token::RightBrace);
1571 assert_eq!(tokens[3], Token::Eof);
1572 }
1573
1574 #[test]
1575 fn test_lexer_string_literal() {
1576 let mut lexer = DslLexer::new();
1577 let tokens = lexer.tokenize("\"hello world\"").unwrap();
1578
1579 if let Token::StringLiteral(s) = &tokens[0] {
1580 assert_eq!(s, "hello world");
1581 } else {
1582 panic!("Expected string literal");
1583 }
1584 }
1585
1586 #[test]
1587 fn test_lexer_number_literal() {
1588 let mut lexer = DslLexer::new();
1589 let tokens = lexer.tokenize("42.5").unwrap();
1590
1591 if let Token::NumberLiteral(n) = &tokens[0] {
1592 assert_eq!(*n, 42.5);
1593 } else {
1594 panic!("Expected number literal");
1595 }
1596 }
1597
1598 #[test]
1599 fn test_parser_simple_pipeline() {
1600 let mut dsl = PipelineDSL::new();
1601 let input = r#"
1602 pipeline "Test Pipeline" {
1603 version "1.0.0"
1604 step scaler: StandardScaler {
1605 with_mean: true
1606 }
1607 }
1608 "#;
1609
1610 let workflow = dsl.parse(input).unwrap();
1611 assert_eq!(workflow.metadata.name, "Test Pipeline");
1612 assert_eq!(workflow.metadata.version, "1.0.0");
1613 assert_eq!(workflow.steps.len(), 1);
1614 assert_eq!(workflow.steps[0].algorithm, "StandardScaler");
1615 }
1616
1617 #[test]
1618 fn test_dsl_generation() {
1619 let mut workflow = WorkflowDefinition::default();
1620 workflow.metadata.name = "Test Workflow".to_string();
1621 workflow.metadata.version = "1.0.0".to_string();
1622
1623 let step = StepDefinition::new("scaler", StepType::Transformer, "StandardScaler")
1624 .with_parameter("with_mean", ParameterValue::Bool(true));
1625 workflow.steps.push(step);
1626
1627 let dsl = PipelineDSL::new();
1628 let generated = dsl.generate(&workflow);
1629
1630 assert!(generated.contains("pipeline \"Test Workflow\""));
1631 assert!(generated.contains("version \"1.0.0\""));
1632 assert!(generated.contains("step scaler: StandardScaler"));
1633 assert!(generated.contains("with_mean: true"));
1634 }
1635
1636 #[test]
1637 fn test_syntax_validation() {
1638 let mut dsl = PipelineDSL::new();
1639
1640 let valid_input = r#"
1642 pipeline "Valid" {
1643 version "1.0.0"
1644 }
1645 "#;
1646 let errors = dsl.validate_syntax(valid_input).unwrap();
1647 assert!(errors.is_empty());
1648
1649 let invalid_input = r#"
1651 pipeline "Invalid" {
1652 version // missing value
1653 }
1654 "#;
1655 let errors = dsl.validate_syntax(invalid_input).unwrap();
1656 assert!(!errors.is_empty());
1657 }
1658}