1use super::{Formatter, HeaderMetadata, Import, Section};
15use std::collections::BTreeMap;
16
17#[derive(Debug, Clone)]
42pub struct PythonFormatter;
43
44impl PythonFormatter {
45 #[must_use]
47 pub const fn new() -> Self {
48 Self
49 }
50}
51
52impl Default for PythonFormatter {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl Formatter for PythonFormatter {
59 fn format_header(&self, metadata: &HeaderMetadata) -> String {
60 let mut header = String::new();
61
62 header.push_str("#!/usr/bin/env python3\n");
64
65 header.push_str("# ruff: noqa: EXE001, I001\n");
68
69 if metadata.auto_generated {
70 header.push_str("# DO NOT EDIT - Auto-generated by Spikard CLI\n");
71 if let Some(schema_file) = &metadata.schema_file {
72 header.push_str(&format!("# Schema: {schema_file}\n"));
73 }
74 if let Some(version) = &metadata.generator_version {
75 header.push_str(&format!("# Generator: Spikard {version}\n"));
76 }
77 }
78
79 header.push_str("\n\"\"\"GraphQL types generated from schema.\"\"\"\n");
81
82 header
83 }
84
85 fn format_imports(&self, imports: &[Import]) -> String {
86 if imports.is_empty() {
87 return String::new();
88 }
89
90 let mut future_imports = Vec::new();
92 let mut stdlib_imports = BTreeMap::new();
93 let mut third_party_imports = BTreeMap::new();
94 let mut local_imports = BTreeMap::new();
95
96 let stdlib_modules = [
98 "abc",
99 "argparse",
100 "array",
101 "asyncio",
102 "bisect",
103 "builtins",
104 "calendar",
105 "cmath",
106 "cmd",
107 "code",
108 "codeop",
109 "collections",
110 "colorsys",
111 "compileall",
112 "concurrent",
113 "configparser",
114 "contextlib",
115 "contextvars",
116 "copy",
117 "copyreg",
118 "cprofile",
119 "csv",
120 "ctypes",
121 "curses",
122 "dataclasses",
123 "datetime",
124 "dbm",
125 "decimal",
126 "difflib",
127 "dis",
128 "doctest",
129 "email",
130 "encodings",
131 "enum",
132 "errno",
133 "faulthandler",
134 "fcntl",
135 "filecmp",
136 "fileinput",
137 "fnmatch",
138 "fractions",
139 "ftplib",
140 "functools",
141 "gc",
142 "getopt",
143 "getpass",
144 "gettext",
145 "glob",
146 "grp",
147 "gzip",
148 "hashlib",
149 "heapq",
150 "hmac",
151 "html",
152 "http",
153 "idlelib",
154 "imaplib",
155 "imghdr",
156 "imp",
157 "importlib",
158 "inspect",
159 "io",
160 "ipaddress",
161 "itertools",
162 "json",
163 "keyword",
164 "lib2to3",
165 "linecache",
166 "locale",
167 "logging",
168 "lzma",
169 "mailbox",
170 "mailcap",
171 "marshal",
172 "math",
173 "mimetypes",
174 "mmap",
175 "modulefinder",
176 "msilib",
177 "msvcrt",
178 "multiprocessing",
179 "netrc",
180 "nis",
181 "nntplib",
182 "numbers",
183 "operator",
184 "optparse",
185 "os",
186 "ossaudiodev",
187 "parser",
188 "pathlib",
189 "pdb",
190 "pickle",
191 "pickletools",
192 "pipes",
193 "pkgutil",
194 "platform",
195 "plistlib",
196 "poplib",
197 "posix",
198 "posixpath",
199 "pprint",
200 "profile",
201 "pstats",
202 "pty",
203 "pwd",
204 "py_compile",
205 "pyclbr",
206 "pydoc",
207 "queue",
208 "quopri",
209 "random",
210 "re",
211 "readline",
212 "reprlib",
213 "resource",
214 "rlcompleter",
215 "runpy",
216 "sched",
217 "secrets",
218 "select",
219 "selectors",
220 "shelve",
221 "shlex",
222 "shutil",
223 "signal",
224 "site",
225 "smtpd",
226 "smtplib",
227 "sndhdr",
228 "socket",
229 "socketserver",
230 "spwd",
231 "sqlite3",
232 "ssl",
233 "stat",
234 "statistics",
235 "string",
236 "stringprep",
237 "struct",
238 "subprocess",
239 "sunau",
240 "symbol",
241 "symtable",
242 "sys",
243 "sysconfig",
244 "syslog",
245 "tabnanny",
246 "tarfile",
247 "telnetlib",
248 "tempfile",
249 "termios",
250 "test",
251 "textwrap",
252 "threading",
253 "time",
254 "timeit",
255 "tkinter",
256 "token",
257 "tokenize",
258 "trace",
259 "traceback",
260 "tracemalloc",
261 "tty",
262 "turtle",
263 "types",
264 "typing",
265 "typing_extensions",
266 "unicodedata",
267 "unittest",
268 "urllib",
269 "uu",
270 "uuid",
271 "venv",
272 "warnings",
273 "wave",
274 "weakref",
275 "webbrowser",
276 "winreg",
277 "winsound",
278 "wsgiref",
279 "xdrlib",
280 "xml",
281 "xmlrpc",
282 "zipapp",
283 "zipfile",
284 "zipimport",
285 "zlib",
286 ];
287
288 for import in imports {
289 let module_name = import.module.split('.').next().unwrap_or(&import.module);
290
291 if module_name == "__future__" {
292 future_imports.push(import.clone());
293 } else if stdlib_modules.contains(&module_name) {
294 stdlib_imports
295 .entry(import.module.clone())
296 .or_insert_with(Vec::new)
297 .push(import.clone());
298 } else if module_name.starts_with('.') {
299 local_imports
300 .entry(import.module.clone())
301 .or_insert_with(Vec::new)
302 .push(import.clone());
303 } else {
304 third_party_imports
305 .entry(import.module.clone())
306 .or_insert_with(Vec::new)
307 .push(import.clone());
308 }
309 }
310
311 let mut output = String::new();
312
313 if !future_imports.is_empty() {
315 for import in &future_imports {
316 output.push_str(&format_python_import(import));
317 output.push('\n');
318 }
319 output.push('\n');
320 }
321
322 if !stdlib_imports.is_empty() {
324 for imports_vec in stdlib_imports.values() {
325 for import in imports_vec {
326 output.push_str(&format_python_import(import));
327 output.push('\n');
328 }
329 }
330 output.push('\n');
331 }
332
333 if !third_party_imports.is_empty() {
335 for imports_vec in third_party_imports.values() {
336 for import in imports_vec {
337 output.push_str(&format_python_import(import));
338 output.push('\n');
339 }
340 }
341 output.push('\n');
342 }
343
344 if !local_imports.is_empty() {
346 for imports_vec in local_imports.values() {
347 for import in imports_vec {
348 output.push_str(&format_python_import(import));
349 output.push('\n');
350 }
351 }
352 }
353
354 output.trim_end().to_string()
356 }
357
358 fn format_docstring(&self, content: &str) -> String {
359 let escaped = content.replace("\"\"\"", r#"\"\"\""#);
361
362 format!("\"\"\"{escaped}\"\"\"")
364 }
365
366 fn merge_sections(&self, sections: &[Section]) -> String {
367 let mut header = String::new();
368 let mut imports = String::new();
369 let mut body = String::new();
370
371 for section in sections {
373 match section {
374 Section::Header(content) => {
375 if header.is_empty() {
376 header = content.clone();
377 }
378 }
380 Section::Imports(content) => {
381 if imports.is_empty() {
382 imports = content.clone();
383 }
384 }
385 Section::Body(content) => {
386 if body.is_empty() {
387 body = content.clone();
388 }
389 }
390 }
391 }
392
393 let mut output = String::new();
394
395 if !header.is_empty() {
397 output.push_str(header.trim_end());
398 output.push_str("\n\n");
399 }
400
401 if !imports.is_empty() {
403 output.push_str(imports.trim_end());
404 output.push_str("\n\n");
405 }
406
407 if !body.is_empty() {
409 output.push_str(body.trim_end());
410 output.push('\n');
411 }
412
413 output.trim_end().to_string() + "\n"
415 }
416}
417
418fn format_python_import(import: &Import) -> String {
420 if import.items.is_empty() {
421 format!("import {}", import.module)
423 } else {
424 let items = import.items.join(", ");
426 format!("from {} import {}", import.module, items)
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_format_header_with_metadata() {
436 let formatter = PythonFormatter::new();
437 let metadata = HeaderMetadata {
438 auto_generated: true,
439 schema_file: Some("schema.graphql".to_string()),
440 generator_version: Some("0.6.2".to_string()),
441 };
442
443 let header = formatter.format_header(&metadata);
444 assert!(header.contains("#!/usr/bin/env python3"));
445 assert!(header.contains("# ruff: noqa: EXE001, I001"));
446 assert!(header.contains("# DO NOT EDIT - Auto-generated by Spikard CLI"));
447 assert!(header.contains("# Schema: schema.graphql"));
448 assert!(header.contains("# Generator: Spikard 0.6.2"));
449 assert!(header.contains("\"\"\"GraphQL types generated from schema.\"\"\""));
450 }
451
452 #[test]
453 fn test_format_header_without_metadata() {
454 let formatter = PythonFormatter::new();
455 let metadata = HeaderMetadata {
456 auto_generated: false,
457 schema_file: None,
458 generator_version: None,
459 };
460
461 let header = formatter.format_header(&metadata);
462 assert!(header.contains("#!/usr/bin/env python3"));
463 assert!(!header.contains("# DO NOT EDIT"));
464 }
465
466 #[test]
467 fn test_format_imports_empty() {
468 let formatter = PythonFormatter::new();
469 let imports = [];
470 let output = formatter.format_imports(&imports);
471 assert!(output.is_empty());
472 }
473
474 #[test]
475 fn test_format_imports_grouped_and_sorted() {
476 let formatter = PythonFormatter::new();
477 let imports = vec![
478 Import::with_items("typing", vec!["List", "Dict"]),
479 Import::with_items("__future__", vec!["annotations"]),
480 Import::new("msgspec"),
481 Import::new("graphql"),
482 ];
483
484 let output = formatter.format_imports(&imports);
485 let lines: Vec<&str> = output.lines().collect();
486
487 assert_eq!(lines[0], "from __future__ import annotations");
489 assert!(lines.contains(&"from typing import List, Dict"));
490 assert!(lines.contains(&"import graphql"));
491 assert!(lines.contains(&"import msgspec"));
492 }
493
494 #[test]
495 fn test_format_imports_simple_module() {
496 let formatter = PythonFormatter::new();
497 let imports = vec![Import::new("asyncio")];
498
499 let output = formatter.format_imports(&imports);
500 assert_eq!(output.trim(), "import asyncio");
501 }
502
503 #[test]
504 fn test_format_imports_with_items() {
505 let formatter = PythonFormatter::new();
506 let imports = vec![Import::with_items("typing", vec!["Optional", "Union"])];
507
508 let output = formatter.format_imports(&imports);
509 assert_eq!(output.trim(), "from typing import Optional, Union");
510 }
511
512 #[test]
513 fn test_format_docstring() {
514 let formatter = PythonFormatter::new();
515 let content = "This is a test docstring";
516 let output = formatter.format_docstring(content);
517 assert_eq!(output, "\"\"\"This is a test docstring\"\"\"");
518 }
519
520 #[test]
521 fn test_format_docstring_with_quotes() {
522 let formatter = PythonFormatter::new();
523 let content = r#"This says "hello""""#;
524 let output = formatter.format_docstring(content);
525 assert!(output.contains(r#"\"\"\""#));
526 }
527
528 #[test]
529 fn test_merge_sections_in_order() {
530 let formatter = PythonFormatter::new();
531 let sections = vec![
532 Section::Header("#!/usr/bin/env python3\n# Auto-gen".to_string()),
533 Section::Imports("from typing import List".to_string()),
534 Section::Body("class MyType:\n pass".to_string()),
535 ];
536
537 let output = formatter.merge_sections(§ions);
538 let lines: Vec<&str> = output.lines().collect();
539
540 assert!(lines[0].contains("#!/usr/bin/env python3"));
542 assert!(lines.iter().any(|l| l.contains("from typing import List")));
543 assert!(lines.iter().any(|l| l.contains("class MyType")));
544 }
545
546 #[test]
547 fn test_merge_sections_duplicate_headers() {
548 let formatter = PythonFormatter::new();
549 let sections = vec![
550 Section::Header("#!/usr/bin/env python3".to_string()),
551 Section::Header("#!/usr/bin/env python3".to_string()),
552 Section::Body("class MyType:\n pass".to_string()),
553 ];
554
555 let output = formatter.merge_sections(§ions);
556 let header_count = output.matches("#!/usr/bin/env python3").count();
557 assert_eq!(header_count, 1, "Should not duplicate headers");
558 }
559
560 #[test]
561 fn test_merge_sections_trailing_newline() {
562 let formatter = PythonFormatter::new();
563 let sections = vec![Section::Body("class MyType:\n pass".to_string())];
564
565 let output = formatter.merge_sections(§ions);
566 assert!(output.ends_with('\n'));
567 assert!(!output.ends_with("\n\n"));
568 }
569}