Skip to main content

spikard_cli/codegen/formatters/
python.rs

1//! Python-specific code formatter
2//!
3//! Implements the `Formatter` trait for Python code generation, ensuring output
4//! adheres to PEP 8 style guidelines, integrates with standard tools (ruff, mypy),
5//! and follows spikard's async-friendly patterns.
6//!
7//! # Features
8//!
9//! - **Headers**: Shebang, ruff directives, module docstrings
10//! - **Imports**: Grouped and sorted (future, stdlib, third-party, local)
11//! - **Docstrings**: Triple-quoted with proper escaping (`NumPy` style)
12//! - **Spacing**: PEP 8 compliant (2 blank lines between top-level definitions)
13
14use super::{Formatter, HeaderMetadata, Import, Section};
15use std::collections::BTreeMap;
16
17/// Python code formatter implementing language-specific conventions
18///
19/// Formats generated Python code to comply with:
20/// - PEP 8 style guide
21/// - ruff linting rules
22/// - mypy type checking
23/// - `NumPy` docstring conventions
24///
25/// # Example
26///
27/// ```
28/// use spikard_cli::codegen::formatters::{Formatter, PythonFormatter, HeaderMetadata, Import};
29///
30/// let formatter = PythonFormatter::new();
31/// let metadata = HeaderMetadata {
32///     auto_generated: true,
33///     schema_file: Some("schema.graphql".to_string()),
34///     generator_version: Some("0.6.2".to_string()),
35/// };
36///
37/// let header = formatter.format_header(&metadata);
38/// assert!(header.contains("#!/usr/bin/env python3"));
39/// assert!(header.contains("# ruff: noqa"));
40/// ```
41#[derive(Debug, Clone)]
42pub struct PythonFormatter;
43
44impl PythonFormatter {
45    /// Create a new Python code formatter
46    #[must_use]
47    pub const fn new() -> Self {
48        Self
49    }
50}
51
52impl Default for PythonFormatter {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl Formatter for PythonFormatter {
59    fn format_header(&self, metadata: &HeaderMetadata) -> String {
60        let mut header = String::new();
61
62        // Shebang for executable Python scripts
63        header.push_str("#!/usr/bin/env python3\n");
64
65        // Ruff directives to suppress common auto-gen warnings
66        // EXE001: shebang, I001: unsorted imports (we handle sorting)
67        header.push_str("# ruff: noqa: EXE001, I001\n");
68
69        if metadata.auto_generated {
70            header.push_str("# DO NOT EDIT - Auto-generated by Spikard CLI\n");
71            if let Some(schema_file) = &metadata.schema_file {
72                header.push_str(&format!("# Schema: {schema_file}\n"));
73            }
74            if let Some(version) = &metadata.generator_version {
75                header.push_str(&format!("# Generator: Spikard {version}\n"));
76            }
77        }
78
79        // Standard module docstring
80        header.push_str("\n\"\"\"GraphQL types generated from schema.\"\"\"\n");
81
82        header
83    }
84
85    fn format_imports(&self, imports: &[Import]) -> String {
86        if imports.is_empty() {
87            return String::new();
88        }
89
90        // Categorize imports
91        let mut future_imports = Vec::new();
92        let mut stdlib_imports = BTreeMap::new();
93        let mut third_party_imports = BTreeMap::new();
94        let mut local_imports = BTreeMap::new();
95
96        // Known Python standard library modules (common ones)
97        let stdlib_modules = [
98            "abc",
99            "argparse",
100            "array",
101            "asyncio",
102            "bisect",
103            "builtins",
104            "calendar",
105            "cmath",
106            "cmd",
107            "code",
108            "codeop",
109            "collections",
110            "colorsys",
111            "compileall",
112            "concurrent",
113            "configparser",
114            "contextlib",
115            "contextvars",
116            "copy",
117            "copyreg",
118            "cprofile",
119            "csv",
120            "ctypes",
121            "curses",
122            "dataclasses",
123            "datetime",
124            "dbm",
125            "decimal",
126            "difflib",
127            "dis",
128            "doctest",
129            "email",
130            "encodings",
131            "enum",
132            "errno",
133            "faulthandler",
134            "fcntl",
135            "filecmp",
136            "fileinput",
137            "fnmatch",
138            "fractions",
139            "ftplib",
140            "functools",
141            "gc",
142            "getopt",
143            "getpass",
144            "gettext",
145            "glob",
146            "grp",
147            "gzip",
148            "hashlib",
149            "heapq",
150            "hmac",
151            "html",
152            "http",
153            "idlelib",
154            "imaplib",
155            "imghdr",
156            "imp",
157            "importlib",
158            "inspect",
159            "io",
160            "ipaddress",
161            "itertools",
162            "json",
163            "keyword",
164            "lib2to3",
165            "linecache",
166            "locale",
167            "logging",
168            "lzma",
169            "mailbox",
170            "mailcap",
171            "marshal",
172            "math",
173            "mimetypes",
174            "mmap",
175            "modulefinder",
176            "msilib",
177            "msvcrt",
178            "multiprocessing",
179            "netrc",
180            "nis",
181            "nntplib",
182            "numbers",
183            "operator",
184            "optparse",
185            "os",
186            "ossaudiodev",
187            "parser",
188            "pathlib",
189            "pdb",
190            "pickle",
191            "pickletools",
192            "pipes",
193            "pkgutil",
194            "platform",
195            "plistlib",
196            "poplib",
197            "posix",
198            "posixpath",
199            "pprint",
200            "profile",
201            "pstats",
202            "pty",
203            "pwd",
204            "py_compile",
205            "pyclbr",
206            "pydoc",
207            "queue",
208            "quopri",
209            "random",
210            "re",
211            "readline",
212            "reprlib",
213            "resource",
214            "rlcompleter",
215            "runpy",
216            "sched",
217            "secrets",
218            "select",
219            "selectors",
220            "shelve",
221            "shlex",
222            "shutil",
223            "signal",
224            "site",
225            "smtpd",
226            "smtplib",
227            "sndhdr",
228            "socket",
229            "socketserver",
230            "spwd",
231            "sqlite3",
232            "ssl",
233            "stat",
234            "statistics",
235            "string",
236            "stringprep",
237            "struct",
238            "subprocess",
239            "sunau",
240            "symbol",
241            "symtable",
242            "sys",
243            "sysconfig",
244            "syslog",
245            "tabnanny",
246            "tarfile",
247            "telnetlib",
248            "tempfile",
249            "termios",
250            "test",
251            "textwrap",
252            "threading",
253            "time",
254            "timeit",
255            "tkinter",
256            "token",
257            "tokenize",
258            "trace",
259            "traceback",
260            "tracemalloc",
261            "tty",
262            "turtle",
263            "types",
264            "typing",
265            "typing_extensions",
266            "unicodedata",
267            "unittest",
268            "urllib",
269            "uu",
270            "uuid",
271            "venv",
272            "warnings",
273            "wave",
274            "weakref",
275            "webbrowser",
276            "winreg",
277            "winsound",
278            "wsgiref",
279            "xdrlib",
280            "xml",
281            "xmlrpc",
282            "zipapp",
283            "zipfile",
284            "zipimport",
285            "zlib",
286        ];
287
288        for import in imports {
289            let module_name = import.module.split('.').next().unwrap_or(&import.module);
290
291            if module_name == "__future__" {
292                future_imports.push(import.clone());
293            } else if stdlib_modules.contains(&module_name) {
294                stdlib_imports
295                    .entry(import.module.clone())
296                    .or_insert_with(Vec::new)
297                    .push(import.clone());
298            } else if module_name.starts_with('.') {
299                local_imports
300                    .entry(import.module.clone())
301                    .or_insert_with(Vec::new)
302                    .push(import.clone());
303            } else {
304                third_party_imports
305                    .entry(import.module.clone())
306                    .or_insert_with(Vec::new)
307                    .push(import.clone());
308            }
309        }
310
311        let mut output = String::new();
312
313        // Format future imports (always first)
314        if !future_imports.is_empty() {
315            for import in &future_imports {
316                output.push_str(&format_python_import(import));
317                output.push('\n');
318            }
319            output.push('\n');
320        }
321
322        // Format stdlib imports
323        if !stdlib_imports.is_empty() {
324            for imports_vec in stdlib_imports.values() {
325                for import in imports_vec {
326                    output.push_str(&format_python_import(import));
327                    output.push('\n');
328                }
329            }
330            output.push('\n');
331        }
332
333        // Format third-party imports
334        if !third_party_imports.is_empty() {
335            for imports_vec in third_party_imports.values() {
336                for import in imports_vec {
337                    output.push_str(&format_python_import(import));
338                    output.push('\n');
339                }
340            }
341            output.push('\n');
342        }
343
344        // Format local imports
345        if !local_imports.is_empty() {
346            for imports_vec in local_imports.values() {
347                for import in imports_vec {
348                    output.push_str(&format_python_import(import));
349                    output.push('\n');
350                }
351            }
352        }
353
354        // Remove trailing blank lines
355        output.trim_end().to_string()
356    }
357
358    fn format_docstring(&self, content: &str) -> String {
359        // Escape triple quotes in the content to avoid breaking the docstring
360        let escaped = content.replace("\"\"\"", r#"\"\"\""#);
361
362        // Format as triple-quoted string
363        format!("\"\"\"{escaped}\"\"\"")
364    }
365
366    fn merge_sections(&self, sections: &[Section]) -> String {
367        let mut header = String::new();
368        let mut imports = String::new();
369        let mut body = String::new();
370
371        // Parse sections into their components
372        for section in sections {
373            match section {
374                Section::Header(content) => {
375                    if header.is_empty() {
376                        header = content.clone();
377                    }
378                    // Skip duplicate headers
379                }
380                Section::Imports(content) => {
381                    if imports.is_empty() {
382                        imports = content.clone();
383                    }
384                }
385                Section::Body(content) => {
386                    if body.is_empty() {
387                        body = content.clone();
388                    }
389                }
390            }
391        }
392
393        let mut output = String::new();
394
395        // Add header (trim trailing whitespace)
396        if !header.is_empty() {
397            output.push_str(header.trim_end());
398            output.push_str("\n\n");
399        }
400
401        // Add imports (trim trailing whitespace)
402        if !imports.is_empty() {
403            output.push_str(imports.trim_end());
404            output.push_str("\n\n");
405        }
406
407        // Add body (trim trailing whitespace)
408        if !body.is_empty() {
409            output.push_str(body.trim_end());
410            output.push('\n');
411        }
412
413        // Ensure single trailing newline
414        output.trim_end().to_string() + "\n"
415    }
416}
417
418/// Format a single Python import statement
419fn format_python_import(import: &Import) -> String {
420    if import.items.is_empty() {
421        // Simple module import: import module
422        format!("import {}", import.module)
423    } else {
424        // Specific items: from module import item1, item2
425        let items = import.items.join(", ");
426        format!("from {} import {}", import.module, items)
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_format_header_with_metadata() {
436        let formatter = PythonFormatter::new();
437        let metadata = HeaderMetadata {
438            auto_generated: true,
439            schema_file: Some("schema.graphql".to_string()),
440            generator_version: Some("0.6.2".to_string()),
441        };
442
443        let header = formatter.format_header(&metadata);
444        assert!(header.contains("#!/usr/bin/env python3"));
445        assert!(header.contains("# ruff: noqa: EXE001, I001"));
446        assert!(header.contains("# DO NOT EDIT - Auto-generated by Spikard CLI"));
447        assert!(header.contains("# Schema: schema.graphql"));
448        assert!(header.contains("# Generator: Spikard 0.6.2"));
449        assert!(header.contains("\"\"\"GraphQL types generated from schema.\"\"\""));
450    }
451
452    #[test]
453    fn test_format_header_without_metadata() {
454        let formatter = PythonFormatter::new();
455        let metadata = HeaderMetadata {
456            auto_generated: false,
457            schema_file: None,
458            generator_version: None,
459        };
460
461        let header = formatter.format_header(&metadata);
462        assert!(header.contains("#!/usr/bin/env python3"));
463        assert!(!header.contains("# DO NOT EDIT"));
464    }
465
466    #[test]
467    fn test_format_imports_empty() {
468        let formatter = PythonFormatter::new();
469        let imports = [];
470        let output = formatter.format_imports(&imports);
471        assert!(output.is_empty());
472    }
473
474    #[test]
475    fn test_format_imports_grouped_and_sorted() {
476        let formatter = PythonFormatter::new();
477        let imports = vec![
478            Import::with_items("typing", vec!["List", "Dict"]),
479            Import::with_items("__future__", vec!["annotations"]),
480            Import::new("msgspec"),
481            Import::new("graphql"),
482        ];
483
484        let output = formatter.format_imports(&imports);
485        let lines: Vec<&str> = output.lines().collect();
486
487        // Should be: future, then stdlib (typing), then third-party (graphql, msgspec)
488        assert_eq!(lines[0], "from __future__ import annotations");
489        assert!(lines.contains(&"from typing import List, Dict"));
490        assert!(lines.contains(&"import graphql"));
491        assert!(lines.contains(&"import msgspec"));
492    }
493
494    #[test]
495    fn test_format_imports_simple_module() {
496        let formatter = PythonFormatter::new();
497        let imports = vec![Import::new("asyncio")];
498
499        let output = formatter.format_imports(&imports);
500        assert_eq!(output.trim(), "import asyncio");
501    }
502
503    #[test]
504    fn test_format_imports_with_items() {
505        let formatter = PythonFormatter::new();
506        let imports = vec![Import::with_items("typing", vec!["Optional", "Union"])];
507
508        let output = formatter.format_imports(&imports);
509        assert_eq!(output.trim(), "from typing import Optional, Union");
510    }
511
512    #[test]
513    fn test_format_docstring() {
514        let formatter = PythonFormatter::new();
515        let content = "This is a test docstring";
516        let output = formatter.format_docstring(content);
517        assert_eq!(output, "\"\"\"This is a test docstring\"\"\"");
518    }
519
520    #[test]
521    fn test_format_docstring_with_quotes() {
522        let formatter = PythonFormatter::new();
523        let content = r#"This says "hello""""#;
524        let output = formatter.format_docstring(content);
525        assert!(output.contains(r#"\"\"\""#));
526    }
527
528    #[test]
529    fn test_merge_sections_in_order() {
530        let formatter = PythonFormatter::new();
531        let sections = vec![
532            Section::Header("#!/usr/bin/env python3\n# Auto-gen".to_string()),
533            Section::Imports("from typing import List".to_string()),
534            Section::Body("class MyType:\n    pass".to_string()),
535        ];
536
537        let output = formatter.merge_sections(&sections);
538        let lines: Vec<&str> = output.lines().collect();
539
540        // Should have header, blank line, imports, blank line, body
541        assert!(lines[0].contains("#!/usr/bin/env python3"));
542        assert!(lines.iter().any(|l| l.contains("from typing import List")));
543        assert!(lines.iter().any(|l| l.contains("class MyType")));
544    }
545
546    #[test]
547    fn test_merge_sections_duplicate_headers() {
548        let formatter = PythonFormatter::new();
549        let sections = vec![
550            Section::Header("#!/usr/bin/env python3".to_string()),
551            Section::Header("#!/usr/bin/env python3".to_string()),
552            Section::Body("class MyType:\n    pass".to_string()),
553        ];
554
555        let output = formatter.merge_sections(&sections);
556        let header_count = output.matches("#!/usr/bin/env python3").count();
557        assert_eq!(header_count, 1, "Should not duplicate headers");
558    }
559
560    #[test]
561    fn test_merge_sections_trailing_newline() {
562        let formatter = PythonFormatter::new();
563        let sections = vec![Section::Body("class MyType:\n    pass".to_string())];
564
565        let output = formatter.merge_sections(&sections);
566        assert!(output.ends_with('\n'));
567        assert!(!output.ends_with("\n\n"));
568    }
569}