Skip to main content

tenflowers_neural/program_synthesis_ml/
mod.rs

1//! # Program Synthesis ML
2//!
3//! ML methods for code understanding, analysis and program synthesis.
4//!
5//! This module provides:
6//!
7//! * **[`CodeTokenizer`]** — language-agnostic lexer for source code
8//! * **[`ASTEncoder`]** — tree-positional encoding for AST nodes
9//! * **[`CodeBert`]** — masked language model for code (CodeBERT-style)
10//! * **[`CodeContrastive`]** — contrastive learning between code and docstrings
11//! * **[`FlashFillSolver`]** — example-based string transformation via DSL
12//! * **\[`NeuralProgramInducer`\]** — differentiable interpreter for program induction
13//! * **[`CodeSummarizer`]** — pointer-generator network for code summarization
14//! * **[`BugLocalizerGnn`]** — graph neural network to locate buggy lines
15//! * **[`TestCaseGenerator`]** — mutation-based test case generation
16//! * **[`CodeMetrics`]** — static analysis metrics (cyclomatic, Halstead, MI)
17//!
18//! ## Randomness Policy
19//!
20//! All randomness is sourced from `scirs2_core::random` — never from `rand`.
21
22
23pub mod extensions;
24pub use extensions::*;
25
26pub mod neural_exec;
27pub use neural_exec::*;
28
29#[cfg(test)]
30mod tests;
31
32use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
33use scirs2_core::RngExt;
34use std::collections::{HashMap, VecDeque};
35use std::f32::consts::PI;
36
37// ---------------------------------------------------------------------------
38// 1. CodeTokenizer
39// ---------------------------------------------------------------------------
40
41/// Programming language for tokenization hints.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum Language {
44    Python,
45    Rust,
46    C,
47    Generic,
48}
49
50/// Coarse token category.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum TokenKind {
53    Keyword,
54    Identifier,
55    Literal,
56    Operator,
57    Punct,
58    Comment,
59    Whitespace,
60}
61
62/// A single lexical token produced by [`CodeTokenizer`].
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct Token {
65    pub kind: TokenKind,
66    pub text: String,
67}
68
69/// Language-agnostic lexer for source code.
70pub struct CodeTokenizer {
71    python_keywords: Vec<&'static str>,
72    rust_keywords: Vec<&'static str>,
73    c_keywords: Vec<&'static str>,
74}
75
76impl CodeTokenizer {
77    pub fn new() -> Self {
78        Self {
79            python_keywords: vec![
80                "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
81                "continue", "def", "del", "elif", "else", "except", "finally", "for", "from",
82                "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass",
83                "raise", "return", "try", "while", "with", "yield",
84            ],
85            rust_keywords: vec![
86                "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else",
87                "enum", "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match",
88                "mod", "move", "mut", "pub", "ref", "return", "self", "Self", "static", "struct",
89                "super", "trait", "true", "type", "unsafe", "use", "where", "while",
90            ],
91            c_keywords: vec![
92                "auto", "break", "case", "char", "const", "continue", "default", "do", "double",
93                "else", "enum", "extern", "float", "for", "goto", "if", "int", "long", "register",
94                "return", "short", "signed", "sizeof", "static", "struct", "switch", "typedef",
95                "union", "unsigned", "void", "volatile", "while",
96            ],
97        }
98    }
99
100    /// Tokenize `source` according to the given `lang` hint.
101    pub fn tokenize(&self, source: &str, lang: Language) -> Vec<Token> {
102        let keywords: &[&str] = match lang {
103            Language::Python => &self.python_keywords,
104            Language::Rust => &self.rust_keywords,
105            Language::C => &self.c_keywords,
106            Language::Generic => &[],
107        };
108        let mut tokens = Vec::new();
109        let chars: Vec<char> = source.chars().collect();
110        let mut i = 0;
111
112        while i < chars.len() {
113            let ch = chars[i];
114
115            // Whitespace
116            if ch.is_whitespace() {
117                let start = i;
118                while i < chars.len() && chars[i].is_whitespace() {
119                    i += 1;
120                }
121                tokens.push(Token {
122                    kind: TokenKind::Whitespace,
123                    text: chars[start..i].iter().collect(),
124                });
125                continue;
126            }
127
128            // Line comment (// or #)
129            if ch == '/' && i + 1 < chars.len() && chars[i + 1] == '/' {
130                let start = i;
131                while i < chars.len() && chars[i] != '\n' {
132                    i += 1;
133                }
134                tokens.push(Token {
135                    kind: TokenKind::Comment,
136                    text: chars[start..i].iter().collect(),
137                });
138                continue;
139            }
140            if ch == '#' {
141                let start = i;
142                while i < chars.len() && chars[i] != '\n' {
143                    i += 1;
144                }
145                tokens.push(Token {
146                    kind: TokenKind::Comment,
147                    text: chars[start..i].iter().collect(),
148                });
149                continue;
150            }
151            // Block comment /* ... */
152            if ch == '/' && i + 1 < chars.len() && chars[i + 1] == '*' {
153                let start = i;
154                i += 2;
155                while i + 1 < chars.len() && !(chars[i] == '*' && chars[i + 1] == '/') {
156                    i += 1;
157                }
158                i += 2; // consume */
159                tokens.push(Token {
160                    kind: TokenKind::Comment,
161                    text: chars[start..i.min(chars.len())].iter().collect(),
162                });
163                continue;
164            }
165
166            // String literal (double-quote)
167            if ch == '"' {
168                let start = i;
169                i += 1;
170                while i < chars.len() {
171                    if chars[i] == '\\' {
172                        i += 2;
173                    } else if chars[i] == '"' {
174                        i += 1;
175                        break;
176                    } else {
177                        i += 1;
178                    }
179                }
180                tokens.push(Token {
181                    kind: TokenKind::Literal,
182                    text: chars[start..i].iter().collect(),
183                });
184                continue;
185            }
186            // String literal (single-quote)
187            if ch == '\'' {
188                let start = i;
189                i += 1;
190                while i < chars.len() {
191                    if chars[i] == '\\' {
192                        i += 2;
193                    } else if chars[i] == '\'' {
194                        i += 1;
195                        break;
196                    } else {
197                        i += 1;
198                    }
199                }
200                tokens.push(Token {
201                    kind: TokenKind::Literal,
202                    text: chars[start..i].iter().collect(),
203                });
204                continue;
205            }
206
207            // Numeric literal
208            if ch.is_ascii_digit()
209                || (ch == '.' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit())
210            {
211                let start = i;
212                while i < chars.len()
213                    && (chars[i].is_ascii_alphanumeric() || chars[i] == '.' || chars[i] == '_')
214                {
215                    i += 1;
216                }
217                tokens.push(Token {
218                    kind: TokenKind::Literal,
219                    text: chars[start..i].iter().collect(),
220                });
221                continue;
222            }
223
224            // Identifier or keyword
225            if ch.is_alphabetic() || ch == '_' {
226                let start = i;
227                while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
228                    i += 1;
229                }
230                let word: String = chars[start..i].iter().collect();
231                let kind = if keywords.contains(&word.as_str()) {
232                    TokenKind::Keyword
233                } else {
234                    TokenKind::Identifier
235                };
236                tokens.push(Token { kind, text: word });
237                continue;
238            }
239
240            // Operators (multi-char first)
241            let op2: Option<String> = if i + 1 < chars.len() {
242                let s: String = chars[i..i + 2].iter().collect();
243                match s.as_str() {
244                    "==" | "!=" | "<=" | ">=" | "->" | "=>" | "::" | "&&" | "||" | "+=" | "-="
245                    | "*=" | "/=" | "&=" | "|=" | "^=" | "<<" | ">>" => Some(s),
246                    _ => None,
247                }
248            } else {
249                None
250            };
251            if let Some(op) = op2 {
252                i += 2;
253                tokens.push(Token {
254                    kind: TokenKind::Operator,
255                    text: op,
256                });
257                continue;
258            }
259
260            // Single-char operators / punctuation
261            let kind = match ch {
262                '+' | '-' | '*' | '/' | '%' | '=' | '<' | '>' | '!' | '&' | '|' | '^' | '~' => {
263                    TokenKind::Operator
264                }
265                '(' | ')' | '[' | ']' | '{' | '}' | ',' | ';' | ':' | '.' | '@' => TokenKind::Punct,
266                _ => TokenKind::Punct,
267            };
268            tokens.push(Token {
269                kind,
270                text: ch.to_string(),
271            });
272            i += 1;
273        }
274        tokens
275    }
276}
277
278impl Default for CodeTokenizer {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284// ---------------------------------------------------------------------------
285// 2. ASTEncoder
286// ---------------------------------------------------------------------------
287
288/// A node in an Abstract Syntax Tree.
289#[derive(Debug, Clone)]
290pub struct AstNode {
291    pub node_type: String,
292    pub children: Vec<AstNode>,
293    pub depth: usize,
294    pub sibling_idx: usize,
295}
296
297impl AstNode {
298    pub fn new(node_type: impl Into<String>, depth: usize, sibling_idx: usize) -> Self {
299        Self {
300            node_type: node_type.into(),
301            children: Vec::new(),
302            depth,
303            sibling_idx,
304        }
305    }
306
307    pub fn add_child(&mut self, child: AstNode) {
308        self.children.push(child);
309    }
310}
311
312/// Tree-positional encoder for AST nodes using sinusoidal encoding.
313pub struct ASTEncoder {
314    pub embed_dim: usize,
315}
316
317impl ASTEncoder {
318    pub fn new(embed_dim: usize) -> Self {
319        Self { embed_dim }
320    }
321
322    /// Encode a single node using depth + sibling-index sinusoidal encoding.
323    pub fn encode_node(&self, node: &AstNode, embed_dim: usize) -> Vec<f32> {
324        let mut enc = vec![0.0f32; embed_dim];
325        let d = embed_dim / 2;
326        for k in 0..d {
327            let denom = 10_000_f32.powf(2.0 * k as f32 / embed_dim as f32);
328            // First half: depth encoding
329            enc[2 * k] = (node.depth as f32 / denom).sin();
330            if 2 * k + 1 < embed_dim {
331                enc[2 * k + 1] = (node.depth as f32 / denom).cos();
332            }
333        }
334        // Second half: sibling index encoding (interleaved with depth if dim is small)
335        let half = embed_dim / 2;
336        for k in 0..half {
337            let denom = 10_000_f32.powf(2.0 * k as f32 / embed_dim as f32);
338            let idx = half + 2 * k;
339            if idx < embed_dim {
340                enc[idx] = (node.sibling_idx as f32 / denom).sin();
341            }
342            if idx + 1 < embed_dim {
343                enc[idx + 1] = (node.sibling_idx as f32 / denom).cos();
344            }
345        }
346        // Add a node-type hash component into the first position
347        let type_hash = node
348            .node_type
349            .bytes()
350            .fold(0u32, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u32));
351        let type_signal = ((type_hash as f32 / u32::MAX as f32) * 2.0 - 1.0) * 0.1;
352        enc[0] += type_signal;
353        enc
354    }
355
356    /// BFS-order encoding of an entire tree.
357    pub fn encode_tree(&self, root: &AstNode) -> Vec<Vec<f32>> {
358        let mut result = Vec::new();
359        let mut queue: VecDeque<&AstNode> = VecDeque::new();
360        queue.push_back(root);
361        while let Some(node) = queue.pop_front() {
362            result.push(self.encode_node(node, self.embed_dim));
363            for child in &node.children {
364                queue.push_back(child);
365            }
366        }
367        result
368    }
369}
370
371// ---------------------------------------------------------------------------
372// 3. CodeBert
373// ---------------------------------------------------------------------------
374
375/// Configuration for the CodeBERT masked language model.
376#[derive(Debug, Clone)]
377pub struct CodeBertConfig {
378    pub vocab_size: usize,
379    pub embed_dim: usize,
380    pub n_heads: usize,
381    pub n_layers: usize,
382    pub max_seq_len: usize,
383}
384
385impl CodeBertConfig {
386    pub fn new(
387        vocab_size: usize,
388        embed_dim: usize,
389        n_heads: usize,
390        n_layers: usize,
391        max_seq_len: usize,
392    ) -> Self {
393        Self {
394            vocab_size,
395            embed_dim,
396            n_heads,
397            n_layers,
398            max_seq_len,
399        }
400    }
401}
402
403/// CodeBERT: masked language model for code.
404pub struct CodeBert {
405    config: CodeBertConfig,
406    /// Token embedding table: [vocab_size x embed_dim]
407    token_embed: Vec<Vec<f32>>,
408    /// Positional embedding table: [max_seq_len x embed_dim]
409    pos_embed: Vec<Vec<f32>>,
410    /// Per-layer attention weight matrices (Q, K, V, O) each [embed_dim x embed_dim]
411    layers: Vec<CodeBertLayer>,
412    /// Output projection [embed_dim x vocab_size]
413    output_proj: Vec<Vec<f32>>,
414}
415
416struct CodeBertLayer {
417    wq: Vec<Vec<f32>>,
418    wk: Vec<Vec<f32>>,
419    wv: Vec<Vec<f32>>,
420    wo: Vec<Vec<f32>>,
421    ff1: Vec<Vec<f32>>,
422    ff2: Vec<Vec<f32>>,
423}
424
425impl CodeBertLayer {
426    fn new(embed_dim: usize, ff_dim: usize, rng: &mut StdRng) -> Self {
427        let init = |rows: usize, cols: usize, rng: &mut StdRng| -> Vec<Vec<f32>> {
428            let scale = (2.0 / (rows + cols) as f32).sqrt();
429            (0..rows)
430                .map(|_| {
431                    (0..cols)
432                        .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale)
433                        .collect()
434                })
435                .collect()
436        };
437        Self {
438            wq: init(embed_dim, embed_dim, rng),
439            wk: init(embed_dim, embed_dim, rng),
440            wv: init(embed_dim, embed_dim, rng),
441            wo: init(embed_dim, embed_dim, rng),
442            // ff1: maps embed_dim -> ff_dim  (ff_dim rows, each of embed_dim cols)
443            ff1: init(ff_dim, embed_dim, rng),
444            // ff2: maps ff_dim -> embed_dim  (embed_dim rows, each of ff_dim cols)
445            ff2: init(embed_dim, ff_dim, rng),
446        }
447    }
448
449    fn forward(&self, x: &[Vec<f32>], n_heads: usize) -> Vec<Vec<f32>> {
450        let seq_len = x.len();
451        let embed_dim = x[0].len();
452        let head_dim = embed_dim / n_heads;
453        let scale = (head_dim as f32).sqrt();
454
455        // Project Q, K, V
456        let proj = |input: &[Vec<f32>], w: &[Vec<f32>]| -> Vec<Vec<f32>> {
457            input
458                .iter()
459                .map(|xi| {
460                    w.iter()
461                        .map(|row| row.iter().zip(xi.iter()).map(|(a, b)| a * b).sum::<f32>())
462                        .collect()
463                })
464                .collect()
465        };
466        let q = proj(x, &self.wq);
467        let k = proj(x, &self.wk);
468        let v = proj(x, &self.wv);
469
470        // Multi-head attention
471        let mut attn_out = vec![vec![0.0f32; embed_dim]; seq_len];
472        for h in 0..n_heads {
473            let s = h * head_dim;
474            let e = s + head_dim;
475            // Compute scores
476            let mut scores = vec![vec![0.0f32; seq_len]; seq_len];
477            for i in 0..seq_len {
478                for j in 0..seq_len {
479                    scores[i][j] = q[i][s..e]
480                        .iter()
481                        .zip(k[j][s..e].iter())
482                        .map(|(a, b)| a * b)
483                        .sum::<f32>()
484                        / scale;
485                }
486            }
487            // Softmax per row
488            for i in 0..seq_len {
489                let max_s = scores[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
490                let exp: Vec<f32> = scores[i].iter().map(|&x| (x - max_s).exp()).collect();
491                let sum: f32 = exp.iter().sum();
492                for j in 0..seq_len {
493                    scores[i][j] = exp[j] / sum.max(1e-8);
494                }
495            }
496            // Weighted sum of V
497            for i in 0..seq_len {
498                for j in 0..seq_len {
499                    for k_d in s..e {
500                        attn_out[i][k_d] += scores[i][j] * v[j][k_d];
501                    }
502                }
503            }
504        }
505        // Output projection
506        let out = proj(&attn_out, &self.wo);
507        // Residual + layer norm
508        let mut res: Vec<Vec<f32>> = out
509            .iter()
510            .zip(x.iter())
511            .map(|(o, xi)| o.iter().zip(xi.iter()).map(|(a, b)| a + b).collect())
512            .collect();
513        layer_norm_2d(&mut res);
514
515        // Feed-forward
516        let mut ff_out = vec![vec![0.0f32; embed_dim]; seq_len];
517        for (i, xi) in res.iter().enumerate() {
518            let h1: Vec<f32> = self
519                .ff1
520                .iter()
521                .map(|row| {
522                    let sum: f32 = row.iter().zip(xi.iter()).map(|(a, b)| a * b).sum();
523                    gelu(sum)
524                })
525                .collect();
526            for (j, row) in self.ff2.iter().enumerate() {
527                ff_out[i][j] = row.iter().zip(h1.iter()).map(|(a, b)| a * b).sum::<f32>();
528            }
529        }
530        // Residual + layer norm
531        let mut final_out: Vec<Vec<f32>> = ff_out
532            .iter()
533            .zip(res.iter())
534            .map(|(f, r)| f.iter().zip(r.iter()).map(|(a, b)| a + b).collect())
535            .collect();
536        layer_norm_2d(&mut final_out);
537        final_out
538    }
539}
540
541fn gelu(x: f32) -> f32 {
542    0.5 * x * (1.0 + ((2.0 / PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
543}
544
545fn layer_norm_2d(x: &mut [Vec<f32>]) {
546    for xi in x.iter_mut() {
547        let mean = xi.iter().sum::<f32>() / xi.len() as f32;
548        let var = xi.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / xi.len() as f32;
549        let std = (var + 1e-5).sqrt();
550        for v in xi.iter_mut() {
551            *v = (*v - mean) / std;
552        }
553    }
554}
555
556impl CodeBert {
557    /// Create a new CodeBERT model with Xavier-initialised weights.
558    pub fn new(config: CodeBertConfig) -> Self {
559        let mut rng = StdRng::seed_from_u64(42);
560        let scale_emb = (1.0 / config.embed_dim as f32).sqrt();
561        let token_embed: Vec<Vec<f32>> = (0..config.vocab_size)
562            .map(|_| {
563                (0..config.embed_dim)
564                    .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale_emb)
565                    .collect()
566            })
567            .collect();
568        let pos_embed: Vec<Vec<f32>> = (0..config.max_seq_len)
569            .map(|pos| {
570                (0..config.embed_dim)
571                    .enumerate()
572                    .map(|(i, _)| {
573                        let denom = 10_000_f32.powf(2.0 * (i / 2) as f32 / config.embed_dim as f32);
574                        if i % 2 == 0 {
575                            (pos as f32 / denom).sin()
576                        } else {
577                            (pos as f32 / denom).cos()
578                        }
579                    })
580                    .collect()
581            })
582            .collect();
583        let ff_dim = config.embed_dim * 4;
584        let layers = (0..config.n_layers)
585            .map(|_| CodeBertLayer::new(config.embed_dim, ff_dim, &mut rng))
586            .collect();
587        let scale_out = (2.0 / (config.embed_dim + config.vocab_size) as f32).sqrt();
588        // output_proj: [vocab_size x embed_dim] — each row is a vocab entry's projection vector
589        let output_proj: Vec<Vec<f32>> = (0..config.vocab_size)
590            .map(|_| {
591                (0..config.embed_dim)
592                    .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale_out)
593                    .collect()
594            })
595            .collect();
596        Self {
597            config,
598            token_embed,
599            pos_embed,
600            layers,
601            output_proj,
602        }
603    }
604
605    /// Forward pass: returns contextual embeddings [seq_len x embed_dim].
606    pub fn forward(&self, token_ids: &[usize]) -> Vec<Vec<f32>> {
607        let seq_len = token_ids.len().min(self.config.max_seq_len);
608        // Embedding lookup + positional
609        let mut x: Vec<Vec<f32>> = token_ids[..seq_len]
610            .iter()
611            .enumerate()
612            .map(|(pos, &tid)| {
613                let t_idx = tid % self.config.vocab_size;
614                self.token_embed[t_idx]
615                    .iter()
616                    .zip(self.pos_embed[pos].iter())
617                    .map(|(t, p)| t + p)
618                    .collect()
619            })
620            .collect();
621        // Transformer layers
622        for layer in &self.layers {
623            x = layer.forward(&x, self.config.n_heads);
624        }
625        x
626    }
627
628    /// Compute per-position vocab logits for masked language modelling.
629    pub fn mlm_logits(&self, hidden: &[Vec<f32>]) -> Vec<Vec<f32>> {
630        hidden
631            .iter()
632            .map(|h| {
633                self.output_proj
634                    .iter()
635                    .map(|row| row.iter().zip(h.iter()).map(|(a, b)| a * b).sum::<f32>())
636                    .collect()
637            })
638            .collect()
639    }
640}
641
642// ---------------------------------------------------------------------------
643// 4. CodeContrastive
644// ---------------------------------------------------------------------------
645
646/// Contrastive learning between code and natural-language docstrings.
647pub struct CodeContrastive {
648    pub embed_dim: usize,
649    code_proj: Vec<Vec<f32>>,
650    doc_proj: Vec<Vec<f32>>,
651    token_embed: Vec<Vec<f32>>,
652    vocab_size: usize,
653}
654
655impl CodeContrastive {
656    pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
657        let mut rng = StdRng::seed_from_u64(123);
658        let scale = (1.0 / embed_dim as f32).sqrt();
659        let random_matrix = |rng: &mut StdRng| -> Vec<Vec<f32>> {
660            (0..embed_dim)
661                .map(|_| {
662                    (0..embed_dim)
663                        .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale)
664                        .collect()
665                })
666                .collect()
667        };
668        let token_embed: Vec<Vec<f32>> = (0..vocab_size)
669            .map(|_| {
670                (0..embed_dim)
671                    .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale)
672                    .collect()
673            })
674            .collect();
675        Self {
676            embed_dim,
677            code_proj: random_matrix(&mut rng),
678            doc_proj: random_matrix(&mut rng),
679            token_embed,
680            vocab_size,
681        }
682    }
683
684    fn mean_pool(&self, tokens: &[usize]) -> Vec<f32> {
685        if tokens.is_empty() {
686            return vec![0.0; self.embed_dim];
687        }
688        let mut sum = vec![0.0f32; self.embed_dim];
689        for &tid in tokens {
690            let idx = tid % self.vocab_size;
691            for (s, e) in sum.iter_mut().zip(self.token_embed[idx].iter()) {
692                *s += e;
693            }
694        }
695        let n = tokens.len() as f32;
696        sum.iter_mut().for_each(|v| *v /= n);
697        sum
698    }
699
700    fn project(vec: &[f32], mat: &[Vec<f32>]) -> Vec<f32> {
701        mat.iter()
702            .map(|row| row.iter().zip(vec.iter()).map(|(a, b)| a * b).sum::<f32>())
703            .collect()
704    }
705
706    fn l2_norm(v: &mut [f32]) {
707        let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
708        v.iter_mut().for_each(|x| *x /= norm);
709    }
710
711    /// Encode a token sequence as a mean-pooled code embedding.
712    pub fn encode_code(&self, tokens: &[usize]) -> Vec<f32> {
713        let pooled = self.mean_pool(tokens);
714        let mut proj = Self::project(&pooled, &self.code_proj);
715        Self::l2_norm(&mut proj);
716        proj
717    }
718
719    /// Encode a token sequence as a mean-pooled doc embedding.
720    pub fn encode_doc(&self, tokens: &[usize]) -> Vec<f32> {
721        let pooled = self.mean_pool(tokens);
722        let mut proj = Self::project(&pooled, &self.doc_proj);
723        Self::l2_norm(&mut proj);
724        proj
725    }
726
727    /// InfoNCE contrastive loss with in-batch negatives.
728    ///
729    /// `code_embeds` and `doc_embeds` must have the same length N;
730    /// diagonal entries are positive pairs and off-diagonal are negatives.
731    pub fn contrastive_loss(
732        &self,
733        code_embeds: &[Vec<f32>],
734        doc_embeds: &[Vec<f32>],
735        temperature: f32,
736    ) -> f32 {
737        let n = code_embeds.len();
738        if n == 0 {
739            return 0.0;
740        }
741        let dot =
742            |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() };
743        let mut total_loss = 0.0f32;
744        for i in 0..n {
745            // Code → Doc direction
746            let logits: Vec<f32> = (0..n)
747                .map(|j| dot(&code_embeds[i], &doc_embeds[j]) / temperature)
748                .collect();
749            let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
750            let exp: Vec<f32> = logits.iter().map(|&l| (l - max_l).exp()).collect();
751            let sum: f32 = exp.iter().sum();
752            total_loss -= (exp[i] / sum.max(1e-8)).ln();
753        }
754        for j in 0..n {
755            // Doc → Code direction
756            let logits: Vec<f32> = (0..n)
757                .map(|i| dot(&doc_embeds[j], &code_embeds[i]) / temperature)
758                .collect();
759            let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
760            let exp: Vec<f32> = logits.iter().map(|&l| (l - max_l).exp()).collect();
761            let sum: f32 = exp.iter().sum();
762            total_loss -= (exp[j] / sum.max(1e-8)).ln();
763        }
764        total_loss / (2.0 * n as f32)
765    }
766}
767
768// ---------------------------------------------------------------------------
769// 5. FlashFillSolver
770// ---------------------------------------------------------------------------
771
772/// DSL operations for string transformations.
773#[derive(Debug, Clone, PartialEq, Eq)]
774pub enum StringDsl {
775    /// Concatenate sub-results of two sub-programs.
776    Concat(Box<StringDsl>, Box<StringDsl>),
777    /// Extract substring: Substr(start, end) — both inclusive, 0-indexed.
778    Substr(usize, usize),
779    /// Replace all occurrences of first with second.
780    Replace(String, String),
781    /// Convert to uppercase.
782    Upper,
783    /// Convert to lowercase.
784    Lower,
785    /// Strip leading/trailing whitespace.
786    Strip,
787    /// Split on delimiter and return the n-th piece.
788    Split(String, usize),
789    /// Join a split result with a new delimiter (operates on whitespace split).
790    Join(String),
791    /// Apply a simple regex-derived constant extraction by index.
792    Regex(String, usize),
793}
794
795/// A program in the string-transformation DSL.
796#[derive(Debug, Clone)]
797pub struct FlashFillProgram {
798    pub ops: Vec<StringDsl>,
799}
800
801/// FlashFill-style example-based string synthesis.
802pub struct FlashFillSolver;
803
804impl FlashFillSolver {
805    pub fn new() -> Self {
806        Self
807    }
808
809    /// Execute a single DSL op on `input`.
810    fn execute_op(op: &StringDsl, input: &str) -> String {
811        match op {
812            StringDsl::Concat(a, b) => {
813                format!(
814                    "{}{}",
815                    Self::execute_op(a, input),
816                    Self::execute_op(b, input)
817                )
818            }
819            StringDsl::Substr(start, end) => {
820                let chars: Vec<char> = input.chars().collect();
821                let s = (*start).min(chars.len());
822                let e = (*end + 1).min(chars.len());
823                if s >= e {
824                    String::new()
825                } else {
826                    chars[s..e].iter().collect()
827                }
828            }
829            StringDsl::Replace(from, to) => input.replace(from.as_str(), to.as_str()),
830            StringDsl::Upper => input.to_uppercase(),
831            StringDsl::Lower => input.to_lowercase(),
832            StringDsl::Strip => input.trim().to_string(),
833            StringDsl::Split(delim, idx) => {
834                let parts: Vec<&str> = input.split(delim.as_str()).collect();
835                (*parts.get(*idx).unwrap_or(&"")).to_string()
836            }
837            StringDsl::Join(delim) => input
838                .split_whitespace()
839                .collect::<Vec<_>>()
840                .join(delim.as_str()),
841            StringDsl::Regex(pattern, idx) => {
842                // Simple pattern: treat pattern as a literal delimiter and extract idx-th piece
843                let parts: Vec<&str> = input.split(pattern.as_str()).collect();
844                parts.get(*idx).unwrap_or(&"").trim().to_string()
845            }
846        }
847    }
848
849    /// Execute a full program (sequential ops compose: each op applied to input).
850    pub fn execute(program: &FlashFillProgram, input: &str) -> String {
851        let mut current = input.to_string();
852        for op in &program.ops {
853            current = Self::execute_op(op, &current);
854        }
855        current
856    }
857
858    /// Verify a program against all examples.
859    fn verify(program: &FlashFillProgram, examples: &[(&str, &str)]) -> bool {
860        examples
861            .iter()
862            .all(|(inp, out)| Self::execute(program, inp) == *out)
863    }
864
865    /// Enumerate short programs and verify against examples.
866    /// Returns the first program that satisfies all examples, if any.
867    pub fn synthesize(&self, examples: &[(&str, &str)]) -> Option<FlashFillProgram> {
868        if examples.is_empty() {
869            return None;
870        }
871
872        // Build candidate atomic ops from example data
873        let mut candidates: Vec<Vec<StringDsl>> = Vec::new();
874
875        // Identity-like ops
876        candidates.push(vec![StringDsl::Strip]);
877        candidates.push(vec![StringDsl::Upper]);
878        candidates.push(vec![StringDsl::Lower]);
879
880        // Collect delimiters/replacements from examples
881        let common_delimiters = [" ", "-", "_", ",", ".", "/", ":"];
882        for &delim in &common_delimiters {
883            for idx in 0..4 {
884                candidates.push(vec![StringDsl::Split(delim.to_string(), idx)]);
885            }
886            candidates.push(vec![StringDsl::Join(delim.to_string())]);
887        }
888
889        // Substring ranges based on example output lengths
890        let max_len = examples
891            .iter()
892            .map(|(i, _)| i.chars().count())
893            .max()
894            .unwrap_or(0);
895        for start in 0..max_len.min(8) {
896            for end in start..max_len.min(16) {
897                candidates.push(vec![StringDsl::Substr(start, end)]);
898            }
899        }
900
901        // Replace ops from example pairs
902        for &(inp, out) in examples {
903            // Simple prefix strip
904            if out.len() < inp.len() && inp.starts_with(out) {
905                // output is a prefix
906            }
907            // Replacement heuristics
908            for &delim in &common_delimiters {
909                if inp.contains(delim) {
910                    for &new_delim in &common_delimiters {
911                        if new_delim != delim {
912                            candidates.push(vec![StringDsl::Replace(
913                                delim.to_string(),
914                                new_delim.to_string(),
915                            )]);
916                        }
917                    }
918                    candidates.push(vec![StringDsl::Replace(delim.to_string(), "".to_string())]);
919                }
920            }
921        }
922
923        // Two-op compositions
924        let single_ops: Vec<StringDsl> = vec![StringDsl::Strip, StringDsl::Upper, StringDsl::Lower];
925        for op1 in &single_ops {
926            for op2 in &single_ops {
927                candidates.push(vec![op1.clone(), op2.clone()]);
928            }
929        }
930        // Split then join compositions
931        for &delim in &common_delimiters {
932            for &new_delim in &common_delimiters {
933                candidates.push(vec![StringDsl::Join(new_delim.to_string())]);
934                candidates.push(vec![
935                    StringDsl::Split(delim.to_string(), 0),
936                    StringDsl::Strip,
937                ]);
938            }
939        }
940
941        for ops in candidates {
942            let program = FlashFillProgram { ops };
943            if Self::verify(&program, examples) {
944                return Some(program);
945            }
946        }
947        None
948    }
949}
950
951impl Default for FlashFillSolver {
952    fn default() -> Self {
953        Self::new()
954    }
955}
956
957// ---------------------------------------------------------------------------
958// 6. NeuralProgramInducer
959// ---------------------------------------------------------------------------
960
961/// Opcodes for a simple register machine.
962#[derive(Debug, Clone, Copy, PartialEq, Eq)]
963pub enum OpCode {
964    Add,
965    Sub,
966    Mul,
967    Div,
968    Copy,
969    Max,
970    Min,
971    And,
972    Or,
973    Not,
974}
975
976/// A single instruction in the register machine.
977#[derive(Debug, Clone)]
978pub struct Instruction {
979    pub op: OpCode,
980    /// Argument indices into the register file.
981    pub args: Vec<usize>,
982}
983
984/// Differentiable interpreter with soft-gated execution.
985pub struct DifferentiableInterpreter {
986    pub registers: Vec<f32>,
987    pub program: Vec<Instruction>,
988}
989
990impl DifferentiableInterpreter {
991    pub fn new(n_registers: usize, program: Vec<Instruction>) -> Self {
992        Self {
993            registers: vec![0.0; n_registers],
994            program,
995        }
996    }
997
998    /// Execute the program with soft (differentiable) semantics.
999    ///
1000    /// Inputs are loaded into registers 0..input.len(). The function
1001    /// returns the final register file after execution.
1002    pub fn execute_soft(&mut self, input: &[f32]) -> Vec<f32> {
1003        // Load input into registers
1004        for (i, &v) in input.iter().enumerate() {
1005            if i < self.registers.len() {
1006                self.registers[i] = v;
1007            }
1008        }
1009
1010        let n_regs = self.registers.len();
1011        let clamp = |v: f32| v.clamp(-1e6, 1e6);
1012
1013        for instr in &self.program {
1014            let get = |idx: usize| -> f32 {
1015                if idx < n_regs {
1016                    self.registers[idx]
1017                } else {
1018                    0.0
1019                }
1020            };
1021            let a0 = instr.args.first().copied().unwrap_or(0);
1022            let a1 = instr.args.get(1).copied().unwrap_or(0);
1023            let a2 = instr.args.get(2).copied().unwrap_or(0);
1024            let result = match instr.op {
1025                OpCode::Add => clamp(get(a0) + get(a1)),
1026                OpCode::Sub => clamp(get(a0) - get(a1)),
1027                OpCode::Mul => {
1028                    // Soft mul: use tanh gate to prevent explosion
1029                    let product = get(a0) * get(a1);
1030                    clamp(product.tanh() * product.abs().sqrt())
1031                }
1032                OpCode::Div => {
1033                    let denom = get(a1);
1034                    if denom.abs() < 1e-7 {
1035                        0.0
1036                    } else {
1037                        clamp(get(a0) / denom)
1038                    }
1039                }
1040                OpCode::Copy => get(a0),
1041                OpCode::Max => get(a0).max(get(a1)),
1042                OpCode::Min => get(a0).min(get(a1)),
1043                OpCode::And => {
1044                    // Soft AND: product of sigmoids
1045                    let s0 = sigmoid_f32(get(a0));
1046                    let s1 = sigmoid_f32(get(a1));
1047                    s0 * s1
1048                }
1049                OpCode::Or => {
1050                    // Soft OR: 1 - (1-s0)*(1-s1)
1051                    let s0 = sigmoid_f32(get(a0));
1052                    let s1 = sigmoid_f32(get(a1));
1053                    1.0 - (1.0 - s0) * (1.0 - s1)
1054                }
1055                OpCode::Not => {
1056                    // Soft NOT: 1 - sigmoid
1057                    1.0 - sigmoid_f32(get(a0))
1058                }
1059            };
1060            if a2 < n_regs {
1061                self.registers[a2] = result;
1062            }
1063        }
1064        self.registers.clone()
1065    }
1066}
1067
1068pub(crate) fn sigmoid_f32(x: f32) -> f32 {
1069    1.0 / (1.0 + (-x).exp())
1070}
1071
1072// ---------------------------------------------------------------------------
1073// 7. CodeSummarizer
1074// ---------------------------------------------------------------------------
1075
1076/// Configuration for the pointer-generator summarizer.
1077#[derive(Debug, Clone)]
1078pub struct PointerGeneratorConfig {
1079    pub vocab_size: usize,
1080    pub hidden_dim: usize,
1081    pub attn_dim: usize,
1082}
1083
1084/// Pointer-generator network for code summarization.
1085pub struct CodeSummarizer {
1086    config: PointerGeneratorConfig,
1087    /// Encoder embedding table [vocab_size x hidden_dim]
1088    encoder_embed: Vec<Vec<f32>>,
1089    /// Encoder GRU weights (simplified: W_in [hidden x hidden], W_rec [hidden x hidden])
1090    enc_w_in: Vec<Vec<f32>>,
1091    enc_w_rec: Vec<Vec<f32>>,
1092    /// Decoder embedding [vocab_size x hidden_dim]
1093    decoder_embed: Vec<Vec<f32>>,
1094    /// Decoder step weights
1095    dec_w_in: Vec<Vec<f32>>,
1096    dec_w_rec: Vec<Vec<f32>>,
1097    /// Attention: W_enc [attn x hidden], W_dec [attn x hidden], V [1 x attn]
1098    w_enc: Vec<Vec<f32>>,
1099    w_dec: Vec<Vec<f32>>,
1100    v_attn: Vec<f32>,
1101    /// Vocabulary distribution projection [vocab x hidden]
1102    w_vocab: Vec<Vec<f32>>,
1103    /// Copy gate weights [1 x hidden]
1104    w_copy_gate: Vec<f32>,
1105}
1106
1107impl CodeSummarizer {
1108    pub fn new(config: PointerGeneratorConfig) -> Self {
1109        let mut rng = StdRng::seed_from_u64(77);
1110        let h = config.hidden_dim;
1111        let v = config.vocab_size;
1112        let a = config.attn_dim;
1113        let scale = |d: usize| (1.0 / d as f32).sqrt();
1114        let mat = |rows: usize, cols: usize, rng: &mut StdRng| -> Vec<Vec<f32>> {
1115            let s = scale(rows);
1116            (0..rows)
1117                .map(|_| {
1118                    (0..cols)
1119                        .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * s)
1120                        .collect()
1121                })
1122                .collect()
1123        };
1124        let vec_init = |n: usize, rng: &mut StdRng| -> Vec<f32> {
1125            (0..n)
1126                .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale(n))
1127                .collect()
1128        };
1129        Self {
1130            encoder_embed: mat(v, h, &mut rng),
1131            enc_w_in: mat(h, h, &mut rng),
1132            enc_w_rec: mat(h, h, &mut rng),
1133            decoder_embed: mat(v, h, &mut rng),
1134            dec_w_in: mat(h, h, &mut rng),
1135            dec_w_rec: mat(h, h, &mut rng),
1136            w_enc: mat(a, h, &mut rng),
1137            w_dec: mat(a, h, &mut rng),
1138            v_attn: vec_init(a, &mut rng),
1139            w_vocab: mat(v, h, &mut rng),
1140            w_copy_gate: vec_init(h, &mut rng),
1141            config,
1142        }
1143    }
1144
1145    fn mat_vec(mat: &[Vec<f32>], v: &[f32]) -> Vec<f32> {
1146        mat.iter()
1147            .map(|row| row.iter().zip(v.iter()).map(|(a, b)| a * b).sum::<f32>())
1148            .collect()
1149    }
1150
1151    fn gru_step(x: &[f32], h: &[f32], w_in: &[Vec<f32>], w_rec: &[Vec<f32>]) -> Vec<f32> {
1152        let dim = h.len();
1153        let ax = Self::mat_vec(w_in, x);
1154        let ah = Self::mat_vec(w_rec, h);
1155        (0..dim)
1156            .map(|i| (ax.get(i).copied().unwrap_or(0.0) + ah.get(i).copied().unwrap_or(0.0)).tanh())
1157            .collect()
1158    }
1159
1160    /// Encode source tokens into hidden states.
1161    pub fn encode_source(&self, tokens: &[usize]) -> Vec<Vec<f32>> {
1162        let h = self.config.hidden_dim;
1163        let v = self.config.vocab_size;
1164        let mut hidden = vec![0.0f32; h];
1165        let mut states = Vec::with_capacity(tokens.len());
1166        for &tid in tokens {
1167            let idx = tid % v;
1168            let embed = &self.encoder_embed[idx];
1169            hidden = Self::gru_step(embed, &hidden, &self.enc_w_in, &self.enc_w_rec);
1170            states.push(hidden.clone());
1171        }
1172        states
1173    }
1174
1175    /// One decode step: returns (vocab_distribution, copy_distribution).
1176    pub fn decode_step(
1177        &self,
1178        prev_token: usize,
1179        hidden: &[f32],
1180        encoder_states: &[Vec<f32>],
1181    ) -> (Vec<f32>, Vec<f32>) {
1182        let v = self.config.vocab_size;
1183        let a = self.config.attn_dim;
1184        let idx = prev_token % v;
1185        let embed = &self.decoder_embed[idx];
1186        let new_hidden = Self::gru_step(embed, hidden, &self.dec_w_in, &self.dec_w_rec);
1187
1188        // Attention
1189        let dec_proj = Self::mat_vec(&self.w_dec, &new_hidden);
1190        let mut attn_scores: Vec<f32> = encoder_states
1191            .iter()
1192            .map(|es| {
1193                let enc_proj = Self::mat_vec(&self.w_enc, es);
1194                let combined: Vec<f32> = enc_proj
1195                    .iter()
1196                    .zip(dec_proj.iter())
1197                    .map(|(e, d)| (e + d).tanh())
1198                    .collect();
1199                self.v_attn
1200                    .iter()
1201                    .zip(combined.iter())
1202                    .map(|(v, c)| v * c)
1203                    .sum::<f32>()
1204            })
1205            .collect();
1206        // Softmax
1207        let max_s = attn_scores
1208            .iter()
1209            .cloned()
1210            .fold(f32::NEG_INFINITY, f32::max);
1211        let exp: Vec<f32> = attn_scores.iter().map(|&s| (s - max_s).exp()).collect();
1212        let sum: f32 = exp.iter().sum::<f32>().max(1e-8);
1213        attn_scores = exp.iter().map(|e| e / sum).collect();
1214
1215        // Context vector
1216        let mut ctx = vec![0.0f32; self.config.hidden_dim];
1217        for (weight, es) in attn_scores.iter().zip(encoder_states.iter()) {
1218            for (c, e) in ctx.iter_mut().zip(es.iter()) {
1219                *c += weight * e;
1220            }
1221        }
1222
1223        // Vocab distribution
1224        let combined: Vec<f32> = new_hidden
1225            .iter()
1226            .zip(ctx.iter())
1227            .map(|(h, c)| h + c)
1228            .collect();
1229        let logits = Self::mat_vec(&self.w_vocab, &combined);
1230        let vocab_dist = softmax_vec(&logits);
1231
1232        // Copy distribution (just attention weights padded to vocab_size)
1233        let src_count = encoder_states.len();
1234        let mut copy_dist = vec![0.0f32; src_count];
1235        copy_dist.copy_from_slice(&attn_scores);
1236
1237        (vocab_dist, copy_dist)
1238    }
1239
1240    /// Scatter copy-attention weights into vocab-space.
1241    pub fn copy_mechanism(
1242        &self,
1243        attn_weights: &[f32],
1244        src_tokens: &[usize],
1245        vocab_size: usize,
1246    ) -> Vec<f32> {
1247        let mut copy_vocab = vec![0.0f32; vocab_size];
1248        for (&weight, &tid) in attn_weights.iter().zip(src_tokens.iter()) {
1249            let idx = tid % vocab_size;
1250            copy_vocab[idx] += weight;
1251        }
1252        copy_vocab
1253    }
1254}
1255
1256pub(crate) fn softmax_vec(logits: &[f32]) -> Vec<f32> {
1257    let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1258    let exp: Vec<f32> = logits.iter().map(|&l| (l - max_l).exp()).collect();
1259    let sum: f32 = exp.iter().sum::<f32>().max(1e-8);
1260    exp.iter().map(|e| e / sum).collect()
1261}