Skip to main content

rustpython_vm/builtins/
template.rs

1use super::{
2    PyStr, PyTupleRef, PyType, PyTypeRef, genericalias::PyGenericAlias,
3    interpolation::PyInterpolation,
4};
5use crate::{
6    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
7    atomic_func,
8    class::PyClassImpl,
9    common::lock::LazyLock,
10    function::{FuncArgs, PyComparisonValue},
11    protocol::{PyIterReturn, PySequenceMethods},
12    types::{
13        AsSequence, Comparable, Constructor, IterNext, Iterable, PyComparisonOp, Representable,
14        SelfIter,
15    },
16};
17use rustpython_common::wtf8::{Wtf8Buf, wtf8_concat};
18
19/// Template object for t-strings (PEP 750).
20///
21/// Represents a template string with interpolated expressions.
22#[pyclass(module = "string.templatelib", name = "Template")]
23#[derive(Debug, Clone)]
24pub struct PyTemplate {
25    pub strings: PyTupleRef,
26    pub interpolations: PyTupleRef,
27}
28
29impl PyPayload for PyTemplate {
30    #[inline]
31    fn class(ctx: &Context) -> &'static Py<PyType> {
32        ctx.types.template_type
33    }
34}
35
36impl PyTemplate {
37    pub fn new(strings: PyTupleRef, interpolations: PyTupleRef) -> Self {
38        Self {
39            strings,
40            interpolations,
41        }
42    }
43}
44
45impl Constructor for PyTemplate {
46    type Args = FuncArgs;
47
48    fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
49        if !args.kwargs.is_empty() {
50            return Err(vm.new_type_error("Template.__new__ only accepts *args arguments"));
51        }
52
53        let mut strings: Vec<PyObjectRef> = Vec::new();
54        let mut interpolations: Vec<PyObjectRef> = Vec::new();
55        let mut last_was_str = false;
56
57        for item in args.args.iter() {
58            if let Ok(s) = item.clone().downcast::<PyStr>() {
59                if last_was_str {
60                    // Concatenate adjacent strings
61                    if let Some(last) = strings.last_mut() {
62                        let last_str = last.downcast_ref::<PyStr>().unwrap();
63                        let mut buf = last_str.as_wtf8().to_owned();
64                        buf.push_wtf8(s.as_wtf8());
65                        *last = vm.ctx.new_str(buf).into();
66                    }
67                } else {
68                    strings.push(s.into());
69                }
70                last_was_str = true;
71            } else if item.class().is(vm.ctx.types.interpolation_type) {
72                if !last_was_str {
73                    // Add empty string before interpolation
74                    strings.push(vm.ctx.empty_str.to_owned().into());
75                }
76                interpolations.push(item.clone());
77                last_was_str = false;
78            } else {
79                return Err(vm.new_type_error(format!(
80                    "Template.__new__ *args need to be of type 'str' or 'Interpolation', got {}",
81                    item.class().name()
82                )));
83            }
84        }
85
86        if !last_was_str {
87            // Add trailing empty string
88            strings.push(vm.ctx.empty_str.to_owned().into());
89        }
90
91        Ok(PyTemplate {
92            strings: vm.ctx.new_tuple(strings),
93            interpolations: vm.ctx.new_tuple(interpolations),
94        })
95    }
96}
97
98#[pyclass(with(Constructor, Comparable, Iterable, Representable, AsSequence))]
99impl PyTemplate {
100    #[pygetset]
101    fn strings(&self) -> PyTupleRef {
102        self.strings.clone()
103    }
104
105    #[pygetset]
106    fn interpolations(&self) -> PyTupleRef {
107        self.interpolations.clone()
108    }
109
110    #[pygetset]
111    fn values(&self, vm: &VirtualMachine) -> PyTupleRef {
112        let values: Vec<PyObjectRef> = self
113            .interpolations
114            .iter()
115            .map(|interp| {
116                interp
117                    .downcast_ref::<PyInterpolation>()
118                    .map(|i| i.value.clone())
119                    .unwrap_or_else(|| interp.clone())
120            })
121            .collect();
122        vm.ctx.new_tuple(values)
123    }
124
125    fn concat(&self, other: &PyObject, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
126        let other = other.downcast_ref::<PyTemplate>().ok_or_else(|| {
127            vm.new_type_error(format!(
128                "can only concatenate Template (not '{}') to Template",
129                other.class().name()
130            ))
131        })?;
132
133        // Concatenate the two templates
134        let mut new_strings: Vec<PyObjectRef> = Vec::new();
135        let mut new_interps: Vec<PyObjectRef> = Vec::new();
136
137        // Add all strings from self except the last one
138        let self_strings_len = self.strings.len();
139        for i in 0..self_strings_len.saturating_sub(1) {
140            new_strings.push(self.strings.get(i).unwrap().clone());
141        }
142
143        // Add all interpolations from self
144        for interp in self.interpolations.iter() {
145            new_interps.push(interp.clone());
146        }
147
148        // Concatenate last string of self with first string of other
149        let mut buf = Wtf8Buf::new();
150        if let Some(s) = self
151            .strings
152            .get(self_strings_len.saturating_sub(1))
153            .and_then(|s| s.downcast_ref::<PyStr>())
154        {
155            buf.push_wtf8(s.as_wtf8());
156        }
157        if let Some(s) = other
158            .strings
159            .first()
160            .and_then(|s| s.downcast_ref::<PyStr>())
161        {
162            buf.push_wtf8(s.as_wtf8());
163        }
164        new_strings.push(vm.ctx.new_str(buf).into());
165
166        // Add remaining strings from other (skip first)
167        for i in 1..other.strings.len() {
168            new_strings.push(other.strings.get(i).unwrap().clone());
169        }
170
171        // Add all interpolations from other
172        for interp in other.interpolations.iter() {
173            new_interps.push(interp.clone());
174        }
175
176        let template = PyTemplate {
177            strings: vm.ctx.new_tuple(new_strings),
178            interpolations: vm.ctx.new_tuple(new_interps),
179        };
180
181        Ok(template.into_ref(&vm.ctx))
182    }
183
184    fn __add__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
185        self.concat(&other, vm)
186    }
187
188    #[pyclassmethod]
189    fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
190        PyGenericAlias::from_args(cls, args, vm)
191    }
192
193    #[pymethod]
194    fn __reduce__(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
195        // Import string.templatelib._template_unpickle
196        // We need to import string first, then get templatelib from it,
197        // because import("string.templatelib", 0) with empty from_list returns the top-level module
198        let string_mod = vm.import("string.templatelib", 0)?;
199        let templatelib = string_mod.get_attr("templatelib", vm)?;
200        let unpickle_func = templatelib.get_attr("_template_unpickle", vm)?;
201
202        // Return (func, (strings, interpolations))
203        let args = vm.ctx.new_tuple(vec![
204            self.strings.clone().into(),
205            self.interpolations.clone().into(),
206        ]);
207        Ok(vm.ctx.new_tuple(vec![unpickle_func, args.into()]))
208    }
209}
210
211impl AsSequence for PyTemplate {
212    fn as_sequence() -> &'static PySequenceMethods {
213        static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
214            concat: atomic_func!(|seq, other, vm| {
215                let zelf = PyTemplate::sequence_downcast(seq);
216                zelf.concat(other, vm).map(|t| t.into())
217            }),
218            ..PySequenceMethods::NOT_IMPLEMENTED
219        });
220        &AS_SEQUENCE
221    }
222}
223
224impl Comparable for PyTemplate {
225    fn cmp(
226        zelf: &Py<Self>,
227        other: &PyObject,
228        op: PyComparisonOp,
229        vm: &VirtualMachine,
230    ) -> PyResult<PyComparisonValue> {
231        op.eq_only(|| {
232            let other = class_or_notimplemented!(Self, other);
233
234            let eq = vm.bool_eq(zelf.strings.as_object(), other.strings.as_object())?
235                && vm.bool_eq(
236                    zelf.interpolations.as_object(),
237                    other.interpolations.as_object(),
238                )?;
239
240            Ok(eq.into())
241        })
242    }
243}
244
245impl Iterable for PyTemplate {
246    fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
247        Ok(PyTemplateIter::new(zelf).into_pyobject(vm))
248    }
249}
250
251impl Representable for PyTemplate {
252    #[inline]
253    fn repr_wtf8(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
254        let strings_repr = zelf.strings.as_object().repr(vm)?;
255        let interp_repr = zelf.interpolations.as_object().repr(vm)?;
256        Ok(wtf8_concat!(
257            "Template(strings=",
258            strings_repr.as_wtf8(),
259            ", interpolations=",
260            interp_repr.as_wtf8(),
261            ')',
262        ))
263    }
264}
265
266/// Iterator for Template objects
267#[pyclass(module = "string.templatelib", name = "TemplateIter")]
268#[derive(Debug)]
269pub struct PyTemplateIter {
270    template: PyRef<PyTemplate>,
271    index: core::sync::atomic::AtomicUsize,
272    from_strings: core::sync::atomic::AtomicBool,
273}
274
275impl PyPayload for PyTemplateIter {
276    #[inline]
277    fn class(ctx: &Context) -> &'static Py<PyType> {
278        ctx.types.template_iter_type
279    }
280}
281
282impl PyTemplateIter {
283    fn new(template: PyRef<PyTemplate>) -> Self {
284        Self {
285            template,
286            index: core::sync::atomic::AtomicUsize::new(0),
287            from_strings: core::sync::atomic::AtomicBool::new(true),
288        }
289    }
290}
291
292#[pyclass(with(IterNext, Iterable))]
293impl PyTemplateIter {}
294
295impl SelfIter for PyTemplateIter {}
296
297impl IterNext for PyTemplateIter {
298    fn next(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
299        use core::sync::atomic::Ordering;
300
301        loop {
302            let from_strings = zelf.from_strings.load(Ordering::SeqCst);
303            let index = zelf.index.load(Ordering::SeqCst);
304
305            if from_strings {
306                if index < zelf.template.strings.len() {
307                    let item = zelf.template.strings.get(index).unwrap();
308                    zelf.from_strings.store(false, Ordering::SeqCst);
309
310                    // Skip empty strings
311                    if let Some(s) = item.downcast_ref::<PyStr>()
312                        && s.as_wtf8().is_empty()
313                    {
314                        continue;
315                    }
316                    return Ok(PyIterReturn::Return(item.clone()));
317                } else {
318                    return Ok(PyIterReturn::StopIteration(None));
319                }
320            } else if index < zelf.template.interpolations.len() {
321                let item = zelf.template.interpolations.get(index).unwrap();
322                zelf.index.fetch_add(1, Ordering::SeqCst);
323                zelf.from_strings.store(true, Ordering::SeqCst);
324                return Ok(PyIterReturn::Return(item.clone()));
325            } else {
326                return Ok(PyIterReturn::StopIteration(None));
327            }
328        }
329    }
330}
331
332pub fn init(context: &'static Context) {
333    PyTemplate::extend_class(context, context.types.template_type);
334    PyTemplateIter::extend_class(context, context.types.template_iter_type);
335}