Skip to main content

rustpython_vm/
import.rs

1//! Import mechanics
2
3use crate::{
4    AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult,
5    builtins::{PyCode, PyStr, PyUtf8Str, PyUtf8StrRef, traceback::PyTraceback},
6    exceptions::types::PyBaseException,
7    scope::Scope,
8    vm::{VirtualMachine, resolve_frozen_alias, thread},
9};
10
11pub(crate) fn check_pyc_magic_number_bytes(buf: &[u8]) -> bool {
12    buf.starts_with(&crate::version::PYC_MAGIC_NUMBER_BYTES[..2])
13}
14
15pub(crate) fn init_importlib_base(vm: &mut VirtualMachine) -> PyResult<PyObjectRef> {
16    flame_guard!("init importlib");
17
18    // importlib_bootstrap needs these and it inlines checks to sys.modules before calling into
19    // import machinery, so this should bring some speedup
20    #[cfg(all(feature = "threading", not(target_os = "wasi")))]
21    import_builtin(vm, "_thread")?;
22    import_builtin(vm, "_warnings")?;
23    import_builtin(vm, "_weakref")?;
24
25    let importlib = thread::enter_vm(vm, || {
26        let bootstrap = import_frozen(vm, "_frozen_importlib")?;
27        let install = bootstrap.get_attr("_install", vm)?;
28        let imp = import_builtin(vm, "_imp")?;
29        install.call((vm.sys_module.clone(), imp), vm)?;
30        Ok(bootstrap)
31    })?;
32    vm.import_func = importlib.get_attr(identifier!(vm, __import__), vm)?;
33    vm.importlib = importlib.clone();
34    Ok(importlib)
35}
36
37#[cfg(feature = "host_env")]
38pub(crate) fn init_importlib_package(vm: &VirtualMachine, importlib: PyObjectRef) -> PyResult<()> {
39    use crate::{TryFromObject, builtins::PyListRef};
40
41    thread::enter_vm(vm, || {
42        flame_guard!("install_external");
43
44        // same deal as imports above
45        import_builtin(vm, crate::stdlib::os::MODULE_NAME)?;
46        #[cfg(windows)]
47        import_builtin(vm, "winreg")?;
48        import_builtin(vm, "_io")?;
49        import_builtin(vm, "marshal")?;
50
51        let install_external = importlib.get_attr("_install_external_importers", vm)?;
52        install_external.call((), vm)?;
53        let zipimport_res = (|| -> PyResult<()> {
54            let zipimport = vm.import("zipimport", 0)?;
55            let zipimporter = zipimport.get_attr("zipimporter", vm)?;
56            let path_hooks = vm.sys_module.get_attr("path_hooks", vm)?;
57            let path_hooks = PyListRef::try_from_object(vm, path_hooks)?;
58            path_hooks.insert(0, zipimporter);
59            Ok(())
60        })();
61        if zipimport_res.is_err() {
62            warn!("couldn't init zipimport")
63        }
64        Ok(())
65    })
66}
67
68pub fn make_frozen(vm: &VirtualMachine, name: &str) -> PyResult<PyRef<PyCode>> {
69    let frozen = vm.state.frozen.get(name).ok_or_else(|| {
70        vm.new_import_error(
71            format!("No such frozen object named {name}"),
72            vm.ctx.new_utf8_str(name),
73        )
74    })?;
75    Ok(vm.ctx.new_code(frozen.code))
76}
77
78pub fn import_frozen(vm: &VirtualMachine, module_name: &str) -> PyResult {
79    let frozen = vm.state.frozen.get(module_name).ok_or_else(|| {
80        vm.new_import_error(
81            format!("No such frozen object named {module_name}"),
82            vm.ctx.new_utf8_str(module_name),
83        )
84    })?;
85    let module = import_code_obj(vm, module_name, vm.ctx.new_code(frozen.code), false)?;
86    debug_assert!(module.get_attr(identifier!(vm, __name__), vm).is_ok());
87    let origname = resolve_frozen_alias(module_name);
88    module.set_attr("__origname__", vm.ctx.new_utf8_str(origname), vm)?;
89    Ok(module)
90}
91
92pub fn import_builtin(vm: &VirtualMachine, module_name: &str) -> PyResult {
93    let sys_modules = vm.sys_module.get_attr("modules", vm)?;
94
95    // Check if already in sys.modules (handles recursive imports)
96    if let Ok(module) = sys_modules.get_item(module_name, vm) {
97        return Ok(module);
98    }
99
100    // Try multi-phase init first (preferred for modules that import other modules)
101    if let Some(&def) = vm.state.module_defs.get(module_name) {
102        // Phase 1: Create and initialize module
103        let module = def.create_module(vm)?;
104
105        // Add to sys.modules BEFORE exec (critical for circular import handling)
106        sys_modules.set_item(module_name, module.clone().into(), vm)?;
107
108        // Phase 2: Call exec slot (can safely import other modules now)
109        // If exec fails, remove the partially-initialized module from sys.modules
110        if let Err(e) = def.exec_module(vm, &module) {
111            let _ = sys_modules.del_item(module_name, vm);
112            return Err(e);
113        }
114
115        return Ok(module.into());
116    }
117
118    // Module not found in module_defs
119    Err(vm.new_import_error(
120        format!("Cannot import builtin module {module_name}"),
121        vm.ctx.new_utf8_str(module_name),
122    ))
123}
124
125#[cfg(feature = "rustpython-compiler")]
126pub fn import_file(
127    vm: &VirtualMachine,
128    module_name: &str,
129    file_path: String,
130    content: &str,
131) -> PyResult {
132    let code = vm
133        .compile_with_opts(
134            content,
135            crate::compiler::Mode::Exec,
136            file_path,
137            vm.compile_opts(),
138        )
139        .map_err(|err| vm.new_syntax_error(&err, Some(content)))?;
140    import_code_obj(vm, module_name, code, true)
141}
142
143#[cfg(feature = "rustpython-compiler")]
144pub fn import_source(vm: &VirtualMachine, module_name: &str, content: &str) -> PyResult {
145    let code = vm
146        .compile_with_opts(
147            content,
148            crate::compiler::Mode::Exec,
149            "<source>".to_owned(),
150            vm.compile_opts(),
151        )
152        .map_err(|err| vm.new_syntax_error(&err, Some(content)))?;
153    import_code_obj(vm, module_name, code, false)
154}
155
156/// If `__spec__._initializing` is true, wait for the module to finish
157/// initializing by calling `_lock_unlock_module`.
158fn import_ensure_initialized(
159    module: &PyObjectRef,
160    name: &str,
161    vm: &VirtualMachine,
162) -> PyResult<()> {
163    let initializing = match vm.get_attribute_opt(module.clone(), vm.ctx.intern_str("__spec__"))? {
164        Some(spec) => match vm.get_attribute_opt(spec, vm.ctx.intern_str("_initializing"))? {
165            Some(v) => v.try_to_bool(vm)?,
166            None => false,
167        },
168        None => false,
169    };
170    if initializing {
171        let lock_unlock = vm.importlib.get_attr("_lock_unlock_module", vm)?;
172        lock_unlock.call((vm.ctx.new_utf8_str(name),), vm)?;
173    }
174    Ok(())
175}
176
177pub fn import_code_obj(
178    vm: &VirtualMachine,
179    module_name: &str,
180    code_obj: PyRef<PyCode>,
181    set_file_attr: bool,
182) -> PyResult {
183    let attrs = vm.ctx.new_dict();
184    attrs.set_item(
185        identifier!(vm, __name__),
186        vm.ctx.new_utf8_str(module_name).into(),
187        vm,
188    )?;
189    if set_file_attr {
190        attrs.set_item(
191            identifier!(vm, __file__),
192            code_obj.source_path().to_object(),
193            vm,
194        )?;
195    }
196    let module = vm.new_module(module_name, attrs.clone(), None);
197
198    // Store module in cache to prevent infinite loop with mutual importing libs:
199    let sys_modules = vm.sys_module.get_attr("modules", vm)?;
200    sys_modules.set_item(module_name, module.clone().into(), vm)?;
201
202    // Execute main code in module:
203    let scope = Scope::with_builtins(None, attrs, vm);
204    vm.run_code_obj(code_obj, scope)?;
205    Ok(module.into())
206}
207
208fn remove_importlib_frames_inner(
209    vm: &VirtualMachine,
210    tb: Option<PyRef<PyTraceback>>,
211    always_trim: bool,
212) -> (Option<PyRef<PyTraceback>>, bool) {
213    let traceback = if let Some(tb) = tb {
214        tb
215    } else {
216        return (None, false);
217    };
218
219    let file_name = traceback.frame.code.source_path().as_str();
220
221    let (inner_tb, mut now_in_importlib) =
222        remove_importlib_frames_inner(vm, traceback.next.lock().clone(), always_trim);
223    if file_name == "_frozen_importlib" || file_name == "_frozen_importlib_external" {
224        if traceback.frame.code.obj_name.as_str() == "_call_with_frames_removed" {
225            now_in_importlib = true;
226        }
227        if always_trim || now_in_importlib {
228            return (inner_tb, now_in_importlib);
229        }
230    } else {
231        now_in_importlib = false;
232    }
233
234    (
235        Some(
236            PyTraceback::new(
237                inner_tb,
238                traceback.frame.clone(),
239                traceback.lasti,
240                traceback.lineno,
241            )
242            .into_ref(&vm.ctx),
243        ),
244        now_in_importlib,
245    )
246}
247
248// TODO: This function should do nothing on verbose mode.
249// TODO: Fix this function after making PyTraceback.next mutable
250pub fn remove_importlib_frames(vm: &VirtualMachine, exc: &Py<PyBaseException>) {
251    if vm.state.config.settings.verbose != 0 {
252        return;
253    }
254
255    let always_trim = exc.fast_isinstance(vm.ctx.exceptions.import_error);
256
257    if let Some(tb) = exc.__traceback__() {
258        let trimmed_tb = remove_importlib_frames_inner(vm, Some(tb), always_trim).0;
259        exc.set_traceback_typed(trimmed_tb);
260    }
261}
262
263/// Get origin path from a module spec, checking has_location first.
264pub(crate) fn get_spec_file_origin(
265    spec: &Option<PyObjectRef>,
266    vm: &VirtualMachine,
267) -> Option<String> {
268    let spec = spec.as_ref()?;
269    let has_location = spec
270        .get_attr("has_location", vm)
271        .ok()
272        .and_then(|v| v.try_to_bool(vm).ok())
273        .unwrap_or(false);
274    if !has_location {
275        return None;
276    }
277    spec.get_attr("origin", vm).ok().and_then(|origin| {
278        if vm.is_none(&origin) {
279            None
280        } else {
281            origin
282                .downcast_ref::<PyStr>()
283                .and_then(|s| s.to_str().map(|s| s.to_owned()))
284        }
285    })
286}
287
288/// Check if a module file possibly shadows another module of the same name.
289/// Compares the module's directory with the original sys.path[0] (derived from sys.argv[0]).
290pub(crate) fn is_possibly_shadowing_path(origin: &str, vm: &VirtualMachine) -> bool {
291    use std::path::Path;
292
293    if vm.state.config.settings.safe_path {
294        return false;
295    }
296
297    let origin_path = Path::new(origin);
298    let parent = match origin_path.parent() {
299        Some(p) => p,
300        None => return false,
301    };
302    // For packages (__init__.py), look one directory further up
303    let root = if origin_path.file_name() == Some("__init__.py".as_ref()) {
304        parent.parent().unwrap_or(Path::new(""))
305    } else {
306        parent
307    };
308
309    // Compute original sys.path[0] from sys.argv[0] (the script path).
310    // See: config->sys_path_0, which is set once
311    // at initialization and never changes even if sys.path is modified.
312    let sys_path_0 = (|| -> Option<String> {
313        let argv = vm.sys_module.get_attr("argv", vm).ok()?;
314        let argv0 = argv.get_item(&0usize, vm).ok()?;
315        let argv0_str = argv0.downcast_ref::<PyUtf8Str>()?;
316        let s = argv0_str.as_str();
317
318        // For -c and REPL, original sys.path[0] is ""
319        if s == "-c" || s.is_empty() {
320            return Some(String::new());
321        }
322        // For scripts, original sys.path[0] is dirname(argv[0])
323        Some(
324            Path::new(s)
325                .parent()
326                .and_then(|p| p.to_str())
327                .unwrap_or("")
328                .to_owned(),
329        )
330    })();
331
332    let sys_path_0 = match sys_path_0 {
333        Some(p) => p,
334        None => return false,
335    };
336
337    let cmp_path = if sys_path_0.is_empty() {
338        match std::env::current_dir() {
339            Ok(d) => d.to_string_lossy().to_string(),
340            Err(_) => return false,
341        }
342    } else {
343        sys_path_0
344    };
345
346    root.to_str() == Some(cmp_path.as_str())
347}
348
349/// Check if a module name is in sys.stdlib_module_names.
350/// Takes the original __name__ object to preserve str subclass behavior.
351/// Propagates errors (e.g. TypeError for unhashable str subclass).
352pub(crate) fn is_stdlib_module_name(name: &PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
353    let stdlib_names = match vm.sys_module.get_attr("stdlib_module_names", vm) {
354        Ok(names) => names,
355        Err(_) => return Ok(false),
356    };
357    if !stdlib_names.class().fast_issubclass(vm.ctx.types.set_type)
358        && !stdlib_names
359            .class()
360            .fast_issubclass(vm.ctx.types.frozenset_type)
361    {
362        return Ok(false);
363    }
364    let result = vm.call_method(&stdlib_names, "__contains__", (name.clone(),))?;
365    result.try_to_bool(vm)
366}
367
368/// PyImport_ImportModuleLevelObject
369pub(crate) fn import_module_level(
370    name: &Py<PyStr>,
371    globals: Option<PyObjectRef>,
372    fromlist: Option<PyObjectRef>,
373    level: i32,
374    vm: &VirtualMachine,
375) -> PyResult {
376    if level < 0 {
377        return Err(vm.new_value_error("level must be >= 0"));
378    }
379
380    let name_str = match name.to_str() {
381        Some(s) => s,
382        None => {
383            // Name contains surrogates. Like CPython, try sys.modules
384            // lookup with the Python string key directly.
385            if level == 0 {
386                let sys_modules = vm.sys_module.get_attr("modules", vm)?;
387                return sys_modules.get_item(name, vm).map_err(|_| {
388                    vm.new_import_error(format!("No module named '{}'", name), name.to_owned())
389                });
390            }
391            return Err(vm.new_import_error(format!("No module named '{}'", name), name.to_owned()));
392        }
393    };
394
395    // Resolve absolute name
396    let abs_name = if level > 0 {
397        // When globals is not provided (Rust None), raise KeyError
398        // matching resolve_name() where globals==NULL
399        if globals.is_none() {
400            return Err(vm.new_key_error(vm.ctx.new_str("'__name__' not in globals").into()));
401        }
402        let globals_ref = globals.as_ref().unwrap();
403        // When globals is Python None, treat like empty mapping
404        let empty_dict_obj;
405        let globals_ref = if vm.is_none(globals_ref) {
406            empty_dict_obj = vm.ctx.new_dict().into();
407            &empty_dict_obj
408        } else {
409            globals_ref
410        };
411        let package = calc_package(Some(globals_ref), vm)?;
412        if package.is_empty() {
413            return Err(vm.new_import_error(
414                "attempted relative import with no known parent package",
415                vm.ctx.new_utf8_str(""),
416            ));
417        }
418        resolve_name(name_str, &package, level as usize, vm)?
419    } else {
420        if name_str.is_empty() {
421            return Err(vm.new_value_error("Empty module name"));
422        }
423        name_str.to_owned()
424    };
425
426    // import_get_module + import_find_and_load
427    let sys_modules = vm.sys_module.get_attr("modules", vm)?;
428    let module = match sys_modules.get_item(&*abs_name, vm) {
429        Ok(m) if !vm.is_none(&m) => {
430            import_ensure_initialized(&m, &abs_name, vm)?;
431            m
432        }
433        _ => {
434            let find_and_load = vm.importlib.get_attr("_find_and_load", vm)?;
435            let abs_name_obj = vm.ctx.new_utf8_str(&*abs_name);
436            find_and_load.call((abs_name_obj, vm.import_func.clone()), vm)?
437        }
438    };
439
440    // Handle fromlist
441    let has_from = match fromlist.as_ref().filter(|fl| !vm.is_none(fl)) {
442        Some(fl) => fl.clone().try_to_bool(vm)?,
443        None => false,
444    };
445
446    if has_from {
447        let fromlist = fromlist.unwrap();
448        // Only call _handle_fromlist if the module looks like a package
449        // (has __path__). Non-module objects without __name__/__path__ would
450        // crash inside _handle_fromlist; IMPORT_FROM handles per-attribute
451        // errors with proper ImportError conversion.
452        let has_path = vm
453            .get_attribute_opt(module.clone(), vm.ctx.intern_str("__path__"))?
454            .is_some();
455        if has_path {
456            let handle_fromlist = vm.importlib.get_attr("_handle_fromlist", vm)?;
457            handle_fromlist.call((module, fromlist, vm.import_func.clone()), vm)
458        } else {
459            Ok(module)
460        }
461    } else if level == 0 || !name_str.is_empty() {
462        match name_str.find('.') {
463            None => Ok(module),
464            Some(dot) => {
465                let to_return = if level == 0 {
466                    name_str[..dot].to_owned()
467                } else {
468                    let cut_off = name_str.len() - dot;
469                    abs_name[..abs_name.len() - cut_off].to_owned()
470                };
471                match sys_modules.get_item(&*to_return, vm) {
472                    Ok(m) => Ok(m),
473                    Err(_) if level == 0 => {
474                        // For absolute imports (level 0), try importing the
475                        // parent. Matches _bootstrap.__import__ behavior.
476                        let find_and_load = vm.importlib.get_attr("_find_and_load", vm)?;
477                        let to_return_obj = vm.ctx.new_utf8_str(&*to_return);
478                        find_and_load.call((to_return_obj, vm.import_func.clone()), vm)
479                    }
480                    Err(_) => {
481                        // For relative imports (level > 0), raise KeyError
482                        let to_return_obj: PyObjectRef = vm
483                            .ctx
484                            .new_utf8_str(format!("'{to_return}' not in sys.modules as expected"))
485                            .into();
486                        Err(vm.new_key_error(to_return_obj))
487                    }
488                }
489            }
490        }
491    } else {
492        Ok(module)
493    }
494}
495
496/// resolve_name in import.c - resolve relative import name
497fn resolve_name(name: &str, package: &str, level: usize, vm: &VirtualMachine) -> PyResult<String> {
498    // Python: bits = package.rsplit('.', level - 1)
499    // Rust: rsplitn(level, '.') gives maxsplit=level-1
500    let parts: Vec<&str> = package.rsplitn(level, '.').collect();
501    if parts.len() < level {
502        return Err(vm.new_import_error(
503            "attempted relative import beyond top-level package",
504            vm.ctx.new_utf8_str(name),
505        ));
506    }
507    // rsplitn returns parts right-to-left, so last() is the leftmost (base)
508    let base = parts.last().unwrap();
509    if name.is_empty() {
510        Ok(base.to_string())
511    } else {
512        Ok(format!("{base}.{name}"))
513    }
514}
515
516/// _calc___package__ - calculate package from globals for relative imports
517fn calc_package(globals: Option<&PyObjectRef>, vm: &VirtualMachine) -> PyResult<String> {
518    let globals = globals.ok_or_else(|| {
519        vm.new_import_error(
520            "attempted relative import with no known parent package",
521            vm.ctx.new_utf8_str(""),
522        )
523    })?;
524
525    let package = globals.get_item("__package__", vm).ok();
526    let spec = globals.get_item("__spec__", vm).ok();
527
528    if let Some(ref pkg) = package
529        && !vm.is_none(pkg)
530    {
531        let pkg_str: PyUtf8StrRef = pkg
532            .clone()
533            .downcast()
534            .map_err(|_| vm.new_type_error("package must be a string"))?;
535        // Warn if __package__ != __spec__.parent
536        if let Some(ref spec) = spec
537            && !vm.is_none(spec)
538            && let Ok(parent) = spec.get_attr("parent", vm)
539            && !pkg_str.is(&parent)
540            && pkg_str
541                .as_object()
542                .rich_compare_bool(&parent, crate::types::PyComparisonOp::Ne, vm)
543                .unwrap_or(false)
544        {
545            let parent_repr = parent
546                .repr_utf8(vm)
547                .map(|s| s.as_str().to_owned())
548                .unwrap_or_default();
549            let msg = format!(
550                "__package__ != __spec__.parent ('{}' != {})",
551                pkg_str.as_str(),
552                parent_repr
553            );
554            let warn = vm
555                .import("_warnings", 0)
556                .and_then(|w| w.get_attr("warn", vm));
557            if let Ok(warn_fn) = warn {
558                let _ = warn_fn.call(
559                    (
560                        vm.ctx.new_str(msg),
561                        vm.ctx.exceptions.deprecation_warning.to_owned(),
562                    ),
563                    vm,
564                );
565            }
566        }
567        return Ok(pkg_str.as_str().to_owned());
568    } else if let Some(ref spec) = spec
569        && !vm.is_none(spec)
570        && let Ok(parent) = spec.get_attr("parent", vm)
571        && !vm.is_none(&parent)
572    {
573        let parent_str: PyUtf8StrRef = parent
574            .downcast()
575            .map_err(|_| vm.new_type_error("package set to non-string"))?;
576        return Ok(parent_str.as_str().to_owned());
577    }
578
579    // Fall back to __name__ and __path__
580    let warn = vm
581        .import("_warnings", 0)
582        .and_then(|w| w.get_attr("warn", vm));
583    if let Ok(warn_fn) = warn {
584        let _ = warn_fn.call(
585            (
586                vm.ctx.new_str("can't resolve package from __spec__ or __package__, falling back on __name__ and __path__"),
587                vm.ctx.exceptions.import_warning.to_owned(),
588            ),
589            vm,
590        );
591    }
592
593    let mod_name = globals.get_item("__name__", vm).map_err(|_| {
594        vm.new_import_error(
595            "attempted relative import with no known parent package",
596            vm.ctx.new_utf8_str(""),
597        )
598    })?;
599    let mod_name_str: PyUtf8StrRef = mod_name
600        .downcast()
601        .map_err(|_| vm.new_type_error("__name__ must be a string"))?;
602    let mut package = mod_name_str.as_str().to_owned();
603    // If not a package (no __path__), strip last component.
604    // Uses rpartition('.')[0] semantics: returns empty string when no dot.
605    if globals.get_item("__path__", vm).is_err() {
606        package = match package.rfind('.') {
607            Some(dot) => package[..dot].to_owned(),
608            None => String::new(),
609        };
610    }
611    Ok(package)
612}