rust_enum_derive/
lib.rs

1extern crate toml;
2#[macro_use]
3extern crate log;
4extern crate env_logger;
5extern crate regex;
6
7use std::cmp::Ordering;
8use std::fs::{self, File, OpenOptions};
9use std::io::prelude::*;
10use std::io::{BufReader, BufWriter, Error, ErrorKind, Result};
11use std::path::PathBuf;
12
13/// Arguments for how to process() an input file.
14#[derive(Debug)]
15pub struct FileArgs {
16    /// the enum name (Name if not specified)
17    pub name: Option<String>,
18    /// Which traits to derive. Ex: "Debug, PartialEq"
19    pub derive: Option<String>,
20    /// parse C #define input instead of enum
21    pub define: bool,
22    /// implement the Default trait with the first value
23    pub default: bool,
24    /// implement the std::fmt::Display trait
25    pub display: bool,
26    /// implement the num::traits::FromPrimitive trait
27    pub fromprimative: bool,
28    /// implement the std::str::FromStr trait
29    pub fromstr: bool,
30    /// hexadecimal output
31    pub hex: bool,
32    /// implement pretty_fmt()
33    pub pretty_fmt: bool,
34}
35impl Default for FileArgs {
36    fn default() -> FileArgs
37    {
38        FileArgs{ name: None, derive: None, define: false, default: false, display: false,
39                 fromstr: false, fromprimative: false, hex: false, pretty_fmt: false }
40    }
41}
42
43#[derive(Debug)]
44struct CEnum {
45    i: i32,
46    s: String,
47}
48impl CEnum {
49    fn new(i: i32, s: &str) -> CEnum {
50        CEnum { i:i, s: String::from(s) }
51    }
52}
53impl ::std::cmp::Eq for CEnum {}
54impl ::std::cmp::PartialEq for CEnum {
55    fn eq(&self, other: &Self) -> bool {
56        if self.i == other.i {
57            return true;
58        }
59        false
60    }
61}
62impl ::std::cmp::PartialOrd for CEnum {
63    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64        if self.i < other.i {
65            return Some(Ordering::Less);
66        }
67        else if self.i > other.i {
68            return Some(Ordering::Greater);
69        }
70        Some(Ordering::Equal)
71    }
72}
73impl ::std::cmp::Ord for CEnum {
74    fn cmp(&self, other: &Self) -> Ordering {
75        if self.i < other.i {
76            return Ordering::Less;
77        }
78        else if self.i > other.i {
79            return Ordering::Greater;
80        }
81        Ordering::Equal
82    }
83}
84
85trait FormatOutput {
86    fn write(&self, w: &mut Write, name: &String, hex: bool, vec: &Vec<CEnum>) -> Result<()>;
87}
88
89struct FormatOutputFromPrimative;
90impl FormatOutput for FormatOutputFromPrimative {
91    fn write(&self, w: &mut Write, name: &String, hex: bool, vec: &Vec<CEnum>) -> Result<()> {
92        try!(write!(w, "impl ::num::traits::FromPrimitive for {} {{\n", name));
93        try!(write!(w, "    #[allow(dead_code)]\n"));
94        try!(write!(w, "    fn from_i64(n: i64) -> Option<Self> {{\n"));
95        try!(write!(w, "        match n {{\n"));
96        for v in vec {
97            if hex {
98                try!(write!(w, "            0x{:X} => Some({}::{}),\n", v.i, name, v.s));
99            }
100            else {
101                try!(write!(w, "            {} => Some({}::{}),\n", v.i, name, v.s));
102            }
103        }
104        try!(write!(w, "            _ => None\n"));
105        try!(write!(w, "        }}\n"));
106        try!(write!(w, "    }}\n"));
107        try!(write!(w, "    #[allow(dead_code)]\n"));
108        try!(write!(w, "    fn from_u64(n: u64) -> Option<Self> {{\n"));
109        try!(write!(w, "        match n {{\n"));
110        for v in vec {
111            if hex {
112                try!(write!(w, "            0x{:X} => Some({}::{}),\n", v.i, name, v.s));
113            }
114            else {
115                try!(write!(w, "            {} => Some({}::{}),\n", v.i, name, v.s));
116            }
117        }
118        try!(write!(w, "            _ => None\n"));
119        try!(write!(w, "        }}\n"));
120        try!(write!(w, "    }}\n"));
121        try!(write!(w, "}}\n"));
122        Ok(())
123    }
124}
125
126struct FormatOutputPrettyFmt;
127impl FormatOutput for FormatOutputPrettyFmt {
128    #[allow(unused_variables)]
129    fn write(&self, w: &mut Write, name: &String, hex: bool, vec: &Vec<CEnum>) -> Result<()> {
130        try!(write!(w, "impl {} {{\n", name));
131        try!(write!(w, "    fn pretty_fmt(f: &mut ::std::fmt::Formatter, flags: u32) -> ::std::fmt::Result {{\n"));
132        try!(write!(w, "        let mut shift: u32 = 0;\n"));
133        try!(write!(w, "        let mut result: u32 = 1<<shift;\n"));
134        try!(write!(w, "        let mut found = false;\n"));
135        // This should never fail because we check in main() to make sure that
136        // it isn't empty.
137        try!(write!(w, "        while result <= {}::{} as u32 {{\n", name, vec.last().unwrap().s));
138        try!(write!(w, "            let tmp = result & flags;\n"));
139        try!(write!(w, "            if tmp > 0 {{\n"));
140        try!(write!(w, "                if found {{\n"));
141        try!(write!(w, "                    try!(write!(f, \"|\"));\n"));
142        try!(write!(w, "                }}\n"));
143        try!(write!(w, "                let flag = {}::from_u32(tmp).unwrap();\n", name));
144        try!(write!(w, "                try!(write!(f, \"{{}}\", flag));\n"));
145        try!(write!(w, "                found = true;\n"));
146        try!(write!(w, "            }}\n"));
147        try!(write!(w, "            shift += 1;\n"));
148        try!(write!(w, "            result = 1<<shift;\n"));
149        try!(write!(w, "        }}\n"));
150        try!(write!(w, "        write!(f, \"\")\n"));
151        try!(write!(w, "    }}\n"));
152        try!(write!(w, "}}\n"));
153        Ok(())
154    }
155}
156
157struct FormatOutputDefault;
158impl FormatOutput for FormatOutputDefault {
159    #[allow(unused_variables)]
160    fn write(&self, w: &mut Write, name: &String, hex: bool, vec: &Vec<CEnum>) -> Result<()> {
161        try!(write!(w, "impl Default for {} {{\n", name));
162        try!(write!(w, "    fn default() -> {} {{\n", name));
163        try!(write!(w, "        {}::{}\n", name, vec[0].s));
164        try!(write!(w, "    }}\n"));
165        try!(write!(w, "}}\n"));
166        Ok(())
167    }
168}
169
170struct FormatOutputDisplay;
171impl FormatOutput for FormatOutputDisplay {
172    #[allow(unused_variables)]
173    fn write(&self, w: &mut Write, name: &String, hex: bool, vec: &Vec<CEnum>) -> Result<()> {
174        try!(write!(w, "impl ::std::fmt::Display for {} {{\n", name));
175        try!(write!(w, "    #[allow(dead_code)]\n"));
176        try!(write!(w, "    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {{\n"));
177        try!(write!(w, "        match *self {{\n"));
178        for v in vec {
179            try!(write!(w, "            {}::{} => write!(f, \"{}\"),\n", name, v.s, v.s));
180        }
181        try!(write!(w, "        }}\n"));
182        try!(write!(w, "    }}\n"));
183        try!(write!(w, "}}\n"));
184        Ok(())
185    }
186}
187
188struct FormatOutputFromStr;
189impl FormatOutput for FormatOutputFromStr {
190    #[allow(unused_variables)]
191    fn write(&self, w: &mut Write, name: &String, hex: bool, vec: &Vec<CEnum>) -> Result<()> {
192        try!(write!(w, "impl ::std::str::FromStr for {} {{\n", name));
193        try!(write!(w, "    type Err = ();\n"));
194        try!(write!(w, "    #[allow(dead_code)]\n"));
195        try!(write!(w, "    fn from_str(s: &str) -> Result<Self, Self::Err> {{\n"));
196        try!(write!(w, "        match s {{\n"));
197        for v in vec {
198            try!(write!(w, "            \"{}\" => Ok({}::{}),\n", v.s, name, v.s));
199        }
200        try!(write!(w, "            _ => Err( () )\n"));
201        try!(write!(w, "        }}\n"));
202        try!(write!(w, "    }}\n"));
203        try!(write!(w, "}}\n"));
204        Ok(())
205    }
206}
207
208struct FormatOutputEnum;
209impl FormatOutputEnum {
210    fn write(&self, w: &mut Write, name: &String, derive: Option<&String>, hex: bool, vec: &Vec<CEnum>) -> Result<()> {
211        try!(write!(w, "#[allow(dead_code, non_camel_case_types)]\n"));
212        match derive
213        {
214            Some(s) => try!(write!(w, "#[derive({})]\n", s)),
215            None => (),
216        }
217        try!(write!(w, "pub enum {} {{\n", name));
218
219        for v in vec {
220            if hex {
221                try!(write!(w, "    {} = 0x{:X},\n", v.s, v.i));
222            }
223            else {
224                try!(write!(w, "    {} = {},\n", v.s, v.i));
225            }
226        }
227
228        try!(write!(w, "}}\n"));
229        Ok(())
230    }
231}
232
233// A macro to retrieve an str element from a toml::Table
234// $t - Table to lookup in
235// $a - Where to assign Some(String)
236// $v - the name to look for in the toml
237macro_rules! get_key_string {
238    ($t:ident, $a:ident, $v:ident) => {
239        if $t.contains_key(stringify!($v)) {
240            let $v = $t.get(stringify!($v)).unwrap();
241            let $v = $v.as_str();
242            if $v.is_none() {
243                return Err(Error::new(ErrorKind::Other,
244                                      format!("{} wasn't available as str",
245                                              stringify!($v))))
246            }
247            let $v = $v.unwrap();
248            $a.$v = Some(String::from($v));
249        }
250    }
251}
252
253// same as get_key_bool, except for bool instead of str/string
254macro_rules! get_key_bool {
255    ($t:ident, $a:ident, $v:ident) => {
256        if $t.contains_key(stringify!($v)) {
257            let $v = $t.get(stringify!($v)).unwrap();
258            let $v = $v.as_bool();
259            if $v.is_none() {
260                return Err(Error::new(ErrorKind::Other,
261                                      format!("{} wasn't available as bool",
262                                              stringify!($v))))
263            }
264            $a.$v = $v.unwrap();
265        }
266    }
267}
268
269fn parse_toml(path: &PathBuf) -> Result<FileArgs>
270{
271    let mut fa = FileArgs::default();
272    let mut f = try!(File::open(&path));
273
274    let mut s = String::new();
275    try!(f.read_to_string(&mut s));
276    let table = toml::Parser::new(&s).parse();
277    if table.is_none() {
278        return Err(Error::new(ErrorKind::Other,
279                              format!("failed to parse {}", path.display())))
280    }
281    let table = table.unwrap();
282
283    let rust_enum_derive = table.get("rust-enum-derive");
284    if rust_enum_derive.is_none() {
285        return Err(Error::new(ErrorKind::Other,
286                              format!("couldn't find a rust-enum-derive table in {}",
287                                      path.display())))
288    }
289    let rust_enum_derive = rust_enum_derive.unwrap();
290    let rust_enum_derive = rust_enum_derive.as_table();
291    if rust_enum_derive.is_none() {
292        return Err(Error::new(ErrorKind::Other,
293                              format!("rust-enum-derive wasn't a table")))
294    }
295    let rust_enum_derive = rust_enum_derive.unwrap();
296
297    get_key_string!(rust_enum_derive, fa, name);
298    get_key_string!(rust_enum_derive, fa, derive);
299    get_key_bool!(rust_enum_derive, fa, define);
300    get_key_bool!(rust_enum_derive, fa, default);
301    get_key_bool!(rust_enum_derive, fa, display);
302    get_key_bool!(rust_enum_derive, fa, fromstr);
303    get_key_bool!(rust_enum_derive, fa, fromprimative);
304    get_key_bool!(rust_enum_derive, fa, hex);
305    get_key_bool!(rust_enum_derive, fa, pretty_fmt);
306    debug!("fa = {:?}", fa);
307
308    Ok(fa)
309}
310
311fn get_num(s: &str) -> i32 {
312    use std::str::FromStr;
313    use regex::Regex;
314    let re_int = Regex::new(r"^(0x)?([:digit:]+)$").unwrap();
315    let re_shift = Regex::new(r"^([:digit:]+)[:space:]*<<[:space:]*([:digit:]+)$").unwrap();
316
317    if re_int.is_match(s) {
318        let caps = re_int.captures(s).unwrap();
319        let radix: u32 = match caps.at(1) {
320            Some(_) => 16,
321            None => 10,
322        };
323        let digits = caps.at(2).unwrap();
324        i32::from_str_radix(digits, radix).unwrap()
325    }
326    else if re_shift.is_match(s) {
327        let caps = re_shift.captures(s).unwrap();
328        let l: i32 = FromStr::from_str(caps.at(1).unwrap()).unwrap();
329        let r: i32 = FromStr::from_str(caps.at(2).unwrap()).unwrap();
330        l<<r
331    }
332    else {
333        panic!("couldn't parse '{}' as int", s)
334    }
335}
336
337/// Return a sorted Vec of CEnum structs
338fn parse_buff<T: BufRead>(read: T, parse_enum: bool) -> Vec<CEnum> {
339    use regex::Regex;
340    let re = match parse_enum {
341        true => Regex::new(r"^[:space:]*([[:alnum:]_]+)([:space:]*=[:space:]*([:graph:]+))?[:space:]*,").unwrap(),
342        false => Regex::new(r"^#define[:space:]+([:graph:]+)[:space:]+([:graph:]+)").unwrap(),
343    };
344    let mut v: Vec<CEnum> = Vec::new();
345
346    let mut num: i32 = 0;
347    for line in read.lines() {
348        let s = line.unwrap();
349        for cap in re.captures_iter(&s) {
350            let i: i32 = match parse_enum {
351                true => match cap.at(3) {
352                    Some(s) => get_num(s),
353                    None => num,
354                },
355                false => get_num(cap.at(2).unwrap()),
356            };
357            num = i + 1;
358            v.push(CEnum::new(i, cap.at(1).unwrap()));
359        }
360    }
361
362    v.sort();
363    v
364}
365
366fn get_input(file_path: Option<&PathBuf>, file_args: &FileArgs) -> Vec<CEnum> {
367    match file_path {
368        Some(ref s) => {
369            // remove this unwrap as soon as expect is stabalized
370            let f = File::open(s).unwrap();
371            let r = BufReader::new(f);
372            parse_buff(r, !file_args.define)
373        }
374        None => {
375            let r = BufReader::new(std::io::stdin());
376            parse_buff(r, !file_args.define)
377        }
378    }
379}
380
381fn write_factory(file_path: Option<&PathBuf>) -> Result<Box<Write>> {
382    match file_path {
383        Some(s) => {
384            try!(std::fs::create_dir_all(s.parent().unwrap()));
385            let f = try!(OpenOptions::new().write(true)
386                                           .create(true)
387                                           .truncate(true)
388                                           .open(s));
389            let w = BufWriter::new(f);
390            Ok(Box::new(w))
391        }
392        None => {
393            let w = BufWriter::new(std::io::stdout());
394            Ok(Box::new(w))
395        }
396    }
397}
398
399/// This is the function that you call to process one file (Enum) worth of data.
400///
401/// * `file_path_in` - The file input path to read from (or stdin if None)
402/// * `file_path_out` - The file output path to write to (or stdout from None)
403/// * `file_argsfile_args` - The arguments for how to process the input
404pub fn process(file_path_in: Option<&PathBuf>, file_path_out: Option<&PathBuf>,
405               file_args: &FileArgs) -> Result<()> {
406    let mut fov: Vec<Box<FormatOutput>> = Vec::new();
407    if file_args.fromstr { fov.push(Box::new(FormatOutputFromStr)); }
408    if file_args.default { fov.push(Box::new(FormatOutputDefault)); }
409    if file_args.display { fov.push(Box::new(FormatOutputDisplay)); }
410    if file_args.fromprimative { fov.push(Box::new(FormatOutputFromPrimative)); }
411    if file_args.pretty_fmt { fov.push(Box::new(FormatOutputPrettyFmt)); }
412
413    let vi = get_input(file_path_in, &file_args);
414    if vi.len() < 1 {
415        let input = match file_path_in {
416            Some(pb) => pb.to_string_lossy().into_owned(),
417            None => String::from("standard in"),
418        };
419        return Err(Error::new(ErrorKind::Other,
420                              format!("couldn't parse any input from {}.",
421                                      input)))
422    }
423
424    let mut w = try!(write_factory(file_path_out));
425    let name = match file_args.name
426    {
427        Some(ref s) => s.clone(),
428        None => String::from("Name"),
429    };
430
431    let derive = file_args.derive.as_ref();
432    try!(FormatOutputEnum.write(&mut w, &name, derive, file_args.hex, &vi));
433    for vw in fov {
434        try!(vw.write(&mut w, &name, file_args.hex, &vi));
435    }
436
437    Ok(())
438}
439
440fn traverse_dir_impl(base_input_dir: &PathBuf,
441                     base_output_dir: &PathBuf,
442                     sub_dir: &PathBuf) -> Result<()> {
443    let mut dir = PathBuf::new();
444    dir.push(base_input_dir);
445    dir.push(sub_dir);
446
447    if !fs::metadata(&dir).unwrap().is_dir() {
448        return Err(Error::new(ErrorKind::Other,
449                              format!("{} is not a directory", dir.display())))
450    }
451
452    // TODO: revisit. This follows symlinks, is that what we want?
453    // If no we could use fs::symlink_metadata() treats symbolic links as
454    // files, or DirEntry::file_type() which returns a FileType which we could
455    // use to tell if this was a symbolic link or not?
456    for entry in try!(fs::read_dir(dir)) {
457        let entry = try!(entry);
458        if fs::metadata(entry.path()).unwrap().is_dir() {
459            let mut new_sub_dir = PathBuf::new();
460            new_sub_dir.push(sub_dir);
461            new_sub_dir.push(entry.file_name());
462            try!(traverse_dir_impl(base_input_dir, base_output_dir, &new_sub_dir));
463        } else {
464            let path = entry.path();
465            if path.extension().is_some() {
466                let extension = path.extension().unwrap();
467                let extension = extension.to_string_lossy();
468                let extension = extension.to_lowercase();
469                if extension == "toml" {
470                    let args = try!(parse_toml(&path));
471
472                    let path = entry.path();
473                    let base = path.file_stem().unwrap();
474
475                    let mut input_file_path = PathBuf::new();
476                    input_file_path.push(base_input_dir);
477                    input_file_path.push(sub_dir);
478                    input_file_path.push(base);
479                    input_file_path.set_extension("in");
480
481                    let mut output_file_path = PathBuf::new();
482                    output_file_path.push(base_output_dir);
483                    output_file_path.push(sub_dir);
484                    output_file_path.push(base);
485                    output_file_path.set_extension("rs");
486
487                    try!(process(Some(&input_file_path), Some(&output_file_path), &args));
488                }
489            }
490        }
491    } // for entry in try!(fs::read_dir(dir))
492
493    Ok(())
494}
495
496/// This is the function that you call to process a whole directory heirarcy full of files.
497///
498/// * `input_dir` - The input path of the directory to read from
499/// * `output_dir` - The output path of the directory to write to
500/// * `file_argsfile_args` - The arguments for how to process the input
501pub fn traverse_dir(input_dir: &PathBuf, output_dir: &PathBuf) -> Result<()> {
502    traverse_dir_impl(input_dir, &output_dir, &PathBuf::new())
503}
504
505#[test]
506fn test_CENum_order() {
507    let a = CEnum::new(0, "");
508    let b = CEnum::new(1, "");
509    let c = CEnum::new(2, "");
510    let d = CEnum::new(0, "");
511    assert!(a < b);
512    assert!(b < c);
513    assert!(a < c);
514    assert!(b > a);
515    assert!(c > b);
516    assert!(c > a);
517    assert!(a == d);
518}
519
520#[test]
521fn test_parse_buff() {
522    use std::io::Cursor;
523    let s = "#define NETLINK_ROUTE 0\n\
524    #define NETLINK_UNUSED 1\n\
525    #define NETLINK_FIREWALL 3\n\
526    #define NETLINK_SOCK_DIAG 4\n\
527    #define NETLINK_GENERIC 16";
528
529    let buff = Cursor::new(s.as_bytes());
530
531    let v = parse_buff(buff, false);
532
533    assert!(v[0].i == 0); assert!(v[0].s == "NETLINK_ROUTE");
534    assert!(v[1].i == 1); assert!(v[1].s == "NETLINK_UNUSED");
535    assert!(v[2].i == 3); assert!(v[2].s == "NETLINK_FIREWALL");
536    assert!(v[3].i == 4); assert!(v[3].s == "NETLINK_SOCK_DIAG");
537    assert!(v[4].i == 16); assert!(v[4].s == "NETLINK_GENERIC");
538}
539
540#[test]
541fn test_parse_buff_enum() {
542    use std::io::Cursor;
543    let s = "RTM_NEWLINK    = 16,\n\
544             #define RTM_NEWLINK    RTM_NEWLINK\n\
545                 RTM_DELLINK,\n\
546             #define RTM_DELLINK    RTM_DELLINK\n\
547                 RTM_GETLINK,\n\
548             #define RTM_GETLINK    RTM_GETLINK\n\
549                 RTM_SETLINK,\n\
550             #define RTM_SETLINK    RTM_SETLINK\n\n\
551                 RTM_NEWADDR    = 20,\n\
552             #define RTM_NEWADDR    RTM_NEWADDR\n\
553                 RTM_DELADDR,";
554
555    let buff = Cursor::new(s.as_bytes());
556    let v = parse_buff(buff, true);
557
558    assert!(v[0].i == 16); assert!(v[0].s == "RTM_NEWLINK");
559    assert!(v[1].i == 17); assert!(v[1].s == "RTM_DELLINK");
560    assert!(v[2].i == 18); assert!(v[2].s == "RTM_GETLINK");
561    assert!(v[3].i == 19); assert!(v[3].s == "RTM_SETLINK");
562    assert!(v[4].i == 20); assert!(v[4].s == "RTM_NEWADDR");
563    assert!(v[5].i == 21); assert!(v[5].s == "RTM_DELADDR");
564}