Skip to main content

tlua/
rust_tables.rs

1use crate::{
2    ffi,
3    lua_tables::LuaTable,
4    tuples::TuplePushError::{self, First, Other},
5    AsLua, LuaRead, LuaState, Push, PushGuard, PushInto, PushOne, PushOneInto, ReadResult, Void,
6    WrongType,
7};
8
9use std::collections::{BTreeMap, HashMap, HashSet};
10use std::fmt::{self, Debug};
11use std::hash::Hash;
12use std::iter;
13use std::num::NonZeroI32;
14
15#[inline]
16pub(crate) fn push_iter<L, I>(lua: L, iterator: I) -> Result<PushGuard<L>, (PushIterErrorOf<I>, L)>
17where
18    L: AsLua,
19    I: Iterator,
20    <I as Iterator>::Item: PushInto<LuaState>,
21{
22    // creating empty table
23    unsafe { ffi::lua_newtable(lua.as_lua()) };
24
25    for (elem, index) in iterator.zip(1..) {
26        let size = match elem.push_into_lua(lua.as_lua()) {
27            Ok(pushed) => pushed.forget_internal(),
28            Err((err, _)) => unsafe {
29                // TODO(gmoshkin): return an error capturing this push guard
30                // drop the lua table
31                drop(PushGuard::new(lua.as_lua(), 1));
32                return Err((PushIterError::ValuePushError(err), lua));
33            },
34        };
35
36        match size {
37            0 => continue,
38            1 => {
39                lua.as_lua().push_one(index).forget_internal();
40                unsafe { ffi::lua_insert(lua.as_lua(), -2) }
41                unsafe { ffi::lua_settable(lua.as_lua(), -3) }
42            }
43            2 => unsafe { ffi::lua_settable(lua.as_lua(), -3) },
44            n => unsafe {
45                // TODO(gmoshkin): return an error capturing this push guard
46                // n + 1 == n values from the recent push + lua table
47                drop(PushGuard::new(lua.as_lua(), n + 1));
48                return Err((PushIterError::TooManyValues(n), lua));
49            },
50        }
51    }
52
53    unsafe { Ok(PushGuard::new(lua, 1)) }
54}
55
56pub type PushIterErrorOf<I> = PushIterError<<<I as Iterator>::Item as PushInto<LuaState>>::Err>;
57
58#[derive(Debug, PartialEq, Eq)]
59pub enum PushIterError<E> {
60    TooManyValues(i32),
61    ValuePushError(E),
62}
63
64impl<E> PushIterError<E> {
65    pub fn map<F, R>(self, f: F) -> PushIterError<R>
66    where
67        F: FnOnce(E) -> R,
68    {
69        match self {
70            Self::ValuePushError(e) => PushIterError::ValuePushError(f(e)),
71            Self::TooManyValues(n) => PushIterError::TooManyValues(n),
72        }
73    }
74}
75
76impl<E> fmt::Display for PushIterError<E>
77where
78    E: fmt::Display,
79{
80    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
81        match self {
82            Self::TooManyValues(n) => {
83                write!(
84                    fmt,
85                    "Can only push 1 or 2 values as lua table item, got {n} instead",
86                )
87            }
88            Self::ValuePushError(e) => {
89                write!(fmt, "Pushing iterable item failed: {e}")
90            }
91        }
92    }
93}
94
95// NOTE: only the following From<_> for Void implementations are correct,
96//       don't add other ones!
97
98// T::Err: Void => no error possible
99// NOTE: making this one generic would conflict with the below implementations.
100impl From<PushIterError<Void>> for Void {
101    fn from(_: PushIterError<Void>) -> Self {
102        unreachable!("no way to create instance of Void")
103    }
104}
105
106// T::Err: Void; (T,) => no error possible
107impl<T> From<PushIterError<TuplePushError<T, Void>>> for Void
108where
109    T: Into<Void>,
110{
111    fn from(_: PushIterError<TuplePushError<T, Void>>) -> Self {
112        unreachable!("no way to create instance of Void")
113    }
114}
115
116// K::Err: Void; V::Err: Void; (K, V) => no error possible
117impl<K, V> From<PushIterError<TuplePushError<K, TuplePushError<V, Void>>>> for Void
118where
119    K: Into<Void>,
120    V: Into<Void>,
121{
122    fn from(_: PushIterError<TuplePushError<K, TuplePushError<V, Void>>>) -> Self {
123        unreachable!("no way to create instance of Void")
124    }
125}
126
127////////////////////////////////////////////////////////////////////////////////
128// TableFromIter
129////////////////////////////////////////////////////////////////////////////////
130
131/// A wrapper struct for converting arbitrary iterators into lua tables. Use
132/// this instead of converting the iterator into a `Vec` to avoid unnecessary
133/// allocations
134/// # Example
135/// ```no_run
136/// use std::io::BufRead;
137/// let lua = tlua::Lua::new();
138/// lua.set(
139///     "foo",
140///     tlua::TableFromIter(std::io::stdin().lock().lines().flatten()),
141/// )
142/// // Global variable 'foo' now contains an array of lines read from stdin
143/// ```
144pub struct TableFromIter<I>(pub I);
145
146impl<L, I> PushInto<L> for TableFromIter<I>
147where
148    L: AsLua,
149    I: Iterator,
150    <I as Iterator>::Item: PushInto<LuaState>,
151{
152    type Err = PushIterError<<I::Item as PushInto<LuaState>>::Err>;
153
154    fn push_into_lua(self, lua: L) -> crate::PushIntoResult<L, Self> {
155        push_iter(lua, self.0)
156    }
157}
158
159impl<L, I> PushOneInto<L> for TableFromIter<I>
160where
161    L: AsLua,
162    I: Iterator,
163    <I as Iterator>::Item: PushInto<LuaState>,
164{
165}
166
167////////////////////////////////////////////////////////////////////////////////
168// Vec
169////////////////////////////////////////////////////////////////////////////////
170
171impl<L, T> Push<L> for Vec<T>
172where
173    L: AsLua,
174    T: Push<LuaState>,
175{
176    type Err = PushIterError<T::Err>;
177
178    #[inline]
179    fn push_to_lua(&self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
180        push_iter(lua, self.iter())
181    }
182}
183
184impl<L, T> PushOne<L> for Vec<T>
185where
186    L: AsLua,
187    T: Push<LuaState>,
188{
189}
190
191impl<L, T> PushInto<L> for Vec<T>
192where
193    L: AsLua,
194    T: PushInto<LuaState>,
195{
196    type Err = PushIterError<T::Err>;
197
198    #[inline]
199    fn push_into_lua(self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
200        push_iter(lua, self.into_iter())
201    }
202}
203
204impl<L, T> PushOneInto<L> for Vec<T>
205where
206    L: AsLua,
207    T: PushInto<LuaState>,
208{
209}
210
211impl<L, T> LuaRead<L> for Vec<T>
212where
213    L: AsLua,
214    T: for<'a> LuaRead<PushGuard<&'a LuaTable<L>>>,
215    T: 'static,
216{
217    fn lua_read_at_position(lua: L, index: NonZeroI32) -> ReadResult<Self, L> {
218        // We need this as iteration order isn't guaranteed to match order of
219        // keys, even if they're numeric
220        // https://www.lua.org/manual/5.2/manual.html#pdf-next
221        let table = LuaTable::lua_read_at_position(lua, index)?;
222        let mut dict: BTreeMap<i32, T> = BTreeMap::new();
223
224        let mut max_key = i32::MIN;
225        let mut min_key = i32::MAX;
226
227        {
228            let mut iter = table.iter::<i32, T>();
229            while let Some(maybe_kv) = iter.next() {
230                let (key, value) = crate::unwrap_ok_or! { maybe_kv,
231                    Err(e) => {
232                        drop(iter);
233                        let lua = table.into_inner();
234                        let e = e.when("converting Lua table to Vec<_>")
235                            .expected_type::<Self>();
236                        return Err((lua, e))
237                    }
238                };
239                max_key = max_key.max(key);
240                min_key = min_key.min(key);
241                dict.insert(key, value);
242            }
243        }
244
245        if dict.is_empty() {
246            return Ok(vec![]);
247        }
248
249        if min_key != 1 {
250            // Rust doesn't support sparse arrays or arrays with negative
251            // indices
252            let e = WrongType::info("converting Lua table to Vec<_>")
253                .expected("indexes in range 1..N")
254                .actual(format!("value with index {min_key}"));
255            return Err((table.into_inner(), e));
256        }
257
258        let mut result = Vec::with_capacity(max_key as _);
259
260        // We expect to start with first element of table and have this
261        // be smaller that first key by one
262        let mut previous_key = 0;
263
264        // By this point, we actually iterate the map to move values to Vec
265        // and check that table represented non-sparse 1-indexed array
266        for (k, v) in dict {
267            if previous_key + 1 != k {
268                let e = WrongType::info("converting Lua table to Vec<_>")
269                    .expected("indexes in range 1..N")
270                    .actual(format!("Lua table with missing index {}", previous_key + 1));
271                return Err((table.into_inner(), e));
272            } else {
273                // We just push, thus converting Lua 1-based indexing
274                // to Rust 0-based indexing
275                result.push(v);
276                previous_key = k;
277            }
278        }
279
280        Ok(result)
281    }
282}
283
284////////////////////////////////////////////////////////////////////////////////
285// \[T]
286////////////////////////////////////////////////////////////////////////////////
287
288impl<L, T> Push<L> for [T]
289where
290    L: AsLua,
291    T: Push<LuaState>,
292{
293    type Err = PushIterError<T::Err>;
294
295    #[inline]
296    fn push_to_lua(&self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
297        push_iter(lua, self.iter())
298    }
299}
300
301impl<L, T> PushOne<L> for [T]
302where
303    L: AsLua,
304    T: Push<LuaState>,
305{
306}
307
308////////////////////////////////////////////////////////////////////////////////
309// [T; N]
310////////////////////////////////////////////////////////////////////////////////
311
312impl<L, T, const N: usize> Push<L> for [T; N]
313where
314    L: AsLua,
315    T: Push<LuaState>,
316{
317    type Err = PushIterError<T::Err>;
318
319    #[inline]
320    fn push_to_lua(&self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
321        push_iter(lua, self.iter())
322    }
323}
324
325impl<L, T, const N: usize> PushOne<L> for [T; N]
326where
327    L: AsLua,
328    T: Push<LuaState>,
329{
330}
331
332impl<L, T, const N: usize> PushInto<L> for [T; N]
333where
334    L: AsLua,
335    T: PushInto<LuaState>,
336{
337    type Err = PushIterError<T::Err>;
338
339    #[inline]
340    fn push_into_lua(self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
341        push_iter(lua, IntoIterator::into_iter(self))
342    }
343}
344
345impl<L, T, const N: usize> PushOneInto<L> for [T; N]
346where
347    L: AsLua,
348    T: PushInto<LuaState>,
349{
350}
351
352impl<L, T, const N: usize> LuaRead<L> for [T; N]
353where
354    L: AsLua,
355    T: for<'a> LuaRead<PushGuard<&'a LuaTable<L>>>,
356    T: 'static,
357{
358    fn lua_read_at_position(lua: L, index: NonZeroI32) -> ReadResult<Self, L> {
359        let table = LuaTable::lua_read_at_position(lua, index)?;
360        let mut res = std::mem::MaybeUninit::uninit();
361        let ptr = &mut res as *mut _ as *mut [T; N] as *mut T;
362        let mut was_assigned = [false; N];
363        let mut err = None;
364
365        for maybe_kv in table.iter::<i32, T>() {
366            match maybe_kv {
367                Ok((key, value)) if 1 <= key && key as usize <= N => {
368                    let i = (key - 1) as usize;
369                    unsafe { std::ptr::write(ptr.add(i), value) }
370                    was_assigned[i] = true;
371                }
372                Err(e) => {
373                    err = Some(Error::Subtype(e));
374                    break;
375                }
376                Ok((index, _)) => {
377                    err = Some(Error::WrongIndex(index));
378                    break;
379                }
380            }
381        }
382
383        if err.is_none() {
384            err = was_assigned
385                .iter()
386                .zip(1..)
387                .find(|(&was_assigned, _)| !was_assigned)
388                .map(|(_, i)| Error::MissingIndex(i));
389        }
390
391        let err = crate::unwrap_or! { err,
392            return Ok(unsafe { res.assume_init() });
393        };
394
395        for i in IntoIterator::into_iter(was_assigned)
396            .enumerate()
397            .flat_map(|(i, was_assigned)| was_assigned.then_some(i))
398        {
399            unsafe { std::ptr::drop_in_place(ptr.add(i)) }
400        }
401
402        let when = "converting Lua table to array";
403        let e = match err {
404            Error::Subtype(err) => err.when(when).expected_type::<Self>(),
405            Error::WrongIndex(index) => WrongType::info(when)
406                .expected(format!("indexes in range 1..={N}"))
407                .actual(format!("value with index {index}")),
408            Error::MissingIndex(index) => WrongType::info(when)
409                .expected(format!("indexes in range 1..={N}"))
410                .actual(format!("Lua table with missing index {index}")),
411        };
412        return Err((table.into_inner(), e));
413
414        enum Error {
415            Subtype(WrongType),
416            WrongIndex(i32),
417            MissingIndex(i32),
418        }
419    }
420}
421
422////////////////////////////////////////////////////////////////////////////////
423// HashMap
424////////////////////////////////////////////////////////////////////////////////
425
426impl<L, K, V, S> LuaRead<L> for HashMap<K, V, S>
427where
428    L: AsLua,
429    K: 'static + Hash + Eq,
430    K: for<'k> LuaRead<&'k LuaTable<L>>,
431    V: 'static,
432    V: for<'v> LuaRead<PushGuard<&'v LuaTable<L>>>,
433    S: Default,
434    S: std::hash::BuildHasher,
435{
436    fn lua_read_at_position(lua: L, index: NonZeroI32) -> ReadResult<Self, L> {
437        let table = LuaTable::lua_read_at_position(lua, index)?;
438        let res: Result<_, _> = table.iter().collect();
439        res.map_err(|err| {
440            let l = table.into_inner();
441            let e = err
442                .when("converting Lua table to HashMap<_, _>")
443                .expected_type::<Self>();
444            (l, e)
445        })
446    }
447}
448
449macro_rules! push_hashmap_impl {
450    ($self:expr, $lua:expr) => {
451        push_iter($lua, $self.into_iter()).map_err(|(e, lua)| match e {
452            PushIterError::TooManyValues(_) => unreachable!("K and V implement PushOne"),
453            PushIterError::ValuePushError(First(e)) => (First(e), lua),
454            PushIterError::ValuePushError(Other(e)) => (Other(e.first()), lua),
455        })
456    };
457}
458
459impl<L, K, V, S> Push<L> for HashMap<K, V, S>
460where
461    L: AsLua,
462    K: PushOne<LuaState> + Eq + Hash + Debug,
463    V: PushOne<LuaState> + Debug,
464{
465    type Err = TuplePushError<K::Err, V::Err>;
466
467    #[inline]
468    fn push_to_lua(&self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
469        push_hashmap_impl!(self, lua)
470    }
471}
472
473impl<L, K, V, S> PushOne<L> for HashMap<K, V, S>
474where
475    L: AsLua,
476    K: PushOne<LuaState> + Eq + Hash + Debug,
477    V: PushOne<LuaState> + Debug,
478{
479}
480
481impl<L, K, V, S> PushInto<L> for HashMap<K, V, S>
482where
483    L: AsLua,
484    K: PushOneInto<LuaState> + Eq + Hash + Debug,
485    V: PushOneInto<LuaState> + Debug,
486{
487    type Err = TuplePushError<K::Err, V::Err>;
488
489    #[inline]
490    fn push_into_lua(self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
491        push_hashmap_impl!(self, lua)
492    }
493}
494
495impl<L, K, V, S> PushOneInto<L> for HashMap<K, V, S>
496where
497    L: AsLua,
498    K: PushOneInto<LuaState> + Eq + Hash + Debug,
499    V: PushOneInto<LuaState> + Debug,
500{
501}
502
503////////////////////////////////////////////////////////////////////////////////
504// HashSet
505////////////////////////////////////////////////////////////////////////////////
506
507macro_rules! push_hashset_impl {
508    ($self:expr, $lua:expr) => {
509        push_iter($lua, $self.into_iter().zip(iter::repeat(true))).map_err(|(e, lua)| match e {
510            PushIterError::TooManyValues(_) => unreachable!("K implements PushOne"),
511            PushIterError::ValuePushError(First(e)) => (e, lua),
512            PushIterError::ValuePushError(Other(_)) => {
513                unreachable!("no way to create instance of Void")
514            }
515        })
516    };
517}
518
519impl<L, K, S> Push<L> for HashSet<K, S>
520where
521    L: AsLua,
522    K: PushOne<LuaState> + Eq + Hash + Debug,
523{
524    type Err = K::Err;
525
526    #[inline]
527    fn push_to_lua(&self, lua: L) -> Result<PushGuard<L>, (K::Err, L)> {
528        push_hashset_impl!(self, lua)
529    }
530}
531
532impl<L, K, S> PushOne<L> for HashSet<K, S>
533where
534    L: AsLua,
535    K: PushOne<LuaState> + Eq + Hash + Debug,
536{
537}
538
539impl<L, K, S> PushInto<L> for HashSet<K, S>
540where
541    L: AsLua,
542    K: PushOneInto<LuaState> + Eq + Hash + Debug,
543{
544    type Err = K::Err;
545
546    #[inline]
547    fn push_into_lua(self, lua: L) -> Result<PushGuard<L>, (K::Err, L)> {
548        push_hashset_impl!(self, lua)
549    }
550}
551
552impl<L, K, S> PushOneInto<L> for HashSet<K, S>
553where
554    L: AsLua,
555    K: PushOneInto<LuaState> + Eq + Hash + Debug,
556{
557}