rust_network_mgr/
network.rs

1use crate::types::{AppError, NetworkEvent, NetworkEventSender, Result};
2use futures::stream::{StreamExt, TryStreamExt};
3// Import only the minimum necessary from rtnetlink
4use rtnetlink::{
5    new_connection,
6};
7// Import the netlink_packet_core crate directly for the message types
8use netlink_packet_core::{
9    NetlinkMessage, NetlinkPayload,
10};
11// Import the netlink_packet_route crate directly for the route-specific types
12use netlink_packet_route::{
13    address::AddressMessage,
14    link::LinkMessage,
15    RouteNetlinkMessage,
16};
17use std::collections::HashMap;
18use std::net::IpAddr;
19
20/// Monitors network interface and address changes using rtnetlink.
21pub struct NetworkMonitor {
22    event_sender: NetworkEventSender,
23    // Store interface index to name mapping for easier lookup
24    if_index_to_name: HashMap<u32, String>,
25    // Store current IPs per interface index
26    current_ips: HashMap<u32, Vec<IpAddr>>,
27}
28
29impl NetworkMonitor {
30    pub fn new(event_sender: NetworkEventSender) -> Self {
31        NetworkMonitor {
32            event_sender,
33            if_index_to_name: HashMap::new(),
34            current_ips: HashMap::new(),
35        }
36    }
37
38    /// Starts the monitoring loop.
39    /// This function will run indefinitely until an error occurs or the stream ends.
40    pub async fn start(mut self) -> Result<()> {
41        tracing::info!("Starting network monitor...");
42
43        // Use new_connection for simple setup
44        let (connection, handle, mut messages) = new_connection().map_err(|e| {
45             AppError::Init(format!("Failed to create netlink connection: {}", e))
46        })?;
47        tokio::spawn(connection);
48        
49        tracing::info!("Listening for netlink address and link events...");
50
51        // --- Initial State Population ---
52        tracing::debug!("Gathering initial network state...");
53
54        // 1. Get Interfaces to map index to name
55        let mut links = handle.link().get().execute();
56        while let Some(link) = links.try_next().await? {
57            let mut name = None;
58            for nla in link.attributes.iter() {
59                if let netlink_packet_route::link::LinkAttribute::IfName(if_name) = nla {
60                    name = Some(if_name.clone());
61                    break;
62                }
63            }
64            if let Some(name) = name {
65                tracing::debug!("Found interface: index={}, name={}", link.header.index, name);
66                self.if_index_to_name.insert(link.header.index, name);
67            }
68        }
69        tracing::debug!("Interface map populated: {:?}", self.if_index_to_name);
70
71        // 2. Get Addresses for initial state
72        let mut addresses = handle.address().get().execute();
73        while let Some(msg) = addresses.try_next().await? {
74            let if_index = msg.header.index;
75            if let Some(if_name) = self.if_index_to_name.get(&if_index) {
76                for nla in msg.attributes.iter() {
77                    if let netlink_packet_route::address::AddressAttribute::Address(ip_addr) = nla {
78                        let ip = ip_addr;
79                        
80                        tracing::info!(
81                            "Initial state: Found IP {} for interface {} ({})",
82                            ip, if_name, if_index
83                        );
84                        let ips = self.current_ips.entry(if_index).or_default();
85                        if !ips.contains(&ip) {
86                            ips.push(*ip);
87                            // Optionally send initial state as events
88                            // self.send_event(NetworkEvent::IpAdded(if_name.clone(), *ip)).await?;
89                        }
90                    }
91                }
92            }
93        }
94         tracing::debug!("Initial IP state populated: {:?}", self.current_ips);
95
96        // --- Listen for Events ---
97        loop {
98            match messages.next().await {
99                Some((message, _addr)) => {
100                    if let Err(e) = self.handle_netlink_message(message).await {
101                        tracing::error!("Error handling netlink message: {}", e);
102                    }
103                }
104                None => {
105                     tracing::warn!("Netlink message stream ended unexpectedly.");
106                     break;
107                }
108            }
109        }
110
111        Ok(())
112    }
113
114    async fn handle_netlink_message(&mut self, message: NetlinkMessage<RouteNetlinkMessage>) -> Result<()> {
115         match message.payload {
116            NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewAddress(msg)) => {
117                self.handle_address_change(msg, true).await?;
118            }
119            NetlinkPayload::InnerMessage(RouteNetlinkMessage::DelAddress(msg)) => {
120                self.handle_address_change(msg, false).await?;
121            }
122             NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewLink(msg)) => {
123                self.handle_link_change(msg, true).await?;
124            }
125            NetlinkPayload::InnerMessage(RouteNetlinkMessage::DelLink(msg)) => {
126                self.handle_link_change(msg, false).await?;
127            }
128            NetlinkPayload::Error(err) => {
129                tracing::error!("Received netlink error message: {:?}", err);
130            }
131            _ => {
132                // tracing::trace!("Ignoring other netlink message type: {:?}", message.payload);
133            }
134         }
135         Ok(())
136    }
137
138    async fn handle_address_change(&mut self, msg: AddressMessage, is_add: bool) -> Result<()> {
139        let if_index = msg.header.index;
140
141        if let Some(if_name) = self.if_index_to_name.get(&if_index).cloned() {
142            for nla in msg.attributes.iter() {
143                if let netlink_packet_route::address::AddressAttribute::Address(ip_addr) = nla {
144                    let ip = ip_addr;
145
146                    if is_add {
147                        tracing::info!("Detected IP Added: {} on {}", ip, if_name);
148                        let ips = self.current_ips.entry(if_index).or_default();
149                        if !ips.contains(&ip) {
150                             ips.push(*ip);
151                             self.send_event(NetworkEvent::IpAdded(if_name.clone(), *ip)).await?;
152                        }
153                    } else {
154                        tracing::info!("Detected IP Removed: {} from {}", ip, if_name);
155                        if let Some(ips) = self.current_ips.get_mut(&if_index) {
156                             if let Some(pos) = ips.iter().position(|&x| x == *ip) {
157                                ips.remove(pos);
158                                self.send_event(NetworkEvent::IpRemoved(if_name.clone(), *ip)).await?;
159                             }
160                        }
161                    }
162                }
163            }
164        } else {
165            tracing::warn!("Received address event for unknown interface index: {}", if_index);
166        }
167        Ok(())
168    }
169
170     async fn handle_link_change(&mut self, msg: LinkMessage, is_add: bool) -> Result<()> {
171        let if_index = msg.header.index;
172
173        if is_add {
174            let mut name = None;
175            for nla in msg.attributes.iter() {
176                if let netlink_packet_route::link::LinkAttribute::IfName(if_name) = nla {
177                    name = Some(if_name.clone());
178                    break;
179                }
180            }
181            if let Some(name) = name {
182                 tracing::info!("Detected Interface Added/Updated: index={}, name={}", if_index, name);
183                 self.if_index_to_name.insert(if_index, name);
184            }
185        } else {
186            if let Some(removed_name) = self.if_index_to_name.remove(&if_index) {
187                 tracing::info!("Detected Interface Removed: index={}, name={}", if_index, removed_name);
188                 if let Some(removed_ips) = self.current_ips.remove(&if_index) {
189                     for ip in removed_ips {
190                          self.send_event(NetworkEvent::IpRemoved(removed_name.clone(), ip)).await?;
191                     }
192                 }
193            } else {
194                tracing::debug!("Ignoring DelLink for unknown index: {}", if_index);
195            }
196        }
197         Ok(())
198     }
199
200    async fn send_event(&self, event: NetworkEvent) -> Result<()> {
201        self.event_sender
202            .send(event.clone())
203            .await
204            .map_err(|e| AppError::ChannelSend(format!("Failed to send network event {:?}: {}", event, e)))
205    }
206}
207
208// Testing rtnetlink still requires specific setup (like network namespaces) or root privileges.