Skip to main content

simploxide_bindgen/
commands.rs

1//! Turns COMMANDS.md file into na Iterator of [`crate::commands::CommandResponse`].
2
3use convert_case::{Case, Casing as _};
4
5use crate::{
6    parse_utils,
7    types::{
8        DiscriminatedUnionType, Field, RecordType, TopLevelDocs,
9        discriminated_union_type::DiscriminatedUnionVariant,
10    },
11};
12
13pub fn parse(commands_md: &str) -> impl Iterator<Item = Result<CommandResponse, String>> {
14    let mut parser = Parser::default();
15
16    commands_md
17        .split("---")
18        .skip(1)
19        .filter_map(|s| {
20            let trimmed = s.trim();
21            (!trimmed.is_empty()).then_some(trimmed)
22        })
23        .map(move |blk| parser.parse_block(blk))
24}
25
26pub struct CommandResponse {
27    pub command: RecordType,
28    pub response: DiscriminatedUnionType,
29}
30
31/// Generates the provided trait method for the ClientApi trait.
32///
33/// The ClientApi trait definition itself must be generated by the client and should look like this
34///
35/// ```ignore
36/// pub trait ClientApi {
37///     type Error;
38///
39///     fn send_raw(&self, command: String) -> impl Future<Output = Result<Arc<{}>, Self::Error>> + Send;
40///
41///     //..
42/// }
43/// ```
44///
45/// Then the provided methods could be inserted.
46///
47/// For methods that return multiple kinds a valid responses the additional wrapper type should be
48/// generated. You can get this type definition using the
49/// [`CommandResponseTraitMethod::response_wrapper`] method.
50pub struct CommandResponseTraitMethod<'a> {
51    pub command: &'a RecordType,
52    pub response: &'a DiscriminatedUnionType,
53    pub shapes: &'a [RecordType],
54}
55
56impl<'a> CommandResponseTraitMethod<'a> {
57    pub fn new(
58        command: &'a RecordType,
59        response: &'a DiscriminatedUnionType,
60        shapes: &'a [RecordType],
61    ) -> Self {
62        Self {
63            command,
64            response,
65            shapes,
66        }
67    }
68}
69
70impl<'a> CommandResponseTraitMethod<'a> {
71    /// If some method has multiple valid responses a helper type representing valid response
72    /// variants must be generated.
73    ///
74    /// If only one possible valid response is possible it can get inlined without extra helper
75    /// types and for this case this method returns `None`.
76    pub fn response_wrapper(&self) -> Option<ResponseWrapperFmt> {
77        if self.can_inline_response().is_some() {
78            return None;
79        }
80
81        Some(ResponseWrapperFmt(DiscriminatedUnionType::new(
82            self.response_wrapper_name(),
83            self.valid_responses()
84                .cloned()
85                .zip(self.valid_response_shapes())
86                .map(|(mut resp, shape)| {
87                    if shape.fields.len() == 1 {
88                        resp.fields[0] = Field {
89                            api_name: String::new(),
90                            rust_name: String::new(),
91                            typ: shape.fields[0].typ.clone(),
92                        }
93                    }
94
95                    resp
96                })
97                .collect(),
98        )))
99    }
100
101    /// Instead of accepting a command type directly we can accept its arguments and construct it
102    /// internally.
103    ///
104    /// ```ignore
105    ///     fn api_show_my_address(cmd: ApiShowMyAddressCommand) -> ...
106    /// ```
107    ///
108    /// turns into
109    ///
110    /// ```ignore
111    ///     fn api_show_my_address(user_id: i64) -> ... {
112    ///         let cmd = ApiShowMyAddressCommand { user_id };
113    ///     }
114    /// ```
115    ///
116    /// This condition determines when such transformation takes place
117    fn can_inline_args(&self) -> bool {
118        !self
119            .command
120            .fields
121            .iter()
122            .any(|f| f.is_optional() || f.is_bool())
123    }
124
125    /// If response consists only of a single valid variant this variant's inner struct can be
126    /// used directly as a return value of the API method.
127    fn can_inline_response(&self) -> Option<&DiscriminatedUnionVariant> {
128        if self.valid_responses().count() == 1 {
129            self.valid_responses().next()
130        } else {
131            None
132        }
133    }
134
135    /// If underlying struct of the response contains only a single documented field this field can be directly
136    /// returned instead of returning the whole response struct.
137    fn can_inline_response_shape(&self) -> Option<&Field> {
138        if self.valid_response_shapes().count() != 1 {
139            return None;
140        }
141
142        let shape = self.valid_response_shapes().next().unwrap();
143
144        if shape.fields.len() == 1 {
145            Some(&shape.fields[0])
146        } else {
147            None
148        }
149    }
150
151    fn valid_responses(&self) -> impl Iterator<Item = &'_ DiscriminatedUnionVariant> {
152        self.response
153            .variants
154            .iter()
155            .filter(|x| x.rust_name != "ChatCmdError")
156    }
157
158    fn valid_response_shapes(&self) -> impl Iterator<Item = &'_ RecordType> {
159        self.shapes.iter().filter(|x| x.name != "ChatCmdError")
160    }
161
162    fn response_wrapper_name(&self) -> String {
163        format!("{}s", self.response.name)
164    }
165}
166
167impl<'a> std::fmt::Display for CommandResponseTraitMethod<'a> {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        self.command.write_docs_fmt(f)?;
170        write!(
171            f,
172            "    fn {}(&self",
173            self.command.name.remove_empty().to_case(Case::Snake)
174        )?;
175
176        let (ret_type, unwrapped_response_typename) =
177            if let Some(inlined_variant) = self.can_inline_response() {
178                let typename = if let Some(field) = self.can_inline_response_shape() {
179                    field.typ.clone()
180                } else {
181                    inlined_variant.fields[0].typ.clone()
182                };
183
184                (format!("Arc<{typename}>"), typename)
185            } else {
186                let typename = self.response_wrapper_name();
187                (typename.clone(), typename)
188            };
189
190        if self.can_inline_args() {
191            for field in self.command.fields.iter() {
192                write!(f, ", {}: {}", field.rust_name, field.typ)?;
193            }
194
195            writeln!(
196                f,
197                ") -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
198            )?;
199            write!(f, "        let command = {} {{", self.command.name)?;
200
201            for (ix, field) in self.command.fields.iter().enumerate() {
202                if ix > 0 {
203                    write!(f, ", ")?;
204                }
205
206                write!(f, "{}", field.rust_name)?;
207            }
208            writeln!(f, "}};")?;
209        } else {
210            writeln!(
211                f,
212                ", command: {}) -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
213                self.command.name,
214            )?;
215        }
216
217        writeln!(
218            f,
219            "        let json = self.send_raw(command.to_command_string()).await?;"
220        )?;
221        writeln!(
222            f,
223            "        // Safe to unwrap because unrecognized JSON goes to undocumented variant"
224        )?;
225        writeln!(
226            f,
227            "        let response = serde_json::from_value(json).unwrap();"
228        )?;
229        writeln!(f, "        match response {{")?;
230
231        if let Some(variant) = self.can_inline_response() {
232            if let Some(field) = self.can_inline_response_shape() {
233                writeln!(
234                    f,
235                    "            {}::{}(resp) => Ok(Arc::new(resp.{})),",
236                    self.response.name, variant.rust_name, field.rust_name,
237                )?;
238            } else {
239                writeln!(
240                    f,
241                    "            {}::{}(resp) => Ok(Arc::new(resp)),",
242                    self.response.name, variant.rust_name
243                )?;
244            }
245        } else {
246            for (variant, shape) in self.valid_responses().zip(self.valid_response_shapes()) {
247                if shape.fields.len() == 1 {
248                    writeln!(
249                        f,
250                        "            {resp_name}::{var_name}(resp) => Ok({typename}::{var_name}(Arc::new(resp.{field}))),",
251                        resp_name = self.response.name,
252                        typename = unwrapped_response_typename,
253                        var_name = variant.rust_name,
254                        field = shape.fields[0].rust_name,
255                    )?;
256                } else {
257                    writeln!(
258                        f,
259                        "            {}::{var_name}(resp) => Ok({}::{var_name}(Arc::new(resp))),",
260                        self.response.name,
261                        unwrapped_response_typename,
262                        var_name = variant.rust_name,
263                    )?;
264                }
265            }
266        }
267
268        writeln!(
269            f,
270            "            {}::ChatCmdError(resp) => Err(BadResponseError::ChatCmdError(Arc::new(resp.chat_error)).into()),",
271            self.response.name,
272        )?;
273        writeln!(
274            f,
275            "            {}::Undocumented(resp) => Err(BadResponseError::Undocumented(resp).into()),",
276            self.response.name,
277        )?;
278        writeln!(f, "        }}")?;
279
280        writeln!(f, "    }}")?;
281        writeln!(f, "    }}")
282    }
283}
284
285/// Use this formatter for command types instead of the standard std::fmt::Display impl of the
286/// [`RecordType`]. This impl strips down all serialization attributes and undocumented fields.
287pub struct CommandFmt<'a>(pub &'a RecordType);
288
289impl std::fmt::Display for CommandFmt<'_> {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        self.0.write_docs_fmt(f)?;
292
293        writeln!(f, "#[derive(Debug, Clone, PartialEq)]")?;
294        writeln!(f, "#[cfg_attr(feature = \"bon\", derive(::bon::Builder))]")?;
295
296        writeln!(f, "pub struct {} {{", self.0.name)?;
297
298        for field in self.0.fields.iter() {
299            writeln!(f, "    pub {}: {},", field.rust_name, field.typ)?;
300        }
301
302        writeln!(f, "}}")
303    }
304}
305
306pub struct ResponseWrapperFmt(pub DiscriminatedUnionType);
307
308impl std::fmt::Display for ResponseWrapperFmt {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        writeln!(
311            f,
312            "#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]"
313        )?;
314        writeln!(f, "#[serde(tag = \"type\")]")?;
315        writeln!(f, "pub enum {} {{", self.0.name)?;
316
317        for variant in &self.0.variants {
318            for comment_line in &variant.doc_comments {
319                writeln!(f, "    /// {}", comment_line)?;
320            }
321            writeln!(f, "    #[serde(rename = \"{}\")]", variant.api_name)?;
322            writeln!(
323                f,
324                "    {}(Arc<{}>),",
325                variant.rust_name, variant.fields[0].typ
326            )?;
327        }
328        writeln!(f, "}}\n")?;
329
330        // Gen helper getters
331        writeln!(f, "impl {} {{", self.0.name)?;
332
333        for var in self.0.variants.iter() {
334            assert_eq!(var.fields.len(), 1, "Discriminated union is not disjointed");
335            assert!(
336                var.fields[0].rust_name.is_empty(),
337                "Discriminated union is not disjointed"
338            );
339
340            writeln!(
341                f,
342                "    pub fn {}(&self) -> Option<&{}> {{",
343                var.rust_name.remove_empty().to_case(Case::Snake),
344                var.fields[0].typ
345            )?;
346
347            writeln!(f, "        if let Self::{}(ret) = self {{", var.rust_name)?;
348            writeln!(f, "            Some(ret)",)?;
349            writeln!(f, "        }} else {{ None }}",)?;
350            writeln!(f, "    }}\n")?;
351        }
352
353        writeln!(f, "}}")
354    }
355}
356
357#[derive(Default)]
358struct Parser {
359    current_doc_section: Option<DocSection>,
360}
361
362impl Parser {
363    pub fn parse_block(&mut self, block: &str) -> Result<CommandResponse, String> {
364        self.parser(block.lines().map(str::trim))
365            .map_err(|e| format!("{e} in block\n```\n{block}\n```"))
366    }
367
368    fn parser<'a>(
369        &mut self,
370        mut lines: impl Iterator<Item = &'a str>,
371    ) -> Result<CommandResponse, String> {
372        const DOC_SECTION_PAT: &str = parse_utils::H2;
373        const TYPENAME_PAT: &str = parse_utils::H3;
374        const TYPEKINDS_PAT: &str = parse_utils::BOLD;
375
376        let mut next =
377            parse_utils::skip_empty(&mut lines).ok_or_else(|| "Got an empty block".to_owned())?;
378
379        let mut command_docs: Vec<String> = Vec::new();
380
381        let (typename, mut typekind) = loop {
382            if let Some(section_name) = next.strip_prefix(DOC_SECTION_PAT) {
383                let mut doc_section = DocSection::new(section_name.to_owned());
384
385                next = parse_utils::parse_doc_lines(&mut lines, &mut doc_section.contents, |s| {
386                    s.starts_with(TYPENAME_PAT)
387                })
388                .ok_or_else(|| format!("Failed to find a typename by pattern {TYPENAME_PAT:?} after the doc section"))?;
389
390                self.current_doc_section.replace(doc_section);
391            } else if let Some(name) = next.strip_prefix(TYPENAME_PAT) {
392                next = parse_utils::parse_doc_lines(&mut lines, &mut command_docs, |s| {
393                    s.starts_with(TYPEKINDS_PAT)
394                })
395                .map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
396                .ok_or_else(|| format!("Failed to find a typekind by pattern {TYPEKINDS_PAT:?} after the inner docs "))?;
397
398                break (name, next);
399            }
400        };
401
402        let command_name = typename.to_case(Case::Pascal);
403        let mut command = RecordType::new(command_name.clone(), vec![]);
404
405        loop {
406            if typekind.starts_with("Parameters") {
407                typekind = parse_utils::parse_record_fields(
408                    &mut lines,
409                    &mut command.fields,
410                    |s| s.starts_with(TYPEKINDS_PAT),
411                )?
412                .map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
413                .ok_or_else(|| format!(
414                    "Failed to find a command syntax after parameters by pattern {TYPENAME_PAT:?}"
415                ))?;
416            } else if typekind.starts_with("Syntax") {
417                parse_utils::parse_syntax(&mut lines, &mut command.syntax)?;
418                break;
419            }
420        }
421
422        let mut response_variants: Vec<DiscriminatedUnionVariant> = Vec::with_capacity(4);
423
424        parse_utils::skip_while(&mut lines, |s| !s.starts_with("**Response")).ok_or_else(|| {
425            "Failed to find responses section by pattern \"**Response\"".to_owned()
426        })?;
427
428        let mut variant_docline = Vec::new();
429
430        while let Some(docline) = parse_utils::skip_empty(&mut lines) {
431            if docline.starts_with(TYPEKINDS_PAT) {
432                break;
433            } else {
434                variant_docline.push(docline.to_owned());
435            }
436
437            let (mut variant, next) = parse_utils::parse_discriminated_union_variant(&mut lines)?;
438            assert!(next.map(|s| s.is_empty()).unwrap_or(true));
439            variant.doc_comments = std::mem::take(&mut variant_docline);
440            response_variants.push(variant);
441        }
442
443        let response =
444            DiscriminatedUnionType::new(format!("{command_name}Response"), response_variants);
445
446        if let Some(ref outer_docs) = self.current_doc_section {
447            command
448                .doc_comments
449                .push(format!("### {}", outer_docs.header.clone()));
450
451            command.doc_comments.push(String::new());
452
453            command
454                .doc_comments
455                .extend(outer_docs.contents.iter().cloned());
456
457            command.doc_comments.push(String::new());
458            command.doc_comments.push("----".to_owned());
459            command.doc_comments.push(String::new());
460        }
461
462        command.doc_comments.extend(command_docs);
463        Ok(CommandResponse { command, response })
464    }
465}
466
467#[derive(Default, Clone)]
468struct DocSection {
469    header: String,
470    contents: Vec<String>,
471}
472
473impl DocSection {
474    fn new(header: String) -> Self {
475        Self {
476            header,
477            contents: Vec::new(),
478        }
479    }
480}