1pub mod extensions;
24pub use extensions::*;
25
26pub mod neural_exec;
27pub use neural_exec::*;
28
29#[cfg(test)]
30mod tests;
31
32use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
33use scirs2_core::RngExt;
34use std::collections::{HashMap, VecDeque};
35use std::f32::consts::PI;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum Language {
44 Python,
45 Rust,
46 C,
47 Generic,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum TokenKind {
53 Keyword,
54 Identifier,
55 Literal,
56 Operator,
57 Punct,
58 Comment,
59 Whitespace,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct Token {
65 pub kind: TokenKind,
66 pub text: String,
67}
68
69pub struct CodeTokenizer {
71 python_keywords: Vec<&'static str>,
72 rust_keywords: Vec<&'static str>,
73 c_keywords: Vec<&'static str>,
74}
75
76impl CodeTokenizer {
77 pub fn new() -> Self {
78 Self {
79 python_keywords: vec![
80 "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
81 "continue", "def", "del", "elif", "else", "except", "finally", "for", "from",
82 "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass",
83 "raise", "return", "try", "while", "with", "yield",
84 ],
85 rust_keywords: vec![
86 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else",
87 "enum", "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match",
88 "mod", "move", "mut", "pub", "ref", "return", "self", "Self", "static", "struct",
89 "super", "trait", "true", "type", "unsafe", "use", "where", "while",
90 ],
91 c_keywords: vec![
92 "auto", "break", "case", "char", "const", "continue", "default", "do", "double",
93 "else", "enum", "extern", "float", "for", "goto", "if", "int", "long", "register",
94 "return", "short", "signed", "sizeof", "static", "struct", "switch", "typedef",
95 "union", "unsigned", "void", "volatile", "while",
96 ],
97 }
98 }
99
100 pub fn tokenize(&self, source: &str, lang: Language) -> Vec<Token> {
102 let keywords: &[&str] = match lang {
103 Language::Python => &self.python_keywords,
104 Language::Rust => &self.rust_keywords,
105 Language::C => &self.c_keywords,
106 Language::Generic => &[],
107 };
108 let mut tokens = Vec::new();
109 let chars: Vec<char> = source.chars().collect();
110 let mut i = 0;
111
112 while i < chars.len() {
113 let ch = chars[i];
114
115 if ch.is_whitespace() {
117 let start = i;
118 while i < chars.len() && chars[i].is_whitespace() {
119 i += 1;
120 }
121 tokens.push(Token {
122 kind: TokenKind::Whitespace,
123 text: chars[start..i].iter().collect(),
124 });
125 continue;
126 }
127
128 if ch == '/' && i + 1 < chars.len() && chars[i + 1] == '/' {
130 let start = i;
131 while i < chars.len() && chars[i] != '\n' {
132 i += 1;
133 }
134 tokens.push(Token {
135 kind: TokenKind::Comment,
136 text: chars[start..i].iter().collect(),
137 });
138 continue;
139 }
140 if ch == '#' {
141 let start = i;
142 while i < chars.len() && chars[i] != '\n' {
143 i += 1;
144 }
145 tokens.push(Token {
146 kind: TokenKind::Comment,
147 text: chars[start..i].iter().collect(),
148 });
149 continue;
150 }
151 if ch == '/' && i + 1 < chars.len() && chars[i + 1] == '*' {
153 let start = i;
154 i += 2;
155 while i + 1 < chars.len() && !(chars[i] == '*' && chars[i + 1] == '/') {
156 i += 1;
157 }
158 i += 2; tokens.push(Token {
160 kind: TokenKind::Comment,
161 text: chars[start..i.min(chars.len())].iter().collect(),
162 });
163 continue;
164 }
165
166 if ch == '"' {
168 let start = i;
169 i += 1;
170 while i < chars.len() {
171 if chars[i] == '\\' {
172 i += 2;
173 } else if chars[i] == '"' {
174 i += 1;
175 break;
176 } else {
177 i += 1;
178 }
179 }
180 tokens.push(Token {
181 kind: TokenKind::Literal,
182 text: chars[start..i].iter().collect(),
183 });
184 continue;
185 }
186 if ch == '\'' {
188 let start = i;
189 i += 1;
190 while i < chars.len() {
191 if chars[i] == '\\' {
192 i += 2;
193 } else if chars[i] == '\'' {
194 i += 1;
195 break;
196 } else {
197 i += 1;
198 }
199 }
200 tokens.push(Token {
201 kind: TokenKind::Literal,
202 text: chars[start..i].iter().collect(),
203 });
204 continue;
205 }
206
207 if ch.is_ascii_digit()
209 || (ch == '.' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit())
210 {
211 let start = i;
212 while i < chars.len()
213 && (chars[i].is_ascii_alphanumeric() || chars[i] == '.' || chars[i] == '_')
214 {
215 i += 1;
216 }
217 tokens.push(Token {
218 kind: TokenKind::Literal,
219 text: chars[start..i].iter().collect(),
220 });
221 continue;
222 }
223
224 if ch.is_alphabetic() || ch == '_' {
226 let start = i;
227 while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
228 i += 1;
229 }
230 let word: String = chars[start..i].iter().collect();
231 let kind = if keywords.contains(&word.as_str()) {
232 TokenKind::Keyword
233 } else {
234 TokenKind::Identifier
235 };
236 tokens.push(Token { kind, text: word });
237 continue;
238 }
239
240 let op2: Option<String> = if i + 1 < chars.len() {
242 let s: String = chars[i..i + 2].iter().collect();
243 match s.as_str() {
244 "==" | "!=" | "<=" | ">=" | "->" | "=>" | "::" | "&&" | "||" | "+=" | "-="
245 | "*=" | "/=" | "&=" | "|=" | "^=" | "<<" | ">>" => Some(s),
246 _ => None,
247 }
248 } else {
249 None
250 };
251 if let Some(op) = op2 {
252 i += 2;
253 tokens.push(Token {
254 kind: TokenKind::Operator,
255 text: op,
256 });
257 continue;
258 }
259
260 let kind = match ch {
262 '+' | '-' | '*' | '/' | '%' | '=' | '<' | '>' | '!' | '&' | '|' | '^' | '~' => {
263 TokenKind::Operator
264 }
265 '(' | ')' | '[' | ']' | '{' | '}' | ',' | ';' | ':' | '.' | '@' => TokenKind::Punct,
266 _ => TokenKind::Punct,
267 };
268 tokens.push(Token {
269 kind,
270 text: ch.to_string(),
271 });
272 i += 1;
273 }
274 tokens
275 }
276}
277
278impl Default for CodeTokenizer {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284#[derive(Debug, Clone)]
290pub struct AstNode {
291 pub node_type: String,
292 pub children: Vec<AstNode>,
293 pub depth: usize,
294 pub sibling_idx: usize,
295}
296
297impl AstNode {
298 pub fn new(node_type: impl Into<String>, depth: usize, sibling_idx: usize) -> Self {
299 Self {
300 node_type: node_type.into(),
301 children: Vec::new(),
302 depth,
303 sibling_idx,
304 }
305 }
306
307 pub fn add_child(&mut self, child: AstNode) {
308 self.children.push(child);
309 }
310}
311
312pub struct ASTEncoder {
314 pub embed_dim: usize,
315}
316
317impl ASTEncoder {
318 pub fn new(embed_dim: usize) -> Self {
319 Self { embed_dim }
320 }
321
322 pub fn encode_node(&self, node: &AstNode, embed_dim: usize) -> Vec<f32> {
324 let mut enc = vec![0.0f32; embed_dim];
325 let d = embed_dim / 2;
326 for k in 0..d {
327 let denom = 10_000_f32.powf(2.0 * k as f32 / embed_dim as f32);
328 enc[2 * k] = (node.depth as f32 / denom).sin();
330 if 2 * k + 1 < embed_dim {
331 enc[2 * k + 1] = (node.depth as f32 / denom).cos();
332 }
333 }
334 let half = embed_dim / 2;
336 for k in 0..half {
337 let denom = 10_000_f32.powf(2.0 * k as f32 / embed_dim as f32);
338 let idx = half + 2 * k;
339 if idx < embed_dim {
340 enc[idx] = (node.sibling_idx as f32 / denom).sin();
341 }
342 if idx + 1 < embed_dim {
343 enc[idx + 1] = (node.sibling_idx as f32 / denom).cos();
344 }
345 }
346 let type_hash = node
348 .node_type
349 .bytes()
350 .fold(0u32, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u32));
351 let type_signal = ((type_hash as f32 / u32::MAX as f32) * 2.0 - 1.0) * 0.1;
352 enc[0] += type_signal;
353 enc
354 }
355
356 pub fn encode_tree(&self, root: &AstNode) -> Vec<Vec<f32>> {
358 let mut result = Vec::new();
359 let mut queue: VecDeque<&AstNode> = VecDeque::new();
360 queue.push_back(root);
361 while let Some(node) = queue.pop_front() {
362 result.push(self.encode_node(node, self.embed_dim));
363 for child in &node.children {
364 queue.push_back(child);
365 }
366 }
367 result
368 }
369}
370
371#[derive(Debug, Clone)]
377pub struct CodeBertConfig {
378 pub vocab_size: usize,
379 pub embed_dim: usize,
380 pub n_heads: usize,
381 pub n_layers: usize,
382 pub max_seq_len: usize,
383}
384
385impl CodeBertConfig {
386 pub fn new(
387 vocab_size: usize,
388 embed_dim: usize,
389 n_heads: usize,
390 n_layers: usize,
391 max_seq_len: usize,
392 ) -> Self {
393 Self {
394 vocab_size,
395 embed_dim,
396 n_heads,
397 n_layers,
398 max_seq_len,
399 }
400 }
401}
402
403pub struct CodeBert {
405 config: CodeBertConfig,
406 token_embed: Vec<Vec<f32>>,
408 pos_embed: Vec<Vec<f32>>,
410 layers: Vec<CodeBertLayer>,
412 output_proj: Vec<Vec<f32>>,
414}
415
416struct CodeBertLayer {
417 wq: Vec<Vec<f32>>,
418 wk: Vec<Vec<f32>>,
419 wv: Vec<Vec<f32>>,
420 wo: Vec<Vec<f32>>,
421 ff1: Vec<Vec<f32>>,
422 ff2: Vec<Vec<f32>>,
423}
424
425impl CodeBertLayer {
426 fn new(embed_dim: usize, ff_dim: usize, rng: &mut StdRng) -> Self {
427 let init = |rows: usize, cols: usize, rng: &mut StdRng| -> Vec<Vec<f32>> {
428 let scale = (2.0 / (rows + cols) as f32).sqrt();
429 (0..rows)
430 .map(|_| {
431 (0..cols)
432 .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale)
433 .collect()
434 })
435 .collect()
436 };
437 Self {
438 wq: init(embed_dim, embed_dim, rng),
439 wk: init(embed_dim, embed_dim, rng),
440 wv: init(embed_dim, embed_dim, rng),
441 wo: init(embed_dim, embed_dim, rng),
442 ff1: init(ff_dim, embed_dim, rng),
444 ff2: init(embed_dim, ff_dim, rng),
446 }
447 }
448
449 fn forward(&self, x: &[Vec<f32>], n_heads: usize) -> Vec<Vec<f32>> {
450 let seq_len = x.len();
451 let embed_dim = x[0].len();
452 let head_dim = embed_dim / n_heads;
453 let scale = (head_dim as f32).sqrt();
454
455 let proj = |input: &[Vec<f32>], w: &[Vec<f32>]| -> Vec<Vec<f32>> {
457 input
458 .iter()
459 .map(|xi| {
460 w.iter()
461 .map(|row| row.iter().zip(xi.iter()).map(|(a, b)| a * b).sum::<f32>())
462 .collect()
463 })
464 .collect()
465 };
466 let q = proj(x, &self.wq);
467 let k = proj(x, &self.wk);
468 let v = proj(x, &self.wv);
469
470 let mut attn_out = vec![vec![0.0f32; embed_dim]; seq_len];
472 for h in 0..n_heads {
473 let s = h * head_dim;
474 let e = s + head_dim;
475 let mut scores = vec![vec![0.0f32; seq_len]; seq_len];
477 for i in 0..seq_len {
478 for j in 0..seq_len {
479 scores[i][j] = q[i][s..e]
480 .iter()
481 .zip(k[j][s..e].iter())
482 .map(|(a, b)| a * b)
483 .sum::<f32>()
484 / scale;
485 }
486 }
487 for i in 0..seq_len {
489 let max_s = scores[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
490 let exp: Vec<f32> = scores[i].iter().map(|&x| (x - max_s).exp()).collect();
491 let sum: f32 = exp.iter().sum();
492 for j in 0..seq_len {
493 scores[i][j] = exp[j] / sum.max(1e-8);
494 }
495 }
496 for i in 0..seq_len {
498 for j in 0..seq_len {
499 for k_d in s..e {
500 attn_out[i][k_d] += scores[i][j] * v[j][k_d];
501 }
502 }
503 }
504 }
505 let out = proj(&attn_out, &self.wo);
507 let mut res: Vec<Vec<f32>> = out
509 .iter()
510 .zip(x.iter())
511 .map(|(o, xi)| o.iter().zip(xi.iter()).map(|(a, b)| a + b).collect())
512 .collect();
513 layer_norm_2d(&mut res);
514
515 let mut ff_out = vec![vec![0.0f32; embed_dim]; seq_len];
517 for (i, xi) in res.iter().enumerate() {
518 let h1: Vec<f32> = self
519 .ff1
520 .iter()
521 .map(|row| {
522 let sum: f32 = row.iter().zip(xi.iter()).map(|(a, b)| a * b).sum();
523 gelu(sum)
524 })
525 .collect();
526 for (j, row) in self.ff2.iter().enumerate() {
527 ff_out[i][j] = row.iter().zip(h1.iter()).map(|(a, b)| a * b).sum::<f32>();
528 }
529 }
530 let mut final_out: Vec<Vec<f32>> = ff_out
532 .iter()
533 .zip(res.iter())
534 .map(|(f, r)| f.iter().zip(r.iter()).map(|(a, b)| a + b).collect())
535 .collect();
536 layer_norm_2d(&mut final_out);
537 final_out
538 }
539}
540
541fn gelu(x: f32) -> f32 {
542 0.5 * x * (1.0 + ((2.0 / PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
543}
544
545fn layer_norm_2d(x: &mut [Vec<f32>]) {
546 for xi in x.iter_mut() {
547 let mean = xi.iter().sum::<f32>() / xi.len() as f32;
548 let var = xi.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / xi.len() as f32;
549 let std = (var + 1e-5).sqrt();
550 for v in xi.iter_mut() {
551 *v = (*v - mean) / std;
552 }
553 }
554}
555
556impl CodeBert {
557 pub fn new(config: CodeBertConfig) -> Self {
559 let mut rng = StdRng::seed_from_u64(42);
560 let scale_emb = (1.0 / config.embed_dim as f32).sqrt();
561 let token_embed: Vec<Vec<f32>> = (0..config.vocab_size)
562 .map(|_| {
563 (0..config.embed_dim)
564 .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale_emb)
565 .collect()
566 })
567 .collect();
568 let pos_embed: Vec<Vec<f32>> = (0..config.max_seq_len)
569 .map(|pos| {
570 (0..config.embed_dim)
571 .enumerate()
572 .map(|(i, _)| {
573 let denom = 10_000_f32.powf(2.0 * (i / 2) as f32 / config.embed_dim as f32);
574 if i % 2 == 0 {
575 (pos as f32 / denom).sin()
576 } else {
577 (pos as f32 / denom).cos()
578 }
579 })
580 .collect()
581 })
582 .collect();
583 let ff_dim = config.embed_dim * 4;
584 let layers = (0..config.n_layers)
585 .map(|_| CodeBertLayer::new(config.embed_dim, ff_dim, &mut rng))
586 .collect();
587 let scale_out = (2.0 / (config.embed_dim + config.vocab_size) as f32).sqrt();
588 let output_proj: Vec<Vec<f32>> = (0..config.vocab_size)
590 .map(|_| {
591 (0..config.embed_dim)
592 .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale_out)
593 .collect()
594 })
595 .collect();
596 Self {
597 config,
598 token_embed,
599 pos_embed,
600 layers,
601 output_proj,
602 }
603 }
604
605 pub fn forward(&self, token_ids: &[usize]) -> Vec<Vec<f32>> {
607 let seq_len = token_ids.len().min(self.config.max_seq_len);
608 let mut x: Vec<Vec<f32>> = token_ids[..seq_len]
610 .iter()
611 .enumerate()
612 .map(|(pos, &tid)| {
613 let t_idx = tid % self.config.vocab_size;
614 self.token_embed[t_idx]
615 .iter()
616 .zip(self.pos_embed[pos].iter())
617 .map(|(t, p)| t + p)
618 .collect()
619 })
620 .collect();
621 for layer in &self.layers {
623 x = layer.forward(&x, self.config.n_heads);
624 }
625 x
626 }
627
628 pub fn mlm_logits(&self, hidden: &[Vec<f32>]) -> Vec<Vec<f32>> {
630 hidden
631 .iter()
632 .map(|h| {
633 self.output_proj
634 .iter()
635 .map(|row| row.iter().zip(h.iter()).map(|(a, b)| a * b).sum::<f32>())
636 .collect()
637 })
638 .collect()
639 }
640}
641
642pub struct CodeContrastive {
648 pub embed_dim: usize,
649 code_proj: Vec<Vec<f32>>,
650 doc_proj: Vec<Vec<f32>>,
651 token_embed: Vec<Vec<f32>>,
652 vocab_size: usize,
653}
654
655impl CodeContrastive {
656 pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
657 let mut rng = StdRng::seed_from_u64(123);
658 let scale = (1.0 / embed_dim as f32).sqrt();
659 let random_matrix = |rng: &mut StdRng| -> Vec<Vec<f32>> {
660 (0..embed_dim)
661 .map(|_| {
662 (0..embed_dim)
663 .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale)
664 .collect()
665 })
666 .collect()
667 };
668 let token_embed: Vec<Vec<f32>> = (0..vocab_size)
669 .map(|_| {
670 (0..embed_dim)
671 .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale)
672 .collect()
673 })
674 .collect();
675 Self {
676 embed_dim,
677 code_proj: random_matrix(&mut rng),
678 doc_proj: random_matrix(&mut rng),
679 token_embed,
680 vocab_size,
681 }
682 }
683
684 fn mean_pool(&self, tokens: &[usize]) -> Vec<f32> {
685 if tokens.is_empty() {
686 return vec![0.0; self.embed_dim];
687 }
688 let mut sum = vec![0.0f32; self.embed_dim];
689 for &tid in tokens {
690 let idx = tid % self.vocab_size;
691 for (s, e) in sum.iter_mut().zip(self.token_embed[idx].iter()) {
692 *s += e;
693 }
694 }
695 let n = tokens.len() as f32;
696 sum.iter_mut().for_each(|v| *v /= n);
697 sum
698 }
699
700 fn project(vec: &[f32], mat: &[Vec<f32>]) -> Vec<f32> {
701 mat.iter()
702 .map(|row| row.iter().zip(vec.iter()).map(|(a, b)| a * b).sum::<f32>())
703 .collect()
704 }
705
706 fn l2_norm(v: &mut [f32]) {
707 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
708 v.iter_mut().for_each(|x| *x /= norm);
709 }
710
711 pub fn encode_code(&self, tokens: &[usize]) -> Vec<f32> {
713 let pooled = self.mean_pool(tokens);
714 let mut proj = Self::project(&pooled, &self.code_proj);
715 Self::l2_norm(&mut proj);
716 proj
717 }
718
719 pub fn encode_doc(&self, tokens: &[usize]) -> Vec<f32> {
721 let pooled = self.mean_pool(tokens);
722 let mut proj = Self::project(&pooled, &self.doc_proj);
723 Self::l2_norm(&mut proj);
724 proj
725 }
726
727 pub fn contrastive_loss(
732 &self,
733 code_embeds: &[Vec<f32>],
734 doc_embeds: &[Vec<f32>],
735 temperature: f32,
736 ) -> f32 {
737 let n = code_embeds.len();
738 if n == 0 {
739 return 0.0;
740 }
741 let dot =
742 |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() };
743 let mut total_loss = 0.0f32;
744 for i in 0..n {
745 let logits: Vec<f32> = (0..n)
747 .map(|j| dot(&code_embeds[i], &doc_embeds[j]) / temperature)
748 .collect();
749 let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
750 let exp: Vec<f32> = logits.iter().map(|&l| (l - max_l).exp()).collect();
751 let sum: f32 = exp.iter().sum();
752 total_loss -= (exp[i] / sum.max(1e-8)).ln();
753 }
754 for j in 0..n {
755 let logits: Vec<f32> = (0..n)
757 .map(|i| dot(&doc_embeds[j], &code_embeds[i]) / temperature)
758 .collect();
759 let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
760 let exp: Vec<f32> = logits.iter().map(|&l| (l - max_l).exp()).collect();
761 let sum: f32 = exp.iter().sum();
762 total_loss -= (exp[j] / sum.max(1e-8)).ln();
763 }
764 total_loss / (2.0 * n as f32)
765 }
766}
767
768#[derive(Debug, Clone, PartialEq, Eq)]
774pub enum StringDsl {
775 Concat(Box<StringDsl>, Box<StringDsl>),
777 Substr(usize, usize),
779 Replace(String, String),
781 Upper,
783 Lower,
785 Strip,
787 Split(String, usize),
789 Join(String),
791 Regex(String, usize),
793}
794
795#[derive(Debug, Clone)]
797pub struct FlashFillProgram {
798 pub ops: Vec<StringDsl>,
799}
800
801pub struct FlashFillSolver;
803
804impl FlashFillSolver {
805 pub fn new() -> Self {
806 Self
807 }
808
809 fn execute_op(op: &StringDsl, input: &str) -> String {
811 match op {
812 StringDsl::Concat(a, b) => {
813 format!(
814 "{}{}",
815 Self::execute_op(a, input),
816 Self::execute_op(b, input)
817 )
818 }
819 StringDsl::Substr(start, end) => {
820 let chars: Vec<char> = input.chars().collect();
821 let s = (*start).min(chars.len());
822 let e = (*end + 1).min(chars.len());
823 if s >= e {
824 String::new()
825 } else {
826 chars[s..e].iter().collect()
827 }
828 }
829 StringDsl::Replace(from, to) => input.replace(from.as_str(), to.as_str()),
830 StringDsl::Upper => input.to_uppercase(),
831 StringDsl::Lower => input.to_lowercase(),
832 StringDsl::Strip => input.trim().to_string(),
833 StringDsl::Split(delim, idx) => {
834 let parts: Vec<&str> = input.split(delim.as_str()).collect();
835 (*parts.get(*idx).unwrap_or(&"")).to_string()
836 }
837 StringDsl::Join(delim) => input
838 .split_whitespace()
839 .collect::<Vec<_>>()
840 .join(delim.as_str()),
841 StringDsl::Regex(pattern, idx) => {
842 let parts: Vec<&str> = input.split(pattern.as_str()).collect();
844 parts.get(*idx).unwrap_or(&"").trim().to_string()
845 }
846 }
847 }
848
849 pub fn execute(program: &FlashFillProgram, input: &str) -> String {
851 let mut current = input.to_string();
852 for op in &program.ops {
853 current = Self::execute_op(op, ¤t);
854 }
855 current
856 }
857
858 fn verify(program: &FlashFillProgram, examples: &[(&str, &str)]) -> bool {
860 examples
861 .iter()
862 .all(|(inp, out)| Self::execute(program, inp) == *out)
863 }
864
865 pub fn synthesize(&self, examples: &[(&str, &str)]) -> Option<FlashFillProgram> {
868 if examples.is_empty() {
869 return None;
870 }
871
872 let mut candidates: Vec<Vec<StringDsl>> = Vec::new();
874
875 candidates.push(vec![StringDsl::Strip]);
877 candidates.push(vec![StringDsl::Upper]);
878 candidates.push(vec![StringDsl::Lower]);
879
880 let common_delimiters = [" ", "-", "_", ",", ".", "/", ":"];
882 for &delim in &common_delimiters {
883 for idx in 0..4 {
884 candidates.push(vec![StringDsl::Split(delim.to_string(), idx)]);
885 }
886 candidates.push(vec![StringDsl::Join(delim.to_string())]);
887 }
888
889 let max_len = examples
891 .iter()
892 .map(|(i, _)| i.chars().count())
893 .max()
894 .unwrap_or(0);
895 for start in 0..max_len.min(8) {
896 for end in start..max_len.min(16) {
897 candidates.push(vec![StringDsl::Substr(start, end)]);
898 }
899 }
900
901 for &(inp, out) in examples {
903 if out.len() < inp.len() && inp.starts_with(out) {
905 }
907 for &delim in &common_delimiters {
909 if inp.contains(delim) {
910 for &new_delim in &common_delimiters {
911 if new_delim != delim {
912 candidates.push(vec![StringDsl::Replace(
913 delim.to_string(),
914 new_delim.to_string(),
915 )]);
916 }
917 }
918 candidates.push(vec![StringDsl::Replace(delim.to_string(), "".to_string())]);
919 }
920 }
921 }
922
923 let single_ops: Vec<StringDsl> = vec![StringDsl::Strip, StringDsl::Upper, StringDsl::Lower];
925 for op1 in &single_ops {
926 for op2 in &single_ops {
927 candidates.push(vec![op1.clone(), op2.clone()]);
928 }
929 }
930 for &delim in &common_delimiters {
932 for &new_delim in &common_delimiters {
933 candidates.push(vec![StringDsl::Join(new_delim.to_string())]);
934 candidates.push(vec![
935 StringDsl::Split(delim.to_string(), 0),
936 StringDsl::Strip,
937 ]);
938 }
939 }
940
941 for ops in candidates {
942 let program = FlashFillProgram { ops };
943 if Self::verify(&program, examples) {
944 return Some(program);
945 }
946 }
947 None
948 }
949}
950
951impl Default for FlashFillSolver {
952 fn default() -> Self {
953 Self::new()
954 }
955}
956
957#[derive(Debug, Clone, Copy, PartialEq, Eq)]
963pub enum OpCode {
964 Add,
965 Sub,
966 Mul,
967 Div,
968 Copy,
969 Max,
970 Min,
971 And,
972 Or,
973 Not,
974}
975
976#[derive(Debug, Clone)]
978pub struct Instruction {
979 pub op: OpCode,
980 pub args: Vec<usize>,
982}
983
984pub struct DifferentiableInterpreter {
986 pub registers: Vec<f32>,
987 pub program: Vec<Instruction>,
988}
989
990impl DifferentiableInterpreter {
991 pub fn new(n_registers: usize, program: Vec<Instruction>) -> Self {
992 Self {
993 registers: vec![0.0; n_registers],
994 program,
995 }
996 }
997
998 pub fn execute_soft(&mut self, input: &[f32]) -> Vec<f32> {
1003 for (i, &v) in input.iter().enumerate() {
1005 if i < self.registers.len() {
1006 self.registers[i] = v;
1007 }
1008 }
1009
1010 let n_regs = self.registers.len();
1011 let clamp = |v: f32| v.clamp(-1e6, 1e6);
1012
1013 for instr in &self.program {
1014 let get = |idx: usize| -> f32 {
1015 if idx < n_regs {
1016 self.registers[idx]
1017 } else {
1018 0.0
1019 }
1020 };
1021 let a0 = instr.args.first().copied().unwrap_or(0);
1022 let a1 = instr.args.get(1).copied().unwrap_or(0);
1023 let a2 = instr.args.get(2).copied().unwrap_or(0);
1024 let result = match instr.op {
1025 OpCode::Add => clamp(get(a0) + get(a1)),
1026 OpCode::Sub => clamp(get(a0) - get(a1)),
1027 OpCode::Mul => {
1028 let product = get(a0) * get(a1);
1030 clamp(product.tanh() * product.abs().sqrt())
1031 }
1032 OpCode::Div => {
1033 let denom = get(a1);
1034 if denom.abs() < 1e-7 {
1035 0.0
1036 } else {
1037 clamp(get(a0) / denom)
1038 }
1039 }
1040 OpCode::Copy => get(a0),
1041 OpCode::Max => get(a0).max(get(a1)),
1042 OpCode::Min => get(a0).min(get(a1)),
1043 OpCode::And => {
1044 let s0 = sigmoid_f32(get(a0));
1046 let s1 = sigmoid_f32(get(a1));
1047 s0 * s1
1048 }
1049 OpCode::Or => {
1050 let s0 = sigmoid_f32(get(a0));
1052 let s1 = sigmoid_f32(get(a1));
1053 1.0 - (1.0 - s0) * (1.0 - s1)
1054 }
1055 OpCode::Not => {
1056 1.0 - sigmoid_f32(get(a0))
1058 }
1059 };
1060 if a2 < n_regs {
1061 self.registers[a2] = result;
1062 }
1063 }
1064 self.registers.clone()
1065 }
1066}
1067
1068pub(crate) fn sigmoid_f32(x: f32) -> f32 {
1069 1.0 / (1.0 + (-x).exp())
1070}
1071
1072#[derive(Debug, Clone)]
1078pub struct PointerGeneratorConfig {
1079 pub vocab_size: usize,
1080 pub hidden_dim: usize,
1081 pub attn_dim: usize,
1082}
1083
1084pub struct CodeSummarizer {
1086 config: PointerGeneratorConfig,
1087 encoder_embed: Vec<Vec<f32>>,
1089 enc_w_in: Vec<Vec<f32>>,
1091 enc_w_rec: Vec<Vec<f32>>,
1092 decoder_embed: Vec<Vec<f32>>,
1094 dec_w_in: Vec<Vec<f32>>,
1096 dec_w_rec: Vec<Vec<f32>>,
1097 w_enc: Vec<Vec<f32>>,
1099 w_dec: Vec<Vec<f32>>,
1100 v_attn: Vec<f32>,
1101 w_vocab: Vec<Vec<f32>>,
1103 w_copy_gate: Vec<f32>,
1105}
1106
1107impl CodeSummarizer {
1108 pub fn new(config: PointerGeneratorConfig) -> Self {
1109 let mut rng = StdRng::seed_from_u64(77);
1110 let h = config.hidden_dim;
1111 let v = config.vocab_size;
1112 let a = config.attn_dim;
1113 let scale = |d: usize| (1.0 / d as f32).sqrt();
1114 let mat = |rows: usize, cols: usize, rng: &mut StdRng| -> Vec<Vec<f32>> {
1115 let s = scale(rows);
1116 (0..rows)
1117 .map(|_| {
1118 (0..cols)
1119 .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * s)
1120 .collect()
1121 })
1122 .collect()
1123 };
1124 let vec_init = |n: usize, rng: &mut StdRng| -> Vec<f32> {
1125 (0..n)
1126 .map(|_| (rng.random::<f32>() * 2.0 - 1.0) * scale(n))
1127 .collect()
1128 };
1129 Self {
1130 encoder_embed: mat(v, h, &mut rng),
1131 enc_w_in: mat(h, h, &mut rng),
1132 enc_w_rec: mat(h, h, &mut rng),
1133 decoder_embed: mat(v, h, &mut rng),
1134 dec_w_in: mat(h, h, &mut rng),
1135 dec_w_rec: mat(h, h, &mut rng),
1136 w_enc: mat(a, h, &mut rng),
1137 w_dec: mat(a, h, &mut rng),
1138 v_attn: vec_init(a, &mut rng),
1139 w_vocab: mat(v, h, &mut rng),
1140 w_copy_gate: vec_init(h, &mut rng),
1141 config,
1142 }
1143 }
1144
1145 fn mat_vec(mat: &[Vec<f32>], v: &[f32]) -> Vec<f32> {
1146 mat.iter()
1147 .map(|row| row.iter().zip(v.iter()).map(|(a, b)| a * b).sum::<f32>())
1148 .collect()
1149 }
1150
1151 fn gru_step(x: &[f32], h: &[f32], w_in: &[Vec<f32>], w_rec: &[Vec<f32>]) -> Vec<f32> {
1152 let dim = h.len();
1153 let ax = Self::mat_vec(w_in, x);
1154 let ah = Self::mat_vec(w_rec, h);
1155 (0..dim)
1156 .map(|i| (ax.get(i).copied().unwrap_or(0.0) + ah.get(i).copied().unwrap_or(0.0)).tanh())
1157 .collect()
1158 }
1159
1160 pub fn encode_source(&self, tokens: &[usize]) -> Vec<Vec<f32>> {
1162 let h = self.config.hidden_dim;
1163 let v = self.config.vocab_size;
1164 let mut hidden = vec![0.0f32; h];
1165 let mut states = Vec::with_capacity(tokens.len());
1166 for &tid in tokens {
1167 let idx = tid % v;
1168 let embed = &self.encoder_embed[idx];
1169 hidden = Self::gru_step(embed, &hidden, &self.enc_w_in, &self.enc_w_rec);
1170 states.push(hidden.clone());
1171 }
1172 states
1173 }
1174
1175 pub fn decode_step(
1177 &self,
1178 prev_token: usize,
1179 hidden: &[f32],
1180 encoder_states: &[Vec<f32>],
1181 ) -> (Vec<f32>, Vec<f32>) {
1182 let v = self.config.vocab_size;
1183 let a = self.config.attn_dim;
1184 let idx = prev_token % v;
1185 let embed = &self.decoder_embed[idx];
1186 let new_hidden = Self::gru_step(embed, hidden, &self.dec_w_in, &self.dec_w_rec);
1187
1188 let dec_proj = Self::mat_vec(&self.w_dec, &new_hidden);
1190 let mut attn_scores: Vec<f32> = encoder_states
1191 .iter()
1192 .map(|es| {
1193 let enc_proj = Self::mat_vec(&self.w_enc, es);
1194 let combined: Vec<f32> = enc_proj
1195 .iter()
1196 .zip(dec_proj.iter())
1197 .map(|(e, d)| (e + d).tanh())
1198 .collect();
1199 self.v_attn
1200 .iter()
1201 .zip(combined.iter())
1202 .map(|(v, c)| v * c)
1203 .sum::<f32>()
1204 })
1205 .collect();
1206 let max_s = attn_scores
1208 .iter()
1209 .cloned()
1210 .fold(f32::NEG_INFINITY, f32::max);
1211 let exp: Vec<f32> = attn_scores.iter().map(|&s| (s - max_s).exp()).collect();
1212 let sum: f32 = exp.iter().sum::<f32>().max(1e-8);
1213 attn_scores = exp.iter().map(|e| e / sum).collect();
1214
1215 let mut ctx = vec![0.0f32; self.config.hidden_dim];
1217 for (weight, es) in attn_scores.iter().zip(encoder_states.iter()) {
1218 for (c, e) in ctx.iter_mut().zip(es.iter()) {
1219 *c += weight * e;
1220 }
1221 }
1222
1223 let combined: Vec<f32> = new_hidden
1225 .iter()
1226 .zip(ctx.iter())
1227 .map(|(h, c)| h + c)
1228 .collect();
1229 let logits = Self::mat_vec(&self.w_vocab, &combined);
1230 let vocab_dist = softmax_vec(&logits);
1231
1232 let src_count = encoder_states.len();
1234 let mut copy_dist = vec![0.0f32; src_count];
1235 copy_dist.copy_from_slice(&attn_scores);
1236
1237 (vocab_dist, copy_dist)
1238 }
1239
1240 pub fn copy_mechanism(
1242 &self,
1243 attn_weights: &[f32],
1244 src_tokens: &[usize],
1245 vocab_size: usize,
1246 ) -> Vec<f32> {
1247 let mut copy_vocab = vec![0.0f32; vocab_size];
1248 for (&weight, &tid) in attn_weights.iter().zip(src_tokens.iter()) {
1249 let idx = tid % vocab_size;
1250 copy_vocab[idx] += weight;
1251 }
1252 copy_vocab
1253 }
1254}
1255
1256pub(crate) fn softmax_vec(logits: &[f32]) -> Vec<f32> {
1257 let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1258 let exp: Vec<f32> = logits.iter().map(|&l| (l - max_l).exp()).collect();
1259 let sum: f32 = exp.iter().sum::<f32>().max(1e-8);
1260 exp.iter().map(|e| e / sum).collect()
1261}