Skip to main content

sh_layer3/builtin_tools/
text_tools.rs

1//! # Text Processing Tools
2//!
3//! 文本处理工具集:统计、转换、分割等。
4
5use crate::builtin_tools::BuiltinTool;
6use crate::types::{Layer3Result, ToolCategory};
7use async_trait::async_trait;
8
9// ============================================================================
10// Count Lines Tool
11// ============================================================================
12
13/// 行数统计工具
14pub struct CountLinesTool;
15
16#[async_trait]
17impl BuiltinTool for CountLinesTool {
18    fn name(&self) -> &str {
19        "count_lines"
20    }
21
22    fn description(&self) -> &str {
23        "Count lines, words, and characters in text."
24    }
25
26    fn parameters_schema(&self) -> serde_json::Value {
27        serde_json::json!({
28            "type": "object",
29            "properties": {
30                "text": {
31                    "type": "string",
32                    "description": "Text to analyze"
33                }
34            },
35            "required": ["text"]
36        })
37    }
38
39    fn category(&self) -> ToolCategory {
40        ToolCategory::TextProcessing
41    }
42
43    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
44        let text = args["text"]
45            .as_str()
46            .ok_or_else(|| anyhow::anyhow!("Missing text parameter"))?;
47
48        let lines = text.lines().count();
49        let words = text.split_whitespace().count();
50        let chars = text.chars().count();
51        let bytes = text.len();
52
53        Ok(format!(
54            "Lines: {}\nWords: {}\nCharacters: {}\nBytes: {}",
55            lines, words, chars, bytes
56        ))
57    }
58}
59
60// ============================================================================
61// Word Frequency Tool
62// ============================================================================
63
64/// 词频统计工具
65pub struct WordFrequencyTool;
66
67#[async_trait]
68impl BuiltinTool for WordFrequencyTool {
69    fn name(&self) -> &str {
70        "word_frequency"
71    }
72
73    fn description(&self) -> &str {
74        "Count word frequency in text."
75    }
76
77    fn parameters_schema(&self) -> serde_json::Value {
78        serde_json::json!({
79            "type": "object",
80            "properties": {
81                "text": {
82                    "type": "string",
83                    "description": "Text to analyze"
84                },
85                "top": {
86                    "type": "integer",
87                    "description": "Number of top words to return (default: 10)"
88                },
89                "case_sensitive": {
90                    "type": "boolean",
91                    "description": "Case sensitive counting (default: false)"
92                }
93            },
94            "required": ["text"]
95        })
96    }
97
98    fn category(&self) -> ToolCategory {
99        ToolCategory::TextProcessing
100    }
101
102    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
103        let text = args["text"]
104            .as_str()
105            .ok_or_else(|| anyhow::anyhow!("Missing text parameter"))?;
106
107        let top = args["top"].as_u64().unwrap_or(10) as usize;
108        let case_sensitive = args["case_sensitive"].as_bool().unwrap_or(false);
109
110        use std::collections::HashMap;
111        let mut freq: HashMap<String, usize> = HashMap::new();
112
113        for word in text.split_whitespace() {
114            let word = if case_sensitive {
115                word.to_string()
116            } else {
117                word.to_lowercase()
118            };
119            // Remove punctuation
120            let word: String = word.chars().filter(|c| c.is_alphanumeric()).collect();
121            if !word.is_empty() {
122                *freq.entry(word).or_insert(0) += 1;
123            }
124        }
125
126        let mut freq_vec: Vec<_> = freq.into_iter().collect();
127        freq_vec.sort_by_key(|b| std::cmp::Reverse(b.1));
128
129        let result: Vec<String> = freq_vec
130            .iter()
131            .take(top)
132            .map(|(word, count)| format!("{}: {}", word, count))
133            .collect();
134
135        Ok(result.join("\n"))
136    }
137}
138
139// ============================================================================
140// Text Transform Tool
141// ============================================================================
142
143/// 文本转换工具
144pub struct TextTransformTool;
145
146#[async_trait]
147impl BuiltinTool for TextTransformTool {
148    fn name(&self) -> &str {
149        "text_transform"
150    }
151
152    fn description(&self) -> &str {
153        "Transform text case: uppercase, lowercase, title, snake_case, camelCase, etc."
154    }
155
156    fn parameters_schema(&self) -> serde_json::Value {
157        serde_json::json!({
158            "type": "object",
159            "properties": {
160                "text": {
161                    "type": "string",
162                    "description": "Text to transform"
163                },
164                "transform": {
165                    "type": "string",
166                    "enum": ["uppercase", "lowercase", "title", "capitalize", "snake_case", "camelCase", "PascalCase", "kebab-case", "reverse"],
167                    "description": "Transform to apply"
168                }
169            },
170            "required": ["text", "transform"]
171        })
172    }
173
174    fn category(&self) -> ToolCategory {
175        ToolCategory::TextProcessing
176    }
177
178    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
179        let text = args["text"]
180            .as_str()
181            .ok_or_else(|| anyhow::anyhow!("Missing text parameter"))?;
182
183        let transform = args["transform"]
184            .as_str()
185            .ok_or_else(|| anyhow::anyhow!("Missing transform parameter"))?;
186
187        let result = match transform {
188            "uppercase" => text.to_uppercase(),
189            "lowercase" => text.to_lowercase(),
190            "title" => {
191                let mut result = String::new();
192                let mut capitalize_next = true;
193                for c in text.chars() {
194                    if c.is_whitespace() {
195                        capitalize_next = true;
196                        result.push(c);
197                    } else if capitalize_next {
198                        result.push(c.to_uppercase().next().unwrap_or(c));
199                        capitalize_next = false;
200                    } else {
201                        result.push(c);
202                    }
203                }
204                result
205            }
206            "capitalize" => {
207                let mut chars = text.chars();
208                match chars.next() {
209                    Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
210                    None => String::new(),
211                }
212            }
213            "snake_case" => {
214                let mut result = String::new();
215                for (i, c) in text.chars().enumerate() {
216                    if c.is_uppercase() {
217                        if i > 0 {
218                            result.push('_');
219                        }
220                        result.push(c.to_lowercase().next().unwrap_or(c));
221                    } else if c == ' ' || c == '-' {
222                        result.push('_');
223                    } else {
224                        result.push(c);
225                    }
226                }
227                result.to_lowercase()
228            }
229            "camelCase" => {
230                let mut result = String::new();
231                let mut capitalize_next = false;
232                for c in text.chars() {
233                    if c == '_' || c == ' ' || c == '-' {
234                        capitalize_next = true;
235                    } else if capitalize_next {
236                        result.push(c.to_uppercase().next().unwrap_or(c));
237                        capitalize_next = false;
238                    } else {
239                        result.push(c);
240                    }
241                }
242                result
243            }
244            "PascalCase" => {
245                let camel = {
246                    let mut result = String::new();
247                    let mut capitalize_next = true;
248                    for c in text.chars() {
249                        if c == '_' || c == ' ' || c == '-' {
250                            capitalize_next = true;
251                        } else if capitalize_next {
252                            result.push(c.to_uppercase().next().unwrap_or(c));
253                            capitalize_next = false;
254                        } else {
255                            result.push(c);
256                        }
257                    }
258                    result
259                };
260                camel
261            }
262            "kebab-case" => {
263                let mut result = String::new();
264                for (i, c) in text.chars().enumerate() {
265                    if c.is_uppercase() {
266                        if i > 0 {
267                            result.push('-');
268                        }
269                        result.push(c.to_lowercase().next().unwrap_or(c));
270                    } else if c == ' ' || c == '_' {
271                        result.push('-');
272                    } else {
273                        result.push(c);
274                    }
275                }
276                result.to_lowercase()
277            }
278            "reverse" => text.chars().rev().collect(),
279            _ => return Err(anyhow::anyhow!("Unknown transform: {}", transform)),
280        };
281
282        Ok(result)
283    }
284}
285
286// ============================================================================
287// Text Split Tool
288// ============================================================================
289
290/// 文本分割工具
291pub struct TextSplitTool;
292
293#[async_trait]
294impl BuiltinTool for TextSplitTool {
295    fn name(&self) -> &str {
296        "text_split"
297    }
298
299    fn description(&self) -> &str {
300        "Split text by delimiter, line count, or character count."
301    }
302
303    fn parameters_schema(&self) -> serde_json::Value {
304        serde_json::json!({
305            "type": "object",
306            "properties": {
307                "text": {
308                    "type": "string",
309                    "description": "Text to split"
310                },
311                "method": {
312                    "type": "string",
313                    "enum": ["delimiter", "lines", "chars"],
314                    "description": "Split method"
315                },
316                "delimiter": {
317                    "type": "string",
318                    "description": "Delimiter for 'delimiter' method (default: whitespace)"
319                },
320                "count": {
321                    "type": "integer",
322                    "description": "Chunk count/size for 'lines' or 'chars' method"
323                }
324            },
325            "required": ["text", "method"]
326        })
327    }
328
329    fn category(&self) -> ToolCategory {
330        ToolCategory::TextProcessing
331    }
332
333    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
334        let text = args["text"]
335            .as_str()
336            .ok_or_else(|| anyhow::anyhow!("Missing text parameter"))?;
337
338        let method = args["method"]
339            .as_str()
340            .ok_or_else(|| anyhow::anyhow!("Missing method parameter"))?;
341
342        let result = match method {
343            "delimiter" => {
344                let delimiter = args["delimiter"].as_str().unwrap_or(" ");
345                let parts: Vec<&str> = text.split(delimiter).collect();
346                parts
347                    .iter()
348                    .enumerate()
349                    .map(|(i, part)| format!("{}: {}", i + 1, part))
350                    .collect::<Vec<_>>()
351                    .join("\n")
352            }
353            "lines" => {
354                let count = args["count"].as_u64().unwrap_or(10) as usize;
355                let lines: Vec<&str> = text.lines().collect();
356                let chunk_size = lines.len().div_ceil(count);
357                lines
358                    .chunks(chunk_size.max(1))
359                    .enumerate()
360                    .map(|(i, chunk)| format!("--- Chunk {} ---\n{}", i + 1, chunk.join("\n")))
361                    .collect::<Vec<_>>()
362                    .join("\n\n")
363            }
364            "chars" => {
365                let count = args["count"].as_u64().unwrap_or(100) as usize;
366                text.as_bytes()
367                    .chunks(count.max(1))
368                    .enumerate()
369                    .map(|(i, chunk)| {
370                        format!(
371                            "--- Chunk {} ({} chars) ---\n{}",
372                            i + 1,
373                            chunk.len(),
374                            String::from_utf8_lossy(chunk)
375                        )
376                    })
377                    .collect::<Vec<_>>()
378                    .join("\n\n")
379            }
380            _ => return Err(anyhow::anyhow!("Unknown method: {}", method)),
381        };
382
383        Ok(result)
384    }
385}
386
387// ============================================================================
388// Regex Match Tool
389// ============================================================================
390
391/// 正则匹配工具
392pub struct RegexMatchTool;
393
394#[async_trait]
395impl BuiltinTool for RegexMatchTool {
396    fn name(&self) -> &str {
397        "regex_match"
398    }
399
400    fn description(&self) -> &str {
401        "Match a regex pattern against text and return matches."
402    }
403
404    fn parameters_schema(&self) -> serde_json::Value {
405        serde_json::json!({
406            "type": "object",
407            "properties": {
408                "text": {
409                    "type": "string",
410                    "description": "Text to search"
411                },
412                "pattern": {
413                    "type": "string",
414                    "description": "Regex pattern"
415                },
416                "group": {
417                    "type": "integer",
418                    "description": "Capture group to return (default: 0 for full match)"
419                }
420            },
421            "required": ["text", "pattern"]
422        })
423    }
424
425    fn category(&self) -> ToolCategory {
426        ToolCategory::TextProcessing
427    }
428
429    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
430        let text = args["text"]
431            .as_str()
432            .ok_or_else(|| anyhow::anyhow!("Missing text parameter"))?;
433
434        let pattern = args["pattern"]
435            .as_str()
436            .ok_or_else(|| anyhow::anyhow!("Missing pattern parameter"))?;
437
438        let group = args["group"].as_u64().unwrap_or(0) as usize;
439
440        let re = regex::Regex::new(pattern).map_err(|e| anyhow::anyhow!("Invalid regex: {}", e))?;
441
442        let matches: Vec<String> = re
443            .captures_iter(text)
444            .filter_map(|cap| cap.get(group).map(|m| m.as_str().to_string()))
445            .collect();
446
447        if matches.is_empty() {
448            Ok("No matches found".to_string())
449        } else {
450            Ok(format!(
451                "Found {} matches:\n{}",
452                matches.len(),
453                matches.join("\n")
454            ))
455        }
456    }
457}
458
459// ============================================================================
460// Text Diff Tool
461// ============================================================================
462
463/// 文本差异工具
464pub struct TextDiffTool;
465
466#[async_trait]
467impl BuiltinTool for TextDiffTool {
468    fn name(&self) -> &str {
469        "text_diff"
470    }
471
472    fn description(&self) -> &str {
473        "Compare two texts and show differences."
474    }
475
476    fn parameters_schema(&self) -> serde_json::Value {
477        serde_json::json!({
478            "type": "object",
479            "properties": {
480                "text1": {
481                    "type": "string",
482                    "description": "First text"
483                },
484                "text2": {
485                    "type": "string",
486                    "description": "Second text"
487                }
488            },
489            "required": ["text1", "text2"]
490        })
491    }
492
493    fn category(&self) -> ToolCategory {
494        ToolCategory::TextProcessing
495    }
496
497    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
498        let text1 = args["text1"]
499            .as_str()
500            .ok_or_else(|| anyhow::anyhow!("Missing text1 parameter"))?;
501
502        let text2 = args["text2"]
503            .as_str()
504            .ok_or_else(|| anyhow::anyhow!("Missing text2 parameter"))?;
505
506        let lines1: Vec<&str> = text1.lines().collect();
507        let lines2: Vec<&str> = text2.lines().collect();
508
509        // Simple line-by-line diff
510        let mut result = Vec::new();
511        let max_lines = lines1.len().max(lines2.len());
512
513        for i in 0..max_lines {
514            let line1 = lines1.get(i).copied().unwrap_or("");
515            let line2 = lines2.get(i).copied().unwrap_or("");
516
517            if line1 != line2 {
518                if i < lines1.len() {
519                    result.push(format!("- {}: {}", i + 1, line1));
520                }
521                if i < lines2.len() {
522                    result.push(format!("+ {}: {}", i + 1, line2));
523                }
524            }
525        }
526
527        if result.is_empty() {
528            Ok("No differences found".to_string())
529        } else {
530            Ok(format!("Differences:\n{}", result.join("\n")))
531        }
532    }
533}
534
535// ============================================================================
536// Sort Lines Tool
537// ============================================================================
538
539/// 行排序工具
540pub struct SortLinesTool;
541
542#[async_trait]
543impl BuiltinTool for SortLinesTool {
544    fn name(&self) -> &str {
545        "sort_lines"
546    }
547
548    fn description(&self) -> &str {
549        "Sort lines of text alphabetically or numerically."
550    }
551
552    fn parameters_schema(&self) -> serde_json::Value {
553        serde_json::json!({
554            "type": "object",
555            "properties": {
556                "text": {
557                    "type": "string",
558                    "description": "Text to sort"
559                },
560                "reverse": {
561                    "type": "boolean",
562                    "description": "Sort in reverse order (default: false)"
563                },
564                "numeric": {
565                    "type": "boolean",
566                    "description": "Sort numerically (default: false)"
567                },
568                "unique": {
569                    "type": "boolean",
570                    "description": "Remove duplicate lines (default: false)"
571                }
572            },
573            "required": ["text"]
574        })
575    }
576
577    fn category(&self) -> ToolCategory {
578        ToolCategory::TextProcessing
579    }
580
581    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
582        let text = args["text"]
583            .as_str()
584            .ok_or_else(|| anyhow::anyhow!("Missing text parameter"))?;
585
586        let reverse = args["reverse"].as_bool().unwrap_or(false);
587        let numeric = args["numeric"].as_bool().unwrap_or(false);
588        let unique = args["unique"].as_bool().unwrap_or(false);
589
590        let mut lines: Vec<&str> = text.lines().collect();
591
592        if numeric {
593            lines.sort_by(|a, b| {
594                let a_num = a.trim().parse::<f64>().unwrap_or(f64::NEG_INFINITY);
595                let b_num = b.trim().parse::<f64>().unwrap_or(f64::NEG_INFINITY);
596                a_num
597                    .partial_cmp(&b_num)
598                    .unwrap_or(std::cmp::Ordering::Equal)
599            });
600        } else {
601            lines.sort();
602        }
603
604        if reverse {
605            lines.reverse();
606        }
607
608        if unique {
609            lines.dedup();
610        }
611
612        Ok(lines.join("\n"))
613    }
614}
615
616// ============================================================================
617// Tests
618// ============================================================================
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623    use serde_json::json;
624
625    #[tokio::test]
626    async fn test_count_lines() {
627        let tool = CountLinesTool;
628        let result = tool
629            .execute(json!({"text": "Hello\nWorld\nTest"}))
630            .await
631            .unwrap();
632        assert!(result.contains("Lines: 3"));
633        assert!(result.contains("Words: 3"));
634    }
635
636    #[tokio::test]
637    async fn test_word_frequency() {
638        let tool = WordFrequencyTool;
639        let result = tool
640            .execute(json!({"text": "hello world hello test hello", "top": 2}))
641            .await
642            .unwrap();
643        assert!(result.contains("hello: 3"));
644    }
645
646    #[tokio::test]
647    async fn test_text_transform_uppercase() {
648        let tool = TextTransformTool;
649        let result = tool
650            .execute(json!({"text": "hello", "transform": "uppercase"}))
651            .await
652            .unwrap();
653        assert_eq!(result, "HELLO");
654    }
655
656    #[tokio::test]
657    async fn test_text_transform_snake_case() {
658        let tool = TextTransformTool;
659        let result = tool
660            .execute(json!({"text": "HelloWorld", "transform": "snake_case"}))
661            .await
662            .unwrap();
663        assert_eq!(result, "hello_world");
664    }
665
666    #[tokio::test]
667    async fn test_regex_match() {
668        let tool = RegexMatchTool;
669        let result = tool
670            .execute(json!({"text": "hello 123 world 456", "pattern": r"\d+"}))
671            .await
672            .unwrap();
673        assert!(result.contains("123"));
674        assert!(result.contains("456"));
675    }
676
677    #[tokio::test]
678    async fn test_sort_lines() {
679        let tool = SortLinesTool;
680        let result = tool
681            .execute(json!({"text": "zebra\napple\nbanana", "unique": false}))
682            .await
683            .unwrap();
684        assert!(result.starts_with("apple"));
685        assert!(result.ends_with("zebra"));
686    }
687}