shapels/
lib.rs

1#![allow(clippy::too_many_arguments)]
2
3use lsp_types::{Diagnostic, DiagnosticSeverity, Position, Range};
4use rustpython_parser::Parse;
5use rustpython_parser::ast::{
6    self, Arguments, Constant, Expr, ExprBinOp, ExprCall, ExprCompare, ExprList, ExprTuple,
7    Identifier, Operator, Stmt, Suite,
8};
9use rustpython_parser::text_size::{TextRange, TextSize};
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::fs;
13use std::path::{Path, PathBuf};
14mod infer;
15pub mod op_groups;
16use crate::infer::{
17    ShapeOrExpr, Transpose, infer_broadcastable_poswise, infer_conv, infer_creation_size,
18    infer_flatten, infer_index, infer_matmul_shapes, infer_noop, infer_permute, infer_range_size,
19    infer_repeat, infer_repeat_interleave, infer_squeeze, infer_to, infer_unsqueeze,
20    infer_view_like, shape_dims_equal,
21};
22pub use crate::op_groups::AGGR_ALIASES;
23use crate::op_groups::{BroadcastOp, Imports, TORCH_DTYPES, TorchOp, collect_imports};
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct Shape {
27    pub dtype: Option<String>,
28    pub dims: Vec<String>,
29}
30
31impl Shape {
32    pub fn render(&self) -> String {
33        format!("[{}]", self.dims.join(" "))
34    }
35    pub fn dim_string(&self) -> String {
36        self.dims.join(" ")
37    }
38}
39
40#[derive(Debug, Clone)]
41pub struct HoverInfo {
42    pub shape: Option<Shape>,
43}
44
45#[derive(Debug, Default)]
46pub struct Analysis {
47    pub diagnostics: Vec<Diagnostic>,
48    pub hover_entries: Vec<(Range, HoverInfo)>,
49}
50
51impl Analysis {
52    pub fn hover(&self, position: Position) -> Option<&HoverInfo> {
53        // Prefer exact containment with smallest span.
54        if let Some((_, info)) = self
55            .hover_entries
56            .iter()
57            .filter(|(range, _)| within(range, &position))
58            .min_by_key(|(range, _)| range_span(range, &position))
59        {
60            return Some(info);
61        }
62
63        // Fallback: nearest entry on the same line to the left.
64        if let Some((_, info)) = self
65            .hover_entries
66            .iter()
67            .filter(|(range, _)| range.start.line == position.line)
68            .min_by_key(|(range, _)| {
69                let a = range.start.character as i64;
70                let b = position.character as i64;
71                (a - b).abs()
72            })
73        {
74            return Some(info);
75        }
76        None
77    }
78}
79
80fn within(range: &Range, pos: &Position) -> bool {
81    (pos.line > range.start.line
82        || (pos.line == range.start.line && pos.character >= range.start.character))
83        && (pos.line < range.end.line
84            || (pos.line == range.end.line && pos.character <= range.end.character))
85}
86
87fn range_span(range: &Range, _pos: &Position) -> u32 {
88    // Prioritize entries that wrap the position tightly.
89    let line_span = (range.end.line as i64 - range.start.line as i64).unsigned_abs() as u32;
90    let char_span = if range.start.line == range.end.line {
91        (range.end.character as i64 - range.start.character as i64).unsigned_abs() as u32
92    } else {
93        1000
94    };
95    line_span * 1000 + char_span
96}
97
98#[derive(Debug, Clone, Default)]
99struct VarState {
100    annotated: Option<Shape>,
101    inferred: Option<Shape>,
102    class_ref: Option<ClassRef>,
103}
104
105/// Output of [`simulate_function`], such that return type
106/// can be recorded and matched against tuple destructuring on
107/// assignment.
108#[derive(Debug, Clone, Default)]
109enum ReturnValue {
110    Single(Shape),
111    Tuple(Vec<Option<Shape>>),
112    #[default]
113    None,
114}
115
116impl ReturnValue {
117    fn from_shape(shape: Option<Shape>) -> Self {
118        shape.map(Self::Single).unwrap_or(Self::None)
119    }
120
121    fn from_tuple(shapes: Vec<Option<Shape>>) -> Self {
122        Self::Tuple(shapes)
123    }
124
125    fn first(&self) -> Option<&Shape> {
126        match self {
127            Self::Single(shape) => Some(shape),
128            Self::Tuple(tuple) => tuple.first().and_then(|x| x.as_ref()),
129            Self::None => None,
130        }
131    }
132
133    fn tuple(&self) -> Option<&[Option<Shape>]> {
134        match self {
135            Self::Tuple(tuple) => Some(tuple.as_ref()),
136            _ => None,
137        }
138    }
139
140    fn is_some(&self) -> bool {
141        match self {
142            Self::Single(_) => true,
143            Self::Tuple(tup) => !tup.is_empty(),
144            Self::None => false,
145        }
146    }
147}
148
149#[derive(Clone)]
150struct FunctionInfo {
151    args: Box<Arguments>,
152    body: Vec<Stmt>,
153    returns: Option<Box<Expr>>,
154}
155
156type FuncMap = HashMap<Identifier, FunctionInfo>;
157
158#[derive(Debug, Clone)]
159struct ClassRef {
160    name: Identifier,
161    module: Option<String>,
162}
163
164#[derive(Clone)]
165struct ClassInfo {
166    is_torch_module: bool,
167    forward: Option<FunctionInfo>,
168    methods: HashMap<Identifier, FunctionInfo>,
169}
170
171type ClassMap = HashMap<Identifier, ClassInfo>;
172
173#[derive(Clone)]
174pub(crate) struct CachedModule {
175    path: PathBuf,
176    source: String,
177    func_map: FuncMap,
178    imports: Imports,
179    class_map: ClassMap,
180}
181
182pub(crate) struct ModuleCache {
183    modules: HashMap<String, CachedModule>,
184    project_root: Option<PathBuf>,
185}
186
187impl ModuleCache {
188    fn new(current_file: &Path) -> Self {
189        let project_root = find_project_root(current_file);
190        Self {
191            modules: HashMap::new(),
192            project_root,
193        }
194    }
195
196    fn project_root(&self) -> Option<&Path> {
197        self.project_root.as_deref()
198    }
199
200    /// Return a cloned module entry, loading and parsing it if necessary.
201    fn get_module(&mut self, module_name: &str, current_file: &Path) -> Option<CachedModule> {
202        if let Some(cached) = self.modules.get(module_name) {
203            return Some(cached.clone());
204        }
205        let path = resolve_module_path(module_name, current_file, self.project_root.as_deref())?;
206        let source = fs::read_to_string(&path).ok()?;
207        let normalized = normalize_return_annotations(&source);
208        let source = normalized.into_owned();
209        let module = Suite::parse(&source, module_name).ok()?;
210        let imports = collect_imports(&module, Some(&path), self.project_root());
211        let mut func_map: FuncMap = HashMap::new();
212        collect_function_defs(&module, &mut func_map);
213        let class_map = collect_class_defs(&module, &imports);
214        let cached = CachedModule {
215            path,
216            source,
217            func_map,
218            imports,
219            class_map,
220        };
221        self.modules.insert(module_name.to_string(), cached.clone());
222        Some(cached)
223    }
224}
225
226/// Normalize return annotations so `-> A, B` parses as `-> (A, B)`.
227fn normalize_return_annotations<'a>(source: &'a str) -> Cow<'a, str> {
228    let mut changed = false;
229    let mut out = String::with_capacity(source.len());
230    let mut first = true;
231    for line in source.lines() {
232        if !first {
233            out.push('\n');
234        }
235        first = false;
236        if let (Some(arrow_idx), Some(colon_idx)) = (line.find("->"), line.rfind(':'))
237            && arrow_idx < colon_idx
238        {
239            let ann = &line[arrow_idx + 2..colon_idx];
240            let ann_trim = ann.trim();
241            if ann_trim.contains(',') && !ann_trim.starts_with('(') {
242                changed = true;
243                out.push_str(&line[..arrow_idx + 2]);
244                out.push(' ');
245                out.push('(');
246                out.push_str(ann_trim);
247                out.push(')');
248                out.push_str(&line[colon_idx..]);
249                continue;
250            }
251        }
252        out.push_str(line);
253    }
254    if changed {
255        Cow::Owned(out)
256    } else {
257        Cow::Borrowed(source)
258    }
259}
260
261fn find_project_root(start: &Path) -> Option<PathBuf> {
262    let mut dir = start.parent();
263    let home = std::env::var_os("HOME").map(PathBuf::from);
264    while let Some(current) = dir {
265        if current.join(".git").exists()
266            || current.join("pyproject.toml").exists()
267            || current.join("setup.py").exists()
268            || current.join("setup.cfg").exists()
269        {
270            return Some(current.to_path_buf());
271        }
272        if Some(current.to_path_buf()) == home {
273            break;
274        }
275        dir = current.parent();
276    }
277    None
278}
279
280/// Resolve a class reference across local/module context and run a callback with its info.
281fn with_class_info<R, F>(
282    class_ref: &ClassRef,
283    source: &str,
284    func_map: &FuncMap,
285    imports: &Imports,
286    class_map: &ClassMap,
287    module_cache: &mut Option<&mut ModuleCache>,
288    module_path: Option<&Path>,
289    f: F,
290) -> Option<R>
291where
292    F: FnOnce(
293        &ClassInfo,
294        &str,
295        &FuncMap,
296        &Imports,
297        &ClassMap,
298        Option<&Path>,
299        &mut Option<&mut ModuleCache>,
300    ) -> R,
301{
302    match &class_ref.module {
303        Some(module_name) => {
304            let module = {
305                let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
306                else {
307                    return None;
308                };
309                cache.get_module(module_name, cur_path)?
310            };
311            let class_info = module.class_map.get(&class_ref.name)?;
312            Some(f(
313                class_info,
314                &module.source,
315                &module.func_map,
316                &module.imports,
317                &module.class_map,
318                Some(module.path.as_path()),
319                module_cache,
320            ))
321        }
322        None => {
323            let class_info = class_map.get(&class_ref.name)?;
324            Some(f(
325                class_info,
326                source,
327                func_map,
328                imports,
329                class_map,
330                module_path,
331                module_cache,
332            ))
333        }
334    }
335}
336
337fn resolve_module_path(
338    module: &str,
339    current_file: &Path,
340    project_root: Option<&Path>,
341) -> Option<PathBuf> {
342    let mut search_roots: Vec<PathBuf> = Vec::new();
343
344    if let Some(parent) = current_file.parent() {
345        search_roots.push(parent.to_path_buf());
346    }
347    if let Some(prj) = project_root {
348        search_roots.push(prj.to_path_buf());
349        let src_dir = prj.join("src");
350        if src_dir.exists() {
351            search_roots.push(src_dir);
352        }
353    }
354
355    // Search for virtual environment markers in PATH.
356    if let Ok(path_var) = std::env::var("PATH") {
357        for entry in path_var.split(':') {
358            if !entry.contains(".venv") {
359                continue;
360            }
361            let mut p = PathBuf::from(entry);
362            while let Some(parent) = p.parent() {
363                if let Some(name) = parent.file_name()
364                    && name.to_string_lossy().contains(".venv")
365                {
366                    let venv_dir = parent.to_path_buf();
367                    // parent of .venv might be the project root
368                    if let Some(parent_parent) = venv_dir.parent() {
369                        search_roots.push(parent_parent.to_path_buf());
370                    }
371                    // site-packages paths
372                    let lib_dir = venv_dir.join("lib");
373                    if lib_dir.exists()
374                        && let Ok(entries) = fs::read_dir(&lib_dir)
375                    {
376                        for entry in entries.flatten() {
377                            let fname = entry.file_name();
378                            if fname.to_string_lossy().starts_with("python") {
379                                let sp = entry.path().join("site-packages");
380                                if sp.exists() {
381                                    search_roots.push(sp);
382                                }
383                            }
384                        }
385                    }
386                    break;
387                }
388                p = parent.to_path_buf();
389            }
390        }
391    }
392
393    // Convert module name to path components.
394    let parts: Vec<&str> = module.split('.').collect();
395    for root in search_roots {
396        let mut base = root.clone();
397        for part in &parts {
398            base.push(part);
399        }
400        let file_candidate = base.with_extension("py");
401        if file_candidate.exists() {
402            return Some(file_candidate);
403        }
404        let init_candidate = base.join("__init__.py");
405        if init_candidate.exists() {
406            return Some(init_candidate);
407        }
408        // Heuristic: if root already points at the top-level package (e.g., root ends with parts[0]),
409        // try resolving without repeating the first component to avoid example_python/example_python duplication.
410        if let Some(root_name) = root.file_name()
411            && root_name
412                == parts
413                    .first()
414                    .map(std::ffi::OsStr::new)
415                    .unwrap_or_else(|| std::ffi::OsStr::new(""))
416            && parts.len() > 1
417        {
418            let mut base = root.clone();
419            for part in parts.iter().skip(1) {
420                base.push(part);
421            }
422            let file_candidate = base.with_extension("py");
423            if file_candidate.exists() {
424                return Some(file_candidate);
425            }
426            let init_candidate = base.join("__init__.py");
427            if init_candidate.exists() {
428                return Some(init_candidate);
429            }
430        }
431    }
432    None
433}
434
435pub fn analyze_source(source: &str) -> Analysis {
436    analyze_source_internal(source, None, None)
437}
438
439/// Analyze in-memory source but anchored at a file path so imports can resolve.
440pub fn analyze_source_at_path(source: &str, path: &Path) -> Analysis {
441    let mut cache = ModuleCache::new(path);
442    analyze_source_internal(source, Some(path), Some(&mut cache))
443}
444
445/// Analyze a python file with module resolution enabled.
446pub fn analyze_file(path: &Path) -> Analysis {
447    match fs::read_to_string(path) {
448        Ok(src) => {
449            let mut cache = ModuleCache::new(path);
450            analyze_source_internal(&src, Some(path), Some(&mut cache))
451        }
452        Err(err) => Analysis {
453            diagnostics: vec![Diagnostic {
454                range: default_range(),
455                severity: Some(DiagnosticSeverity::ERROR),
456                code: None,
457                code_description: None,
458                source: Some("shapels".into()),
459                message: format!("Failed to read file: {err}"),
460                related_information: None,
461                tags: None,
462                data: None,
463            }],
464            hover_entries: Vec::new(),
465        },
466    }
467}
468
469fn analyze_source_internal<'a>(
470    source: &'a str,
471    current_path: Option<&'a Path>,
472    mut module_cache: Option<&mut ModuleCache>,
473) -> Analysis {
474    let mut analysis = Analysis::default();
475    let normalized = normalize_return_annotations(source);
476    let source = normalized.as_ref();
477    let parse_name = current_path
478        .map(|p| p.to_string_lossy().to_string())
479        .unwrap_or_else(|| "<memory>".to_string());
480    match Suite::parse(source, &parse_name) {
481        Ok(module) => {
482            let imports = collect_imports(
483                &module,
484                current_path,
485                module_cache.as_deref().and_then(|c| c.project_root()),
486            );
487            // collect function definitions first
488            let mut func_map: FuncMap = HashMap::new();
489            collect_function_defs(&module, &mut func_map);
490            let class_map = collect_class_defs(&module, &imports);
491
492            // analyze top-level statements (outside functions)
493            let empty_args = Arguments {
494                range: ast::OptionalRange::from(TextRange::new(
495                    TextSize::from(0),
496                    TextSize::from(0),
497                )),
498                posonlyargs: Vec::new(),
499                args: Vec::new(),
500                vararg: None,
501                kwonlyargs: Vec::new(),
502                kwarg: None,
503            };
504            let (mut top_diags, mut top_hovers, _) = simulate_function(
505                &empty_args,
506                &module,
507                source,
508                &func_map,
509                &imports,
510                &class_map,
511                &mut Vec::new(),
512                HashMap::new(),
513                true,
514                module_cache.as_deref_mut(),
515                current_path,
516            );
517            analysis.diagnostics.append(&mut top_diags);
518            analysis.hover_entries.append(&mut top_hovers);
519
520            analyze_function_bodies(
521                &module,
522                source,
523                &func_map,
524                &imports,
525                &class_map,
526                &mut analysis,
527                module_cache,
528                current_path,
529            );
530        }
531        Err(err) => {
532            analysis.diagnostics.push(Diagnostic {
533                range: default_range(),
534                severity: Some(DiagnosticSeverity::ERROR),
535                code: None,
536                code_description: None,
537                source: Some("shapels".into()),
538                message: format!("Parse error: {err}"),
539                related_information: None,
540                tags: None,
541                data: None,
542            });
543        }
544    }
545    analysis
546}
547
548fn collect_function_defs(body: &[Stmt], func_map: &mut FuncMap) {
549    for stmt in body {
550        if let Stmt::FunctionDef(func) = stmt {
551            func_map.insert(
552                func.name.clone(),
553                FunctionInfo {
554                    args: func.args.clone(),
555                    body: func.body.clone(),
556                    returns: func.returns.clone(),
557                },
558            );
559        }
560    }
561}
562
563struct ClassDefInfo {
564    forward: Option<FunctionInfo>,
565    methods: HashMap<Identifier, FunctionInfo>,
566    base_names: Vec<Identifier>,
567    direct_torch: bool,
568}
569
570fn is_torch_nn_module_base(expr: &Expr, imports: &Imports) -> bool {
571    let Expr::Attribute(attr) = expr else {
572        return false;
573    };
574    if attr.attr.as_str() != "Module" {
575        return false;
576    }
577    match attr.value.as_ref() {
578        Expr::Attribute(nn_attr) if nn_attr.attr.as_str() == "nn" => {
579            if let Expr::Name(torch_name) = nn_attr.value.as_ref() {
580                return imports.torch_aliases.contains(&torch_name.id);
581            }
582            false
583        }
584        Expr::Name(nn_name) => imports
585            .module_aliases
586            .get(&nn_name.id)
587            .map(|module| module == "torch.nn")
588            .unwrap_or(false),
589        _ => false,
590    }
591}
592
593fn collect_class_defs(body: &[Stmt], imports: &Imports) -> ClassMap {
594    let mut defs: HashMap<Identifier, ClassDefInfo> = HashMap::new();
595    for stmt in body {
596        if let Stmt::ClassDef(class_def) = stmt {
597            let mut methods = HashMap::new();
598            let mut forward = None;
599            for stmt in class_def.body.iter() {
600                if let Stmt::FunctionDef(func) = stmt {
601                    let info = FunctionInfo {
602                        args: func.args.clone(),
603                        body: func.body.clone(),
604                        returns: func.returns.clone(),
605                    };
606                    if func.name.as_str() == "forward" {
607                        forward = Some(info.clone());
608                    }
609                    methods.insert(func.name.clone(), info);
610                }
611            }
612            let base_names = class_def
613                .bases
614                .iter()
615                .filter_map(|base| match base {
616                    Expr::Name(name) => Some(name.id.clone()),
617                    _ => None,
618                })
619                .collect();
620            let direct_torch = class_def
621                .bases
622                .iter()
623                .any(|base| is_torch_nn_module_base(base, imports));
624            defs.insert(
625                class_def.name.clone(),
626                ClassDefInfo {
627                    forward,
628                    methods,
629                    base_names,
630                    direct_torch,
631                },
632            );
633        }
634    }
635
636    let mut is_torch: HashMap<Identifier, bool> = defs
637        .iter()
638        .map(|(name, info)| (name.clone(), info.direct_torch))
639        .collect();
640    let mut changed = true;
641    while changed {
642        changed = false;
643        for (name, info) in defs.iter() {
644            if !is_torch.get(name).copied().unwrap_or(false)
645                && info
646                    .base_names
647                    .iter()
648                    .any(|base| is_torch.get(base).copied().unwrap_or(false))
649            {
650                is_torch.insert(name.clone(), true);
651                changed = true;
652            }
653        }
654    }
655
656    let mut class_map = HashMap::new();
657    for (name, info) in defs {
658        class_map.insert(
659            name.clone(),
660            ClassInfo {
661                is_torch_module: is_torch.get(&name).copied().unwrap_or(false),
662                forward: info.forward,
663                methods: info.methods,
664            },
665        );
666    }
667    class_map
668}
669
670/// Iterate over the function and classes of a module `body`.
671///
672/// The base case is a function, where shape inference is run. For classes,
673/// it's called recursively, treating the class as a module where each method is a function.
674fn analyze_function_bodies(
675    body: &[Stmt],
676    source: &str,
677    func_map: &FuncMap,
678    imports: &Imports,
679    class_map: &ClassMap,
680    analysis: &mut Analysis,
681    mut module_cache: Option<&mut ModuleCache>,
682    module_path: Option<&Path>,
683) {
684    for stmt in body {
685        match stmt {
686            Stmt::FunctionDef(func) => {
687                let mut func_analysis = analyze_function(
688                    &func.args,
689                    &func.body,
690                    source,
691                    func_map,
692                    imports,
693                    class_map,
694                    &mut Vec::new(),
695                    module_cache.as_deref_mut(),
696                    module_path,
697                );
698                analysis.diagnostics.append(&mut func_analysis.diagnostics);
699                analysis
700                    .hover_entries
701                    .append(&mut func_analysis.hover_entries);
702            }
703            Stmt::ClassDef(class_def) => analyze_function_bodies(
704                &class_def.body,
705                source,
706                func_map,
707                imports,
708                class_map,
709                analysis,
710                module_cache.as_deref_mut(),
711                module_path,
712            ),
713            _ => {}
714        }
715    }
716}
717
718fn analyze_function(
719    args: &Arguments,
720    body: &[Stmt],
721    source: &str,
722    func_map: &FuncMap,
723    imports: &Imports,
724    class_map: &ClassMap,
725    call_stack: &mut Vec<Identifier>,
726    module_cache: Option<&mut ModuleCache>,
727    module_path: Option<&Path>,
728) -> Analysis {
729    let (diagnostics, hover_entries, _) = simulate_function(
730        args,
731        body,
732        source,
733        func_map,
734        imports,
735        class_map,
736        call_stack,
737        HashMap::new(),
738        true,
739        module_cache,
740        module_path,
741    );
742
743    Analysis {
744        diagnostics,
745        hover_entries,
746    }
747}
748
749/// Initializes inputs of a function and wraps around [`simulate_block`]
750/// that may run recursively carrying the initialized inputs.
751fn simulate_function(
752    args: &Arguments,
753    body: &[Stmt],
754    source: &str,
755    func_map: &FuncMap,
756    imports: &Imports,
757    class_map: &ClassMap,
758    call_stack: &mut Vec<Identifier>,
759    mut initial_vars: HashMap<Identifier, VarState>,
760    record_hovers: bool,
761    module_cache: Option<&mut ModuleCache>,
762    module_path: Option<&Path>,
763) -> (Vec<Diagnostic>, Vec<(Range, HoverInfo)>, ReturnValue) {
764    let mut diagnostics = Vec::new();
765    let mut hover_entries = Vec::new();
766    let mut vars: HashMap<Identifier, VarState> = HashMap::new();
767
768    seed_args_from_annotations(
769        args,
770        source,
771        &mut vars,
772        &mut hover_entries,
773        Some(&mut initial_vars),
774        record_hovers,
775        imports,
776        class_map,
777    );
778
779    let mut return_value = ReturnValue::default();
780
781    let _ = simulate_block(
782        body,
783        &mut vars,
784        &mut diagnostics,
785        &mut hover_entries,
786        &mut return_value,
787        source,
788        func_map,
789        imports,
790        class_map,
791        call_stack,
792        record_hovers,
793        module_cache,
794        module_path,
795        false,
796    );
797
798    (diagnostics, hover_entries, return_value)
799}
800
801#[derive(Clone, Copy, Debug, PartialEq, Eq)]
802enum BlockFlow {
803    None,
804    Break,
805    Continue,
806    Return,
807}
808
809/// Run static shape inference on assignments and return types.
810fn simulate_block(
811    body: &[Stmt],
812    vars: &mut HashMap<Identifier, VarState>,
813    diagnostics: &mut Vec<Diagnostic>,
814    hover_entries: &mut Vec<(Range, HoverInfo)>,
815    return_value: &mut ReturnValue,
816    source: &str,
817    func_map: &FuncMap,
818    imports: &Imports,
819    class_map: &ClassMap,
820    call_stack: &mut Vec<Identifier>,
821    record_hovers: bool,
822    mut module_cache: Option<&mut ModuleCache>,
823    module_path: Option<&Path>,
824    in_loop: bool,
825) -> BlockFlow {
826    for stmt in body {
827        match stmt {
828            Stmt::AnnAssign(assign) => {
829                if assignment_shape_checks(
830                    &assign.target,
831                    assign.value.as_deref().unwrap_or(assign.target.as_ref()),
832                    vars,
833                    diagnostics,
834                    hover_entries,
835                    record_hovers,
836                    source,
837                ) {
838                    continue;
839                }
840                if let Some(name) = name_from_expr(&assign.target) {
841                    let ann_shape = parse_shape_annotation(&assign.annotation);
842                    let range = text_range_to_lsp(expr_text_range(&assign.target), source);
843                    let mut inferred = None;
844                    if let Some(val) = &assign.value {
845                        if assignment_shape_checks(
846                            val,
847                            val,
848                            vars,
849                            diagnostics,
850                            hover_entries,
851                            record_hovers,
852                            source,
853                        ) {
854                            inferred = vars
855                                .get(&name)
856                                .and_then(|v| v.annotated.clone().or(v.inferred.clone()));
857                        } else {
858                            inferred = infer_expr_shape(
859                                val,
860                                vars,
861                                func_map,
862                                imports,
863                                class_map,
864                                call_stack,
865                                diagnostics,
866                                hover_entries,
867                                record_hovers,
868                                source,
869                                module_cache.as_deref_mut(),
870                                module_path,
871                            );
872                        }
873                    }
874                    if let (Some(ann), Some(inf)) = (ann_shape.clone(), inferred.clone())
875                        && !shape_dims_equal(&ann, &inf)
876                    {
877                        // If annotation is a shape-unroll, treat it as a rename rather than mismatch.
878                        if let Expr::Subscript(_sub) = &*assign.annotation
879                            && ann.dims.len() == inf.dims.len()
880                            && matches!(
881                                assign.value.as_deref(),
882                                Some(Expr::Name(_)) | Some(Expr::Attribute(_))
883                            )
884                        {
885                            let mut renamed = inf.clone();
886                            renamed.dims = ann.dims.clone();
887                            vars.insert(
888                                name.clone(),
889                                VarState {
890                                    annotated: ann_shape.clone(),
891                                    inferred: Some(renamed.clone()),
892                                    class_ref: None,
893                                },
894                            );
895                            if record_hovers {
896                                hover_entries.push((
897                                    range,
898                                    HoverInfo {
899                                        shape: Some(renamed),
900                                    },
901                                ));
902                            }
903                            continue;
904                        }
905                        diagnostics.push(Diagnostic {
906                            range,
907                            severity: Some(DiagnosticSeverity::ERROR),
908                            code: None,
909                            code_description: None,
910                            source: Some("shapels".into()),
911                            message: format!(
912                                "Shape mismatch: annotation {} vs inferred {}",
913                                ann.render(),
914                                inf.render()
915                            ),
916                            related_information: None,
917                            tags: None,
918                            data: None,
919                        });
920                    }
921                    let chosen_shape = ann_shape.clone().or(inferred.clone());
922                    if let Some(shape) = chosen_shape {
923                        vars.insert(
924                            name.clone(),
925                            VarState {
926                                annotated: ann_shape.clone(),
927                                inferred,
928                                class_ref: None,
929                            },
930                        );
931                        if record_hovers {
932                            hover_entries.push((range, HoverInfo { shape: Some(shape) }));
933                        }
934                    } else if let Some(val) = &assign.value
935                        && let Some(class_ref) = class_ref_from_constructor_call(
936                            val,
937                            class_map,
938                            imports,
939                            module_cache.as_deref_mut(),
940                            module_path,
941                        )
942                    {
943                        vars.insert(
944                            name.clone(),
945                            VarState {
946                                annotated: ann_shape.clone(),
947                                inferred: None,
948                                class_ref: Some(class_ref),
949                            },
950                        );
951                    }
952                }
953            }
954            Stmt::Assign(assign) => {
955                // handle tuple destructuring of `.shape`
956                if assign.targets.len() == 1
957                    && assignment_shape_checks(
958                        &assign.targets[0],
959                        &assign.value,
960                        vars,
961                        diagnostics,
962                        hover_entries,
963                        record_hovers,
964                        source,
965                    )
966                {
967                    continue;
968                }
969                if assign.targets.len() == 1
970                    && matches!(assign.targets[0], Expr::Tuple(_))
971                    && let Expr::Tuple(target_tuple) = &assign.targets[0]
972                    && let Some(tuple_shapes) = infer_tuple_elements(
973                        &assign.value,
974                        vars,
975                        func_map,
976                        imports,
977                        class_map,
978                        call_stack,
979                        diagnostics,
980                        hover_entries,
981                        record_hovers,
982                        source,
983                        module_cache.as_deref_mut(),
984                        module_path,
985                    )
986                {
987                    for (target_expr, shape_opt) in
988                        target_tuple.elts.iter().zip(tuple_shapes.into_iter())
989                    {
990                        if let (Some(name), Some(shape)) = (name_from_expr(target_expr), shape_opt)
991                        {
992                            vars.insert(
993                                name.clone(),
994                                VarState {
995                                    annotated: None,
996                                    inferred: Some(shape.clone()),
997                                    class_ref: None,
998                                },
999                            );
1000                            if record_hovers {
1001                                let range = text_range_to_lsp(expr_text_range(target_expr), source);
1002                                hover_entries.push((range, HoverInfo { shape: Some(shape) }));
1003                            }
1004                        }
1005                    }
1006                    continue;
1007                }
1008                if assign.targets.len() == 1
1009                    && let Some(name) = name_from_expr(&assign.targets[0])
1010                {
1011                    let range = text_range_to_lsp(expr_text_range(&assign.targets[0]), source);
1012                    let shape = infer_expr_shape(
1013                        &assign.value,
1014                        vars,
1015                        func_map,
1016                        imports,
1017                        class_map,
1018                        call_stack,
1019                        diagnostics,
1020                        hover_entries,
1021                        record_hovers,
1022                        source,
1023                        module_cache.as_deref_mut(),
1024                        module_path,
1025                    );
1026                    if let Some(shape) = shape {
1027                        vars.insert(
1028                            name.clone(),
1029                            VarState {
1030                                annotated: None,
1031                                inferred: Some(shape.clone()),
1032                                class_ref: None,
1033                            },
1034                        );
1035                        if record_hovers {
1036                            hover_entries.push((range, HoverInfo { shape: Some(shape) }));
1037                        }
1038                    } else if let Some(class_ref) = class_ref_from_constructor_call(
1039                        &assign.value,
1040                        class_map,
1041                        imports,
1042                        module_cache.as_deref_mut(),
1043                        module_path,
1044                    ) {
1045                        vars.insert(
1046                            name.clone(),
1047                            VarState {
1048                                annotated: None,
1049                                inferred: None,
1050                                class_ref: Some(class_ref),
1051                            },
1052                        );
1053                    }
1054                }
1055            }
1056            Stmt::For(for_stmt) => {
1057                let flow = simulate_block(
1058                    &for_stmt.body,
1059                    vars,
1060                    diagnostics,
1061                    hover_entries,
1062                    return_value,
1063                    source,
1064                    func_map,
1065                    imports,
1066                    class_map,
1067                    call_stack,
1068                    record_hovers,
1069                    module_cache.as_deref_mut(),
1070                    module_path,
1071                    true,
1072                );
1073                if flow == BlockFlow::Return {
1074                    return BlockFlow::Return;
1075                }
1076                if flow != BlockFlow::Break {
1077                    let else_flow = simulate_block(
1078                        &for_stmt.orelse,
1079                        vars,
1080                        diagnostics,
1081                        hover_entries,
1082                        return_value,
1083                        source,
1084                        func_map,
1085                        imports,
1086                        class_map,
1087                        call_stack,
1088                        record_hovers,
1089                        module_cache.as_deref_mut(),
1090                        module_path,
1091                        true,
1092                    );
1093                    if else_flow == BlockFlow::Return {
1094                        return BlockFlow::Return;
1095                    }
1096                }
1097            }
1098            Stmt::While(while_stmt) => {
1099                let flow = simulate_block(
1100                    &while_stmt.body,
1101                    vars,
1102                    diagnostics,
1103                    hover_entries,
1104                    return_value,
1105                    source,
1106                    func_map,
1107                    imports,
1108                    class_map,
1109                    call_stack,
1110                    record_hovers,
1111                    module_cache.as_deref_mut(),
1112                    module_path,
1113                    true,
1114                );
1115                if flow == BlockFlow::Return {
1116                    return BlockFlow::Return;
1117                }
1118                if flow != BlockFlow::Break {
1119                    let else_flow = simulate_block(
1120                        &while_stmt.orelse,
1121                        vars,
1122                        diagnostics,
1123                        hover_entries,
1124                        return_value,
1125                        source,
1126                        func_map,
1127                        imports,
1128                        class_map,
1129                        call_stack,
1130                        record_hovers,
1131                        module_cache.as_deref_mut(),
1132                        module_path,
1133                        true,
1134                    );
1135                    if else_flow == BlockFlow::Return {
1136                        return BlockFlow::Return;
1137                    }
1138                }
1139            }
1140            Stmt::If(if_stmt) => {
1141                let mut body_vars = vars.clone();
1142                let body_flow = simulate_block(
1143                    &if_stmt.body,
1144                    &mut body_vars,
1145                    diagnostics,
1146                    hover_entries,
1147                    return_value,
1148                    source,
1149                    func_map,
1150                    imports,
1151                    class_map,
1152                    call_stack,
1153                    record_hovers,
1154                    module_cache.as_deref_mut(),
1155                    module_path,
1156                    in_loop,
1157                );
1158                let mut else_vars = vars.clone();
1159                let else_flow = simulate_block(
1160                    &if_stmt.orelse,
1161                    &mut else_vars,
1162                    diagnostics,
1163                    hover_entries,
1164                    return_value,
1165                    source,
1166                    func_map,
1167                    imports,
1168                    class_map,
1169                    call_stack,
1170                    record_hovers,
1171                    module_cache.as_deref_mut(),
1172                    module_path,
1173                    in_loop,
1174                );
1175                if body_flow == else_flow
1176                    && matches!(
1177                        body_flow,
1178                        BlockFlow::Break | BlockFlow::Continue | BlockFlow::Return
1179                    )
1180                {
1181                    return body_flow;
1182                }
1183            }
1184            Stmt::Break(_) => {
1185                if in_loop {
1186                    return BlockFlow::Break;
1187                }
1188            }
1189            Stmt::Continue(_) => {
1190                if in_loop {
1191                    return BlockFlow::Continue;
1192                }
1193            }
1194            Stmt::Return(ret) => {
1195                if let Some(val) = &ret.value {
1196                    let new_value = match val.as_ref() {
1197                        Expr::Tuple(tuple) => {
1198                            let tuple_shapes: Vec<Option<Shape>> = tuple
1199                                .elts
1200                                .iter()
1201                                .map(|elt| {
1202                                    infer_expr_shape(
1203                                        elt,
1204                                        vars,
1205                                        func_map,
1206                                        imports,
1207                                        class_map,
1208                                        call_stack,
1209                                        diagnostics,
1210                                        hover_entries,
1211                                        record_hovers,
1212                                        source,
1213                                        module_cache.as_deref_mut(),
1214                                        module_path,
1215                                    )
1216                                })
1217                                .collect();
1218                            ReturnValue::from_tuple(tuple_shapes)
1219                        }
1220                        _ => ReturnValue::from_shape(infer_expr_shape(
1221                            val,
1222                            vars,
1223                            func_map,
1224                            imports,
1225                            class_map,
1226                            call_stack,
1227                            diagnostics,
1228                            hover_entries,
1229                            record_hovers,
1230                            source,
1231                            module_cache.as_deref_mut(),
1232                            module_path,
1233                        )),
1234                    };
1235                    if record_hovers && let Some(shape) = new_value.first() {
1236                        let range = text_range_to_lsp(expr_text_range(val), source);
1237                        hover_entries.push((
1238                            range,
1239                            HoverInfo {
1240                                shape: Some(shape.clone()),
1241                            },
1242                        ));
1243                    }
1244                    if new_value.is_some() {
1245                        *return_value = new_value;
1246                    }
1247                }
1248                return BlockFlow::Return;
1249            }
1250            _ => {}
1251        }
1252    }
1253    BlockFlow::None
1254}
1255
1256/// Recursively run static shape inference on an [`Expr`].
1257///
1258/// This is the central function for inference, it calls the specialized
1259/// inference at src/infer.rs depending on type of the expression.
1260fn infer_expr_shape(
1261    expr: &Expr,
1262    vars: &HashMap<Identifier, VarState>,
1263    func_map: &FuncMap,
1264    imports: &Imports,
1265    class_map: &ClassMap,
1266    call_stack: &mut Vec<Identifier>,
1267    diagnostics: &mut Vec<Diagnostic>,
1268    hover_entries: &mut Vec<(Range, HoverInfo)>,
1269    record_hovers: bool,
1270    source: &str,
1271    mut module_cache: Option<&mut ModuleCache>,
1272    module_path: Option<&Path>,
1273) -> Option<Shape> {
1274    match expr {
1275        Expr::BinOp(ExprBinOp {
1276            left,
1277            op,
1278            right,
1279            range: expr_range,
1280        }) => match op {
1281            Operator::Mult | Operator::Add | Operator::Sub | Operator::Div => {
1282                infer_broadcastable_poswise(
1283                    &ShapeOrExpr::Expr(left),
1284                    right,
1285                    vars,
1286                    func_map,
1287                    imports,
1288                    class_map,
1289                    call_stack,
1290                    diagnostics,
1291                    hover_entries,
1292                    record_hovers,
1293                    source,
1294                    *expr_range,
1295                    module_cache.as_deref_mut(),
1296                    module_path,
1297                    BroadcastOp::Arithmetic,
1298                )
1299            }
1300            Operator::BitAnd
1301            | Operator::BitXor
1302            | Operator::BitOr
1303            | Operator::LShift
1304            | Operator::RShift => infer_broadcastable_poswise(
1305                &ShapeOrExpr::Expr(left),
1306                right,
1307                vars,
1308                func_map,
1309                imports,
1310                class_map,
1311                call_stack,
1312                diagnostics,
1313                hover_entries,
1314                record_hovers,
1315                source,
1316                *expr_range,
1317                module_cache.as_deref_mut(),
1318                module_path,
1319                BroadcastOp::Bitwise,
1320            ),
1321            Operator::MatMult => infer_matmul_shapes(
1322                left,
1323                right,
1324                vars,
1325                func_map,
1326                imports,
1327                class_map,
1328                call_stack,
1329                diagnostics,
1330                hover_entries,
1331                record_hovers,
1332                source,
1333                *expr_range,
1334                module_cache.as_deref_mut(),
1335                module_path,
1336            ),
1337            _ => None,
1338        },
1339        Expr::Compare(ExprCompare {
1340            left: init_left,
1341            // can be chained, so we need to compute the pairs left to right
1342            comparators,
1343            range: expr_range,
1344            ..
1345        }) => {
1346            // first, infer with two Expr
1347            let mut iter = comparators.iter();
1348            let first_right = iter.next()?;
1349            let init = infer_broadcastable_poswise(
1350                &ShapeOrExpr::Expr(init_left.as_ref()),
1351                first_right,
1352                vars,
1353                func_map,
1354                imports,
1355                class_map,
1356                call_stack,
1357                diagnostics,
1358                hover_entries,
1359                record_hovers,
1360                source,
1361                *expr_range,
1362                module_cache.as_deref_mut(),
1363                module_path,
1364                BroadcastOp::Eq,
1365            );
1366
1367            // second, fold with Shape and Expr
1368            iter.fold(init, |left, right| {
1369                infer_broadcastable_poswise(
1370                    &ShapeOrExpr::Shape(left.as_ref()),
1371                    right,
1372                    vars,
1373                    func_map,
1374                    imports,
1375                    class_map,
1376                    call_stack,
1377                    diagnostics,
1378                    hover_entries,
1379                    record_hovers,
1380                    source,
1381                    *expr_range,
1382                    module_cache.as_deref_mut(),
1383                    module_path,
1384                    BroadcastOp::Eq,
1385                )
1386            })
1387        }
1388        Expr::Call(call) => {
1389            if let Expr::Name(func_name) = call.func.as_ref() {
1390                if let Some(class_ref) = vars.get(&func_name.id).and_then(|v| v.class_ref.as_ref())
1391                    && let Some(ret) = infer_class_call_return(
1392                        call,
1393                        class_ref,
1394                        vars,
1395                        func_map,
1396                        imports,
1397                        class_map,
1398                        call_stack,
1399                        diagnostics,
1400                        hover_entries,
1401                        record_hovers,
1402                        source,
1403                        module_cache.as_deref_mut(),
1404                        module_path,
1405                    )
1406                {
1407                    return ret.first().cloned();
1408                }
1409                let torchop_shape = torch_op_to_shape(
1410                    vars,
1411                    func_map,
1412                    imports,
1413                    class_map,
1414                    call_stack,
1415                    diagnostics,
1416                    hover_entries,
1417                    record_hovers,
1418                    source,
1419                    &mut module_cache,
1420                    module_path,
1421                    call,
1422                    func_name.id.as_str(),
1423                    TorchOp::as_call(&func_name.id, imports),
1424                    TorchOpKind::Function,
1425                    None,
1426                );
1427                if torchop_shape.is_some() {
1428                    return torchop_shape;
1429                }
1430
1431                if let Some((module_name, original)) = imports.from_imports.get(&func_name.id)
1432                    && let (Some(cache), Some(cur_path)) =
1433                        (module_cache.as_deref_mut(), module_path)
1434                    && let Some(module) = cache.get_module(module_name, cur_path)
1435                    && let Some(callee_info) = module.func_map.get(original)
1436                    && let Some(ret) = infer_call_return_from_info(
1437                        call,
1438                        &func_name.id,
1439                        callee_info,
1440                        &module.source,
1441                        &module.func_map,
1442                        &module.imports,
1443                        &module.class_map,
1444                        vars,
1445                        func_map,
1446                        imports,
1447                        class_map,
1448                        call_stack,
1449                        diagnostics,
1450                        hover_entries,
1451                        record_hovers,
1452                        source,
1453                        module_cache.as_deref_mut(),
1454                        Some(module.path.as_path()),
1455                        0,
1456                        false,
1457                    )
1458                {
1459                    return ret.first().cloned();
1460                }
1461                if let Some(callee_info) = func_map.get(&func_name.id)
1462                    && let Some(ret) = infer_call_return_from_info(
1463                        call,
1464                        &func_name.id,
1465                        callee_info,
1466                        source,
1467                        func_map,
1468                        imports,
1469                        class_map,
1470                        vars,
1471                        func_map,
1472                        imports,
1473                        class_map,
1474                        call_stack,
1475                        diagnostics,
1476                        hover_entries,
1477                        record_hovers,
1478                        source,
1479                        module_cache.as_deref_mut(),
1480                        module_path,
1481                        0,
1482                        false,
1483                    )
1484                {
1485                    return ret.first().cloned();
1486                }
1487            }
1488            // methods are functions with attributes
1489            if let Expr::Attribute(attr) = call.func.as_ref() {
1490                let attr_name: &str = attr.attr.as_ref();
1491                if let Expr::Name(base_name) = attr.value.as_ref()
1492                    && let Some(class_ref) =
1493                        vars.get(&base_name.id).and_then(|v| v.class_ref.as_ref())
1494                    && let Some(ret) = infer_class_method_call_return(
1495                        call,
1496                        class_ref,
1497                        &attr.attr,
1498                        vars,
1499                        func_map,
1500                        imports,
1501                        class_map,
1502                        call_stack,
1503                        diagnostics,
1504                        hover_entries,
1505                        record_hovers,
1506                        source,
1507                        module_cache.as_deref_mut(),
1508                        module_path,
1509                    )
1510                {
1511                    return ret.first().cloned();
1512                }
1513                if let Expr::Name(module_ident) = attr.value.as_ref()
1514                    && let Some(module_name) = imports.module_aliases.get(&module_ident.id)
1515                    && module_name != "torch"
1516                    && let (Some(cache), Some(cur_path)) =
1517                        (module_cache.as_deref_mut(), module_path)
1518                    && let Some(module) = cache.get_module(module_name, cur_path)
1519                    && let Some(callee_info) = module.func_map.get(&attr.attr)
1520                    && let Some(ret) = infer_call_return_from_info(
1521                        call,
1522                        &attr.attr,
1523                        callee_info,
1524                        &module.source,
1525                        &module.func_map,
1526                        &module.imports,
1527                        &module.class_map,
1528                        vars,
1529                        func_map,
1530                        imports,
1531                        class_map,
1532                        call_stack,
1533                        diagnostics,
1534                        hover_entries,
1535                        record_hovers,
1536                        source,
1537                        module_cache.as_deref_mut(),
1538                        Some(module.path.as_path()),
1539                        0,
1540                        false,
1541                    )
1542                {
1543                    return ret.first().cloned();
1544                }
1545                // two cases: torch.ATTR_NAME(torch.Tensor, ...) or torch.Tensor.ATTR_NAME(...)
1546                let op_kind = function_or_method(&attr.value, imports);
1547                let aliased_shape = torch_op_to_shape(
1548                    vars,
1549                    func_map,
1550                    imports,
1551                    class_map,
1552                    call_stack,
1553                    diagnostics,
1554                    hover_entries,
1555                    record_hovers,
1556                    source,
1557                    &mut module_cache,
1558                    module_path,
1559                    call,
1560                    attr_name,
1561                    TorchOp::from_attr(attr_name),
1562                    op_kind,
1563                    Some(attr),
1564                );
1565                if aliased_shape.is_some() {
1566                    return aliased_shape;
1567                }
1568            }
1569            None
1570        }
1571        Expr::Subscript(sub) => infer_index(
1572            &sub.value,
1573            &sub.slice,
1574            vars,
1575            func_map,
1576            imports,
1577            class_map,
1578            call_stack,
1579            diagnostics,
1580            hover_entries,
1581            record_hovers,
1582            source,
1583            module_cache.as_deref_mut(),
1584            module_path,
1585        ),
1586        // attributes of a tensor, not a method!
1587        Expr::Attribute(attr) => {
1588            let attr_name: &str = attr.attr.as_ref();
1589            if attr_name == "T" {
1590                let base_hint =
1591                    lookup_shape(&attr.value, vars, hover_entries, record_hovers, source).or_else(
1592                        || {
1593                            infer_expr_shape(
1594                                &attr.value,
1595                                vars,
1596                                func_map,
1597                                imports,
1598                                class_map,
1599                                call_stack,
1600                                diagnostics,
1601                                hover_entries,
1602                                false,
1603                                source,
1604                                module_cache.as_deref_mut(),
1605                                module_path,
1606                            )
1607                        },
1608                    );
1609                let order_args: Vec<&Expr> = Vec::new();
1610                let res = infer_permute(
1611                    &attr.value,
1612                    &order_args,
1613                    Transpose::T,
1614                    base_hint,
1615                    vars,
1616                    diagnostics,
1617                    hover_entries,
1618                    record_hovers,
1619                    source,
1620                    attr.range,
1621                );
1622                if record_hovers && let Some(s) = res.clone() {
1623                    let range = text_range_to_lsp(attr.range, source);
1624                    hover_entries.push((range, HoverInfo { shape: Some(s) }));
1625                }
1626                return res;
1627            } else if attr_name == "shape" || attr_name == "dtype" {
1628                return lookup_shape(&attr.value, vars, hover_entries, record_hovers, source)
1629                    .or_else(|| {
1630                        infer_expr_shape(
1631                            &attr.value,
1632                            vars,
1633                            func_map,
1634                            imports,
1635                            class_map,
1636                            call_stack,
1637                            diagnostics,
1638                            hover_entries,
1639                            false,
1640                            source,
1641                            module_cache,
1642                            module_path,
1643                        )
1644                    });
1645            }
1646            None
1647        }
1648        Expr::Name(expr_name) => {
1649            let shape = vars
1650                .get(&expr_name.id)
1651                .and_then(|v| v.annotated.clone().or_else(|| v.inferred.clone()));
1652            if record_hovers && let Some(s) = shape.clone() {
1653                let range = text_range_to_lsp(expr_text_range(expr), source);
1654                hover_entries.push((range, HoverInfo { shape: Some(s) }));
1655            }
1656            shape
1657        }
1658        _ => None,
1659    }
1660}
1661
1662/// Match an `TorchOp` against all supported torch operations.
1663///
1664/// It can be a function or a method, returns None if the operation
1665/// is not implemented OR if the arguments to the operation are not
1666/// incorrect.
1667fn torch_op_to_shape(
1668    vars: &HashMap<Identifier, VarState>,
1669    func_map: &HashMap<Identifier, FunctionInfo>,
1670    imports: &Imports,
1671    class_map: &HashMap<Identifier, ClassInfo>,
1672    call_stack: &mut Vec<Identifier>,
1673    diagnostics: &mut Vec<Diagnostic>,
1674    hover_entries: &mut Vec<(Range, HoverInfo)>,
1675    record_hovers: bool,
1676    source: &str,
1677    module_cache: &mut Option<&mut ModuleCache>,
1678    module_path: Option<&Path>,
1679    call: &ExprCall,
1680    attr_name: &str,
1681    torch_op: TorchOp,
1682    op_kind: TorchOpKind,
1683    maybe_attr: Option<&ast::ExprAttribute>,
1684) -> Option<Shape> {
1685    use TorchOpKind::*;
1686
1687    let (may_arg0, may_arg1, offset) = match op_kind {
1688        Function => (call.args.first(), call.args.get(1), 1),
1689        Method => (maybe_attr.map(|x| x.value.as_ref()), call.args.first(), 0),
1690    };
1691    match (torch_op, may_arg0, may_arg1, &op_kind) {
1692        (TorchOp::MatMul, Some(arg0), Some(arg1), _) => infer_matmul_shapes(
1693            arg0,
1694            arg1,
1695            vars,
1696            func_map,
1697            imports,
1698            class_map,
1699            call_stack,
1700            diagnostics,
1701            hover_entries,
1702            record_hovers,
1703            source,
1704            call.range,
1705            module_cache.as_deref_mut(),
1706            module_path,
1707        ),
1708        (op @ (TorchOp::Squeeze | TorchOp::Aggr), Some(arg0), _, _) => infer_squeeze(
1709            arg0,
1710            get_arg(call, "dim", offset),
1711            vars,
1712            func_map,
1713            imports,
1714            class_map,
1715            call_stack,
1716            diagnostics,
1717            hover_entries,
1718            record_hovers,
1719            source,
1720            call.range,
1721            module_cache.as_deref_mut(),
1722            module_path,
1723            matches!(op, TorchOp::Squeeze),
1724            get_arg(call, "keepdim", offset + 1),
1725        ),
1726        (TorchOp::NoopDim, Some(base), _, _) => {
1727            let base_hint = infer_expr_shape(
1728                base,
1729                vars,
1730                func_map,
1731                imports,
1732                class_map,
1733                call_stack,
1734                diagnostics,
1735                hover_entries,
1736                false,
1737                source,
1738                module_cache.as_deref_mut(),
1739                module_path,
1740            );
1741            infer_noop(
1742                base_hint,
1743                get_arg(call, "dim", offset),
1744                vars,
1745                diagnostics,
1746                source,
1747                call.range,
1748                attr_name.contains("soft"), // HACK: dim not optional for softmax
1749            )
1750        }
1751        (TorchOp::Noop, Some(base), _, _) => infer_expr_shape(
1752            base,
1753            vars,
1754            func_map,
1755            imports,
1756            class_map,
1757            call_stack,
1758            diagnostics,
1759            hover_entries,
1760            false,
1761            source,
1762            module_cache.as_deref_mut(),
1763            module_path,
1764        ),
1765        (TorchOp::Creation { is_size }, _, _, Function) => {
1766            let shape_assign = tensor_or_shape_as_arg(
1767                is_size,
1768                vars,
1769                func_map,
1770                imports,
1771                class_map,
1772                call_stack,
1773                diagnostics,
1774                hover_entries,
1775                record_hovers,
1776                source,
1777                module_cache,
1778                module_path,
1779                call,
1780            );
1781            let dtype = get_arg(call, "dtype", 200).and_then(|expr| {
1782                let attr_dtype = if matches!(expr, Expr::Attribute(_)) {
1783                    infer_expr_shape(
1784                        expr,
1785                        vars,
1786                        func_map,
1787                        imports,
1788                        class_map,
1789                        call_stack,
1790                        diagnostics,
1791                        hover_entries,
1792                        record_hovers,
1793                        source,
1794                        module_cache.as_deref_mut(),
1795                        module_path,
1796                    )
1797                    .and_then(|shape| shape.dtype)
1798                } else {
1799                    None
1800                };
1801                attr_dtype.or_else(|| get_dtype(expr, imports).map(|x| x.to_string()))
1802            });
1803
1804            infer_creation_size(
1805                call,
1806                vars,
1807                diagnostics,
1808                source,
1809                shape_assign,
1810                dtype,
1811                is_size,
1812            )
1813        }
1814        (TorchOp::RangeOp(range_op), _, _, Function) => infer_range_size(
1815            call,
1816            range_op,
1817            vars,
1818            func_map,
1819            imports,
1820            class_map,
1821            call_stack,
1822            diagnostics,
1823            hover_entries,
1824            record_hovers,
1825            source,
1826            module_cache.as_deref_mut(),
1827            module_path,
1828        ),
1829        (TorchOp::NoArg { predef_dtype }, Some(base), _, _)
1830            // only `.to` (no predef type) is allowed as Function
1831            if predef_dtype.is_none() || matches!(op_kind, Method) =>
1832        {
1833            let base_expr = infer_expr_shape(
1834                base,
1835                vars,
1836                func_map,
1837                imports,
1838                class_map,
1839                call_stack,
1840                diagnostics,
1841                hover_entries,
1842                false,
1843                source,
1844                module_cache.as_deref_mut(),
1845                module_path,
1846            );
1847            let dtype_expr = if let Some(dtype) = predef_dtype {
1848                Some(&Expr::Constant(ast::ExprConstant {
1849                    range: expr_text_range(base),
1850                    value: Constant::Str(dtype.to_string()),
1851                    kind: None,
1852                }))
1853            } else {
1854                get_arg(call, "dtype", 0)
1855            };
1856            infer_to(base_expr, dtype_expr, diagnostics, source, imports)
1857        }
1858        (TorchOp::Broadcastable(broadcast_op), Some(left), Some(right), _) => {
1859            infer_broadcastable_poswise(
1860                &ShapeOrExpr::Expr(left),
1861                right,
1862                vars,
1863                func_map,
1864                imports,
1865                class_map,
1866                call_stack,
1867                diagnostics,
1868                hover_entries,
1869                record_hovers,
1870                source,
1871                call.range,
1872                module_cache.as_deref_mut(),
1873                module_path,
1874                broadcast_op,
1875            )
1876        }
1877        (op @ (TorchOp::View | TorchOp::Expand), Some(base), _, kind)
1878            if matches!(kind, Method) || matches!(op, TorchOp::View) =>
1879        {
1880            let base_hint =
1881                lookup_shape(base, vars, hover_entries, record_hovers, source).or_else(|| {
1882                    infer_expr_shape(
1883                        base,
1884                        vars,
1885                        func_map,
1886                        imports,
1887                        class_map,
1888                        call_stack,
1889                        diagnostics,
1890                        hover_entries,
1891                        false,
1892                        source,
1893                        module_cache.as_deref_mut(),
1894                        module_path,
1895                    )
1896                });
1897            infer_view_like(
1898                base,
1899                &call.args.iter().collect::<Vec<_>>(),
1900                &op,
1901                base_hint,
1902                vars,
1903                diagnostics,
1904                hover_entries,
1905                record_hovers,
1906                source,
1907                call.range,
1908            )
1909        }
1910        (TorchOp::Transpose(transpose), Some(base), _, _) => {
1911            let order_args = match &transpose {
1912                Transpose::Permute => {
1913                    match get_arg(call, "dims", offset) {
1914                        Some(Expr::Tuple(ExprTuple { elts, .. }))
1915                        | Some(Expr::List(ExprList { elts, .. })) => elts.iter().collect(),
1916                        // torch.permute does not accept variadic args for dims
1917                        // but the torch.Tensor.permute does
1918                        _ if matches!(op_kind, Method) => call.args.iter().collect(),
1919                        _ => Vec::new(),
1920                    }
1921                }
1922                // torch.transpose does not accept a size-like
1923                Transpose::Explicit => call.args.iter().skip(offset).take(2).collect(),
1924                // no args
1925                Transpose::T => Vec::new(),
1926            };
1927            let base_hint = infer_expr_shape(
1928                base,
1929                vars,
1930                func_map,
1931                imports,
1932                class_map,
1933                call_stack,
1934                diagnostics,
1935                hover_entries,
1936                false,
1937                source,
1938                module_cache.as_deref_mut(),
1939                module_path,
1940            );
1941            infer_permute(
1942                base,
1943                &order_args,
1944                transpose,
1945                base_hint,
1946                vars,
1947                diagnostics,
1948                hover_entries,
1949                record_hovers,
1950                source,
1951                call.range,
1952            )
1953        }
1954        (TorchOp::Unsqueeze, Some(base), _, _) => {
1955            let dim_arg = get_arg(call, "dim", offset);
1956            infer_unsqueeze(
1957                base,
1958                dim_arg,
1959                vars,
1960                func_map,
1961                imports,
1962                class_map,
1963                call_stack,
1964                diagnostics,
1965                hover_entries,
1966                record_hovers,
1967                source,
1968                module_cache.as_deref_mut(),
1969                module_path,
1970            )
1971        }
1972        (TorchOp::Conv(d), Some(base), _, Function) => {
1973            let kernel = lookup_shape(get_arg(call, "weight", 1)?, vars, hover_entries, record_hovers, source)?;
1974            lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| infer_conv(shape, kernel.dims, call, d, diagnostics, source))
1975        }
1976        (TorchOp::Repeat, Some(base), _, Method) => {
1977            lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| infer_repeat(shape, call, vars, diagnostics, source))
1978        }
1979        (TorchOp::RepeatInterleave, Some(base), _, _) => {
1980            lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| {
1981                infer_repeat_interleave(shape, call, offset, diagnostics, source)
1982            })
1983        }
1984        (TorchOp::Flatten, Some(base), _, _) => {
1985            // TODO(carrascomj): ravel is function-only
1986            lookup_shape(base, vars, hover_entries, record_hovers, source).and_then(|shape| infer_flatten(shape, offset, call, vars, diagnostics, source))
1987        }
1988        (TorchOp::Unknown, _, _, Function) => {
1989            // TODO(carrascomj): check if emitting diagnostics here
1990            // is not too annoying
1991            None
1992        }
1993        // unsupported or not a torch tensor method, etc.
1994        _ => None,
1995    }
1996}
1997
1998fn infer_tuple_elements(
1999    expr: &Expr,
2000    vars: &HashMap<Identifier, VarState>,
2001    func_map: &FuncMap,
2002    imports: &Imports,
2003    class_map: &ClassMap,
2004    call_stack: &mut Vec<Identifier>,
2005    diagnostics: &mut Vec<Diagnostic>,
2006    hover_entries: &mut Vec<(Range, HoverInfo)>,
2007    record_hovers: bool,
2008    source: &str,
2009    mut module_cache: Option<&mut ModuleCache>,
2010    module_path: Option<&Path>,
2011) -> Option<Vec<Option<Shape>>> {
2012    match expr {
2013        Expr::Tuple(tuple) => Some(
2014            tuple
2015                .elts
2016                .iter()
2017                .map(|elt| {
2018                    infer_expr_shape(
2019                        elt,
2020                        vars,
2021                        func_map,
2022                        imports,
2023                        class_map,
2024                        call_stack,
2025                        diagnostics,
2026                        hover_entries,
2027                        record_hovers,
2028                        source,
2029                        module_cache.as_deref_mut(),
2030                        module_path,
2031                    )
2032                })
2033                .collect(),
2034        ),
2035        Expr::Call(call) => infer_defined_call_return(
2036            call,
2037            vars,
2038            func_map,
2039            imports,
2040            class_map,
2041            call_stack,
2042            diagnostics,
2043            hover_entries,
2044            record_hovers,
2045            source,
2046            module_cache,
2047            module_path,
2048        )
2049        .and_then(|ret| ret.tuple().map(|tuple| tuple.to_vec())),
2050        _ => None,
2051    }
2052}
2053
2054fn method_param_offset(args: &Arguments) -> usize {
2055    match args.args.first() {
2056        Some(param) if param.def.arg.as_str() == "self" => 1,
2057        _ => 0,
2058    }
2059}
2060
2061fn class_ref_from_constructor_call(
2062    call_expr: &Expr,
2063    class_map: &ClassMap,
2064    imports: &Imports,
2065    mut module_cache: Option<&mut ModuleCache>,
2066    module_path: Option<&Path>,
2067) -> Option<ClassRef> {
2068    let Expr::Call(call) = call_expr else {
2069        return None;
2070    };
2071    match call.func.as_ref() {
2072        Expr::Name(name) => {
2073            if class_map.contains_key(&name.id) {
2074                return Some(ClassRef {
2075                    name: name.id.clone(),
2076                    module: None,
2077                });
2078            }
2079            if let Some((module_name, original)) = imports.from_imports.get(&name.id)
2080                && module_name != "torch"
2081                && !module_name.starts_with("torch.")
2082                && let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
2083                && let Some(module) = cache.get_module(module_name, cur_path)
2084                && module.class_map.contains_key(original)
2085            {
2086                return Some(ClassRef {
2087                    name: original.clone(),
2088                    module: Some(module_name.to_string()),
2089                });
2090            }
2091        }
2092        Expr::Attribute(attr) => {
2093            if let Expr::Name(module_ident) = attr.value.as_ref()
2094                && let Some(module_name) = imports.module_aliases.get(&module_ident.id)
2095                && module_name != "torch"
2096                && !module_name.starts_with("torch.")
2097                && let (Some(cache), Some(cur_path)) = (module_cache, module_path)
2098                && let Some(module) = cache.get_module(module_name, cur_path)
2099                && module.class_map.contains_key(&attr.attr)
2100            {
2101                return Some(ClassRef {
2102                    name: attr.attr.clone(),
2103                    module: Some(module_name.to_string()),
2104                });
2105            }
2106        }
2107        _ => {}
2108    }
2109    None
2110}
2111
2112fn class_ref_from_annotation(
2113    ann: &Expr,
2114    imports: &Imports,
2115    class_map: &ClassMap,
2116) -> Option<ClassRef> {
2117    if let Expr::Name(name) = ann {
2118        if class_map.contains_key(&name.id) {
2119            return Some(ClassRef {
2120                name: name.id.clone(),
2121                module: None,
2122            });
2123        }
2124        if let Some((module_name, original)) = imports.from_imports.get(&name.id) {
2125            return Some(ClassRef {
2126                name: original.clone(),
2127                module: Some(module_name.clone()),
2128            });
2129        }
2130    }
2131    None
2132}
2133
2134fn union_members<'a>(expr: &'a Expr, out: &mut Vec<&'a Expr>) {
2135    if let Expr::BinOp(ExprBinOp {
2136        left, op, right, ..
2137    }) = expr
2138        && matches!(op, Operator::BitOr)
2139    {
2140        union_members(left, out);
2141        union_members(right, out);
2142    } else {
2143        out.push(expr);
2144    }
2145}
2146
2147fn shape_or_class_from_union(
2148    ann: &Expr,
2149    imports: &Imports,
2150    class_map: &ClassMap,
2151) -> (Option<Shape>, Option<ClassRef>) {
2152    let mut members = Vec::new();
2153    union_members(ann, &mut members);
2154    members
2155        .into_iter()
2156        .find_map(|member| {
2157            parse_shape_annotation(member)
2158                .map(|shape| (Some(shape), None))
2159                .or_else(|| {
2160                    class_ref_from_annotation(member, imports, class_map)
2161                        .map(|class_ref| (None, Some(class_ref)))
2162                })
2163        })
2164        .unwrap_or((None, None))
2165}
2166
2167fn tuple_shapes_from_annotation(
2168    ann: &Expr,
2169    imports: &Imports,
2170    class_map: &ClassMap,
2171) -> Option<Vec<Option<Shape>>> {
2172    let elements: Vec<&Expr> = match ann {
2173        Expr::Tuple(t) => t.elts.iter().collect(),
2174        Expr::Subscript(sub) => {
2175            let is_tuple = name_like(&sub.value)
2176                .map(|n| n.eq_ignore_ascii_case("tuple"))
2177                .unwrap_or(false);
2178            if !is_tuple {
2179                return None;
2180            }
2181            match &*sub.slice {
2182                Expr::Tuple(t) => t.elts.iter().collect(),
2183                other => vec![other],
2184            }
2185        }
2186        _ => return None,
2187    };
2188    let tuple_shapes: Vec<Option<Shape>> = elements
2189        .iter()
2190        .map(|elt| {
2191            parse_shape_annotation(elt)
2192                .or_else(|| shape_or_class_from_union(elt, imports, class_map).0)
2193        })
2194        .collect();
2195    if tuple_shapes.iter().any(|s| s.is_some()) {
2196        Some(tuple_shapes)
2197    } else {
2198        None
2199    }
2200}
2201
2202/// Infer return shapes for a call, optionally simulating the callee body for diagnostics.
2203fn infer_call_return_from_info(
2204    call: &ExprCall<TextRange>,
2205    callee_name: &Identifier,
2206    callee_info: &FunctionInfo,
2207    callee_source: &str,
2208    callee_func_map: &FuncMap,
2209    callee_imports: &Imports,
2210    callee_class_map: &ClassMap,
2211    vars: &HashMap<Identifier, VarState>,
2212    func_map: &FuncMap,
2213    imports: &Imports,
2214    class_map: &ClassMap,
2215    call_stack: &mut Vec<Identifier>,
2216    diagnostics: &mut Vec<Diagnostic>,
2217    hover_entries: &mut Vec<(Range, HoverInfo)>,
2218    record_hovers: bool,
2219    source: &str,
2220    mut module_cache: Option<&mut ModuleCache>,
2221    module_path: Option<&Path>,
2222    param_offset: usize,
2223    emit_body_diagnostics: bool,
2224) -> Option<ReturnValue> {
2225    if call_stack.iter().any(|id| id == callee_name) {
2226        return None;
2227    }
2228    let mut arg_shapes: HashMap<Identifier, VarState> = HashMap::new();
2229    for (idx, param) in callee_info.args.args.iter().enumerate().skip(param_offset) {
2230        let call_idx = idx.saturating_sub(param_offset);
2231        if let Some(arg_expr) = call.args.get(call_idx) {
2232            if let Some(shape) = infer_expr_shape(
2233                arg_expr,
2234                vars,
2235                func_map,
2236                imports,
2237                class_map,
2238                call_stack,
2239                diagnostics,
2240                hover_entries,
2241                record_hovers,
2242                source,
2243                module_cache.as_deref_mut(),
2244                module_path,
2245            ) {
2246                if let Some(ann) = param
2247                    .def
2248                    .annotation
2249                    .as_deref()
2250                    .and_then(parse_shape_annotation)
2251                {
2252                    let dims_match = ann.dims.len() == shape.dims.len()
2253                        && ann.dims.iter().zip(shape.dims.iter()).all(|(left, right)| {
2254                            // less stringent equality (alpha-equivalence):
2255                            // behaves as a annotated aliasing; only concrete
2256                            // dimensions at both sides must match
2257                            match (left.parse::<i32>().is_ok(), right.parse::<i32>().is_ok()) {
2258                                (true, true) => left == right,
2259                                _ => true,
2260                            }
2261                        });
2262                    if !dims_match {
2263                        diagnostics.push(Diagnostic {
2264                            range: text_range_to_lsp(expr_text_range(arg_expr), source),
2265                            severity: Some(DiagnosticSeverity::ERROR),
2266                            code: None,
2267                            code_description: None,
2268                            source: Some("shapels".into()),
2269                            message: format!(
2270                                "Shape mismatch: annotation {} vs inferred {}",
2271                                ann.render(),
2272                                shape.render()
2273                            ),
2274                            related_information: None,
2275                            tags: None,
2276                            data: None,
2277                        });
2278                    }
2279                }
2280                arg_shapes.insert(
2281                    param.def.arg.clone(),
2282                    VarState {
2283                        annotated: None,
2284                        inferred: Some(shape),
2285                        class_ref: None,
2286                    },
2287                );
2288            } else if let Expr::Name(name) = arg_expr
2289                && let Some(class_ref) = vars.get(&name.id).and_then(|v| v.class_ref.clone())
2290            {
2291                arg_shapes.insert(
2292                    param.def.arg.clone(),
2293                    VarState {
2294                        annotated: None,
2295                        inferred: None,
2296                        class_ref: Some(class_ref),
2297                    },
2298                );
2299            }
2300        }
2301    }
2302    if let Some(ret_ann) = callee_info.returns.as_deref() {
2303        let annotated_return = if let Some(ret_shape) = parse_shape_annotation(ret_ann) {
2304            Some(ReturnValue::from_shape(Some(ret_shape)))
2305        } else if let Some(tuple_shapes) = tuple_shapes_from_annotation(ret_ann, imports, class_map)
2306        {
2307            Some(ReturnValue::from_tuple(tuple_shapes))
2308        } else {
2309            let (shape_union, _) = shape_or_class_from_union(ret_ann, imports, class_map);
2310            shape_union.map(|shape| ReturnValue::from_shape(Some(shape)))
2311        };
2312        if let Some(ret) = annotated_return {
2313            if emit_body_diagnostics {
2314                call_stack.push(callee_name.clone());
2315                let (mut diag, mut hovers, _) = simulate_function(
2316                    callee_info.args.as_ref(),
2317                    &callee_info.body,
2318                    callee_source,
2319                    callee_func_map,
2320                    callee_imports,
2321                    callee_class_map,
2322                    call_stack,
2323                    arg_shapes,
2324                    record_hovers,
2325                    module_cache.as_deref_mut(),
2326                    module_path,
2327                );
2328                diagnostics.append(&mut diag);
2329                if record_hovers {
2330                    hover_entries.append(&mut hovers);
2331                }
2332                call_stack.pop();
2333            }
2334            return Some(ret);
2335        }
2336    }
2337    call_stack.push(callee_name.clone());
2338    let (mut diag, mut hovers, ret_value) = simulate_function(
2339        callee_info.args.as_ref(),
2340        &callee_info.body,
2341        callee_source,
2342        callee_func_map,
2343        callee_imports,
2344        callee_class_map,
2345        call_stack,
2346        arg_shapes,
2347        false,
2348        module_cache,
2349        module_path,
2350    );
2351    diagnostics.append(&mut diag);
2352    if record_hovers {
2353        hover_entries.append(&mut hovers);
2354    }
2355    call_stack.pop();
2356    Some(ret_value)
2357}
2358
2359fn infer_class_call_return(
2360    call: &ExprCall<TextRange>,
2361    class_ref: &ClassRef,
2362    vars: &HashMap<Identifier, VarState>,
2363    func_map: &FuncMap,
2364    imports: &Imports,
2365    class_map: &ClassMap,
2366    call_stack: &mut Vec<Identifier>,
2367    diagnostics: &mut Vec<Diagnostic>,
2368    hover_entries: &mut Vec<(Range, HoverInfo)>,
2369    record_hovers: bool,
2370    source: &str,
2371    mut module_cache: Option<&mut ModuleCache>,
2372    module_path: Option<&Path>,
2373) -> Option<ReturnValue> {
2374    with_class_info(
2375        class_ref,
2376        source,
2377        func_map,
2378        imports,
2379        class_map,
2380        &mut module_cache,
2381        module_path,
2382        |class_info,
2383         callee_source,
2384         callee_func_map,
2385         callee_imports,
2386         callee_class_map,
2387         callee_path,
2388         module_cache| {
2389            if !class_info.is_torch_module {
2390                return None;
2391            }
2392            let forward = class_info.forward.as_ref()?;
2393            let param_offset = method_param_offset(&forward.args);
2394            infer_call_return_from_info(
2395                call,
2396                &class_ref.name,
2397                forward,
2398                callee_source,
2399                callee_func_map,
2400                callee_imports,
2401                callee_class_map,
2402                vars,
2403                func_map,
2404                imports,
2405                class_map,
2406                call_stack,
2407                diagnostics,
2408                hover_entries,
2409                record_hovers,
2410                source,
2411                module_cache.as_deref_mut(),
2412                callee_path,
2413                param_offset,
2414                false,
2415            )
2416        },
2417    )
2418    .and_then(|ret| ret)
2419}
2420
2421fn infer_class_method_call_return(
2422    call: &ExprCall<TextRange>,
2423    class_ref: &ClassRef,
2424    method_name: &Identifier,
2425    vars: &HashMap<Identifier, VarState>,
2426    func_map: &FuncMap,
2427    imports: &Imports,
2428    class_map: &ClassMap,
2429    call_stack: &mut Vec<Identifier>,
2430    diagnostics: &mut Vec<Diagnostic>,
2431    hover_entries: &mut Vec<(Range, HoverInfo)>,
2432    record_hovers: bool,
2433    source: &str,
2434    mut module_cache: Option<&mut ModuleCache>,
2435    module_path: Option<&Path>,
2436) -> Option<ReturnValue> {
2437    let callee_name = Identifier::from(format!("{}::{}", class_ref.name, method_name));
2438    with_class_info(
2439        class_ref,
2440        source,
2441        func_map,
2442        imports,
2443        class_map,
2444        &mut module_cache,
2445        module_path,
2446        |class_info,
2447         callee_source,
2448         callee_func_map,
2449         callee_imports,
2450         callee_class_map,
2451         callee_path,
2452         module_cache| {
2453            let method_info = class_info.methods.get(method_name)?;
2454            let param_offset = method_param_offset(&method_info.args);
2455            infer_call_return_from_info(
2456                call,
2457                &callee_name,
2458                method_info,
2459                callee_source,
2460                callee_func_map,
2461                callee_imports,
2462                callee_class_map,
2463                vars,
2464                func_map,
2465                imports,
2466                class_map,
2467                call_stack,
2468                diagnostics,
2469                hover_entries,
2470                record_hovers,
2471                source,
2472                module_cache.as_deref_mut(),
2473                callee_path,
2474                param_offset,
2475                false,
2476            )
2477        },
2478    )
2479    .and_then(|ret| ret)
2480}
2481
2482fn infer_defined_call_return(
2483    call: &ExprCall<TextRange>,
2484    vars: &HashMap<Identifier, VarState>,
2485    func_map: &FuncMap,
2486    imports: &Imports,
2487    class_map: &ClassMap,
2488    call_stack: &mut Vec<Identifier>,
2489    diagnostics: &mut Vec<Diagnostic>,
2490    hover_entries: &mut Vec<(Range, HoverInfo)>,
2491    record_hovers: bool,
2492    source: &str,
2493    mut module_cache: Option<&mut ModuleCache>,
2494    module_path: Option<&Path>,
2495) -> Option<ReturnValue> {
2496    if let Expr::Name(func_name) = call.func.as_ref() {
2497        if let Some(class_ref) = vars.get(&func_name.id).and_then(|v| v.class_ref.as_ref()) {
2498            return infer_class_call_return(
2499                call,
2500                class_ref,
2501                vars,
2502                func_map,
2503                imports,
2504                class_map,
2505                call_stack,
2506                diagnostics,
2507                hover_entries,
2508                record_hovers,
2509                source,
2510                module_cache.as_deref_mut(),
2511                module_path,
2512            );
2513        }
2514        if let Some((module_name, original)) = imports.from_imports.get(&func_name.id)
2515            && let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
2516            && let Some(module) = cache.get_module(module_name, cur_path)
2517            && let Some(callee_info) = module.func_map.get(original)
2518        {
2519            return infer_call_return_from_info(
2520                call,
2521                &func_name.id,
2522                callee_info,
2523                &module.source,
2524                &module.func_map,
2525                &module.imports,
2526                &module.class_map,
2527                vars,
2528                func_map,
2529                imports,
2530                class_map,
2531                call_stack,
2532                diagnostics,
2533                hover_entries,
2534                record_hovers,
2535                source,
2536                module_cache.as_deref_mut(),
2537                Some(module.path.as_path()),
2538                0,
2539                false,
2540            );
2541        }
2542        if let Some(callee_info) = func_map.get(&func_name.id) {
2543            return infer_call_return_from_info(
2544                call,
2545                &func_name.id,
2546                callee_info,
2547                source,
2548                func_map,
2549                imports,
2550                class_map,
2551                vars,
2552                func_map,
2553                imports,
2554                class_map,
2555                call_stack,
2556                diagnostics,
2557                hover_entries,
2558                record_hovers,
2559                source,
2560                module_cache.as_deref_mut(),
2561                module_path,
2562                0,
2563                false,
2564            );
2565        }
2566    }
2567    if let Expr::Attribute(attr) = call.func.as_ref()
2568        && let Expr::Name(base_name) = attr.value.as_ref()
2569        && let Some(class_ref) = vars.get(&base_name.id).and_then(|v| v.class_ref.as_ref())
2570    {
2571        return infer_class_method_call_return(
2572            call,
2573            class_ref,
2574            &attr.attr,
2575            vars,
2576            func_map,
2577            imports,
2578            class_map,
2579            call_stack,
2580            diagnostics,
2581            hover_entries,
2582            record_hovers,
2583            source,
2584            module_cache.as_deref_mut(),
2585            module_path,
2586        );
2587    }
2588    if let Expr::Attribute(attr) = call.func.as_ref()
2589        && let Expr::Name(module_ident) = attr.value.as_ref()
2590        && let Some(module_name) = imports.module_aliases.get(&module_ident.id)
2591        && module_name != "torch"
2592        && let (Some(cache), Some(cur_path)) = (module_cache.as_deref_mut(), module_path)
2593        && let Some(module) = cache.get_module(module_name, cur_path)
2594        && let Some(callee_info) = module.func_map.get(&attr.attr)
2595    {
2596        return infer_call_return_from_info(
2597            call,
2598            &attr.attr,
2599            callee_info,
2600            &module.source,
2601            &module.func_map,
2602            &module.imports,
2603            &module.class_map,
2604            vars,
2605            func_map,
2606            imports,
2607            class_map,
2608            call_stack,
2609            diagnostics,
2610            hover_entries,
2611            record_hovers,
2612            source,
2613            module_cache,
2614            Some(module.path.as_path()),
2615            0,
2616            false,
2617        );
2618    }
2619    None
2620}
2621
2622fn tensor_or_shape_as_arg(
2623    is_size: bool,
2624    vars: &HashMap<Identifier, VarState>,
2625    func_map: &FuncMap,
2626    imports: &Imports,
2627    class_map: &ClassMap,
2628    call_stack: &mut Vec<Identifier>,
2629    diagnostics: &mut Vec<Diagnostic>,
2630    hover_entries: &mut Vec<(Range, HoverInfo)>,
2631    record_hovers: bool,
2632    source: &str,
2633    module_cache: &mut Option<&mut ModuleCache>,
2634    module_path: Option<&Path>,
2635    call: &ExprCall,
2636) -> Option<Shape> {
2637    match call.args.first() {
2638        // a torch.Tensor.shape might be the first argument
2639        Some(arg0 @ Expr::Attribute(_)) if is_size => infer_expr_shape(
2640            arg0,
2641            vars,
2642            func_map,
2643            imports,
2644            class_map,
2645            call_stack,
2646            diagnostics,
2647            hover_entries,
2648            record_hovers,
2649            source,
2650            module_cache.as_deref_mut(),
2651            module_path,
2652        ),
2653        // zeros_like, ones_like etc. accept a tensor as first arg
2654        Some(arg0 @ Expr::Name(_)) if !is_size => infer_expr_shape(
2655            arg0,
2656            vars,
2657            func_map,
2658            imports,
2659            class_map,
2660            call_stack,
2661            diagnostics,
2662            hover_entries,
2663            record_hovers,
2664            source,
2665            module_cache.as_deref_mut(),
2666            module_path,
2667        ),
2668        _ => None,
2669    }
2670}
2671
2672fn lookup_shape(
2673    expr: &Expr,
2674    vars: &HashMap<Identifier, VarState>,
2675    hover_entries: &mut Vec<(Range, HoverInfo)>,
2676    record_hovers: bool,
2677    source: &str,
2678) -> Option<Shape> {
2679    match expr {
2680        Expr::Name(expr_name) => {
2681            let shape = vars
2682                .get(&expr_name.id)
2683                .and_then(|v| v.annotated.clone().or_else(|| v.inferred.clone()));
2684            if record_hovers && let Some(s) = shape.clone() {
2685                let range = text_range_to_lsp(expr_text_range(expr), source);
2686                hover_entries.push((range, HoverInfo { shape: Some(s) }));
2687            }
2688            shape
2689        }
2690        _ => None,
2691    }
2692}
2693
2694/// An operation might be a function `torch.FUNCTION` (might be imported and
2695/// aliased) or a `torch.Tensor.METHOD`.
2696enum TorchOpKind {
2697    /// `torch.FUNCTION`
2698    Function,
2699    /// `torch.Tensor.METHOD`
2700    Method,
2701}
2702
2703fn function_or_method<R>(expr: &Expr<R>, imports: &Imports) -> TorchOpKind {
2704    match expr {
2705        Expr::Name(n)
2706            if n.id.as_str() == "torch"
2707                || imports.torch_aliases.contains(&n.id)
2708                || imports.torch_nn_functional_aliases.contains(&n.id) =>
2709        {
2710            TorchOpKind::Function
2711        }
2712        _ => TorchOpKind::Method,
2713    }
2714}
2715
2716fn is_alias_of(canonical: &str, ident: &Identifier, imports: &Imports) -> bool {
2717    imports
2718        .func_aliases
2719        .get(canonical)
2720        .map(|set| set.contains(ident))
2721        .unwrap_or(false)
2722}
2723
2724fn assignment_shape_checks(
2725    target: &Expr,
2726    value: &Expr,
2727    vars: &mut HashMap<Identifier, VarState>,
2728    diagnostics: &mut Vec<Diagnostic>,
2729    hover_entries: &mut Vec<(Range, HoverInfo)>,
2730    record_hovers: bool,
2731    source: &str,
2732) -> bool {
2733    let Expr::Tuple(tup) = target else {
2734        return false;
2735    };
2736    let dims: Vec<Identifier> = tup.elts.iter().filter_map(name_from_expr).collect();
2737    if dims.len() != tup.elts.len() || dims.is_empty() {
2738        return false;
2739    }
2740    let Expr::Attribute(attr) = value else {
2741        return false;
2742    };
2743    if attr.attr.as_str() != "shape" {
2744        return false;
2745    }
2746    let Some(base_id) = name_from_expr(&attr.value) else {
2747        return false;
2748    };
2749
2750    let range = text_range_to_lsp(expr_text_range(value), source);
2751    let existing_state = vars.get(&base_id);
2752    let existing_shape = existing_state.and_then(|v| v.annotated.clone().or(v.inferred.clone()));
2753
2754    if let Some(shape) = existing_shape {
2755        if shape.dims.len() == dims.len() {
2756            let mut new_shape = shape.clone();
2757            new_shape.dims = dims.iter().map(|d| d.to_string()).collect();
2758            vars.insert(
2759                base_id.clone(),
2760                VarState {
2761                    annotated: None,
2762                    inferred: Some(new_shape.clone()),
2763                    class_ref: None,
2764                },
2765            );
2766            if record_hovers {
2767                let hrange = text_range_to_lsp(expr_text_range(&attr.value), source);
2768                hover_entries.push((
2769                    hrange,
2770                    HoverInfo {
2771                        shape: Some(new_shape),
2772                    },
2773                ));
2774            }
2775        } else {
2776            diagnostics.push(Diagnostic {
2777                range,
2778                severity: Some(DiagnosticSeverity::ERROR),
2779                code: None,
2780                code_description: None,
2781                source: Some("shapels".into()),
2782                message: "Cannot unroll shape with different rank".into(),
2783                related_information: None,
2784                tags: None,
2785                data: None,
2786            });
2787        }
2788    } else {
2789        let new_shape = Shape {
2790            dtype: None,
2791            dims: dims.iter().map(|d| d.to_string()).collect(),
2792        };
2793        vars.insert(
2794            base_id.clone(),
2795            VarState {
2796                annotated: None,
2797                inferred: Some(new_shape.clone()),
2798                class_ref: None,
2799            },
2800        );
2801        if record_hovers {
2802            let hrange = text_range_to_lsp(expr_text_range(&attr.value), source);
2803            hover_entries.push((
2804                hrange,
2805                HoverInfo {
2806                    shape: Some(new_shape),
2807                },
2808            ));
2809        }
2810    }
2811    true
2812}
2813
2814fn parse_shape_annotation(expr: &Expr) -> Option<Shape> {
2815    if let Expr::Subscript(sub) = expr {
2816        let dtype = name_like(&sub.value);
2817        let components: Vec<&Expr> = match &*sub.slice {
2818            ast::Expr::Tuple(t) => t.elts.iter().collect(),
2819            other => vec![other],
2820        };
2821        if components.len() >= 2
2822            && let Some(raw) = string_literal_value(components[1])
2823        {
2824            let dims = normalize_shape_tokens(&raw);
2825            return Some(Shape { dtype, dims });
2826        }
2827    }
2828    None
2829}
2830
2831fn name_like(expr: &Expr) -> Option<String> {
2832    match expr {
2833        Expr::Name(name) => Some(name.id.to_string()),
2834        Expr::Attribute(attr) => {
2835            let mut base = name_like(&attr.value)?;
2836            base.push('.');
2837            base.push_str(attr.attr.as_ref());
2838            Some(base)
2839        }
2840        _ => None,
2841    }
2842}
2843
2844fn string_literal_value(expr: &Expr) -> Option<String> {
2845    match expr {
2846        Expr::Constant(c) => {
2847            if let ast::Constant::Str(s) = &c.value {
2848                Some(s.clone())
2849            } else {
2850                None
2851            }
2852        }
2853        Expr::JoinedStr(js) => {
2854            let mut buf = String::new();
2855            for val in &js.values {
2856                if let Expr::Constant(c) = val
2857                    && let ast::Constant::Str(s) = &c.value
2858                {
2859                    buf.push_str(s);
2860                }
2861            }
2862            if buf.is_empty() { None } else { Some(buf) }
2863        }
2864        _ => None,
2865    }
2866}
2867
2868fn normalize_shape_tokens(raw: &str) -> Vec<String> {
2869    raw.split_whitespace()
2870        .map(|s| s.trim_matches('"').to_string())
2871        .collect()
2872}
2873
2874fn name_from_expr(expr: &Expr) -> Option<Identifier> {
2875    match expr {
2876        Expr::Name(n) => Some(n.id.clone()),
2877        _ => None,
2878    }
2879}
2880
2881fn seed_args_from_annotations(
2882    args: &Arguments,
2883    source: &str,
2884    vars: &mut HashMap<Identifier, VarState>,
2885    hover_entries: &mut Vec<(Range, HoverInfo)>,
2886    provided: Option<&mut HashMap<Identifier, VarState>>,
2887    record_hovers: bool,
2888    imports: &Imports,
2889    class_map: &ClassMap,
2890) {
2891    for arg in &args.args {
2892        let ann_shape = arg
2893            .def
2894            .annotation
2895            .as_ref()
2896            .and_then(|expr| parse_shape_annotation(expr.as_ref()));
2897        let (union_shape, union_class) = arg
2898            .def
2899            .annotation
2900            .as_ref()
2901            .map(|ann| shape_or_class_from_union(ann.as_ref(), imports, class_map))
2902            .unwrap_or((None, None));
2903        let range = text_range_to_lsp(arg.def.range, source);
2904        let provided_state = provided.as_ref().and_then(|p| p.get(&arg.def.arg));
2905        let state = VarState {
2906            annotated: ann_shape.clone().or(union_shape),
2907            inferred: provided_state.and_then(|s| s.inferred.clone()),
2908            class_ref: provided_state
2909                .and_then(|s| s.class_ref.clone())
2910                .or_else(|| {
2911                    arg.def
2912                        .annotation
2913                        .as_ref()
2914                        .and_then(|ann| class_ref_from_annotation(ann.as_ref(), imports, class_map))
2915                })
2916                .or(union_class),
2917        };
2918
2919        if state.annotated.is_some() || state.inferred.is_some() || state.class_ref.is_some() {
2920            vars.insert(arg.def.arg.clone(), state.clone());
2921            if record_hovers {
2922                hover_entries.push((
2923                    range,
2924                    HoverInfo {
2925                        shape: state.annotated.or(state.inferred),
2926                    },
2927                ));
2928            }
2929        }
2930    }
2931}
2932
2933fn text_range_to_lsp(range: TextRange, source: &str) -> Range {
2934    Range {
2935        start: offset_to_position(source, range.start().to_usize()),
2936        end: offset_to_position(source, range.end().to_usize()),
2937    }
2938}
2939
2940fn expr_text_range(expr: &Expr) -> TextRange {
2941    match expr {
2942        Expr::Name(n) => n.range,
2943        Expr::BoolOp(b) => b.range,
2944        Expr::NamedExpr(n) => n.range,
2945        Expr::BinOp(b) => b.range,
2946        Expr::UnaryOp(u) => u.range,
2947        Expr::Lambda(l) => l.range,
2948        Expr::IfExp(i) => i.range,
2949        Expr::Dict(d) => d.range,
2950        Expr::Set(s) => s.range,
2951        Expr::ListComp(l) => l.range,
2952        Expr::SetComp(s) => s.range,
2953        Expr::DictComp(d) => d.range,
2954        Expr::GeneratorExp(g) => g.range,
2955        Expr::Await(a) => a.range,
2956        Expr::Yield(y) => y.range,
2957        Expr::YieldFrom(y) => y.range,
2958        Expr::Compare(c) => c.range,
2959        Expr::Call(c) => c.range,
2960        Expr::FormattedValue(f) => f.range,
2961        Expr::Subscript(s) => s.range,
2962        Expr::Attribute(a) => a.range,
2963        Expr::Starred(s) => s.range,
2964        Expr::Constant(c) => c.range,
2965        Expr::JoinedStr(j) => j.range,
2966        Expr::List(l) => l.range,
2967        Expr::Tuple(t) => t.range,
2968        Expr::Slice(s) => s.range,
2969    }
2970}
2971
2972fn offset_to_position(source: &str, offset: usize) -> Position {
2973    let mut line = 0u32;
2974    let mut col = 0u32;
2975    let mut count = 0usize;
2976    for ch in source.chars() {
2977        if count == offset {
2978            break;
2979        }
2980        if ch == '\n' {
2981            line += 1;
2982            col = 0;
2983        } else {
2984            col += 1;
2985        }
2986        count += ch.len_utf8();
2987    }
2988    Position {
2989        line,
2990        character: col,
2991    }
2992}
2993
2994fn default_range() -> Range {
2995    Range {
2996        start: Position {
2997            line: 0,
2998            character: 0,
2999        },
3000        end: Position {
3001            line: 0,
3002            character: 1,
3003        },
3004    }
3005}
3006
3007fn get_arg<'a, R>(
3008    call: &'a ExprCall<R>,
3009    name_arg: &str,
3010    as_positional: usize,
3011) -> Option<&'a Expr<R>> {
3012    // first check for positional argument, then named argument
3013    (call.args.get(as_positional)).or(call
3014        .keywords
3015        .iter()
3016        .find(|kw| kw.arg.as_deref() == Some(name_arg))
3017        .map(|kw| &kw.value))
3018}
3019
3020fn get_dtype<'expr, R>(dtype_expr: &'expr Expr<R>, imports: &Imports) -> Option<&'expr str> {
3021    match dtype_expr {
3022        Expr::Constant(constant) => match &constant.value {
3023            Constant::Str(string_dtype) if TORCH_DTYPES.contains(string_dtype.as_str()) => {
3024                Some(string_dtype.as_str())
3025            }
3026            _ => Some("Float"),
3027        },
3028        Expr::Name(name) => Some(name.id.as_str()),
3029        Expr::Attribute(attr)
3030            if matches!(
3031                function_or_method(attr.value.as_ref(), imports),
3032                TorchOpKind::Function
3033            ) && TORCH_DTYPES.contains(attr.attr.as_str()) =>
3034        {
3035            Some(attr.attr.as_str())
3036        }
3037        _ => Some("Float"),
3038    }
3039}