1use std::{
2    collections::HashMap,
3    hash::{BuildHasherDefault, DefaultHasher},
4    ops::Deref,
5    sync::Mutex,
6    time::Instant,
7};
8
9use log::*;
10use thiserror::Error;
11use tokio::process::Command;
12
13use crate::os::{ID, OS};
14
15mod info;
16
17#[cfg_attr(windows, path = "os/windows/mod.rs")]
18#[cfg_attr(target_os = "linux", path = "os/linux/mod.rs")]
19mod os;
20
21pub use info::*;
22
23static MAX_BANDWIDTH: Mutex<
24    HashMap<ID, HashMap<usize, BandWidth>, BuildHasherDefault<DefaultHasher>>,
25> = Mutex::new(HashMap::with_hasher(BuildHasherDefault::new()));
26
27#[derive(Error, Debug)]
28pub enum WiFiError {
29    #[error("System error: {0}")]
30    System(String),
31    #[error("Not support: {0}")]
32    NotSupport(String),
33}
34
35impl WiFiError {
36    fn new_system<E: Deref<Target = str>>(e: E) -> Self {
37        WiFiError::System(e.to_string())
38    }
39}
40
41impl From<std::io::Error> for WiFiError {
42    fn from(e: std::io::Error) -> Self {
43        WiFiError::new_system(e.to_string())
44    }
45}
46
47pub type WiFiResult<T = ()> = std::result::Result<T, WiFiError>;
48
49#[derive(Debug, Clone)]
50pub struct Interface {
51    pub id: ID,
52    pub support_mode: Vec<Mode>,
53}
54
55impl Interface {
56    pub async fn set_mode(&self, mode: Mode) -> WiFiResult {
57        let start = Instant::now();
58        OS::set_mode(&self.id, mode).await?;
59        debug!(
60            "Set mode for interface [{}] to {:?} took {:?}",
61            self.id,
62            mode,
63            start.elapsed()
64        );
65        Ok(())
66    }
67
68    pub async fn set_channel(
69        &self,
70        channel: usize,
71        band_width: Option<BandWidth>,
72        second: Option<SecondChannel>,
73    ) -> WiFiResult {
74        if let Err(e) = self.try_set_chennel(channel, band_width, second).await {
75            warn!(
76                "interface `{}` set channel {channel} {band_width:?} fail, try downcast, err: {}",
77                self.id, e
78            );
79            downcast_channel_max_bandwidth(&self.id, channel);
80            self.try_set_chennel(channel, None, None).await
81        } else {
82            Ok(())
83        }
84    }
85    pub async fn set_frequency(
86        &self,
87        freq_mhz: usize,
88        band_width: Option<BandWidth>,
89        second: Option<SecondChannel>,
90    ) -> WiFiResult {
91        let channel = freq_mhz_to_channel(freq_mhz);
92        self.set_channel(channel, band_width, second).await
93    }
94
95    async fn try_set_chennel(
96        &self,
97        channel: usize,
98        chennel: Option<BandWidth>,
99        second: Option<SecondChannel>,
100    ) -> WiFiResult {
101        let start = Instant::now();
102        let band_width = adapt_channel_max_bandwidth(&self.id, channel, chennel, second);
103        OS::set_channel(&self.id, channel, band_width).await?;
104        let band_width_str = band_width
105            .map(|bw| format!(" bandwidth {}", bw))
106            .unwrap_or_default();
107
108        debug!(
109            "Set interface [{}] to channel {channel} {band_width_str} took {:?}",
110            self.id,
111            start.elapsed()
112        );
113        Ok(())
114    }
115
116    pub async fn ifup(&self) -> WiFiResult {
117        OS::ifup(&self.id).await
118    }
119    pub async fn ifdown(&self) -> WiFiResult {
120        OS::ifdown(&self.id).await
121    }
122}
123
124#[derive(Debug, Clone, Copy)]
125pub enum Mode {
126    Managed,
127    Monitor,
128}
129impl Mode {
130    fn cmd(&self) -> &str {
131        match self {
132            Mode::Monitor => "monitor",
133            Mode::Managed => "managed",
134        }
135    }
136}
137
138impl TryFrom<&str> for Mode {
139    type Error = ();
140
141    fn try_from(value: &str) -> Result<Self, Self::Error> {
142        match value.trim() {
143            "managed" => Ok(Mode::Managed),
144            "monitor" => Ok(Mode::Monitor),
145            _ => Err(()),
146        }
147    }
148}
149
150trait Impl {
151    async fn check_environment() -> WiFiResult;
152    async fn interface_list() -> Result<Vec<Interface>, WiFiError>;
153    async fn set_mode(id: &ID, mode: Mode) -> WiFiResult;
154
155    async fn set_channel(id: &ID, channel: usize, band_width: Option<BandWidthArg>) -> WiFiResult;
156    async fn ifup(id: &ID) -> WiFiResult;
157    async fn ifdown(id: &ID) -> WiFiResult;
158    async fn freq_max_bandwidth(id: &ID) -> WiFiResult<HashMap<usize, BandWidth>>;
159}
160
161pub async fn check_environment() -> WiFiResult {
162    OS::check_environment().await
163}
164
165pub async fn interface_list() -> Result<Vec<Interface>, WiFiError> {
166    let mut out = vec![];
167    for one in OS::interface_list().await? {
168        let id = one.id.clone();
169        out.push(one);
170        #[allow(clippy::map_entry)]
171        if !MAX_BANDWIDTH.lock().unwrap().contains_key(&id) {
172            let mut map = HashMap::new();
173            let max_bandwidth = OS::freq_max_bandwidth(&id).await?;
174            for (freq, bandwidth) in max_bandwidth {
175                let channel = freq_mhz_to_channel(freq);
176                map.insert(channel, bandwidth);
177            }
178            MAX_BANDWIDTH.lock().unwrap().insert(id, map);
179        }
180    }
181
182    Ok(out)
183}
184
185#[allow(unused)]
186async fn check_command(cmd: &str) -> WiFiResult {
187    Command::new(cmd)
188        .arg("--help")
189        .output()
190        .await
191        .map_err(|e| WiFiError::NotSupport(format!("command [{}] fail: {:?}", cmd, e)))?;
192    Ok(())
193}
194
195#[allow(unused)]
196trait CommandExt {
197    async fn execute<T: AsRef<str>>(&mut self, expect: T) -> WiFiResult;
198}
199
200impl CommandExt for Command {
201    async fn execute<T: AsRef<str>>(&mut self, expect: T) -> WiFiResult {
202        let program = self.as_std().get_program().to_os_string();
203        let program = program.to_string_lossy();
204        let expect = expect.as_ref();
205
206        let status = self.status().await.map_err(|e| {
207            WiFiError::new_system(format!("{expect} failed, program `{program}`: {e}"))
208        })?;
209        if !status.success() {
210            return Err(WiFiError::new_system(format!(
211                "{expect} failed, program `{program}`"
212            )));
213        }
214        Ok(())
215    }
216}
217
218pub fn channel_to_freq_mhz(channel: usize) -> usize {
219    if channel < 14 {
220        2407 + channel * 5
221    } else {
222        5000 + channel * 5
223    }
224}
225
226pub fn freq_mhz_to_channel(freq_mhz: usize) -> usize {
227    if freq_mhz > 5000 {
228        return (freq_mhz - 5000) / 5;
229    }
230    (freq_mhz - 2407) / 5
231}
232
233fn adapt_channel_max_bandwidth(
234    id: &ID,
235    channel: usize,
236    bandwidth: Option<BandWidth>,
237    second: Option<SecondChannel>,
238) -> Option<BandWidthArg> {
239    let mut bandwidth = bandwidth?;
240    if let Some(max_bandwidth) = channel_max_bandwidth(id, channel) {
241        if bandwidth > max_bandwidth {
242            debug!(
243                "Channel {} supports max bandwidth: {:?}, using it",
244                channel, max_bandwidth
245            );
246            bandwidth = max_bandwidth;
247        }
248    } else {
249        debug!("channel {} not found in max bandwidth map", channel);
250    }
251
252    let out = match bandwidth {
253        BandWidth::HT40 => {
254            if let Some(second) = second {
255                match second {
256                    SecondChannel::Above => BandWidthArg::HT40Above,
257                    SecondChannel::Below => BandWidthArg::HT40Below,
258                }
259            } else {
260                match channel {
261                    1..=6 => BandWidthArg::HT40Above,
262                    7..=13 => BandWidthArg::HT40Below,
263                    _ => {
264                        warn!(
265                            "Channel {} is not in the range of 1-13, defaulting to HT40Above",
266                            channel
267                        );
268                        BandWidthArg::HT40Above
269                    }
270                }
271            }
272        }
273        BandWidth::HT20 => BandWidthArg::HT20,
274        BandWidth::MHz80 => BandWidthArg::MHz80,
275        BandWidth::MHz160 => BandWidthArg::MHz160,
276    };
277
278    Some(out)
279}
280
281fn downcast_channel_max_bandwidth(id: &ID, freq: usize) -> Option<()> {
282    let mut max_bandwidth = MAX_BANDWIDTH.lock().unwrap();
283    let map = max_bandwidth.get_mut(id)?;
284    map.insert(freq, BandWidth::HT20);
285    Some(())
286}
287
288fn channel_max_bandwidth(id: &ID, channel: usize) -> Option<BandWidth> {
289    let max_bandwidth = MAX_BANDWIDTH.lock().unwrap();
290    max_bandwidth.get(id).and_then(|m| m.get(&channel)).cloned()
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[tokio::test]
298    async fn test_get_wifi_adapter_names() {
299        for one in interface_list().await.unwrap() {
300            println!("{one:?}");
301        }
302    }
303
304    #[tokio::test]
305    async fn test_set_mode() {
306        let interface = interface_list().await.unwrap().remove(0);
307        interface.set_mode(Mode::Monitor).await.unwrap();
308    }
309
310    #[test]
311    fn test_channel_to_freq_mhz() {
312        assert_eq!(channel_to_freq_mhz(1), 2412);
313        assert_eq!(channel_to_freq_mhz(6), 2437);
314        assert_eq!(channel_to_freq_mhz(13), 2472);
315
316        assert_eq!(channel_to_freq_mhz(36), 5180);
317    }
318
319    #[test]
320    fn test_freq_mhz_to_channel() {
321        assert_eq!(freq_mhz_to_channel(2412), 1);
322        assert_eq!(freq_mhz_to_channel(2437), 6);
323        assert_eq!(freq_mhz_to_channel(2472), 13);
324        assert_eq!(freq_mhz_to_channel(5180), 36);
325    }
326
327    #[tokio::test]
328    async fn test_set_channel() {
329        env_logger::builder()
330            .filter_level(log::LevelFilter::Debug)
331            .is_test(true)
332            .init();
333        let interface = interface_list().await.unwrap().remove(0);
334        interface.set_mode(Mode::Monitor).await.unwrap();
335
336        interface
337            .set_channel(13, Some(BandWidth::MHz160), Some(SecondChannel::Below))
338            .await
339            .unwrap();
340        interface
341            .set_channel(2, Some(BandWidth::MHz160), None)
342            .await
343            .unwrap();
344        interface
345            .set_channel(2, Some(BandWidth::MHz160), Some(SecondChannel::Above))
346            .await
347            .unwrap();
348
349        }
354}