rust_iptables/
iptables.rs

1use lazy_static::lazy_static;
2use nix::fcntl::{flock, FlockArg};
3use regex::{Match, Regex};
4use std::convert::From;
5use std::error::Error;
6use std::ffi::OsStr;
7use std::fmt;
8use std::fs::File;
9use std::os::unix::io::AsRawFd;
10use std::process::{Command, Output};
11use std::vec::Vec;
12
13lazy_static! {
14    static ref RULE_SPLIT: Regex = Regex::new(r#"["'].+?["']|[^ ]+"#).unwrap();
15}
16
17trait SplitQuoted {
18    fn split_quoted(&self) -> Vec<&str>;
19}
20
21impl SplitQuoted for str {
22    fn split_quoted(&self) -> Vec<&str> {
23        RULE_SPLIT
24            .find_iter(self)
25            .map(|m| Match::as_str(&m))
26            .map(|s| s.trim_matches(|c| c == '"' || c == '\''))
27            .collect::<Vec<_>>()
28    }
29}
30
31fn error_from_str(msg: &str) -> Box<dyn Error> {
32    msg.into()
33}
34
35fn output_to_result(output: Output) -> Result<(), Box<dyn Error>> {
36    if !output.status.success() {
37        return Err(Box::new(IptablesError::from(output)));
38    }
39    Ok(())
40}
41
42#[derive(Debug)]
43pub struct IptablesError {
44    pub code: i32,
45    pub msg: String,
46}
47
48impl fmt::Display for IptablesError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(f, "code: {}, msg: {}", self.code, self.msg)
51    }
52}
53
54impl From<Output> for IptablesError {
55    fn from(output: Output) -> Self {
56        Self {
57            code: output.status.code().unwrap_or(-1),
58            msg: String::from_utf8_lossy(output.stderr.as_slice()).into(),
59        }
60    }
61}
62
63impl Error for IptablesError {}
64
65pub struct IPTables {
66    pub cmd: &'static str,
67    pub save_cmd: &'static str,
68    pub restore_cmd: &'static str,
69    pub has_check: bool,
70    pub has_wait: bool,
71
72    pub v_major: isize,
73    pub v_minor: isize,
74    pub v_patch: isize,
75}
76
77#[cfg(target_os = "linux")]
78pub fn new_with_protocol(is_ipv6: bool) -> Result<IPTables, Box<dyn Error>> {
79    let cmd = if is_ipv6 { "ip6tables" } else { "iptables" };
80    let save_cmd = if is_ipv6 {
81        "ip6tables-save"
82    } else {
83        "iptables-save"
84    };
85    let restore_cmd = if is_ipv6 {
86        "ip6tables-restore"
87    } else {
88        "iptables-restore"
89    };
90
91    let version_output = Command::new(cmd).arg("--version").output()?;
92    let re = Regex::new(r"v(\d+)\.(\d+)\.(\d+)")?;
93    let version_string = String::from_utf8_lossy(version_output.stdout.as_slice());
94    let versions = re
95        .captures(&version_string)
96        .ok_or("invalid version number")?;
97    let v_major = versions
98        .get(1)
99        .ok_or("unable to get major version number")?
100        .as_str()
101        .parse::<i32>()?;
102    let v_minor = versions
103        .get(2)
104        .ok_or("unable to get minor version number")?
105        .as_str()
106        .parse::<i32>()?;
107    let v_patch = versions
108        .get(3)
109        .ok_or("unable to get patch version number")?
110        .as_str()
111        .parse::<i32>()?;
112
113    Ok(IPTables {
114        cmd,
115        save_cmd,
116        restore_cmd,
117        has_check: (v_major > 1)
118            || (v_major == 1 && v_minor > 4)
119            || (v_major == 1 && v_minor == 4 && v_patch > 10),
120        has_wait: (v_major > 1)
121            || (v_major == 1 && v_minor > 4)
122            || (v_major == 1 && v_minor == 4 && v_patch > 19),
123        v_major: v_major as isize,
124        v_minor: v_minor as isize,
125        v_patch: v_patch as isize,
126    })
127}
128
129#[cfg(not(target_os = "linux"))]
130pub fn new() -> Result<IPTables, Box<dyn Error>> {
131    Err(error_from_str("iptables only works on Linux"))
132}
133
134#[cfg(target_os = "linux")]
135pub fn new() -> Result<IPTables, Box<dyn Error>> {
136    new_with_protocol(false)
137}
138
139impl IPTables {
140    pub fn save_table(&self, table: &str, target: &str) -> Result<Output, Box<dyn Error>> {
141        let cmd = format!("{} -t {} > {}", self.save_cmd, table, target);
142        let output = Command::new("sh").arg("-c").arg(cmd).output()?;
143        Ok(output)
144    }
145
146    pub fn save_all(&self, target: &str) -> Result<Output, Box<dyn Error>> {
147        let cmd = format!("{} > {}", self.save_cmd, target);
148        let output = Command::new("sh").arg("-c").arg(cmd).output()?;
149        Ok(output)
150    }
151
152    pub fn restore_table(&self, table: &str, target: &str) -> Result<Output, Box<dyn Error>> {
153        let cmd = format!("{} -t {} < {}", self.restore_cmd, table, target);
154        let output = Command::new("sh").arg("-c").arg(cmd).output()?;
155        Ok(output)
156    }
157
158    pub fn restore_all(&self, target: &str) -> Result<Output, Box<dyn Error>> {
159        let cmd = format!("{} < {}", self.restore_cmd, target);
160        let output = Command::new("sh").arg("-c").arg(cmd).output()?;
161        Ok(output)
162    }
163
164    fn run<S: AsRef<OsStr>>(&self, args: &[S]) -> Result<Output, Box<dyn Error>> {
165        let mut file_lock = None;
166
167        let mut output_cmd = Command::new(self.cmd);
168        let output;
169
170        if self.has_wait {
171            output = output_cmd.args(args).arg("--wait").output()?;
172        } else {
173            file_lock = Some(File::create("/var/run/xtables_old.lock")?);
174
175            let mut need_retry = true;
176            let mut limit = 10;
177            while need_retry {
178                match flock(
179                    file_lock.as_ref().unwrap().as_raw_fd(),
180                    FlockArg::LockExclusiveNonblock,
181                ) {
182                    Ok(_) => need_retry = false,
183                    Err(nix::Error::Sys(en)) if en == nix::errno::Errno::EAGAIN => {
184                        if limit > 0 {
185                            need_retry = true;
186                            limit -= 1;
187                        } else {
188                            return Err(error_from_str("get lock failed"));
189                        }
190                    }
191                    Err(e) => {
192                        return Err(Box::new(e));
193                    }
194                }
195            }
196            output = output_cmd.args(args).output()?;
197        }
198
199        if let Some(f) = file_lock {
200            drop(f)
201        }
202        Ok(output)
203    }
204
205    fn exists_old_version(
206        &self,
207        table: &str,
208        chain: &str,
209        rule: &str,
210    ) -> Result<bool, Box<dyn Error>> {
211        self.run(&["-t", table, "-S"]).map(|output| {
212            String::from_utf8_lossy(&output.stdout).contains(&format!("-A {} {}", chain, rule))
213        })
214    }
215
216    fn get_list<S: AsRef<OsStr>>(&self, args: &[S]) -> Result<Vec<String>, Box<dyn Error>> {
217        let stdout = self.run(args)?.stdout;
218        Ok(String::from_utf8_lossy(stdout.as_slice())
219            .trim()
220            .split('\n')
221            .map(String::from)
222            .collect())
223    }
224}
225
226impl IPTables {
227    #[cfg(target_os = "linux")]
228    pub fn exists(&self, table: &str, chain: &str, rule: &str) -> Result<bool, Box<dyn Error>> {
229        if !self.has_check {
230            return self.exists_old_version(table, chain, rule);
231        }
232
233        self.run(&[&["-t", table, "-C", chain], rule.split_quoted().as_slice()].concat())
234            .map(|output| output.status.success())
235    }
236
237    pub fn insert(
238        &self,
239        table: &str,
240        chain: &str,
241        position: i32,
242        rule: &str,
243    ) -> Result<(), Box<dyn Error>> {
244        self.run(
245            &[
246                &["-t", table, "-I", chain, &position.to_string()],
247                rule.split_quoted().as_slice(),
248            ]
249            .concat(),
250        )
251        .and_then(output_to_result)
252    }
253
254    pub fn append(&self, table: &str, chain: &str, rule: &str) -> Result<(), Box<dyn Error>> {
255        self.run(&[&["-t", table, "-A", chain], rule.split_quoted().as_slice()].concat())
256            .and_then(output_to_result)
257    }
258
259    pub fn append_unique(
260        &self,
261        table: &str,
262        chain: &str,
263        rule: &str,
264    ) -> Result<(), Box<dyn Error>> {
265        if self.exists(table, chain, rule)? {
266            return Err(error_from_str("the rule exists in the table/chain"));
267        }
268
269        self.append(table, chain, rule)
270    }
271
272    pub fn delete(&self, table: &str, chain: &str, rule: &str) -> Result<(), Box<dyn Error>> {
273        self.run(&[&["-t", table, "-D", chain], rule.split_quoted().as_slice()].concat())
274            .and_then(output_to_result)
275    }
276
277    pub fn delete_if_exsits(
278        &self,
279        table: &str,
280        chain: &str,
281        rule: &str,
282    ) -> Result<(), Box<dyn Error>> {
283        while self.exists(table, chain, rule)? {
284            self.delete(table, chain, rule)?;
285        }
286
287        Ok(())
288    }
289
290    pub fn list(&self, table: &str, chain: &str) -> Result<Vec<String>, Box<dyn Error>> {
291        self.get_list(&["-t", table, "-S", chain])
292    }
293
294    pub fn list_with_counters(
295        &self,
296        table: &str,
297        chain: &str,
298    ) -> Result<Vec<String>, Box<dyn Error>> {
299        self.get_list(&["-t", table, "-v", "-S", chain])
300    }
301
302    pub fn list_chains(&self, table: &str) -> Result<Vec<String>, Box<dyn Error>> {
303        let mut list = Vec::new();
304        let stdout = self.run(&["-t", table, "-S"])?.stdout;
305        let output = String::from_utf8_lossy(stdout.as_slice());
306        for item in output.trim().split('\n') {
307            let fields = item.split(' ').collect::<Vec<&str>>();
308            if fields.len() > 1 && (fields[0] == "-P" || fields[0] == "-N") {
309                list.push(fields[1].to_string());
310            }
311        }
312        Ok(list)
313    }
314
315    pub fn chain_exists(&self, table: &str, chain: &str) -> Result<bool, Box<dyn Error>> {
316        self.run(&["-t", table, "-L", chain])
317            .map(|output| output.status.success())
318    }
319
320    pub fn new_chain(&self, table: &str, chain: &str) -> Result<(), Box<dyn Error>> {
321        self.run(&["-t", table, "-N", chain])
322            .and_then(output_to_result)
323    }
324
325    pub fn flush_chain(&self, table: &str, chain: &str) -> Result<(), Box<dyn Error>> {
326        self.run(&["-t", table, "-F", chain])
327            .and_then(output_to_result)
328    }
329
330    pub fn rename_chain(
331        &self,
332        table: &str,
333        old_chain: &str,
334        new_chain: &str,
335    ) -> Result<(), Box<dyn Error>> {
336        self.run(&["-t", table, "-E", old_chain, new_chain])
337            .and_then(output_to_result)
338    }
339
340    pub fn delete_chain(&self, table: &str, chain: &str) -> Result<(), Box<dyn Error>> {
341        self.run(&["-t", table, "-X", chain])
342            .and_then(output_to_result)
343    }
344
345    pub fn flush_and_delete_chain(&self, table: &str, chain: &str) -> Result<(), Box<dyn Error>> {
346        while self.chain_exists(table, chain)? {
347            match self.flush_chain(table, chain) {
348                Ok(_) => {
349                    return self.delete_chain(table, chain);
350                }
351                Err(e) => {
352                    return Err(e);
353                }
354            }
355        }
356
357        Ok(())
358    }
359
360    pub fn flush_table(&self, table: &str) -> Result<(), Box<dyn Error>> {
361        self.run(&["-t", table, "-F"]).and_then(output_to_result)
362    }
363
364    pub fn delete_table(&self, table: &str) -> Result<(), Box<dyn Error>> {
365        self.run(&["-t", table, "-X"]).and_then(output_to_result)
366    }
367
368    pub fn flush_all(&self) -> Result<(), Box<dyn Error>> {
369        self.run(&["-F"]).and_then(output_to_result)
370    }
371
372    pub fn delete_all(&self) -> Result<(), Box<dyn Error>> {
373        self.run(&["-X"]).and_then(output_to_result)
374    }
375
376    pub fn change_policy(
377        &self,
378        table: &str,
379        chain: &str,
380        target: &str,
381    ) -> Result<(), Box<dyn Error>> {
382        self.run(&["-t", table, "-P", chain, target])
383            .and_then(output_to_result)
384    }
385
386    pub fn get_iptables_version(self) -> (isize, isize, isize) {
387        (self.v_major, self.v_minor, self.v_patch)
388    }
389}