1#![allow(clippy::too_many_arguments)]
2
3use lsp_types::{Diagnostic, DiagnosticSeverity, Position, Range};
4use rustpython_parser::Parse;
5use rustpython_parser::ast::{
6 self, Arguments, Constant, Expr, ExprBinOp, ExprCall, ExprCompare, ExprList, ExprTuple,
7 Identifier, Operator, Stmt, Suite,
8};
9use rustpython_parser::text_size::{TextRange, TextSize};
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::fs;
13use std::path::{Path, PathBuf};
14mod infer;
15pub mod op_groups;
16use crate::infer::{
17 ShapeOrExpr, Transpose, infer_broadcastable_poswise, infer_conv, infer_creation_size,
18 infer_flatten, infer_index, infer_matmul_shapes, infer_noop, infer_permute, infer_range_size,
19 infer_repeat, infer_repeat_interleave, infer_squeeze, infer_to, infer_unsqueeze,
20 infer_view_like, shape_dims_equal,
21};
22pub use crate::op_groups::AGGR_ALIASES;
23use crate::op_groups::{BroadcastOp, Imports, TORCH_DTYPES, TorchOp, collect_imports};
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct Shape {
27 pub dtype: Option<String>,
28 pub dims: Vec<String>,
29}
30
31impl Shape {
32 pub fn render(&self) -> String {
33 format!("[{}]", self.dims.join(" "))
34 }
35 pub fn dim_string(&self) -> String {
36 self.dims.join(" ")
37 }
38}
39
40#[derive(Debug, Clone)]
41pub struct HoverInfo {
42 pub shape: Option<Shape>,
43}
44
45#[derive(Debug, Default)]
46pub struct Analysis {
47 pub diagnostics: Vec<Diagnostic>,
48 pub hover_entries: Vec<(Range, HoverInfo)>,
49}
50
51impl Analysis {
52 pub fn hover(&self, position: Position) -> Option<&HoverInfo> {
53 if let Some((_, info)) = self
55 .hover_entries
56 .iter()
57 .filter(|(range, _)| within(range, &position))
58 .min_by_key(|(range, _)| range_span(range, &position))
59 {
60 return Some(info);
61 }
62
63 if let Some((_, info)) = self
65 .hover_entries
66 .iter()
67 .filter(|(range, _)| range.start.line == position.line)
68 .min_by_key(|(range, _)| {
69 let a = range.start.character as i64;
70 let b = position.character as i64;
71 (a - b).abs()
72 })
73 {
74 return Some(info);
75 }
76 None
77 }
78}
79
80fn within(range: &Range, pos: &Position) -> bool {
81 (pos.line > range.start.line
82 || (pos.line == range.start.line && pos.character >= range.start.character))
83 && (pos.line < range.end.line
84 || (pos.line == range.end.line && pos.character <= range.end.character))
85}
86
87fn range_span(range: &Range, _pos: &Position) -> u32 {
88 let line_span = (range.end.line as i64 - range.start.line as i64).unsigned_abs() as u32;
90 let char_span = if range.start.line == range.end.line {
91 (range.end.character as i64 - range.start.character as i64).unsigned_abs() as u32
92 } else {
93 1000
94 };
95 line_span * 1000 + char_span
96}
97
98#[derive(Debug, Clone, Default)]
99struct VarState {
100 annotated: Option<Shape>,
101 inferred: Option<Shape>,
102 class_ref: Option<ClassRef>,
103}
104
105#[derive(Debug, Clone, Default)]
109enum ReturnValue {
110 Single(Shape),
111 Tuple(Vec<Option<Shape>>),
112 #[default]
113 None,
114}
115
116impl ReturnValue {
117 fn from_shape(shape: Option<Shape>) -> Self {
118 shape.map(Self::Single).unwrap_or(Self::None)
119 }
120
121 fn from_tuple(shapes: Vec<Option<Shape>>) -> Self {
122 Self::Tuple(shapes)
123 }
124
125 fn first(&self) -> Option<&Shape> {
126 match self {
127 Self::Single(shape) => Some(shape),
128 Self::Tuple(tuple) => tuple.first().and_then(|x| x.as_ref()),
129 Self::None => None,
130 }
131 }
132
133 fn tuple(&self) -> Option<&[Option<Shape>]> {
134 match self {
135 Self::Tuple(tuple) => Some(tuple.as_ref()),
136 _ => None,
137 }
138 }
139
140 fn is_some(&self) -> bool {
141 match self {
142 Self::Single(_) => true,
143 Self::Tuple(tup) => !tup.is_empty(),
144 Self::None => false,
145 }
146 }
147}
148
149#[derive(Clone)]
150struct FunctionInfo {
151 args: Box<Arguments>,
152 body: Vec<Stmt>,
153 returns: Option<Box<Expr>>,
154}
155
156type FuncMap = HashMap<Identifier, FunctionInfo>;
157
158#[derive(Debug, Clone)]
159struct ClassRef {
160 name: Identifier,
161 module: Option<String>,
162}
163
164#[derive(Clone)]
165struct ClassInfo {
166 is_torch_module: bool,
167 forward: Option<FunctionInfo>,
168 methods: HashMap<Identifier, FunctionInfo>,
169}
170
171type ClassMap = HashMap<Identifier, ClassInfo>;
172
173#[derive(Clone)]
174pub(crate) struct CachedModule {
175 path: PathBuf,
176 source: String,
177 func_map: FuncMap,
178 imports: Imports,
179 class_map: ClassMap,
180}
181
182pub(crate) struct ModuleCache {
183 modules: HashMap<String, CachedModule>,
184 project_root: Option<PathBuf>,
185}
186
187impl ModuleCache {
188 fn new(current_file: &Path) -> Self {
189 let project_root = find_project_root(current_file);
190 Self {
191 modules: HashMap::new(),
192 project_root,
193 }
194 }
195
196 fn project_root(&self) -> Option<&Path> {
197 self.project_root.as_deref()
198 }
199
200 fn get_module(&mut self, module_name: &str, current_file: &Path) -> Option<CachedModule> {
202 if let Some(cached) = self.modules.get(module_name) {
203 return Some(cached.clone());
204 }
205 let path = resolve_module_path(module_name, current_file, self.project_root.as_deref())?;
206 let source = fs::read_to_string(&path).ok()?;
207 let normalized = normalize_return_annotations(&source);
208 let source = normalized.into_owned();
209 let module = Suite::parse(&source, module_name).ok()?;
210 let imports = collect_imports(&module, Some(&path), self.project_root());
211 let mut func_map: FuncMap = HashMap::new();
212 collect_function_defs(&module, &mut func_map);
213 let class_map = collect_class_defs(&module, &imports);
214 let cached = CachedModule {
215 path,
216 source,
217 func_map,
218 imports,
219 class_map,
220 };
221 self.modules.insert(module_name.to_string(), cached.clone());
222 Some(cached)
223 }
224}
225
226fn normalize_return_annotations<'a>(source: &'a str) -> Cow<'a, str> {
228 let mut changed = false;
229 let mut out = String::with_capacity(source.len());
230 let mut first = true;
231 for line in source.lines() {
232 if !first {
233 out.push('\n');
234 }
235 first = false;
236 if let (Some(arrow_idx), Some(colon_idx)) = (line.find("->"), line.rfind(':'))
237 && arrow_idx < colon_idx
238 {
239 let ann = &line[arrow_idx + 2..colon_idx];
240 let ann_trim = ann.trim();
241 if ann_trim.contains(',') && !ann_trim.starts_with('(') {
242 changed = true;
243 out.push_str(&line[..arrow_idx + 2]);
244 out.push(' ');
245 out.push('(');
246 out.push_str(ann_trim);
247 out.push(')');
248 out.push_str(&line[colon_idx..]);
249 continue;
250 }
251 }
252 out.push_str(line);
253 }
254 if changed {
255 Cow::Owned(out)
256 } else {
257 Cow::Borrowed(source)
258 }
259}
260
261fn find_project_root(start: &Path) -> Option<PathBuf> {
262 let mut dir = start.parent();
263 let home = std::env::var_os("HOME").map(PathBuf::from);
264 while let Some(current) = dir {
265 if current.join(".git").exists()
266 || current.join("pyproject.toml").exists()
267 || current.join("setup.py").exists()
268 || current.join("setup.cfg").exists()
269 {
270 return Some(current.to_path_buf());
271 }
272 if Some(current.to_path_buf()) == home {
273 break;
274 }
275 dir = current.parent();
276 }
277 None
278}
279
280fn with_class_info<R, F>(
282 class_ref: &ClassRef,
283 source: &str,
284 func_map: &FuncMap,
285 imports: &Imports,
286 class_map: &ClassMap,
287 module_cache: &mut Option<&mut ModuleCache>,
288 module_path: Option<&Path>,
289 f: F,
290) -> Option<R>
291where
292 F: FnOnce(
293 &ClassInfo,
294 &str,
295 &FuncMap,
296 &Imports,
297 &ClassMap,
298 Option<&Path>,
299 &mut Option<&mut ModuleCache>,
300 ) -> R,
301{
302 match &class_ref.module {
303 Some(module_name) => {
304 let module = {
305 let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
306 else {
307 return None;
308 };
309 cache.get_module(module_name, cur_path)?
310 };
311 let class_info = module.class_map.get(&class_ref.name)?;
312 Some(f(
313 class_info,
314 &module.source,
315 &module.func_map,
316 &module.imports,
317 &module.class_map,
318 Some(module.path.as_path()),
319 module_cache,
320 ))
321 }
322 None => {
323 let class_info = class_map.get(&class_ref.name)?;
324 Some(f(
325 class_info,
326 source,
327 func_map,
328 imports,
329 class_map,
330 module_path,
331 module_cache,
332 ))
333 }
334 }
335}
336
337fn resolve_module_path(
338 module: &str,
339 current_file: &Path,
340 project_root: Option<&Path>,
341) -> Option<PathBuf> {
342 let mut search_roots: Vec<PathBuf> = Vec::new();
343
344 if let Some(parent) = current_file.parent() {
345 search_roots.push(parent.to_path_buf());
346 }
347 if let Some(prj) = project_root {
348 search_roots.push(prj.to_path_buf());
349 let src_dir = prj.join("src");
350 if src_dir.exists() {
351 search_roots.push(src_dir);
352 }
353 }
354
355 if let Ok(path_var) = std::env::var("PATH") {
357 for entry in path_var.split(':') {
358 if !entry.contains(".venv") {
359 continue;
360 }
361 let mut p = PathBuf::from(entry);
362 while let Some(parent) = p.parent() {
363 if let Some(name) = parent.file_name()
364 && name.to_string_lossy().contains(".venv")
365 {
366 let venv_dir = parent.to_path_buf();
367 if let Some(parent_parent) = venv_dir.parent() {
369 search_roots.push(parent_parent.to_path_buf());
370 }
371 let lib_dir = venv_dir.join("lib");
373 if lib_dir.exists()
374 && let Ok(entries) = fs::read_dir(&lib_dir)
375 {
376 for entry in entries.flatten() {
377 let fname = entry.file_name();
378 if fname.to_string_lossy().starts_with("python") {
379 let sp = entry.path().join("site-packages");
380 if sp.exists() {
381 search_roots.push(sp);
382 }
383 }
384 }
385 }
386 break;
387 }
388 p = parent.to_path_buf();
389 }
390 }
391 }
392
393 let parts: Vec<&str> = module.split('.').collect();
395 for root in search_roots {
396 let mut base = root.clone();
397 for part in &parts {
398 base.push(part);
399 }
400 let file_candidate = base.with_extension("py");
401 if file_candidate.exists() {
402 return Some(file_candidate);
403 }
404 let init_candidate = base.join("__init__.py");
405 if init_candidate.exists() {
406 return Some(init_candidate);
407 }
408 if let Some(root_name) = root.file_name()
411 && root_name
412 == parts
413 .first()
414 .map(std::ffi::OsStr::new)
415 .unwrap_or_else(|| std::ffi::OsStr::new(""))
416 && parts.len() > 1
417 {
418 let mut base = root.clone();
419 for part in parts.iter().skip(1) {
420 base.push(part);
421 }
422 let file_candidate = base.with_extension("py");
423 if file_candidate.exists() {
424 return Some(file_candidate);
425 }
426 let init_candidate = base.join("__init__.py");
427 if init_candidate.exists() {
428 return Some(init_candidate);
429 }
430 }
431 }
432 None
433}
434
435pub fn analyze_source(source: &str) -> Analysis {
436 analyze_source_internal(source, None, None)
437}
438
439pub fn analyze_source_at_path(source: &str, path: &Path) -> Analysis {
441 let mut cache = ModuleCache::new(path);
442 analyze_source_internal(source, Some(path), Some(&mut cache))
443}
444
445pub fn analyze_file(path: &Path) -> Analysis {
447 match fs::read_to_string(path) {
448 Ok(src) => {
449 let mut cache = ModuleCache::new(path);
450 analyze_source_internal(&src, Some(path), Some(&mut cache))
451 }
452 Err(err) => Analysis {
453 diagnostics: vec![Diagnostic {
454 range: default_range(),
455 severity: Some(DiagnosticSeverity::ERROR),
456 code: None,
457 code_description: None,
458 source: Some("shapels".into()),
459 message: format!("Failed to read file: {err}"),
460 related_information: None,
461 tags: None,
462 data: None,
463 }],
464 hover_entries: Vec::new(),
465 },
466 }
467}
468
469fn analyze_source_internal<'a>(
470 source: &'a str,
471 current_path: Option<&'a Path>,
472 mut module_cache: Option<&mut ModuleCache>,
473) -> Analysis {
474 let mut analysis = Analysis::default();
475 let normalized = normalize_return_annotations(source);
476 let source = normalized.as_ref();
477 let parse_name = current_path
478 .map(|p| p.to_string_lossy().to_string())
479 .unwrap_or_else(|| "<memory>".to_string());
480 match Suite::parse(source, &parse_name) {
481 Ok(module) => {
482 let imports = collect_imports(
483 &module,
484 current_path,
485 module_cache.as_deref().and_then(|c| c.project_root()),
486 );
487 let mut func_map: FuncMap = HashMap::new();
489 collect_function_defs(&module, &mut func_map);
490 let class_map = collect_class_defs(&module, &imports);
491
492 let empty_args = Arguments {
494 range: ast::OptionalRange::from(TextRange::new(
495 TextSize::from(0),
496 TextSize::from(0),
497 )),
498 posonlyargs: Vec::new(),
499 args: Vec::new(),
500 vararg: None,
501 kwonlyargs: Vec::new(),
502 kwarg: None,
503 };
504 let (mut top_diags, mut top_hovers, _) = simulate_function(
505 &empty_args,
506 &module,
507 source,
508 &func_map,
509 &imports,
510 &class_map,
511 &mut Vec::new(),
512 HashMap::new(),
513 true,
514 module_cache.as_deref_mut(),
515 current_path,
516 );
517 analysis.diagnostics.append(&mut top_diags);
518 analysis.hover_entries.append(&mut top_hovers);
519
520 analyze_function_bodies(
521 &module,
522 source,
523 &func_map,
524 &imports,
525 &class_map,
526 &mut analysis,
527 module_cache,
528 current_path,
529 );
530 }
531 Err(err) => {
532 analysis.diagnostics.push(Diagnostic {
533 range: default_range(),
534 severity: Some(DiagnosticSeverity::ERROR),
535 code: None,
536 code_description: None,
537 source: Some("shapels".into()),
538 message: format!("Parse error: {err}"),
539 related_information: None,
540 tags: None,
541 data: None,
542 });
543 }
544 }
545 analysis
546}
547
548fn collect_function_defs(body: &[Stmt], func_map: &mut FuncMap) {
549 for stmt in body {
550 if let Stmt::FunctionDef(func) = stmt {
551 func_map.insert(
552 func.name.clone(),
553 FunctionInfo {
554 args: func.args.clone(),
555 body: func.body.clone(),
556 returns: func.returns.clone(),
557 },
558 );
559 }
560 }
561}
562
563struct ClassDefInfo {
564 forward: Option<FunctionInfo>,
565 methods: HashMap<Identifier, FunctionInfo>,
566 base_names: Vec<Identifier>,
567 direct_torch: bool,
568}
569
570fn is_torch_nn_module_base(expr: &Expr, imports: &Imports) -> bool {
571 let Expr::Attribute(attr) = expr else {
572 return false;
573 };
574 if attr.attr.as_str() != "Module" {
575 return false;
576 }
577 match attr.value.as_ref() {
578 Expr::Attribute(nn_attr) if nn_attr.attr.as_str() == "nn" => {
579 if let Expr::Name(torch_name) = nn_attr.value.as_ref() {
580 return imports.torch_aliases.contains(&torch_name.id);
581 }
582 false
583 }
584 Expr::Name(nn_name) => imports
585 .module_aliases
586 .get(&nn_name.id)
587 .map(|module| module == "torch.nn")
588 .unwrap_or(false),
589 _ => false,
590 }
591}
592
593fn collect_class_defs(body: &[Stmt], imports: &Imports) -> ClassMap {
594 let mut defs: HashMap<Identifier, ClassDefInfo> = HashMap::new();
595 for stmt in body {
596 if let Stmt::ClassDef(class_def) = stmt {
597 let mut methods = HashMap::new();
598 let mut forward = None;
599 for stmt in class_def.body.iter() {
600 if let Stmt::FunctionDef(func) = stmt {
601 let info = FunctionInfo {
602 args: func.args.clone(),
603 body: func.body.clone(),
604 returns: func.returns.clone(),
605 };
606 if func.name.as_str() == "forward" {
607 forward = Some(info.clone());
608 }
609 methods.insert(func.name.clone(), info);
610 }
611 }
612 let base_names = class_def
613 .bases
614 .iter()
615 .filter_map(|base| match base {
616 Expr::Name(name) => Some(name.id.clone()),
617 _ => None,
618 })
619 .collect();
620 let direct_torch = class_def
621 .bases
622 .iter()
623 .any(|base| is_torch_nn_module_base(base, imports));
624 defs.insert(
625 class_def.name.clone(),
626 ClassDefInfo {
627 forward,
628 methods,
629 base_names,
630 direct_torch,
631 },
632 );
633 }
634 }
635
636 let mut is_torch: HashMap<Identifier, bool> = defs
637 .iter()
638 .map(|(name, info)| (name.clone(), info.direct_torch))
639 .collect();
640 let mut changed = true;
641 while changed {
642 changed = false;
643 for (name, info) in defs.iter() {
644 if !is_torch.get(name).copied().unwrap_or(false)
645 && info
646 .base_names
647 .iter()
648 .any(|base| is_torch.get(base).copied().unwrap_or(false))
649 {
650 is_torch.insert(name.clone(), true);
651 changed = true;
652 }
653 }
654 }
655
656 let mut class_map = HashMap::new();
657 for (name, info) in defs {
658 class_map.insert(
659 name.clone(),
660 ClassInfo {
661 is_torch_module: is_torch.get(&name).copied().unwrap_or(false),
662 forward: info.forward,
663 methods: info.methods,
664 },
665 );
666 }
667 class_map
668}
669
670fn analyze_function_bodies(
675 body: &[Stmt],
676 source: &str,
677 func_map: &FuncMap,
678 imports: &Imports,
679 class_map: &ClassMap,
680 analysis: &mut Analysis,
681 mut module_cache: Option<&mut ModuleCache>,
682 module_path: Option<&Path>,
683) {
684 for stmt in body {
685 match stmt {
686 Stmt::FunctionDef(func) => {
687 let mut func_analysis = analyze_function(
688 &func.args,
689 &func.body,
690 source,
691 func_map,
692 imports,
693 class_map,
694 &mut Vec::new(),
695 module_cache.as_deref_mut(),
696 module_path,
697 );
698 analysis.diagnostics.append(&mut func_analysis.diagnostics);
699 analysis
700 .hover_entries
701 .append(&mut func_analysis.hover_entries);
702 }
703 Stmt::ClassDef(class_def) => analyze_function_bodies(
704 &class_def.body,
705 source,
706 func_map,
707 imports,
708 class_map,
709 analysis,
710 module_cache.as_deref_mut(),
711 module_path,
712 ),
713 _ => {}
714 }
715 }
716}
717
718fn analyze_function(
719 args: &Arguments,
720 body: &[Stmt],
721 source: &str,
722 func_map: &FuncMap,
723 imports: &Imports,
724 class_map: &ClassMap,
725 call_stack: &mut Vec<Identifier>,
726 module_cache: Option<&mut ModuleCache>,
727 module_path: Option<&Path>,
728) -> Analysis {
729 let (diagnostics, hover_entries, _) = simulate_function(
730 args,
731 body,
732 source,
733 func_map,
734 imports,
735 class_map,
736 call_stack,
737 HashMap::new(),
738 true,
739 module_cache,
740 module_path,
741 );
742
743 Analysis {
744 diagnostics,
745 hover_entries,
746 }
747}
748
749fn simulate_function(
752 args: &Arguments,
753 body: &[Stmt],
754 source: &str,
755 func_map: &FuncMap,
756 imports: &Imports,
757 class_map: &ClassMap,
758 call_stack: &mut Vec<Identifier>,
759 mut initial_vars: HashMap<Identifier, VarState>,
760 record_hovers: bool,
761 module_cache: Option<&mut ModuleCache>,
762 module_path: Option<&Path>,
763) -> (Vec<Diagnostic>, Vec<(Range, HoverInfo)>, ReturnValue) {
764 let mut diagnostics = Vec::new();
765 let mut hover_entries = Vec::new();
766 let mut vars: HashMap<Identifier, VarState> = HashMap::new();
767
768 seed_args_from_annotations(
769 args,
770 source,
771 &mut vars,
772 &mut hover_entries,
773 Some(&mut initial_vars),
774 record_hovers,
775 imports,
776 class_map,
777 );
778
779 let mut return_value = ReturnValue::default();
780
781 let _ = simulate_block(
782 body,
783 &mut vars,
784 &mut diagnostics,
785 &mut hover_entries,
786 &mut return_value,
787 source,
788 func_map,
789 imports,
790 class_map,
791 call_stack,
792 record_hovers,
793 module_cache,
794 module_path,
795 false,
796 );
797
798 (diagnostics, hover_entries, return_value)
799}
800
801#[derive(Clone, Copy, Debug, PartialEq, Eq)]
802enum BlockFlow {
803 None,
804 Break,
805 Continue,
806 Return,
807}
808
809fn simulate_block(
811 body: &[Stmt],
812 vars: &mut HashMap<Identifier, VarState>,
813 diagnostics: &mut Vec<Diagnostic>,
814 hover_entries: &mut Vec<(Range, HoverInfo)>,
815 return_value: &mut ReturnValue,
816 source: &str,
817 func_map: &FuncMap,
818 imports: &Imports,
819 class_map: &ClassMap,
820 call_stack: &mut Vec<Identifier>,
821 record_hovers: bool,
822 mut module_cache: Option<&mut ModuleCache>,
823 module_path: Option<&Path>,
824 in_loop: bool,
825) -> BlockFlow {
826 for stmt in body {
827 match stmt {
828 Stmt::AnnAssign(assign) => {
829 if assignment_shape_checks(
830 &assign.target,
831 assign.value.as_deref().unwrap_or(assign.target.as_ref()),
832 vars,
833 diagnostics,
834 hover_entries,
835 record_hovers,
836 source,
837 ) {
838 continue;
839 }
840 if let Some(name) = name_from_expr(&assign.target) {
841 let ann_shape = parse_shape_annotation(&assign.annotation);
842 let range = text_range_to_lsp(expr_text_range(&assign.target), source);
843 let mut inferred = None;
844 if let Some(val) = &assign.value {
845 if assignment_shape_checks(
846 val,
847 val,
848 vars,
849 diagnostics,
850 hover_entries,
851 record_hovers,
852 source,
853 ) {
854 inferred = vars
855 .get(&name)
856 .and_then(|v| v.annotated.clone().or(v.inferred.clone()));
857 } else {
858 inferred = infer_expr_shape(
859 val,
860 vars,
861 func_map,
862 imports,
863 class_map,
864 call_stack,
865 diagnostics,
866 hover_entries,
867 record_hovers,
868 source,
869 module_cache.as_deref_mut(),
870 module_path,
871 );
872 }
873 }
874 if let (Some(ann), Some(inf)) = (ann_shape.clone(), inferred.clone())
875 && !shape_dims_equal(&ann, &inf)
876 {
877 if let Expr::Subscript(_sub) = &*assign.annotation
879 && ann.dims.len() == inf.dims.len()
880 && matches!(
881 assign.value.as_deref(),
882 Some(Expr::Name(_)) | Some(Expr::Attribute(_))
883 )
884 {
885 let mut renamed = inf.clone();
886 renamed.dims = ann.dims.clone();
887 vars.insert(
888 name.clone(),
889 VarState {
890 annotated: ann_shape.clone(),
891 inferred: Some(renamed.clone()),
892 class_ref: None,
893 },
894 );
895 if record_hovers {
896 hover_entries.push((
897 range,
898 HoverInfo {
899 shape: Some(renamed),
900 },
901 ));
902 }
903 continue;
904 }
905 diagnostics.push(Diagnostic {
906 range,
907 severity: Some(DiagnosticSeverity::ERROR),
908 code: None,
909 code_description: None,
910 source: Some("shapels".into()),
911 message: format!(
912 "Shape mismatch: annotation {} vs inferred {}",
913 ann.render(),
914 inf.render()
915 ),
916 related_information: None,
917 tags: None,
918 data: None,
919 });
920 }
921 let chosen_shape = ann_shape.clone().or(inferred.clone());
922 if let Some(shape) = chosen_shape {
923 vars.insert(
924 name.clone(),
925 VarState {
926 annotated: ann_shape.clone(),
927 inferred,
928 class_ref: None,
929 },
930 );
931 if record_hovers {
932 hover_entries.push((range, HoverInfo { shape: Some(shape) }));
933 }
934 } else if let Some(val) = &assign.value
935 && let Some(class_ref) = class_ref_from_constructor_call(
936 val,
937 class_map,
938 imports,
939 module_cache.as_deref_mut(),
940 module_path,
941 )
942 {
943 vars.insert(
944 name.clone(),
945 VarState {
946 annotated: ann_shape.clone(),
947 inferred: None,
948 class_ref: Some(class_ref),
949 },
950 );
951 }
952 }
953 }
954 Stmt::Assign(assign) => {
955 if assign.targets.len() == 1
957 && assignment_shape_checks(
958 &assign.targets[0],
959 &assign.value,
960 vars,
961 diagnostics,
962 hover_entries,
963 record_hovers,
964 source,
965 )
966 {
967 continue;
968 }
969 if assign.targets.len() == 1
970 && matches!(assign.targets[0], Expr::Tuple(_))
971 && let Expr::Tuple(target_tuple) = &assign.targets[0]
972 && let Some(tuple_shapes) = infer_tuple_elements(
973 &assign.value,
974 vars,
975 func_map,
976 imports,
977 class_map,
978 call_stack,
979 diagnostics,
980 hover_entries,
981 record_hovers,
982 source,
983 module_cache.as_deref_mut(),
984 module_path,
985 )
986 {
987 for (target_expr, shape_opt) in
988 target_tuple.elts.iter().zip(tuple_shapes.into_iter())
989 {
990 if let (Some(name), Some(shape)) = (name_from_expr(target_expr), shape_opt)
991 {
992 vars.insert(
993 name.clone(),
994 VarState {
995 annotated: None,
996 inferred: Some(shape.clone()),
997 class_ref: None,
998 },
999 );
1000 if record_hovers {
1001 let range = text_range_to_lsp(expr_text_range(target_expr), source);
1002 hover_entries.push((range, HoverInfo { shape: Some(shape) }));
1003 }
1004 }
1005 }
1006 continue;
1007 }
1008 if assign.targets.len() == 1
1009 && let Some(name) = name_from_expr(&assign.targets[0])
1010 {
1011 let range = text_range_to_lsp(expr_text_range(&assign.targets[0]), source);
1012 let shape = infer_expr_shape(
1013 &assign.value,
1014 vars,
1015 func_map,
1016 imports,
1017 class_map,
1018 call_stack,
1019 diagnostics,
1020 hover_entries,
1021 record_hovers,
1022 source,
1023 module_cache.as_deref_mut(),
1024 module_path,
1025 );
1026 if let Some(shape) = shape {
1027 vars.insert(
1028 name.clone(),
1029 VarState {
1030 annotated: None,
1031 inferred: Some(shape.clone()),
1032 class_ref: None,
1033 },
1034 );
1035 if record_hovers {
1036 hover_entries.push((range, HoverInfo { shape: Some(shape) }));
1037 }
1038 } else if let Some(class_ref) = class_ref_from_constructor_call(
1039 &assign.value,
1040 class_map,
1041 imports,
1042 module_cache.as_deref_mut(),
1043 module_path,
1044 ) {
1045 vars.insert(
1046 name.clone(),
1047 VarState {
1048 annotated: None,
1049 inferred: None,
1050 class_ref: Some(class_ref),
1051 },
1052 );
1053 }
1054 }
1055 }
1056 Stmt::For(for_stmt) => {
1057 let flow = simulate_block(
1058 &for_stmt.body,
1059 vars,
1060 diagnostics,
1061 hover_entries,
1062 return_value,
1063 source,
1064 func_map,
1065 imports,
1066 class_map,
1067 call_stack,
1068 record_hovers,
1069 module_cache.as_deref_mut(),
1070 module_path,
1071 true,
1072 );
1073 if flow == BlockFlow::Return {
1074 return BlockFlow::Return;
1075 }
1076 if flow != BlockFlow::Break {
1077 let else_flow = simulate_block(
1078 &for_stmt.orelse,
1079 vars,
1080 diagnostics,
1081 hover_entries,
1082 return_value,
1083 source,
1084 func_map,
1085 imports,
1086 class_map,
1087 call_stack,
1088 record_hovers,
1089 module_cache.as_deref_mut(),
1090 module_path,
1091 true,
1092 );
1093 if else_flow == BlockFlow::Return {
1094 return BlockFlow::Return;
1095 }
1096 }
1097 }
1098 Stmt::While(while_stmt) => {
1099 let flow = simulate_block(
1100 &while_stmt.body,
1101 vars,
1102 diagnostics,
1103 hover_entries,
1104 return_value,
1105 source,
1106 func_map,
1107 imports,
1108 class_map,
1109 call_stack,
1110 record_hovers,
1111 module_cache.as_deref_mut(),
1112 module_path,
1113 true,
1114 );
1115 if flow == BlockFlow::Return {
1116 return BlockFlow::Return;
1117 }
1118 if flow != BlockFlow::Break {
1119 let else_flow = simulate_block(
1120 &while_stmt.orelse,
1121 vars,
1122 diagnostics,
1123 hover_entries,
1124 return_value,
1125 source,
1126 func_map,
1127 imports,
1128 class_map,
1129 call_stack,
1130 record_hovers,
1131 module_cache.as_deref_mut(),
1132 module_path,
1133 true,
1134 );
1135 if else_flow == BlockFlow::Return {
1136 return BlockFlow::Return;
1137 }
1138 }
1139 }
1140 Stmt::If(if_stmt) => {
1141 let mut body_vars = vars.clone();
1142 let body_flow = simulate_block(
1143 &if_stmt.body,
1144 &mut body_vars,
1145 diagnostics,
1146 hover_entries,
1147 return_value,
1148 source,
1149 func_map,
1150 imports,
1151 class_map,
1152 call_stack,
1153 record_hovers,
1154 module_cache.as_deref_mut(),
1155 module_path,
1156 in_loop,
1157 );
1158 let mut else_vars = vars.clone();
1159 let else_flow = simulate_block(
1160 &if_stmt.orelse,
1161 &mut else_vars,
1162 diagnostics,
1163 hover_entries,
1164 return_value,
1165 source,
1166 func_map,
1167 imports,
1168 class_map,
1169 call_stack,
1170 record_hovers,
1171 module_cache.as_deref_mut(),
1172 module_path,
1173 in_loop,
1174 );
1175 if body_flow == else_flow
1176 && matches!(
1177 body_flow,
1178 BlockFlow::Break | BlockFlow::Continue | BlockFlow::Return
1179 )
1180 {
1181 return body_flow;
1182 }
1183 }
1184 Stmt::Break(_) => {
1185 if in_loop {
1186 return BlockFlow::Break;
1187 }
1188 }
1189 Stmt::Continue(_) => {
1190 if in_loop {
1191 return BlockFlow::Continue;
1192 }
1193 }
1194 Stmt::Return(ret) => {
1195 if let Some(val) = &ret.value {
1196 let new_value = match val.as_ref() {
1197 Expr::Tuple(tuple) => {
1198 let tuple_shapes: Vec<Option<Shape>> = tuple
1199 .elts
1200 .iter()
1201 .map(|elt| {
1202 infer_expr_shape(
1203 elt,
1204 vars,
1205 func_map,
1206 imports,
1207 class_map,
1208 call_stack,
1209 diagnostics,
1210 hover_entries,
1211 record_hovers,
1212 source,
1213 module_cache.as_deref_mut(),
1214 module_path,
1215 )
1216 })
1217 .collect();
1218 ReturnValue::from_tuple(tuple_shapes)
1219 }
1220 _ => ReturnValue::from_shape(infer_expr_shape(
1221 val,
1222 vars,
1223 func_map,
1224 imports,
1225 class_map,
1226 call_stack,
1227 diagnostics,
1228 hover_entries,
1229 record_hovers,
1230 source,
1231 module_cache.as_deref_mut(),
1232 module_path,
1233 )),
1234 };
1235 if record_hovers && let Some(shape) = new_value.first() {
1236 let range = text_range_to_lsp(expr_text_range(val), source);
1237 hover_entries.push((
1238 range,
1239 HoverInfo {
1240 shape: Some(shape.clone()),
1241 },
1242 ));
1243 }
1244 if new_value.is_some() {
1245 *return_value = new_value;
1246 }
1247 }
1248 return BlockFlow::Return;
1249 }
1250 _ => {}
1251 }
1252 }
1253 BlockFlow::None
1254}
1255
1256fn infer_expr_shape(
1261 expr: &Expr,
1262 vars: &HashMap<Identifier, VarState>,
1263 func_map: &FuncMap,
1264 imports: &Imports,
1265 class_map: &ClassMap,
1266 call_stack: &mut Vec<Identifier>,
1267 diagnostics: &mut Vec<Diagnostic>,
1268 hover_entries: &mut Vec<(Range, HoverInfo)>,
1269 record_hovers: bool,
1270 source: &str,
1271 mut module_cache: Option<&mut ModuleCache>,
1272 module_path: Option<&Path>,
1273) -> Option<Shape> {
1274 match expr {
1275 Expr::BinOp(ExprBinOp {
1276 left,
1277 op,
1278 right,
1279 range: expr_range,
1280 }) => match op {
1281 Operator::Mult | Operator::Add | Operator::Sub | Operator::Div => {
1282 infer_broadcastable_poswise(
1283 &ShapeOrExpr::Expr(left),
1284 right,
1285 vars,
1286 func_map,
1287 imports,
1288 class_map,
1289 call_stack,
1290 diagnostics,
1291 hover_entries,
1292 record_hovers,
1293 source,
1294 *expr_range,
1295 module_cache.as_deref_mut(),
1296 module_path,
1297 BroadcastOp::Arithmetic,
1298 )
1299 }
1300 Operator::BitAnd
1301 | Operator::BitXor
1302 | Operator::BitOr
1303 | Operator::LShift
1304 | Operator::RShift => infer_broadcastable_poswise(
1305 &ShapeOrExpr::Expr(left),
1306 right,
1307 vars,
1308 func_map,
1309 imports,
1310 class_map,
1311 call_stack,
1312 diagnostics,
1313 hover_entries,
1314 record_hovers,
1315 source,
1316 *expr_range,
1317 module_cache.as_deref_mut(),
1318 module_path,
1319 BroadcastOp::Bitwise,
1320 ),
1321 Operator::MatMult => infer_matmul_shapes(
1322 left,
1323 right,
1324 vars,
1325 func_map,
1326 imports,
1327 class_map,
1328 call_stack,
1329 diagnostics,
1330 hover_entries,
1331 record_hovers,
1332 source,
1333 *expr_range,
1334 module_cache.as_deref_mut(),
1335 module_path,
1336 ),
1337 _ => None,
1338 },
1339 Expr::Compare(ExprCompare {
1340 left: init_left,
1341 comparators,
1343 range: expr_range,
1344 ..
1345 }) => {
1346 let mut iter = comparators.iter();
1348 let first_right = iter.next()?;
1349 let init = infer_broadcastable_poswise(
1350 &ShapeOrExpr::Expr(init_left.as_ref()),
1351 first_right,
1352 vars,
1353 func_map,
1354 imports,
1355 class_map,
1356 call_stack,
1357 diagnostics,
1358 hover_entries,
1359 record_hovers,
1360 source,
1361 *expr_range,
1362 module_cache.as_deref_mut(),
1363 module_path,
1364 BroadcastOp::Eq,
1365 );
1366
1367 iter.fold(init, |left, right| {
1369 infer_broadcastable_poswise(
1370 &ShapeOrExpr::Shape(left.as_ref()),
1371 right,
1372 vars,
1373 func_map,
1374 imports,
1375 class_map,
1376 call_stack,
1377 diagnostics,
1378 hover_entries,
1379 record_hovers,
1380 source,
1381 *expr_range,
1382 module_cache.as_deref_mut(),
1383 module_path,
1384 BroadcastOp::Eq,
1385 )
1386 })
1387 }
1388 Expr::Call(call) => {
1389 if let Expr::Name(func_name) = call.func.as_ref() {
1390 if let Some(class_ref) = vars.get(&func_name.id).and_then(|v| v.class_ref.as_ref())
1391 && let Some(ret) = infer_class_call_return(
1392 call,
1393 class_ref,
1394 vars,
1395 func_map,
1396 imports,
1397 class_map,
1398 call_stack,
1399 diagnostics,
1400 hover_entries,
1401 record_hovers,
1402 source,
1403 module_cache.as_deref_mut(),
1404 module_path,
1405 )
1406 {
1407 return ret.first().cloned();
1408 }
1409 let torchop_shape = torch_op_to_shape(
1410 vars,
1411 func_map,
1412 imports,
1413 class_map,
1414 call_stack,
1415 diagnostics,
1416 hover_entries,
1417 record_hovers,
1418 source,
1419 &mut module_cache,
1420 module_path,
1421 call,
1422 func_name.id.as_str(),
1423 TorchOp::as_call(&func_name.id, imports),
1424 TorchOpKind::Function,
1425 None,
1426 );
1427 if torchop_shape.is_some() {
1428 return torchop_shape;
1429 }
1430
1431 if let Some((module_name, original)) = imports.from_imports.get(&func_name.id)
1432 && let (Some(cache), Some(cur_path)) =
1433 (module_cache.as_deref_mut(), module_path)
1434 && let Some(module) = cache.get_module(module_name, cur_path)
1435 && let Some(callee_info) = module.func_map.get(original)
1436 && let Some(ret) = infer_call_return_from_info(
1437 call,
1438 &func_name.id,
1439 callee_info,
1440 &module.source,
1441 &module.func_map,
1442 &module.imports,
1443 &module.class_map,
1444 vars,
1445 func_map,
1446 imports,
1447 class_map,
1448 call_stack,
1449 diagnostics,
1450 hover_entries,
1451 record_hovers,
1452 source,
1453 module_cache.as_deref_mut(),
1454 Some(module.path.as_path()),
1455 0,
1456 false,
1457 )
1458 {
1459 return ret.first().cloned();
1460 }
1461 if let Some(callee_info) = func_map.get(&func_name.id)
1462 && let Some(ret) = infer_call_return_from_info(
1463 call,
1464 &func_name.id,
1465 callee_info,
1466 source,
1467 func_map,
1468 imports,
1469 class_map,
1470 vars,
1471 func_map,
1472 imports,
1473 class_map,
1474 call_stack,
1475 diagnostics,
1476 hover_entries,
1477 record_hovers,
1478 source,
1479 module_cache.as_deref_mut(),
1480 module_path,
1481 0,
1482 false,
1483 )
1484 {
1485 return ret.first().cloned();
1486 }
1487 }
1488 if let Expr::Attribute(attr) = call.func.as_ref() {
1490 let attr_name: &str = attr.attr.as_ref();
1491 if let Expr::Name(base_name) = attr.value.as_ref()
1492 && let Some(class_ref) =
1493 vars.get(&base_name.id).and_then(|v| v.class_ref.as_ref())
1494 && let Some(ret) = infer_class_method_call_return(
1495 call,
1496 class_ref,
1497 &attr.attr,
1498 vars,
1499 func_map,
1500 imports,
1501 class_map,
1502 call_stack,
1503 diagnostics,
1504 hover_entries,
1505 record_hovers,
1506 source,
1507 module_cache.as_deref_mut(),
1508 module_path,
1509 )
1510 {
1511 return ret.first().cloned();
1512 }
1513 if let Expr::Name(module_ident) = attr.value.as_ref()
1514 && let Some(module_name) = imports.module_aliases.get(&module_ident.id)
1515 && module_name != "torch"
1516 && let (Some(cache), Some(cur_path)) =
1517 (module_cache.as_deref_mut(), module_path)
1518 && let Some(module) = cache.get_module(module_name, cur_path)
1519 && let Some(callee_info) = module.func_map.get(&attr.attr)
1520 && let Some(ret) = infer_call_return_from_info(
1521 call,
1522 &attr.attr,
1523 callee_info,
1524 &module.source,
1525 &module.func_map,
1526 &module.imports,
1527 &module.class_map,
1528 vars,
1529 func_map,
1530 imports,
1531 class_map,
1532 call_stack,
1533 diagnostics,
1534 hover_entries,
1535 record_hovers,
1536 source,
1537 module_cache.as_deref_mut(),
1538 Some(module.path.as_path()),
1539 0,
1540 false,
1541 )
1542 {
1543 return ret.first().cloned();
1544 }
1545 let op_kind = function_or_method(&attr.value, imports);
1547 let aliased_shape = torch_op_to_shape(
1548 vars,
1549 func_map,
1550 imports,
1551 class_map,
1552 call_stack,
1553 diagnostics,
1554 hover_entries,
1555 record_hovers,
1556 source,
1557 &mut module_cache,
1558 module_path,
1559 call,
1560 attr_name,
1561 TorchOp::from_attr(attr_name),
1562 op_kind,
1563 Some(attr),
1564 );
1565 if aliased_shape.is_some() {
1566 return aliased_shape;
1567 }
1568 }
1569 None
1570 }
1571 Expr::Subscript(sub) => infer_index(
1572 &sub.value,
1573 &sub.slice,
1574 vars,
1575 func_map,
1576 imports,
1577 class_map,
1578 call_stack,
1579 diagnostics,
1580 hover_entries,
1581 record_hovers,
1582 source,
1583 module_cache.as_deref_mut(),
1584 module_path,
1585 ),
1586 Expr::Attribute(attr) => {
1588 let attr_name: &str = attr.attr.as_ref();
1589 if attr_name == "T" {
1590 let base_hint =
1591 lookup_shape(&attr.value, vars, hover_entries, record_hovers, source).or_else(
1592 || {
1593 infer_expr_shape(
1594 &attr.value,
1595 vars,
1596 func_map,
1597 imports,
1598 class_map,
1599 call_stack,
1600 diagnostics,
1601 hover_entries,
1602 false,
1603 source,
1604 module_cache.as_deref_mut(),
1605 module_path,
1606 )
1607 },
1608 );
1609 let order_args: Vec<&Expr> = Vec::new();
1610 let res = infer_permute(
1611 &attr.value,
1612 &order_args,
1613 Transpose::T,
1614 base_hint,
1615 vars,
1616 diagnostics,
1617 hover_entries,
1618 record_hovers,
1619 source,
1620 attr.range,
1621 );
1622 if record_hovers && let Some(s) = res.clone() {
1623 let range = text_range_to_lsp(attr.range, source);
1624 hover_entries.push((range, HoverInfo { shape: Some(s) }));
1625 }
1626 return res;
1627 } else if attr_name == "shape" || attr_name == "dtype" {
1628 return lookup_shape(&attr.value, vars, hover_entries, record_hovers, source)
1629 .or_else(|| {
1630 infer_expr_shape(
1631 &attr.value,
1632 vars,
1633 func_map,
1634 imports,
1635 class_map,
1636 call_stack,
1637 diagnostics,
1638 hover_entries,
1639 false,
1640 source,
1641 module_cache,
1642 module_path,
1643 )
1644 });
1645 }
1646 None
1647 }
1648 Expr::Name(expr_name) => {
1649 let shape = vars
1650 .get(&expr_name.id)
1651 .and_then(|v| v.annotated.clone().or_else(|| v.inferred.clone()));
1652 if record_hovers && let Some(s) = shape.clone() {
1653 let range = text_range_to_lsp(expr_text_range(expr), source);
1654 hover_entries.push((range, HoverInfo { shape: Some(s) }));
1655 }
1656 shape
1657 }
1658 _ => None,
1659 }
1660}
1661
1662fn torch_op_to_shape(
1668 vars: &HashMap<Identifier, VarState>,
1669 func_map: &HashMap<Identifier, FunctionInfo>,
1670 imports: &Imports,
1671 class_map: &HashMap<Identifier, ClassInfo>,
1672 call_stack: &mut Vec<Identifier>,
1673 diagnostics: &mut Vec<Diagnostic>,
1674 hover_entries: &mut Vec<(Range, HoverInfo)>,
1675 record_hovers: bool,
1676 source: &str,
1677 module_cache: &mut Option<&mut ModuleCache>,
1678 module_path: Option<&Path>,
1679 call: &ExprCall,
1680 attr_name: &str,
1681 torch_op: TorchOp,
1682 op_kind: TorchOpKind,
1683 maybe_attr: Option<&ast::ExprAttribute>,
1684) -> Option<Shape> {
1685 use TorchOpKind::*;
1686
1687 let (may_arg0, may_arg1, offset) = match op_kind {
1688 Function => (call.args.first(), call.args.get(1), 1),
1689 Method => (maybe_attr.map(|x| x.value.as_ref()), call.args.first(), 0),
1690 };
1691 match (torch_op, may_arg0, may_arg1, &op_kind) {
1692 (TorchOp::MatMul, Some(arg0), Some(arg1), _) => infer_matmul_shapes(
1693 arg0,
1694 arg1,
1695 vars,
1696 func_map,
1697 imports,
1698 class_map,
1699 call_stack,
1700 diagnostics,
1701 hover_entries,
1702 record_hovers,
1703 source,
1704 call.range,
1705 module_cache.as_deref_mut(),
1706 module_path,
1707 ),
1708 (op @ (TorchOp::Squeeze | TorchOp::Aggr), Some(arg0), _, _) => infer_squeeze(
1709 arg0,
1710 get_arg(call, "dim", offset),
1711 vars,
1712 func_map,
1713 imports,
1714 class_map,
1715 call_stack,
1716 diagnostics,
1717 hover_entries,
1718 record_hovers,
1719 source,
1720 call.range,
1721 module_cache.as_deref_mut(),
1722 module_path,
1723 matches!(op, TorchOp::Squeeze),
1724 get_arg(call, "keepdim", offset + 1),
1725 ),
1726 (TorchOp::NoopDim, Some(base), _, _) => {
1727 let base_hint = infer_expr_shape(
1728 base,
1729 vars,
1730 func_map,
1731 imports,
1732 class_map,
1733 call_stack,
1734 diagnostics,
1735 hover_entries,
1736 false,
1737 source,
1738 module_cache.as_deref_mut(),
1739 module_path,
1740 );
1741 infer_noop(
1742 base_hint,
1743 get_arg(call, "dim", offset),
1744 vars,
1745 diagnostics,
1746 source,
1747 call.range,
1748 attr_name.contains("soft"), )
1750 }
1751 (TorchOp::Noop, Some(base), _, _) => infer_expr_shape(
1752 base,
1753 vars,
1754 func_map,
1755 imports,
1756 class_map,
1757 call_stack,
1758 diagnostics,
1759 hover_entries,
1760 false,
1761 source,
1762 module_cache.as_deref_mut(),
1763 module_path,
1764 ),
1765 (TorchOp::Creation { is_size }, _, _, Function) => {
1766 let shape_assign = tensor_or_shape_as_arg(
1767 is_size,
1768 vars,
1769 func_map,
1770 imports,
1771 class_map,
1772 call_stack,
1773 diagnostics,
1774 hover_entries,
1775 record_hovers,
1776 source,
1777 module_cache,
1778 module_path,
1779 call,
1780 );
1781 let dtype = get_arg(call, "dtype", 200).and_then(|expr| {
1782 let attr_dtype = if matches!(expr, Expr::Attribute(_)) {
1783 infer_expr_shape(
1784 expr,
1785 vars,
1786 func_map,
1787 imports,
1788 class_map,
1789 call_stack,
1790 diagnostics,
1791 hover_entries,
1792 record_hovers,
1793 source,
1794 module_cache.as_deref_mut(),
1795 module_path,
1796 )
1797 .and_then(|shape| shape.dtype)
1798 } else {
1799 None
1800 };
1801 attr_dtype.or_else(|| get_dtype(expr, imports).map(|x| x.to_string()))
1802 });
1803
1804 infer_creation_size(
1805 call,
1806 vars,
1807 diagnostics,
1808 source,
1809 shape_assign,
1810 dtype,
1811 is_size,
1812 )
1813 }
1814 (TorchOp::RangeOp(range_op), _, _, Function) => infer_range_size(
1815 call,
1816 range_op,
1817 vars,
1818 func_map,
1819 imports,
1820 class_map,
1821 call_stack,
1822 diagnostics,
1823 hover_entries,
1824 record_hovers,
1825 source,
1826 module_cache.as_deref_mut(),
1827 module_path,
1828 ),
1829 (TorchOp::NoArg { predef_dtype }, Some(base), _, _)
1830 if predef_dtype.is_none() || matches!(op_kind, Method) =>
1832 {
1833 let base_expr = infer_expr_shape(
1834 base,
1835 vars,
1836 func_map,
1837 imports,
1838 class_map,
1839 call_stack,
1840 diagnostics,
1841 hover_entries,
1842 false,
1843 source,
1844 module_cache.as_deref_mut(),
1845 module_path,
1846 );
1847 let dtype_expr = if let Some(dtype) = predef_dtype {
1848 Some(&Expr::Constant(ast::ExprConstant {
1849 range: expr_text_range(base),
1850 value: Constant::Str(dtype.to_string()),
1851 kind: None,
1852 }))
1853 } else {
1854 get_arg(call, "dtype", 0)
1855 };
1856 infer_to(base_expr, dtype_expr, diagnostics, source, imports)
1857 }
1858 (TorchOp::Broadcastable(broadcast_op), Some(left), Some(right), _) => {
1859 infer_broadcastable_poswise(
1860 &ShapeOrExpr::Expr(left),
1861 right,
1862 vars,
1863 func_map,
1864 imports,
1865 class_map,
1866 call_stack,
1867 diagnostics,
1868 hover_entries,
1869 record_hovers,
1870 source,
1871 call.range,
1872 module_cache.as_deref_mut(),
1873 module_path,
1874 broadcast_op,
1875 )
1876 }
1877 (op @ (TorchOp::View | TorchOp::Expand), Some(base), _, kind)
1878 if matches!(kind, Method) || matches!(op, TorchOp::View) =>
1879 {
1880 let base_hint =
1881 lookup_shape(base, vars, hover_entries, record_hovers, source).or_else(|| {
1882 infer_expr_shape(
1883 base,
1884 vars,
1885 func_map,
1886 imports,
1887 class_map,
1888 call_stack,
1889 diagnostics,
1890 hover_entries,
1891 false,
1892 source,
1893 module_cache.as_deref_mut(),
1894 module_path,
1895 )
1896 });
1897 infer_view_like(
1898 base,
1899 &call.args.iter().collect::<Vec<_>>(),
1900 &op,
1901 base_hint,
1902 vars,
1903 diagnostics,
1904 hover_entries,
1905 record_hovers,
1906 source,
1907 call.range,
1908 )
1909 }
1910 (TorchOp::Transpose(transpose), Some(base), _, _) => {
1911 let order_args = match &transpose {
1912 Transpose::Permute => {
1913 match get_arg(call, "dims", offset) {
1914 Some(Expr::Tuple(ExprTuple { elts, .. }))
1915 | Some(Expr::List(ExprList { elts, .. })) => elts.iter().collect(),
1916 _ if matches!(op_kind, Method) => call.args.iter().collect(),
1919 _ => Vec::new(),
1920 }
1921 }
1922 Transpose::Explicit => call.args.iter().skip(offset).take(2).collect(),
1924 Transpose::T => Vec::new(),
1926 };
1927 let base_hint = infer_expr_shape(
1928 base,
1929 vars,
1930 func_map,
1931 imports,
1932 class_map,
1933 call_stack,
1934 diagnostics,
1935 hover_entries,
1936 false,
1937 source,
1938 module_cache.as_deref_mut(),
1939 module_path,
1940 );
1941 infer_permute(
1942 base,
1943 &order_args,
1944 transpose,
1945 base_hint,
1946 vars,
1947 diagnostics,
1948 hover_entries,
1949 record_hovers,
1950 source,
1951 call.range,
1952 )
1953 }
1954 (TorchOp::Unsqueeze, Some(base), _, _) => {
1955 let dim_arg = get_arg(call, "dim", offset);
1956 infer_unsqueeze(
1957 base,
1958 dim_arg,
1959 vars,
1960 func_map,
1961 imports,
1962 class_map,
1963 call_stack,
1964 diagnostics,
1965 hover_entries,
1966 record_hovers,
1967 source,
1968 module_cache.as_deref_mut(),
1969 module_path,
1970 )
1971 }
1972 (TorchOp::Conv(d), Some(base), _, Function) => {
1973 let kernel = lookup_shape(get_arg(call, "weight", 1)?, vars, hover_entries, record_hovers, source)?;
1974 lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| infer_conv(shape, kernel.dims, call, d, diagnostics, source))
1975 }
1976 (TorchOp::Repeat, Some(base), _, Method) => {
1977 lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| infer_repeat(shape, call, vars, diagnostics, source))
1978 }
1979 (TorchOp::RepeatInterleave, Some(base), _, _) => {
1980 lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| {
1981 infer_repeat_interleave(shape, call, offset, diagnostics, source)
1982 })
1983 }
1984 (TorchOp::Flatten, Some(base), _, _) => {
1985 lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| infer_flatten(shape, offset, call, vars, diagnostics, source))
1987 }
1988 (TorchOp::Unknown, _, _, Function) => {
1989 None
1992 }
1993 _ => None,
1995 }
1996}
1997
1998fn infer_tuple_elements(
1999 expr: &Expr,
2000 vars: &HashMap<Identifier, VarState>,
2001 func_map: &FuncMap,
2002 imports: &Imports,
2003 class_map: &ClassMap,
2004 call_stack: &mut Vec<Identifier>,
2005 diagnostics: &mut Vec<Diagnostic>,
2006 hover_entries: &mut Vec<(Range, HoverInfo)>,
2007 record_hovers: bool,
2008 source: &str,
2009 mut module_cache: Option<&mut ModuleCache>,
2010 module_path: Option<&Path>,
2011) -> Option<Vec<Option<Shape>>> {
2012 match expr {
2013 Expr::Tuple(tuple) => Some(
2014 tuple
2015 .elts
2016 .iter()
2017 .map(|elt| {
2018 infer_expr_shape(
2019 elt,
2020 vars,
2021 func_map,
2022 imports,
2023 class_map,
2024 call_stack,
2025 diagnostics,
2026 hover_entries,
2027 record_hovers,
2028 source,
2029 module_cache.as_deref_mut(),
2030 module_path,
2031 )
2032 })
2033 .collect(),
2034 ),
2035 Expr::Call(call) => infer_defined_call_return(
2036 call,
2037 vars,
2038 func_map,
2039 imports,
2040 class_map,
2041 call_stack,
2042 diagnostics,
2043 hover_entries,
2044 record_hovers,
2045 source,
2046 module_cache,
2047 module_path,
2048 )
2049 .and_then(|ret| ret.tuple().map(|tuple| tuple.to_vec())),
2050 _ => None,
2051 }
2052}
2053
2054fn method_param_offset(args: &Arguments) -> usize {
2055 match args.args.first() {
2056 Some(param) if param.def.arg.as_str() == "self" => 1,
2057 _ => 0,
2058 }
2059}
2060
2061fn class_ref_from_constructor_call(
2062 call_expr: &Expr,
2063 class_map: &ClassMap,
2064 imports: &Imports,
2065 mut module_cache: Option<&mut ModuleCache>,
2066 module_path: Option<&Path>,
2067) -> Option<ClassRef> {
2068 let Expr::Call(call) = call_expr else {
2069 return None;
2070 };
2071 match call.func.as_ref() {
2072 Expr::Name(name) => {
2073 if class_map.contains_key(&name.id) {
2074 return Some(ClassRef {
2075 name: name.id.clone(),
2076 module: None,
2077 });
2078 }
2079 if let Some((module_name, original)) = imports.from_imports.get(&name.id)
2080 && module_name != "torch"
2081 && !module_name.starts_with("torch.")
2082 && let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
2083 && let Some(module) = cache.get_module(module_name, cur_path)
2084 && module.class_map.contains_key(original)
2085 {
2086 return Some(ClassRef {
2087 name: original.clone(),
2088 module: Some(module_name.to_string()),
2089 });
2090 }
2091 }
2092 Expr::Attribute(attr) => {
2093 if let Expr::Name(module_ident) = attr.value.as_ref()
2094 && let Some(module_name) = imports.module_aliases.get(&module_ident.id)
2095 && module_name != "torch"
2096 && !module_name.starts_with("torch.")
2097 && let (Some(cache), Some(cur_path)) = (module_cache, module_path)
2098 && let Some(module) = cache.get_module(module_name, cur_path)
2099 && module.class_map.contains_key(&attr.attr)
2100 {
2101 return Some(ClassRef {
2102 name: attr.attr.clone(),
2103 module: Some(module_name.to_string()),
2104 });
2105 }
2106 }
2107 _ => {}
2108 }
2109 None
2110}
2111
2112fn class_ref_from_annotation(
2113 ann: &Expr,
2114 imports: &Imports,
2115 class_map: &ClassMap,
2116) -> Option<ClassRef> {
2117 if let Expr::Name(name) = ann {
2118 if class_map.contains_key(&name.id) {
2119 return Some(ClassRef {
2120 name: name.id.clone(),
2121 module: None,
2122 });
2123 }
2124 if let Some((module_name, original)) = imports.from_imports.get(&name.id) {
2125 return Some(ClassRef {
2126 name: original.clone(),
2127 module: Some(module_name.clone()),
2128 });
2129 }
2130 }
2131 None
2132}
2133
2134fn union_members<'a>(expr: &'a Expr, out: &mut Vec<&'a Expr>) {
2135 if let Expr::BinOp(ExprBinOp {
2136 left, op, right, ..
2137 }) = expr
2138 && matches!(op, Operator::BitOr)
2139 {
2140 union_members(left, out);
2141 union_members(right, out);
2142 } else {
2143 out.push(expr);
2144 }
2145}
2146
2147fn shape_or_class_from_union(
2148 ann: &Expr,
2149 imports: &Imports,
2150 class_map: &ClassMap,
2151) -> (Option<Shape>, Option<ClassRef>) {
2152 let mut members = Vec::new();
2153 union_members(ann, &mut members);
2154 members
2155 .into_iter()
2156 .find_map(|member| {
2157 parse_shape_annotation(member)
2158 .map(|shape| (Some(shape), None))
2159 .or_else(|| {
2160 class_ref_from_annotation(member, imports, class_map)
2161 .map(|class_ref| (None, Some(class_ref)))
2162 })
2163 })
2164 .unwrap_or((None, None))
2165}
2166
2167fn tuple_shapes_from_annotation(
2168 ann: &Expr,
2169 imports: &Imports,
2170 class_map: &ClassMap,
2171) -> Option<Vec<Option<Shape>>> {
2172 let elements: Vec<&Expr> = match ann {
2173 Expr::Tuple(t) => t.elts.iter().collect(),
2174 Expr::Subscript(sub) => {
2175 let is_tuple = name_like(&sub.value)
2176 .map(|n| n.eq_ignore_ascii_case("tuple"))
2177 .unwrap_or(false);
2178 if !is_tuple {
2179 return None;
2180 }
2181 match &*sub.slice {
2182 Expr::Tuple(t) => t.elts.iter().collect(),
2183 other => vec![other],
2184 }
2185 }
2186 _ => return None,
2187 };
2188 let tuple_shapes: Vec<Option<Shape>> = elements
2189 .iter()
2190 .map(|elt| {
2191 parse_shape_annotation(elt)
2192 .or_else(|| shape_or_class_from_union(elt, imports, class_map).0)
2193 })
2194 .collect();
2195 if tuple_shapes.iter().any(|s| s.is_some()) {
2196 Some(tuple_shapes)
2197 } else {
2198 None
2199 }
2200}
2201
2202fn infer_call_return_from_info(
2204 call: &ExprCall<TextRange>,
2205 callee_name: &Identifier,
2206 callee_info: &FunctionInfo,
2207 callee_source: &str,
2208 callee_func_map: &FuncMap,
2209 callee_imports: &Imports,
2210 callee_class_map: &ClassMap,
2211 vars: &HashMap<Identifier, VarState>,
2212 func_map: &FuncMap,
2213 imports: &Imports,
2214 class_map: &ClassMap,
2215 call_stack: &mut Vec<Identifier>,
2216 diagnostics: &mut Vec<Diagnostic>,
2217 hover_entries: &mut Vec<(Range, HoverInfo)>,
2218 record_hovers: bool,
2219 source: &str,
2220 mut module_cache: Option<&mut ModuleCache>,
2221 module_path: Option<&Path>,
2222 param_offset: usize,
2223 emit_body_diagnostics: bool,
2224) -> Option<ReturnValue> {
2225 if call_stack.iter().any(|id| id == callee_name) {
2226 return None;
2227 }
2228 let mut arg_shapes: HashMap<Identifier, VarState> = HashMap::new();
2229 for (idx, param) in callee_info.args.args.iter().enumerate().skip(param_offset) {
2230 let call_idx = idx.saturating_sub(param_offset);
2231 if let Some(arg_expr) = call.args.get(call_idx) {
2232 if let Some(shape) = infer_expr_shape(
2233 arg_expr,
2234 vars,
2235 func_map,
2236 imports,
2237 class_map,
2238 call_stack,
2239 diagnostics,
2240 hover_entries,
2241 record_hovers,
2242 source,
2243 module_cache.as_deref_mut(),
2244 module_path,
2245 ) {
2246 if let Some(ann) = param
2247 .def
2248 .annotation
2249 .as_deref()
2250 .and_then(parse_shape_annotation)
2251 {
2252 let dims_match = ann.dims.len() == shape.dims.len()
2253 && ann.dims.iter().zip(shape.dims.iter()).all(|(left, right)| {
2254 match (left.parse::<i32>().is_ok(), right.parse::<i32>().is_ok()) {
2258 (true, true) => left == right,
2259 _ => true,
2260 }
2261 });
2262 if !dims_match {
2263 diagnostics.push(Diagnostic {
2264 range: text_range_to_lsp(expr_text_range(arg_expr), source),
2265 severity: Some(DiagnosticSeverity::ERROR),
2266 code: None,
2267 code_description: None,
2268 source: Some("shapels".into()),
2269 message: format!(
2270 "Shape mismatch: annotation {} vs inferred {}",
2271 ann.render(),
2272 shape.render()
2273 ),
2274 related_information: None,
2275 tags: None,
2276 data: None,
2277 });
2278 }
2279 }
2280 arg_shapes.insert(
2281 param.def.arg.clone(),
2282 VarState {
2283 annotated: None,
2284 inferred: Some(shape),
2285 class_ref: None,
2286 },
2287 );
2288 } else if let Expr::Name(name) = arg_expr
2289 && let Some(class_ref) = vars.get(&name.id).and_then(|v| v.class_ref.clone())
2290 {
2291 arg_shapes.insert(
2292 param.def.arg.clone(),
2293 VarState {
2294 annotated: None,
2295 inferred: None,
2296 class_ref: Some(class_ref),
2297 },
2298 );
2299 }
2300 }
2301 }
2302 if let Some(ret_ann) = callee_info.returns.as_deref() {
2303 let annotated_return = if let Some(ret_shape) = parse_shape_annotation(ret_ann) {
2304 Some(ReturnValue::from_shape(Some(ret_shape)))
2305 } else if let Some(tuple_shapes) = tuple_shapes_from_annotation(ret_ann, imports, class_map)
2306 {
2307 Some(ReturnValue::from_tuple(tuple_shapes))
2308 } else {
2309 let (shape_union, _) = shape_or_class_from_union(ret_ann, imports, class_map);
2310 shape_union.map(|shape| ReturnValue::from_shape(Some(shape)))
2311 };
2312 if let Some(ret) = annotated_return {
2313 if emit_body_diagnostics {
2314 call_stack.push(callee_name.clone());
2315 let (mut diag, mut hovers, _) = simulate_function(
2316 callee_info.args.as_ref(),
2317 &callee_info.body,
2318 callee_source,
2319 callee_func_map,
2320 callee_imports,
2321 callee_class_map,
2322 call_stack,
2323 arg_shapes,
2324 record_hovers,
2325 module_cache.as_deref_mut(),
2326 module_path,
2327 );
2328 diagnostics.append(&mut diag);
2329 if record_hovers {
2330 hover_entries.append(&mut hovers);
2331 }
2332 call_stack.pop();
2333 }
2334 return Some(ret);
2335 }
2336 }
2337 call_stack.push(callee_name.clone());
2338 let (mut diag, mut hovers, ret_value) = simulate_function(
2339 callee_info.args.as_ref(),
2340 &callee_info.body,
2341 callee_source,
2342 callee_func_map,
2343 callee_imports,
2344 callee_class_map,
2345 call_stack,
2346 arg_shapes,
2347 false,
2348 module_cache,
2349 module_path,
2350 );
2351 diagnostics.append(&mut diag);
2352 if record_hovers {
2353 hover_entries.append(&mut hovers);
2354 }
2355 call_stack.pop();
2356 Some(ret_value)
2357}
2358
2359fn infer_class_call_return(
2360 call: &ExprCall<TextRange>,
2361 class_ref: &ClassRef,
2362 vars: &HashMap<Identifier, VarState>,
2363 func_map: &FuncMap,
2364 imports: &Imports,
2365 class_map: &ClassMap,
2366 call_stack: &mut Vec<Identifier>,
2367 diagnostics: &mut Vec<Diagnostic>,
2368 hover_entries: &mut Vec<(Range, HoverInfo)>,
2369 record_hovers: bool,
2370 source: &str,
2371 mut module_cache: Option<&mut ModuleCache>,
2372 module_path: Option<&Path>,
2373) -> Option<ReturnValue> {
2374 with_class_info(
2375 class_ref,
2376 source,
2377 func_map,
2378 imports,
2379 class_map,
2380 &mut module_cache,
2381 module_path,
2382 |class_info,
2383 callee_source,
2384 callee_func_map,
2385 callee_imports,
2386 callee_class_map,
2387 callee_path,
2388 module_cache| {
2389 if !class_info.is_torch_module {
2390 return None;
2391 }
2392 let forward = class_info.forward.as_ref()?;
2393 let param_offset = method_param_offset(&forward.args);
2394 infer_call_return_from_info(
2395 call,
2396 &class_ref.name,
2397 forward,
2398 callee_source,
2399 callee_func_map,
2400 callee_imports,
2401 callee_class_map,
2402 vars,
2403 func_map,
2404 imports,
2405 class_map,
2406 call_stack,
2407 diagnostics,
2408 hover_entries,
2409 record_hovers,
2410 source,
2411 module_cache.as_deref_mut(),
2412 callee_path,
2413 param_offset,
2414 false,
2415 )
2416 },
2417 )
2418 .and_then(|ret| ret)
2419}
2420
2421fn infer_class_method_call_return(
2422 call: &ExprCall<TextRange>,
2423 class_ref: &ClassRef,
2424 method_name: &Identifier,
2425 vars: &HashMap<Identifier, VarState>,
2426 func_map: &FuncMap,
2427 imports: &Imports,
2428 class_map: &ClassMap,
2429 call_stack: &mut Vec<Identifier>,
2430 diagnostics: &mut Vec<Diagnostic>,
2431 hover_entries: &mut Vec<(Range, HoverInfo)>,
2432 record_hovers: bool,
2433 source: &str,
2434 mut module_cache: Option<&mut ModuleCache>,
2435 module_path: Option<&Path>,
2436) -> Option<ReturnValue> {
2437 let callee_name = Identifier::from(format!("{}::{}", class_ref.name, method_name));
2438 with_class_info(
2439 class_ref,
2440 source,
2441 func_map,
2442 imports,
2443 class_map,
2444 &mut module_cache,
2445 module_path,
2446 |class_info,
2447 callee_source,
2448 callee_func_map,
2449 callee_imports,
2450 callee_class_map,
2451 callee_path,
2452 module_cache| {
2453 let method_info = class_info.methods.get(method_name)?;
2454 let param_offset = method_param_offset(&method_info.args);
2455 infer_call_return_from_info(
2456 call,
2457 &callee_name,
2458 method_info,
2459 callee_source,
2460 callee_func_map,
2461 callee_imports,
2462 callee_class_map,
2463 vars,
2464 func_map,
2465 imports,
2466 class_map,
2467 call_stack,
2468 diagnostics,
2469 hover_entries,
2470 record_hovers,
2471 source,
2472 module_cache.as_deref_mut(),
2473 callee_path,
2474 param_offset,
2475 false,
2476 )
2477 },
2478 )
2479 .and_then(|ret| ret)
2480}
2481
2482fn infer_defined_call_return(
2483 call: &ExprCall<TextRange>,
2484 vars: &HashMap<Identifier, VarState>,
2485 func_map: &FuncMap,
2486 imports: &Imports,
2487 class_map: &ClassMap,
2488 call_stack: &mut Vec<Identifier>,
2489 diagnostics: &mut Vec<Diagnostic>,
2490 hover_entries: &mut Vec<(Range, HoverInfo)>,
2491 record_hovers: bool,
2492 source: &str,
2493 mut module_cache: Option<&mut ModuleCache>,
2494 module_path: Option<&Path>,
2495) -> Option<ReturnValue> {
2496 if let Expr::Name(func_name) = call.func.as_ref() {
2497 if let Some(class_ref) = vars.get(&func_name.id).and_then(|v| v.class_ref.as_ref()) {
2498 return infer_class_call_return(
2499 call,
2500 class_ref,
2501 vars,
2502 func_map,
2503 imports,
2504 class_map,
2505 call_stack,
2506 diagnostics,
2507 hover_entries,
2508 record_hovers,
2509 source,
2510 module_cache.as_deref_mut(),
2511 module_path,
2512 );
2513 }
2514 if let Some((module_name, original)) = imports.from_imports.get(&func_name.id)
2515 && let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
2516 && let Some(module) = cache.get_module(module_name, cur_path)
2517 && let Some(callee_info) = module.func_map.get(original)
2518 {
2519 return infer_call_return_from_info(
2520 call,
2521 &func_name.id,
2522 callee_info,
2523 &module.source,
2524 &module.func_map,
2525 &module.imports,
2526 &module.class_map,
2527 vars,
2528 func_map,
2529 imports,
2530 class_map,
2531 call_stack,
2532 diagnostics,
2533 hover_entries,
2534 record_hovers,
2535 source,
2536 module_cache.as_deref_mut(),
2537 Some(module.path.as_path()),
2538 0,
2539 false,
2540 );
2541 }
2542 if let Some(callee_info) = func_map.get(&func_name.id) {
2543 return infer_call_return_from_info(
2544 call,
2545 &func_name.id,
2546 callee_info,
2547 source,
2548 func_map,
2549 imports,
2550 class_map,
2551 vars,
2552 func_map,
2553 imports,
2554 class_map,
2555 call_stack,
2556 diagnostics,
2557 hover_entries,
2558 record_hovers,
2559 source,
2560 module_cache.as_deref_mut(),
2561 module_path,
2562 0,
2563 false,
2564 );
2565 }
2566 }
2567 if let Expr::Attribute(attr) = call.func.as_ref()
2568 && let Expr::Name(base_name) = attr.value.as_ref()
2569 && let Some(class_ref) = vars.get(&base_name.id).and_then(|v| v.class_ref.as_ref())
2570 {
2571 return infer_class_method_call_return(
2572 call,
2573 class_ref,
2574 &attr.attr,
2575 vars,
2576 func_map,
2577 imports,
2578 class_map,
2579 call_stack,
2580 diagnostics,
2581 hover_entries,
2582 record_hovers,
2583 source,
2584 module_cache.as_deref_mut(),
2585 module_path,
2586 );
2587 }
2588 if let Expr::Attribute(attr) = call.func.as_ref()
2589 && let Expr::Name(module_ident) = attr.value.as_ref()
2590 && let Some(module_name) = imports.module_aliases.get(&module_ident.id)
2591 && module_name != "torch"
2592 && let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
2593 && let Some(module) = cache.get_module(module_name, cur_path)
2594 && let Some(callee_info) = module.func_map.get(&attr.attr)
2595 {
2596 return infer_call_return_from_info(
2597 call,
2598 &attr.attr,
2599 callee_info,
2600 &module.source,
2601 &module.func_map,
2602 &module.imports,
2603 &module.class_map,
2604 vars,
2605 func_map,
2606 imports,
2607 class_map,
2608 call_stack,
2609 diagnostics,
2610 hover_entries,
2611 record_hovers,
2612 source,
2613 module_cache,
2614 Some(module.path.as_path()),
2615 0,
2616 false,
2617 );
2618 }
2619 None
2620}
2621
2622fn tensor_or_shape_as_arg(
2623 is_size: bool,
2624 vars: &HashMap<Identifier, VarState>,
2625 func_map: &FuncMap,
2626 imports: &Imports,
2627 class_map: &ClassMap,
2628 call_stack: &mut Vec<Identifier>,
2629 diagnostics: &mut Vec<Diagnostic>,
2630 hover_entries: &mut Vec<(Range, HoverInfo)>,
2631 record_hovers: bool,
2632 source: &str,
2633 module_cache: &mut Option<&mut ModuleCache>,
2634 module_path: Option<&Path>,
2635 call: &ExprCall,
2636) -> Option<Shape> {
2637 match call.args.first() {
2638 Some(arg0 @ Expr::Attribute(_)) if is_size => infer_expr_shape(
2640 arg0,
2641 vars,
2642 func_map,
2643 imports,
2644 class_map,
2645 call_stack,
2646 diagnostics,
2647 hover_entries,
2648 record_hovers,
2649 source,
2650 module_cache.as_deref_mut(),
2651 module_path,
2652 ),
2653 Some(arg0 @ Expr::Name(_)) if !is_size => infer_expr_shape(
2655 arg0,
2656 vars,
2657 func_map,
2658 imports,
2659 class_map,
2660 call_stack,
2661 diagnostics,
2662 hover_entries,
2663 record_hovers,
2664 source,
2665 module_cache.as_deref_mut(),
2666 module_path,
2667 ),
2668 _ => None,
2669 }
2670}
2671
2672fn lookup_shape(
2673 expr: &Expr,
2674 vars: &HashMap<Identifier, VarState>,
2675 hover_entries: &mut Vec<(Range, HoverInfo)>,
2676 record_hovers: bool,
2677 source: &str,
2678) -> Option<Shape> {
2679 match expr {
2680 Expr::Name(expr_name) => {
2681 let shape = vars
2682 .get(&expr_name.id)
2683 .and_then(|v| v.annotated.clone().or_else(|| v.inferred.clone()));
2684 if record_hovers && let Some(s) = shape.clone() {
2685 let range = text_range_to_lsp(expr_text_range(expr), source);
2686 hover_entries.push((range, HoverInfo { shape: Some(s) }));
2687 }
2688 shape
2689 }
2690 _ => None,
2691 }
2692}
2693
2694enum TorchOpKind {
2697 Function,
2699 Method,
2701}
2702
2703fn function_or_method<R>(expr: &Expr<R>, imports: &Imports) -> TorchOpKind {
2704 match expr {
2705 Expr::Name(n)
2706 if n.id.as_str() == "torch"
2707 || imports.torch_aliases.contains(&n.id)
2708 || imports.torch_nn_functional_aliases.contains(&n.id) =>
2709 {
2710 TorchOpKind::Function
2711 }
2712 _ => TorchOpKind::Method,
2713 }
2714}
2715
2716fn is_alias_of(canonical: &str, ident: &Identifier, imports: &Imports) -> bool {
2717 imports
2718 .func_aliases
2719 .get(canonical)
2720 .map(|set| set.contains(ident))
2721 .unwrap_or(false)
2722}
2723
2724fn assignment_shape_checks(
2725 target: &Expr,
2726 value: &Expr,
2727 vars: &mut HashMap<Identifier, VarState>,
2728 diagnostics: &mut Vec<Diagnostic>,
2729 hover_entries: &mut Vec<(Range, HoverInfo)>,
2730 record_hovers: bool,
2731 source: &str,
2732) -> bool {
2733 let Expr::Tuple(tup) = target else {
2734 return false;
2735 };
2736 let dims: Vec<Identifier> = tup.elts.iter().filter_map(name_from_expr).collect();
2737 if dims.len() != tup.elts.len() || dims.is_empty() {
2738 return false;
2739 }
2740 let Expr::Attribute(attr) = value else {
2741 return false;
2742 };
2743 if attr.attr.as_str() != "shape" {
2744 return false;
2745 }
2746 let Some(base_id) = name_from_expr(&attr.value) else {
2747 return false;
2748 };
2749
2750 let range = text_range_to_lsp(expr_text_range(value), source);
2751 let existing_state = vars.get(&base_id);
2752 let existing_shape = existing_state.and_then(|v| v.annotated.clone().or(v.inferred.clone()));
2753
2754 if let Some(shape) = existing_shape {
2755 if shape.dims.len() == dims.len() {
2756 let mut new_shape = shape.clone();
2757 new_shape.dims = dims.iter().map(|d| d.to_string()).collect();
2758 vars.insert(
2759 base_id.clone(),
2760 VarState {
2761 annotated: None,
2762 inferred: Some(new_shape.clone()),
2763 class_ref: None,
2764 },
2765 );
2766 if record_hovers {
2767 let hrange = text_range_to_lsp(expr_text_range(&attr.value), source);
2768 hover_entries.push((
2769 hrange,
2770 HoverInfo {
2771 shape: Some(new_shape),
2772 },
2773 ));
2774 }
2775 } else {
2776 diagnostics.push(Diagnostic {
2777 range,
2778 severity: Some(DiagnosticSeverity::ERROR),
2779 code: None,
2780 code_description: None,
2781 source: Some("shapels".into()),
2782 message: "Cannot unroll shape with different rank".into(),
2783 related_information: None,
2784 tags: None,
2785 data: None,
2786 });
2787 }
2788 } else {
2789 let new_shape = Shape {
2790 dtype: None,
2791 dims: dims.iter().map(|d| d.to_string()).collect(),
2792 };
2793 vars.insert(
2794 base_id.clone(),
2795 VarState {
2796 annotated: None,
2797 inferred: Some(new_shape.clone()),
2798 class_ref: None,
2799 },
2800 );
2801 if record_hovers {
2802 let hrange = text_range_to_lsp(expr_text_range(&attr.value), source);
2803 hover_entries.push((
2804 hrange,
2805 HoverInfo {
2806 shape: Some(new_shape),
2807 },
2808 ));
2809 }
2810 }
2811 true
2812}
2813
2814fn parse_shape_annotation(expr: &Expr) -> Option<Shape> {
2815 if let Expr::Subscript(sub) = expr {
2816 let dtype = name_like(&sub.value);
2817 let components: Vec<&Expr> = match &*sub.slice {
2818 ast::Expr::Tuple(t) => t.elts.iter().collect(),
2819 other => vec![other],
2820 };
2821 if components.len() >= 2
2822 && let Some(raw) = string_literal_value(components[1])
2823 {
2824 let dims = normalize_shape_tokens(&raw);
2825 return Some(Shape { dtype, dims });
2826 }
2827 }
2828 None
2829}
2830
2831fn name_like(expr: &Expr) -> Option<String> {
2832 match expr {
2833 Expr::Name(name) => Some(name.id.to_string()),
2834 Expr::Attribute(attr) => {
2835 let mut base = name_like(&attr.value)?;
2836 base.push('.');
2837 base.push_str(attr.attr.as_ref());
2838 Some(base)
2839 }
2840 _ => None,
2841 }
2842}
2843
2844fn string_literal_value(expr: &Expr) -> Option<String> {
2845 match expr {
2846 Expr::Constant(c) => {
2847 if let ast::Constant::Str(s) = &c.value {
2848 Some(s.clone())
2849 } else {
2850 None
2851 }
2852 }
2853 Expr::JoinedStr(js) => {
2854 let mut buf = String::new();
2855 for val in &js.values {
2856 if let Expr::Constant(c) = val
2857 && let ast::Constant::Str(s) = &c.value
2858 {
2859 buf.push_str(s);
2860 }
2861 }
2862 if buf.is_empty() { None } else { Some(buf) }
2863 }
2864 _ => None,
2865 }
2866}
2867
2868fn normalize_shape_tokens(raw: &str) -> Vec<String> {
2869 raw.split_whitespace()
2870 .map(|s| s.trim_matches('"').to_string())
2871 .collect()
2872}
2873
2874fn name_from_expr(expr: &Expr) -> Option<Identifier> {
2875 match expr {
2876 Expr::Name(n) => Some(n.id.clone()),
2877 _ => None,
2878 }
2879}
2880
2881fn seed_args_from_annotations(
2882 args: &Arguments,
2883 source: &str,
2884 vars: &mut HashMap<Identifier, VarState>,
2885 hover_entries: &mut Vec<(Range, HoverInfo)>,
2886 provided: Option<&mut HashMap<Identifier, VarState>>,
2887 record_hovers: bool,
2888 imports: &Imports,
2889 class_map: &ClassMap,
2890) {
2891 for arg in &args.args {
2892 let ann_shape = arg
2893 .def
2894 .annotation
2895 .as_ref()
2896 .and_then(|expr| parse_shape_annotation(expr.as_ref()));
2897 let (union_shape, union_class) = arg
2898 .def
2899 .annotation
2900 .as_ref()
2901 .map(|ann| shape_or_class_from_union(ann.as_ref(), imports, class_map))
2902 .unwrap_or((None, None));
2903 let range = text_range_to_lsp(arg.def.range, source);
2904 let provided_state = provided.as_ref().and_then(|p| p.get(&arg.def.arg));
2905 let state = VarState {
2906 annotated: ann_shape.clone().or(union_shape),
2907 inferred: provided_state.and_then(|s| s.inferred.clone()),
2908 class_ref: provided_state
2909 .and_then(|s| s.class_ref.clone())
2910 .or_else(|| {
2911 arg.def
2912 .annotation
2913 .as_ref()
2914 .and_then(|ann| class_ref_from_annotation(ann.as_ref(), imports, class_map))
2915 })
2916 .or(union_class),
2917 };
2918
2919 if state.annotated.is_some() || state.inferred.is_some() || state.class_ref.is_some() {
2920 vars.insert(arg.def.arg.clone(), state.clone());
2921 if record_hovers {
2922 hover_entries.push((
2923 range,
2924 HoverInfo {
2925 shape: state.annotated.or(state.inferred),
2926 },
2927 ));
2928 }
2929 }
2930 }
2931}
2932
2933fn text_range_to_lsp(range: TextRange, source: &str) -> Range {
2934 Range {
2935 start: offset_to_position(source, range.start().to_usize()),
2936 end: offset_to_position(source, range.end().to_usize()),
2937 }
2938}
2939
2940fn expr_text_range(expr: &Expr) -> TextRange {
2941 match expr {
2942 Expr::Name(n) => n.range,
2943 Expr::BoolOp(b) => b.range,
2944 Expr::NamedExpr(n) => n.range,
2945 Expr::BinOp(b) => b.range,
2946 Expr::UnaryOp(u) => u.range,
2947 Expr::Lambda(l) => l.range,
2948 Expr::IfExp(i) => i.range,
2949 Expr::Dict(d) => d.range,
2950 Expr::Set(s) => s.range,
2951 Expr::ListComp(l) => l.range,
2952 Expr::SetComp(s) => s.range,
2953 Expr::DictComp(d) => d.range,
2954 Expr::GeneratorExp(g) => g.range,
2955 Expr::Await(a) => a.range,
2956 Expr::Yield(y) => y.range,
2957 Expr::YieldFrom(y) => y.range,
2958 Expr::Compare(c) => c.range,
2959 Expr::Call(c) => c.range,
2960 Expr::FormattedValue(f) => f.range,
2961 Expr::Subscript(s) => s.range,
2962 Expr::Attribute(a) => a.range,
2963 Expr::Starred(s) => s.range,
2964 Expr::Constant(c) => c.range,
2965 Expr::JoinedStr(j) => j.range,
2966 Expr::List(l) => l.range,
2967 Expr::Tuple(t) => t.range,
2968 Expr::Slice(s) => s.range,
2969 }
2970}
2971
2972fn offset_to_position(source: &str, offset: usize) -> Position {
2973 let mut line = 0u32;
2974 let mut col = 0u32;
2975 let mut count = 0usize;
2976 for ch in source.chars() {
2977 if count == offset {
2978 break;
2979 }
2980 if ch == '\n' {
2981 line += 1;
2982 col = 0;
2983 } else {
2984 col += 1;
2985 }
2986 count += ch.len_utf8();
2987 }
2988 Position {
2989 line,
2990 character: col,
2991 }
2992}
2993
2994fn default_range() -> Range {
2995 Range {
2996 start: Position {
2997 line: 0,
2998 character: 0,
2999 },
3000 end: Position {
3001 line: 0,
3002 character: 1,
3003 },
3004 }
3005}
3006
3007fn get_arg<'a, R>(
3008 call: &'a ExprCall<R>,
3009 name_arg: &str,
3010 as_positional: usize,
3011) -> Option<&'a Expr<R>> {
3012 (call.args.get(as_positional)).or(call
3014 .keywords
3015 .iter()
3016 .find(|kw| kw.arg.as_deref() == Some(name_arg))
3017 .map(|kw| &kw.value))
3018}
3019
3020fn get_dtype<'expr, R>(dtype_expr: &'expr Expr<R>, imports: &Imports) -> Option<&'expr str> {
3021 match dtype_expr {
3022 Expr::Constant(constant) => match &constant.value {
3023 Constant::Str(string_dtype) if TORCH_DTYPES.contains(string_dtype.as_str()) => {
3024 Some(string_dtype.as_str())
3025 }
3026 _ => Some("Float"),
3027 },
3028 Expr::Name(name) => Some(name.id.as_str()),
3029 Expr::Attribute(attr)
3030 if matches!(
3031 function_or_method(attr.value.as_ref(), imports),
3032 TorchOpKind::Function
3033 ) && TORCH_DTYPES.contains(attr.attr.as_str()) =>
3034 {
3035 Some(attr.attr.as_str())
3036 }
3037 _ => Some("Float"),
3038 }
3039}