1use convert_case::{Case, Casing as _};
4
5use crate::{
6 parse_utils,
7 types::{
8 DiscriminatedUnionType, Field, RecordType, TopLevelDocs,
9 discriminated_union_type::DiscriminatedUnionVariant,
10 },
11};
12
13pub fn parse(commands_md: &str) -> impl Iterator<Item = Result<CommandResponse, String>> {
14 let mut parser = Parser::default();
15
16 commands_md
17 .split("---")
18 .skip(1)
19 .filter_map(|s| {
20 let trimmed = s.trim();
21 (!trimmed.is_empty()).then_some(trimmed)
22 })
23 .map(move |blk| parser.parse_block(blk))
24}
25
26pub struct CommandResponse {
27 pub command: RecordType,
28 pub response: DiscriminatedUnionType,
29}
30
31pub struct CommandResponseTraitMethod<'a> {
51 pub command: &'a RecordType,
52 pub response: &'a DiscriminatedUnionType,
53 pub shapes: &'a [RecordType],
54}
55
56impl<'a> CommandResponseTraitMethod<'a> {
57 pub fn new(
58 command: &'a RecordType,
59 response: &'a DiscriminatedUnionType,
60 shapes: &'a [RecordType],
61 ) -> Self {
62 Self {
63 command,
64 response,
65 shapes,
66 }
67 }
68}
69
70impl<'a> CommandResponseTraitMethod<'a> {
71 pub fn response_wrapper(&self) -> Option<ResponseWrapperFmt> {
77 if self.can_inline_response().is_some() {
78 return None;
79 }
80
81 Some(ResponseWrapperFmt(DiscriminatedUnionType::new(
82 self.response_wrapper_name(),
83 self.valid_responses()
84 .cloned()
85 .zip(self.valid_response_shapes())
86 .map(|(mut resp, shape)| {
87 if shape.fields.len() == 1 {
88 resp.fields[0] = Field {
89 api_name: String::new(),
90 rust_name: String::new(),
91 typ: shape.fields[0].typ.clone(),
92 }
93 }
94
95 resp
96 })
97 .collect(),
98 )))
99 }
100
101 fn can_inline_args(&self) -> bool {
118 !self
119 .command
120 .fields
121 .iter()
122 .any(|f| f.is_optional() || f.is_bool())
123 }
124
125 fn can_inline_response(&self) -> Option<&DiscriminatedUnionVariant> {
128 if self.valid_responses().count() == 1 {
129 self.valid_responses().next()
130 } else {
131 None
132 }
133 }
134
135 fn can_inline_response_shape(&self) -> Option<&Field> {
138 if self.valid_response_shapes().count() != 1 {
139 return None;
140 }
141
142 let shape = self.valid_response_shapes().next().unwrap();
143
144 if shape.fields.len() == 1 {
145 Some(&shape.fields[0])
146 } else {
147 None
148 }
149 }
150
151 fn valid_responses(&self) -> impl Iterator<Item = &'_ DiscriminatedUnionVariant> {
152 self.response
153 .variants
154 .iter()
155 .filter(|x| x.rust_name != "ChatCmdError")
156 }
157
158 fn valid_response_shapes(&self) -> impl Iterator<Item = &'_ RecordType> {
159 self.shapes.iter().filter(|x| x.name != "ChatCmdError")
160 }
161
162 fn response_wrapper_name(&self) -> String {
163 format!("{}s", self.response.name)
164 }
165}
166
167impl<'a> std::fmt::Display for CommandResponseTraitMethod<'a> {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 self.command.write_docs_fmt(f)?;
170 write!(
171 f,
172 " fn {}(&self",
173 self.command.name.remove_empty().to_case(Case::Snake)
174 )?;
175
176 let (ret_type, unwrapped_response_typename) =
177 if let Some(inlined_variant) = self.can_inline_response() {
178 let typename = if let Some(field) = self.can_inline_response_shape() {
179 field.typ.clone()
180 } else {
181 inlined_variant.fields[0].typ.clone()
182 };
183
184 (format!("Arc<{typename}>"), typename)
185 } else {
186 let typename = self.response_wrapper_name();
187 (typename.clone(), typename)
188 };
189
190 if self.can_inline_args() {
191 for field in self.command.fields.iter() {
192 write!(f, ", {}: {}", field.rust_name, field.typ)?;
193 }
194
195 writeln!(
196 f,
197 ") -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
198 )?;
199 write!(f, " let command = {} {{", self.command.name)?;
200
201 for (ix, field) in self.command.fields.iter().enumerate() {
202 if ix > 0 {
203 write!(f, ", ")?;
204 }
205
206 write!(f, "{}", field.rust_name)?;
207 }
208 writeln!(f, "}};")?;
209 } else {
210 writeln!(
211 f,
212 ", command: {}) -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
213 self.command.name,
214 )?;
215 }
216
217 writeln!(
218 f,
219 " let json = self.send_raw(command.to_command_string()).await?;"
220 )?;
221 writeln!(
222 f,
223 " // Safe to unwrap because unrecognized JSON goes to undocumented variant"
224 )?;
225 writeln!(
226 f,
227 " let response = serde_json::from_value(json).unwrap();"
228 )?;
229 writeln!(f, " match response {{")?;
230
231 if let Some(variant) = self.can_inline_response() {
232 if let Some(field) = self.can_inline_response_shape() {
233 writeln!(
234 f,
235 " {}::{}(resp) => Ok(Arc::new(resp.{})),",
236 self.response.name, variant.rust_name, field.rust_name,
237 )?;
238 } else {
239 writeln!(
240 f,
241 " {}::{}(resp) => Ok(Arc::new(resp)),",
242 self.response.name, variant.rust_name
243 )?;
244 }
245 } else {
246 for (variant, shape) in self.valid_responses().zip(self.valid_response_shapes()) {
247 if shape.fields.len() == 1 {
248 writeln!(
249 f,
250 " {resp_name}::{var_name}(resp) => Ok({typename}::{var_name}(Arc::new(resp.{field}))),",
251 resp_name = self.response.name,
252 typename = unwrapped_response_typename,
253 var_name = variant.rust_name,
254 field = shape.fields[0].rust_name,
255 )?;
256 } else {
257 writeln!(
258 f,
259 " {}::{var_name}(resp) => Ok({}::{var_name}(Arc::new(resp))),",
260 self.response.name,
261 unwrapped_response_typename,
262 var_name = variant.rust_name,
263 )?;
264 }
265 }
266 }
267
268 writeln!(
269 f,
270 " {}::ChatCmdError(resp) => Err(BadResponseError::ChatCmdError(Arc::new(resp.chat_error)).into()),",
271 self.response.name,
272 )?;
273 writeln!(
274 f,
275 " {}::Undocumented(resp) => Err(BadResponseError::Undocumented(resp).into()),",
276 self.response.name,
277 )?;
278 writeln!(f, " }}")?;
279
280 writeln!(f, " }}")?;
281 writeln!(f, " }}")
282 }
283}
284
285pub struct CommandFmt<'a>(pub &'a RecordType);
288
289impl std::fmt::Display for CommandFmt<'_> {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 self.0.write_docs_fmt(f)?;
292
293 writeln!(f, "#[derive(Debug, Clone, PartialEq)]")?;
294 writeln!(f, "#[cfg_attr(feature = \"bon\", derive(::bon::Builder))]")?;
295
296 writeln!(f, "pub struct {} {{", self.0.name)?;
297
298 for field in self.0.fields.iter() {
299 writeln!(f, " pub {}: {},", field.rust_name, field.typ)?;
300 }
301
302 writeln!(f, "}}")
303 }
304}
305
306pub struct ResponseWrapperFmt(pub DiscriminatedUnionType);
307
308impl std::fmt::Display for ResponseWrapperFmt {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 writeln!(
311 f,
312 "#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]"
313 )?;
314 writeln!(f, "#[serde(tag = \"type\")]")?;
315 writeln!(f, "pub enum {} {{", self.0.name)?;
316
317 for variant in &self.0.variants {
318 for comment_line in &variant.doc_comments {
319 writeln!(f, " /// {}", comment_line)?;
320 }
321 writeln!(f, " #[serde(rename = \"{}\")]", variant.api_name)?;
322 writeln!(
323 f,
324 " {}(Arc<{}>),",
325 variant.rust_name, variant.fields[0].typ
326 )?;
327 }
328 writeln!(f, "}}\n")?;
329
330 writeln!(f, "impl {} {{", self.0.name)?;
332
333 for var in self.0.variants.iter() {
334 assert_eq!(var.fields.len(), 1, "Discriminated union is not disjointed");
335 assert!(
336 var.fields[0].rust_name.is_empty(),
337 "Discriminated union is not disjointed"
338 );
339
340 writeln!(
341 f,
342 " pub fn {}(&self) -> Option<&{}> {{",
343 var.rust_name.remove_empty().to_case(Case::Snake),
344 var.fields[0].typ
345 )?;
346
347 writeln!(f, " if let Self::{}(ret) = self {{", var.rust_name)?;
348 writeln!(f, " Some(ret)",)?;
349 writeln!(f, " }} else {{ None }}",)?;
350 writeln!(f, " }}\n")?;
351 }
352
353 writeln!(f, "}}")
354 }
355}
356
357#[derive(Default)]
358struct Parser {
359 current_doc_section: Option<DocSection>,
360}
361
362impl Parser {
363 pub fn parse_block(&mut self, block: &str) -> Result<CommandResponse, String> {
364 self.parser(block.lines().map(str::trim))
365 .map_err(|e| format!("{e} in block\n```\n{block}\n```"))
366 }
367
368 fn parser<'a>(
369 &mut self,
370 mut lines: impl Iterator<Item = &'a str>,
371 ) -> Result<CommandResponse, String> {
372 const DOC_SECTION_PAT: &str = parse_utils::H2;
373 const TYPENAME_PAT: &str = parse_utils::H3;
374 const TYPEKINDS_PAT: &str = parse_utils::BOLD;
375
376 let mut next =
377 parse_utils::skip_empty(&mut lines).ok_or_else(|| "Got an empty block".to_owned())?;
378
379 let mut command_docs: Vec<String> = Vec::new();
380
381 let (typename, mut typekind) = loop {
382 if let Some(section_name) = next.strip_prefix(DOC_SECTION_PAT) {
383 let mut doc_section = DocSection::new(section_name.to_owned());
384
385 next = parse_utils::parse_doc_lines(&mut lines, &mut doc_section.contents, |s| {
386 s.starts_with(TYPENAME_PAT)
387 })
388 .ok_or_else(|| format!("Failed to find a typename by pattern {TYPENAME_PAT:?} after the doc section"))?;
389
390 self.current_doc_section.replace(doc_section);
391 } else if let Some(name) = next.strip_prefix(TYPENAME_PAT) {
392 next = parse_utils::parse_doc_lines(&mut lines, &mut command_docs, |s| {
393 s.starts_with(TYPEKINDS_PAT)
394 })
395 .map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
396 .ok_or_else(|| format!("Failed to find a typekind by pattern {TYPEKINDS_PAT:?} after the inner docs "))?;
397
398 break (name, next);
399 }
400 };
401
402 let command_name = typename.to_case(Case::Pascal);
403 let mut command = RecordType::new(command_name.clone(), vec![]);
404
405 loop {
406 if typekind.starts_with("Parameters") {
407 typekind = parse_utils::parse_record_fields(
408 &mut lines,
409 &mut command.fields,
410 |s| s.starts_with(TYPEKINDS_PAT),
411 )?
412 .map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
413 .ok_or_else(|| format!(
414 "Failed to find a command syntax after parameters by pattern {TYPENAME_PAT:?}"
415 ))?;
416 } else if typekind.starts_with("Syntax") {
417 parse_utils::parse_syntax(&mut lines, &mut command.syntax)?;
418 break;
419 }
420 }
421
422 let mut response_variants: Vec<DiscriminatedUnionVariant> = Vec::with_capacity(4);
423
424 parse_utils::skip_while(&mut lines, |s| !s.starts_with("**Response")).ok_or_else(|| {
425 "Failed to find responses section by pattern \"**Response\"".to_owned()
426 })?;
427
428 let mut variant_docline = Vec::new();
429
430 while let Some(docline) = parse_utils::skip_empty(&mut lines) {
431 if docline.starts_with(TYPEKINDS_PAT) {
432 break;
433 } else {
434 variant_docline.push(docline.to_owned());
435 }
436
437 let (mut variant, next) = parse_utils::parse_discriminated_union_variant(&mut lines)?;
438 assert!(next.map(|s| s.is_empty()).unwrap_or(true));
439 variant.doc_comments = std::mem::take(&mut variant_docline);
440 response_variants.push(variant);
441 }
442
443 let response =
444 DiscriminatedUnionType::new(format!("{command_name}Response"), response_variants);
445
446 if let Some(ref outer_docs) = self.current_doc_section {
447 command
448 .doc_comments
449 .push(format!("### {}", outer_docs.header.clone()));
450
451 command.doc_comments.push(String::new());
452
453 command
454 .doc_comments
455 .extend(outer_docs.contents.iter().cloned());
456
457 command.doc_comments.push(String::new());
458 command.doc_comments.push("----".to_owned());
459 command.doc_comments.push(String::new());
460 }
461
462 command.doc_comments.extend(command_docs);
463 Ok(CommandResponse { command, response })
464 }
465}
466
467#[derive(Default, Clone)]
468struct DocSection {
469 header: String,
470 contents: Vec<String>,
471}
472
473impl DocSection {
474 fn new(header: String) -> Self {
475 Self {
476 header,
477 contents: Vec::new(),
478 }
479 }
480}