sws_lua/
interop.rs

1use std::rc::{Rc, Weak};
2use std::sync::Arc;
3use std::{fs, thread};
4
5use crossbeam_channel::Sender;
6use mlua::{FromLua, MetaMethod, UserData, UserDataMethods};
7use sws_crawler::{CountedTx, CrawlingContext, PageLocation, ScrapingContext, Sitemap};
8use sws_scraper::CaseSensitivity;
9use sws_scraper::{element_ref::Select, ElementRef, Html, Selector};
10use texting_robots::Robot;
11
12use crate::ns::{globals, sws};
13
14pub struct LuaHtml(pub(crate) Html);
15
16impl UserData for LuaHtml {
17    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
18        methods.add_meta_method(MetaMethod::ToString, |_, html, ()| {
19            Ok(format!("{:?}", html.0))
20        });
21
22        methods.add_method(sws::html::SELECT, |_, html, css_selector: String| {
23            let select = html.0.select(Selector::parse(&css_selector).map_err(|e| {
24                mlua::Error::RuntimeError(format!(
25                    "Invalid CSS selector {:?}: {:?}",
26                    css_selector, e
27                ))
28            })?);
29            Ok(LuaSelect(select))
30        });
31
32        methods.add_method(sws::html::ROOT, |_, html, ()| {
33            Ok(LuaElementRef(html.0.root_element()))
34        });
35    }
36}
37
38#[derive(Clone)]
39pub struct LuaSelect(pub(crate) Select);
40
41impl UserData for LuaSelect {
42    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
43        methods.add_meta_method(MetaMethod::ToString, |_, sel, ()| {
44            Ok(format!("{:?}", sel.0))
45        });
46
47        methods.add_method(sws::select::ITER, |lua, select, ()| {
48            let mut select = select.clone();
49            let iterator =
50                lua.create_function_mut(move |_, ()| Ok(select.0.next().map(LuaElementRef)));
51
52            Ok(iterator)
53        });
54
55        methods.add_method(sws::select::ENUMERATE, |lua, select, ()| {
56            let mut select = select.clone();
57            let mut i = 0;
58            let iterator = lua.create_function_mut(move |_, ()| {
59                i += 1;
60                let next = select.0.next().map(LuaElementRef);
61                if next.is_some() {
62                    Ok((Some(i), next))
63                } else {
64                    Ok((None, None))
65                }
66            });
67            Ok(iterator)
68        });
69    }
70}
71
72pub struct LuaElementRef(pub(crate) ElementRef);
73
74impl UserData for LuaElementRef {
75    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
76        methods.add_meta_method(MetaMethod::ToString, |_, elem, ()| {
77            Ok(format!("{:?}", elem.0))
78        });
79
80        methods.add_method(sws::elem_ref::SELECT, |_, elem, css_selector: String| {
81            let select = elem.0.select(Selector::parse(&css_selector).map_err(|e| {
82                mlua::Error::RuntimeError(format!(
83                    "Invalid CSS selector {:?}: {:?}",
84                    css_selector, e
85                ))
86            })?);
87            Ok(LuaSelect(select))
88        });
89
90        methods.add_method(sws::elem_ref::INNER_HTML, |_, elem, ()| {
91            Ok(elem.0.inner_html())
92        });
93
94        methods.add_method(sws::elem_ref::INNER_TEXT, |_, elem, ()| {
95            Ok(elem.0.inner_text())
96        });
97
98        methods.add_method(sws::elem_ref::NAME, |_, elem, ()| {
99            Ok(elem.0.map_value(|el| el.name().to_string()))
100        });
101
102        methods.add_method(sws::elem_ref::ID, |_, elem, ()| {
103            Ok(elem
104                .0
105                .map_value(|el| el.id().map(String::from))
106                .unwrap_or(None))
107        });
108
109        methods.add_method(sws::elem_ref::HAS_CLASS, |_, elem, class: String| {
110            Ok(elem
111                .0
112                .map_value(|el| el.has_class(&class, CaseSensitivity::AsciiCaseInsensitive)))
113        });
114
115        methods.add_method(sws::elem_ref::CLASSES, |lua, elem, ()| {
116            let classes = lua.create_table()?;
117            elem.0.map_value(|el| {
118                el.classes().enumerate().for_each(|(i, c)| {
119                    classes.set(i + 1, c).ok();
120                });
121            });
122            Ok(classes)
123        });
124
125        methods.add_method(sws::elem_ref::ATTR, |_, elem, attr: String| {
126            Ok(elem
127                .0
128                .map_value(|el| el.attr(&attr).map(String::from))
129                .unwrap_or(None))
130        });
131
132        methods.add_method(sws::elem_ref::ATTRS, |lua, elem, ()| {
133            let attrs = lua.create_table()?;
134            elem.0.map_value(|el| {
135                el.attrs().for_each(|(k, v)| {
136                    attrs.set(k, v).ok();
137                });
138            });
139            Ok(attrs)
140        });
141    }
142}
143
144#[derive(Debug)]
145pub struct LuaPageLocation(pub(crate) Weak<PageLocation>);
146
147impl UserData for LuaPageLocation {
148    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
149        methods.add_meta_method(MetaMethod::ToString, |_, pl, ()| Ok(format!("{:?}", pl.0)));
150
151        methods.add_method(sws::page_location::KIND, |lua, pl, ()| {
152            if let Some(loc) = pl.0.upgrade() {
153                let location = match loc.as_ref() {
154                    PageLocation::Path(_) => sws::location::PATH,
155                    PageLocation::Url(_) => sws::location::URL,
156                };
157                lua.globals()
158                    .get::<_, mlua::Table>(globals::SWS)?
159                    .get::<_, mlua::Table>(sws::LOCATION)?
160                    .get::<_, String>(location)
161                    .map(Some)
162            } else {
163                Ok(None)
164            }
165        });
166
167        methods.add_method(sws::page_location::GET, |_, pl, ()| {
168            if let Some(loc) = pl.0.upgrade() {
169                let loc = match loc.as_ref() {
170                    PageLocation::Path(p) => format!("{}", fs::canonicalize(p)?.display()),
171                    PageLocation::Url(url) => url.to_string(),
172                };
173                Ok(Some(loc))
174            } else {
175                Ok(None)
176            }
177        });
178    }
179}
180
181#[derive(Clone, Default)]
182pub struct LuaStringRecord(pub(crate) csv::StringRecord);
183
184impl<'lua> FromLua<'lua> for LuaStringRecord {
185    fn from_lua(value: mlua::Value<'lua>, _: &'lua mlua::Lua) -> mlua::Result<Self> {
186        match value {
187            mlua::Value::UserData(ud) => Ok(ud.borrow::<Self>()?.clone()),
188            _ => unreachable!(),
189        }
190    }
191}
192
193impl UserData for LuaStringRecord {
194    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
195        methods.add_meta_method(MetaMethod::ToString, |_, r, ()| Ok(format!("{:?}", r.0)));
196
197        methods.add_method_mut(sws::record::PUSH_FIELD, |_, record, field: String| {
198            record.0.push_field(&field);
199            Ok(())
200        });
201    }
202}
203
204pub struct LuaDate(pub(crate) chrono::NaiveDate);
205
206impl LuaDate {
207    pub fn new(d: &str, fmt: &str) -> mlua::Result<Self> {
208        Ok(Self(chrono::NaiveDate::parse_from_str(d, fmt).map_err(
209            |e| mlua::Error::RuntimeError(format!("Couldn't parse date {d} got: {e}")),
210        )?))
211    }
212}
213
214impl UserData for LuaDate {
215    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
216        methods.add_meta_method(MetaMethod::ToString, |_, d, ()| Ok(format!("{:?}", d.0)));
217
218        methods.add_method(sws::date::FORMAT, |_, d, fmt: String| {
219            Ok(d.0.format(&fmt).to_string())
220        });
221    }
222}
223
224#[derive(Clone, Debug)]
225pub struct LuaRobot(pub(crate) Arc<Robot>);
226
227impl UserData for LuaRobot {
228    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
229        methods.add_meta_method(MetaMethod::ToString, |_, r, ()| Ok(format!("{:?}", r.0)));
230
231        methods.add_method(sws::robot::ALLOWED, |_, r, url: String| {
232            Ok(r.0.allowed(&url))
233        });
234    }
235}
236
237#[derive(Clone, Debug)]
238pub struct LuaCrawlingContext {
239    sm: &'static str,
240    robot: Option<LuaRobot>,
241}
242
243impl<'lua> FromLua<'lua> for LuaCrawlingContext {
244    fn from_lua(value: mlua::Value<'lua>, _: &'lua mlua::Lua) -> mlua::Result<Self> {
245        match value {
246            mlua::Value::UserData(ud) => Ok(ud.borrow::<Self>()?.clone()),
247            _ => unreachable!(),
248        }
249    }
250}
251
252impl UserData for LuaCrawlingContext {
253    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
254        methods.add_meta_method(MetaMethod::ToString, |_, ctx, ()| Ok(format!("{:?}", ctx)));
255
256        methods.add_method(sws::crawling_context::ROBOT, |_, ctx, ()| {
257            Ok(ctx.robot.clone())
258        });
259
260        methods.add_method(sws::crawling_context::SITEMAP, |_, ctx, ()| Ok(ctx.sm));
261    }
262}
263
264impl From<CrawlingContext> for LuaCrawlingContext {
265    fn from(ctx: CrawlingContext) -> Self {
266        Self {
267            sm: match ctx.sitemap() {
268                Sitemap::Index => sws::sitemap::INDEX,
269                Sitemap::Urlset => sws::sitemap::URL_SET,
270            },
271            robot: ctx.robot().map(LuaRobot),
272        }
273    }
274}
275
276#[derive(Clone)]
277pub struct LuaScrapingContext {
278    tx_writer: Sender<csv::StringRecord>,
279    page_location: Weak<PageLocation>,
280    tx_url: Option<CountedTx>,
281    robot: Option<Arc<Robot>>,
282}
283
284impl LuaScrapingContext {
285    pub fn new(tx_writer: Sender<csv::StringRecord>, ctx: ScrapingContext) -> Self {
286        Self {
287            tx_writer,
288            page_location: Rc::downgrade(&ctx.location()),
289            tx_url: ctx.tx_url(),
290            robot: ctx.robot(),
291        }
292    }
293}
294
295impl UserData for LuaScrapingContext {
296    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
297        methods.add_method(sws::scraping_context::PAGE_LOCATION, |_, ctx, ()| {
298            Ok(LuaPageLocation(ctx.page_location.clone()))
299        });
300
301        methods.add_method(
302            sws::scraping_context::SEND_RECORD,
303            |_, ctx, record: LuaStringRecord| {
304                ctx.tx_writer.send(record.0).ok();
305                Ok(())
306            },
307        );
308
309        methods.add_method(sws::scraping_context::WORKER_ID, |_, _, ()| {
310            let id = thread::current()
311                .name()
312                .map(String::from)
313                .ok_or_else(|| mlua::Error::RuntimeError("Missing thread name".into()))?;
314            Ok(id)
315        });
316
317        methods.add_method(sws::scraping_context::SEND_URL, |_, ctx, url: String| {
318            if let Some(tx_url) = &ctx.tx_url {
319                tx_url.send(url);
320            } else {
321                log::warn!("Context not initalized, coudln't send URL {url}")
322            }
323            Ok(())
324        });
325
326        methods.add_method(sws::scraping_context::ROBOT, |_, ctx, ()| {
327            Ok(ctx.robot.clone().map(LuaRobot))
328        });
329    }
330}