wifi_ctrl/sta/
mod.rs

1use super::*;
2
3use tokio::time::Duration;
4
5mod types;
6pub use types::*;
7
8mod client;
9pub use client::*;
10
11mod setup;
12pub use setup::*;
13
14mod event_socket;
15use event_socket::*;
16
17const PATH_DEFAULT_SERVER: &str = "/var/run/wpa_supplicant/wlan2";
18
19/// Instance that runs the Wifi process
20pub struct WifiStation {
21    /// Path to the socket
22    socket_path: std::path::PathBuf,
23    /// Channel for receiving requests
24    request_receiver: mpsc::Receiver<Request>,
25    #[allow(unused)]
26    /// Channel for broadcasting alerts
27    broadcast_sender: broadcast::Sender<Broadcast>,
28    /// Channel for sending requests to itself
29    self_sender: mpsc::Sender<Request>,
30    /// Timeout duration in case no valid select response is received
31    select_timeout: Duration,
32}
33
34impl WifiStation {
35    pub async fn run(mut self) -> Result {
36        info!("Starting Wifi Station process");
37        let (socket_handle, mut deferred_requests) = SocketHandle::open(
38            &self.socket_path,
39            "mapper_wpa_ctrl_sync.sock",
40            &mut self.request_receiver,
41        )
42        .await?;
43        // We start up a separate socket for receiving the "unexpected" events that
44        // gets forwarded to us via the unsolicited_receiver
45        let (unsolicited_receiver, next_deferred_requests, unsolicited) =
46            EventSocket::new(&self.socket_path, &mut self.request_receiver).await?;
47        deferred_requests.extend(next_deferred_requests);
48        for request in deferred_requests {
49            let _ = self.self_sender.send(request).await;
50        }
51        self.broadcast_sender.send(Broadcast::Ready)?;
52        tokio::select!(
53            resp = unsolicited.run() => resp,
54            resp = self.run_internal(unsolicited_receiver, socket_handle) => resp,
55        )
56    }
57
58    async fn run_internal(
59        mut self,
60        mut unsolicited_receiver: EventReceiver,
61        mut socket_handle: SocketHandle<10240>,
62    ) -> Result {
63        // We will collect scan requests and batch respond to them when results are ready
64        let mut scan_requests = Vec::new();
65        let mut select_request = None;
66        loop {
67            enum EventOrRequest {
68                Event(Option<Event>),
69                Request(Option<Request>),
70            }
71
72            let event_or_request = tokio::select!(
73                unsolicited_msg = unsolicited_receiver.recv() => {
74                    EventOrRequest::Event(unsolicited_msg)
75                },
76                request = self.request_receiver.recv() => {
77                    EventOrRequest::Request(request)
78                },
79            );
80
81            match event_or_request {
82                EventOrRequest::Event(event) => match event {
83                    Some(unsolicited_msg) => {
84                        debug!("Unsolicited event: {unsolicited_msg:?}");
85                        Self::handle_event(
86                            &mut socket_handle,
87                            unsolicited_msg,
88                            &mut scan_requests,
89                            &mut select_request,
90                            &mut self.broadcast_sender,
91                        )
92                        .await?
93                    }
94                    None => return Err(error::Error::WifiStationEventChannelClosed),
95                },
96                EventOrRequest::Request(request) => match request {
97                    Some(Request::Shutdown) => return Ok(()),
98                    Some(request) => {
99                        self.handle_request(
100                            &mut socket_handle,
101                            request,
102                            &mut scan_requests,
103                            &mut select_request,
104                        )
105                        .await?;
106                    }
107                    None => return Err(error::Error::WifiStationRequestChannelClosed),
108                },
109            }
110        }
111    }
112
113    async fn handle_event<const N: usize>(
114        socket_handle: &mut SocketHandle<N>,
115        event: Event,
116        scan_requests: &mut Vec<oneshot::Sender<Result<Arc<Vec<ScanResult>>>>>,
117        select_request: &mut Option<SelectRequest>,
118        broadcast_sender: &mut broadcast::Sender<Broadcast>,
119    ) -> Result {
120        match event {
121            Event::ScanComplete => {
122                let _n = socket_handle.socket.send(b"SCAN_RESULTS").await?;
123                let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?;
124                let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?;
125                let mut scan_results = ScanResult::vec_from_str(data_str)?;
126                scan_results.sort_by(|a, b| a.signal.cmp(&b.signal));
127
128                let results = Arc::new(scan_results);
129                while let Some(scan_request) = scan_requests.pop() {
130                    if scan_request.send(Ok(results.clone())).is_err() {
131                        error!("Scan request response channel closed before response sent");
132                    }
133                }
134            }
135            Event::Connected => {
136                broadcast_sender.send(Broadcast::Connected)?;
137                if let Some(sender) = select_request.take() {
138                    sender.send(Ok(SelectResult::Success));
139                }
140            }
141            Event::Disconnected => {
142                broadcast_sender.send(Broadcast::Disconnected)?;
143            }
144            Event::NetworkNotFound => {
145                broadcast_sender.send(Broadcast::NetworkNotFound)?;
146                if let Some(sender) = select_request.take() {
147                    sender.send(Ok(SelectResult::NotFound));
148                }
149            }
150            Event::WrongPsk => {
151                broadcast_sender.send(Broadcast::WrongPsk)?;
152                if let Some(sender) = select_request.take() {
153                    sender.send(Ok(SelectResult::WrongPsk));
154                }
155            }
156            Event::Unknown(msg) => {
157                broadcast_sender.send(Broadcast::Unknown(msg))?;
158            }
159        }
160        Ok(())
161    }
162
163    async fn get_status<const N: usize>(socket_handle: &mut SocketHandle<N>) -> Result<Status> {
164        let _n = socket_handle.socket.send(b"STATUS").await?;
165        let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?;
166        let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end();
167        parse_status(data_str)
168    }
169
170    async fn handle_request<const N: usize>(
171        &self,
172        socket_handle: &mut SocketHandle<N>,
173        request: Request,
174        scan_requests: &mut Vec<oneshot::Sender<Result<Arc<Vec<ScanResult>>>>>,
175        select_request: &mut Option<SelectRequest>,
176    ) -> Result {
177        debug!("Handling request: {request:?}");
178        match request {
179            Request::Custom(custom, response_channel) => {
180                let _n = socket_handle.socket.send(custom.as_bytes()).await?;
181                let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?;
182                let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end();
183                debug!("Custom request response: {data_str}");
184                if response_channel.send(Ok(data_str.into())).is_err() {
185                    error!("Custom request response channel closed before response sent");
186                }
187            }
188            Request::SelectTimeout => {
189                if let Some(sender) = select_request.take() {
190                    sender.send(Ok(SelectResult::Timeout));
191                }
192            }
193            Request::Scan(response_channel) => {
194                scan_requests.push(response_channel);
195                if let Err(e) = socket_handle.command(b"SCAN").await {
196                    debug!("Error while requesting SCAN: {e}");
197                }
198            }
199            Request::Networks(response_channel) => {
200                let _n = socket_handle.socket.send(b"LIST_NETWORKS").await?;
201                let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?;
202                let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end();
203                let network_list =
204                    NetworkResult::vec_from_str(data_str, &mut socket_handle.socket).await?;
205                if response_channel.send(Ok(network_list)).is_err() {
206                    error!("Scan request response channel closed before response sent");
207                }
208            }
209            Request::Status(response_channel) => {
210                let status = Self::get_status(socket_handle).await;
211                if response_channel.send(status).is_err() {
212                    error!("Scan request response channel closed before response sent");
213                }
214            }
215            Request::AddNetwork(response_channel) => {
216                let _n = socket_handle.socket.send(b"ADD_NETWORK").await?;
217                let n = socket_handle.socket.recv(&mut socket_handle.buffer).await?;
218                let data_str = std::str::from_utf8(&socket_handle.buffer[..n])?.trim_end();
219                let network_id = usize::from_str(data_str)?;
220                if response_channel.send(Ok(network_id)).is_err() {
221                    error!("Scan request response channel closed before response sent");
222                } else {
223                    debug!("wpa_ctrl created network {network_id}");
224                }
225            }
226            Request::SetNetwork(id, param, response) => {
227                let cmd = format!(
228                    "SET_NETWORK {id} {}",
229                    match param {
230                        SetNetwork::Ssid(ssid) => format!("ssid \"{ssid}\""),
231                        SetNetwork::Bssid(bssid) => format!("bssid \"{bssid}\""),
232                        SetNetwork::Psk(psk) => format!("psk \"{psk}\""),
233                        SetNetwork::KeyMgmt(mgmt) => format!("key_mgmt {}", mgmt),
234                    }
235                );
236                debug!("wpa_ctrl \"{cmd}\"");
237                let bytes = cmd.into_bytes();
238                if let Err(e) = socket_handle.command(&bytes).await {
239                    warn!("Error while setting network parameter: {e}");
240                }
241                let _ = response.send(Ok(()));
242            }
243            Request::SaveConfig(response) => {
244                if let Err(e) = socket_handle.command(b"SAVE_CONFIG").await {
245                    warn!("Error while saving config: {e}");
246                }
247                debug!("wpa_ctrl config saved");
248                let _ = response.send(Ok(()));
249            }
250            Request::RemoveNetwork(id, response) => {
251                let cmd = format!("REMOVE_NETWORK {id}");
252                let bytes = cmd.into_bytes();
253                if let Err(e) = socket_handle.command(&bytes).await {
254                    warn!("Error while removing network {id}: {e}");
255                }
256                debug!("wpa_ctrl removed network {id}");
257                let _ = response.send(Ok(()));
258            }
259            Request::RemoveAllNetworks(response) => {
260                if let Err(e) = socket_handle.command(b"REMOVE_NETWORK all").await {
261                    warn!("Error while removing network all: {e}");
262                }
263                debug!("wpa_ctrl removed network all");
264                let _ = response.send(Ok(()));
265            }
266            Request::SelectNetwork(id, response_sender) => {
267                let response_sender = match select_request {
268                    None => {
269                        let cmd = format!("SELECT_NETWORK {id}");
270                        let bytes = cmd.into_bytes();
271                        if let Err(e) = socket_handle.command(&bytes).await {
272                            warn!("Error while selecting network {id}: {e}");
273                            let _ = response_sender.send(Ok(SelectResult::InvalidNetworkId));
274                            None
275                        } else {
276                            debug!("wpa_ctrl selected network {id}");
277                            let status = Self::get_status(socket_handle).await?;
278                            if let Some(current_id) = status.get("id") {
279                                if current_id == &id.to_string() {
280                                    let _ =
281                                        response_sender.send(Ok(SelectResult::AlreadyConnected));
282                                    None
283                                } else {
284                                    Some(response_sender)
285                                }
286                            } else {
287                                Some(response_sender)
288                            }
289                        }
290                    }
291                    Some(_) => {
292                        warn!("Select request already pending! Dropping this one.");
293                        let _ = response_sender.send(Ok(SelectResult::PendingSelect));
294                        debug!("wpa_ctrl removed network {id}");
295                        None
296                    }
297                };
298                if let Some(response_sender) = response_sender {
299                    *select_request = Some(SelectRequest::new(
300                        self.self_sender.clone(),
301                        response_sender,
302                        self.select_timeout,
303                    ));
304                }
305            }
306            Request::Shutdown => (), //shutdown is handled at the scope above
307        }
308        Ok(())
309    }
310}
311
312struct SelectRequest {
313    response: oneshot::Sender<Result<SelectResult>>,
314    timeout: tokio::task::JoinHandle<()>,
315}
316
317impl SelectRequest {
318    fn new(
319        sender: mpsc::Sender<Request>,
320        response: oneshot::Sender<Result<SelectResult>>,
321        timeout: Duration,
322    ) -> Self {
323        Self {
324            response,
325            timeout: tokio::task::spawn(async move {
326                tokio::time::sleep(timeout).await;
327                let _ = sender.send(Request::SelectTimeout).await;
328            }),
329        }
330    }
331
332    fn send(self, result: Result<SelectResult>) {
333        self.timeout.abort();
334        let _ = self.response.send(result);
335    }
336}