Skip to main content

rustpython_vm/builtins/
interpolation.rs

1use super::{
2    PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef, genericalias::PyGenericAlias,
3    tuple::IntoPyTuple,
4};
5use crate::{
6    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
7    class::PyClassImpl,
8    common::hash::PyHash,
9    convert::ToPyObject,
10    function::{OptionalArg, PyComparisonValue},
11    types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable},
12};
13use itertools::Itertools;
14use rustpython_common::wtf8::Wtf8Buf;
15
16/// Interpolation object for t-strings (PEP 750).
17///
18/// Represents an interpolated expression within a template string.
19#[pyclass(module = "string.templatelib", name = "Interpolation")]
20#[derive(Debug, Clone)]
21pub struct PyInterpolation {
22    pub value: PyObjectRef,
23    pub expression: PyStrRef,
24    pub conversion: PyObjectRef, // None or 's', 'r', 'a'
25    pub format_spec: PyStrRef,
26}
27
28impl PyPayload for PyInterpolation {
29    #[inline]
30    fn class(ctx: &Context) -> &'static Py<PyType> {
31        ctx.types.interpolation_type
32    }
33}
34
35impl PyInterpolation {
36    pub fn new(
37        value: PyObjectRef,
38        expression: PyStrRef,
39        conversion: PyObjectRef,
40        format_spec: PyStrRef,
41        vm: &VirtualMachine,
42    ) -> PyResult<Self> {
43        // Validate conversion like _PyInterpolation_Build does
44        let is_valid = vm.is_none(&conversion)
45            || conversion
46                .downcast_ref::<PyStr>()
47                .is_some_and(|s| matches!(s.to_str(), Some("s") | Some("r") | Some("a")));
48        if !is_valid {
49            return Err(vm.new_exception_msg(
50                vm.ctx.exceptions.system_error.to_owned(),
51                "Interpolation() argument 'conversion' must be one of 's', 'a' or 'r'".into(),
52            ));
53        }
54        Ok(Self {
55            value,
56            expression,
57            conversion,
58            format_spec,
59        })
60    }
61}
62
63impl Constructor for PyInterpolation {
64    type Args = InterpolationArgs;
65
66    fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
67        let conversion: PyObjectRef = if let Some(s) = args.conversion {
68            let has_flag = s
69                .as_bytes()
70                .iter()
71                .exactly_one()
72                .ok()
73                .is_some_and(|s| matches!(*s, b's' | b'r' | b'a'));
74            if !has_flag {
75                return Err(vm.new_value_error(
76                    "Interpolation() argument 'conversion' must be one of 's', 'a' or 'r'",
77                ));
78            }
79            s.into()
80        } else {
81            vm.ctx.none()
82        };
83
84        let expression = args
85            .expression
86            .unwrap_or_else(|| vm.ctx.empty_str.to_owned());
87        let format_spec = args
88            .format_spec
89            .unwrap_or_else(|| vm.ctx.empty_str.to_owned());
90
91        Ok(PyInterpolation {
92            value: args.value,
93            expression,
94            conversion,
95            format_spec,
96        })
97    }
98}
99
100#[derive(FromArgs)]
101pub struct InterpolationArgs {
102    #[pyarg(positional)]
103    value: PyObjectRef,
104    #[pyarg(any, optional)]
105    expression: OptionalArg<PyStrRef>,
106    #[pyarg(
107        any,
108        optional,
109        error_msg = "Interpolation() argument 'conversion' must be str or None"
110    )]
111    conversion: Option<PyStrRef>,
112    #[pyarg(any, optional)]
113    format_spec: OptionalArg<PyStrRef>,
114}
115
116#[pyclass(with(Constructor, Comparable, Hashable, Representable))]
117impl PyInterpolation {
118    #[pyattr]
119    fn __match_args__(ctx: &Context) -> PyTupleRef {
120        ctx.new_tuple(vec![
121            ctx.intern_str("value").to_owned().into(),
122            ctx.intern_str("expression").to_owned().into(),
123            ctx.intern_str("conversion").to_owned().into(),
124            ctx.intern_str("format_spec").to_owned().into(),
125        ])
126    }
127
128    #[pygetset]
129    fn value(&self) -> PyObjectRef {
130        self.value.clone()
131    }
132
133    #[pygetset]
134    fn expression(&self) -> PyStrRef {
135        self.expression.clone()
136    }
137
138    #[pygetset]
139    fn conversion(&self) -> PyObjectRef {
140        self.conversion.clone()
141    }
142
143    #[pygetset]
144    fn format_spec(&self) -> PyStrRef {
145        self.format_spec.clone()
146    }
147
148    #[pyclassmethod]
149    fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
150        PyGenericAlias::from_args(cls, args, vm)
151    }
152
153    #[pymethod]
154    fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
155        let cls = zelf.class().to_owned();
156        let args = (
157            zelf.value.clone(),
158            zelf.expression.clone(),
159            zelf.conversion.clone(),
160            zelf.format_spec.clone(),
161        );
162        (cls, args.to_pyobject(vm)).into_pytuple(vm)
163    }
164}
165
166impl Comparable for PyInterpolation {
167    fn cmp(
168        zelf: &Py<Self>,
169        other: &PyObject,
170        op: PyComparisonOp,
171        vm: &VirtualMachine,
172    ) -> PyResult<PyComparisonValue> {
173        op.eq_only(|| {
174            let other = class_or_notimplemented!(Self, other);
175
176            let eq = vm.bool_eq(&zelf.value, &other.value)?
177                && vm.bool_eq(zelf.expression.as_object(), other.expression.as_object())?
178                && vm.bool_eq(&zelf.conversion, &other.conversion)?
179                && vm.bool_eq(zelf.format_spec.as_object(), other.format_spec.as_object())?;
180
181            Ok(eq.into())
182        })
183    }
184}
185
186impl Hashable for PyInterpolation {
187    fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
188        // Hash based on (value, expression, conversion, format_spec)
189        let value_hash = zelf.value.hash(vm)?;
190        let expr_hash = zelf.expression.as_object().hash(vm)?;
191        let conv_hash = zelf.conversion.hash(vm)?;
192        let spec_hash = zelf.format_spec.as_object().hash(vm)?;
193
194        // Combine hashes
195        Ok(value_hash
196            .wrapping_add(expr_hash.wrapping_mul(3))
197            .wrapping_add(conv_hash.wrapping_mul(5))
198            .wrapping_add(spec_hash.wrapping_mul(7)))
199    }
200}
201
202impl Representable for PyInterpolation {
203    #[inline]
204    fn repr_wtf8(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
205        let value_repr = zelf.value.repr(vm)?;
206        let expr_repr = zelf.expression.repr(vm)?;
207        let spec_repr = zelf.format_spec.repr(vm)?;
208
209        let mut result = Wtf8Buf::from("Interpolation(");
210        result.push_wtf8(value_repr.as_wtf8());
211        result.push_str(", ");
212        result.push_str(&expr_repr);
213        result.push_str(", ");
214        if vm.is_none(&zelf.conversion) {
215            result.push_str("None");
216        } else {
217            result.push_wtf8(zelf.conversion.repr(vm)?.as_wtf8());
218        }
219        result.push_str(", ");
220        result.push_str(&spec_repr);
221        result.push_char(')');
222
223        Ok(result)
224    }
225}
226
227pub fn init(context: &'static Context) {
228    PyInterpolation::extend_class(context, context.types.interpolation_type);
229}