1use std::net::{IpAddr, Ipv4Addr};
2
3use netlink_packet_core::{
4 NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_EXCL, NLM_F_REQUEST, NetlinkMessage, NetlinkPayload,
5};
6use netlink_packet_route::address::{AddressAttribute, AddressMessage};
7use netlink_packet_route::{AddressFamily, RouteNetlinkMessage};
8
9use crate::network::traits::{Client, NetlinkMessageHandler};
10use crate::network::wrapper::ClientWrapper;
11use crate::network::{NetlinkResponse, NetworkError, Result};
12
13pub struct AddressMessageHandler {
18 target_index: Option<u32>,
19}
20
21impl AddressMessageHandler {
22 pub fn new() -> Self {
23 Self { target_index: None }
24 }
25
26 pub fn with_index(index: u32) -> Self {
27 Self {
28 target_index: Some(index),
29 }
30 }
31}
32
33impl Default for AddressMessageHandler {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl NetlinkMessageHandler for AddressMessageHandler {
40 type Response = AddressMessage;
41
42 fn handle_payload(
43 &self,
44 payload: NetlinkPayload<RouteNetlinkMessage>,
45 ) -> Result<NetlinkResponse<Self::Response>> {
46 match payload {
47 NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewAddress(addr)) => {
48 if let Some(target_index) = self.target_index {
49 if addr.header.index == target_index {
50 Ok(NetlinkResponse::Success(addr))
51 } else {
52 Ok(NetlinkResponse::None)
53 }
54 } else {
55 Ok(NetlinkResponse::Success(addr))
56 }
57 }
58 NetlinkPayload::Error(e) => match e.code {
59 None => Ok(NetlinkResponse::Success(AddressMessage::default())),
60 Some(code) => Ok(NetlinkResponse::Error(code.get())),
61 },
62 NetlinkPayload::Done(_) => Ok(NetlinkResponse::Done),
63 _ => Err(NetworkError::IO(std::io::Error::other(format!(
64 "Unexpected message type: {:?}",
65 payload
66 )))),
67 }
68 }
69}
70
71pub struct AddressClient {
76 client: ClientWrapper,
77}
78
79impl AddressClient {
80 pub fn new(client: ClientWrapper) -> Result<Self> {
86 Ok(Self { client })
87 }
88
89 pub fn get_by_index(&mut self, index: u32) -> Result<Vec<AddressMessage>> {
99 let mut message = AddressMessage::default();
100 message.header.index = index;
101 let mut req = NetlinkMessage::from(RouteNetlinkMessage::GetAddress(message));
102 req.header.flags = NLM_F_REQUEST | NLM_F_DUMP;
105 req.finalize();
106
107 let handler = AddressMessageHandler::with_index(index);
108
109 self.client.send_and_receive_multiple(&req, handler)
110 }
111
112 pub fn add(&mut self, index: u32, address: IpAddr, prefix_len: u8) -> Result<()> {
124 let message = self.create_address_request(index, address, prefix_len)?;
125
126 let mut req = NetlinkMessage::from(RouteNetlinkMessage::NewAddress(message));
127 req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
132 req.finalize();
133
134 let handler = AddressMessageHandler::new();
135
136 self.client.send_and_receive(&req, handler)?;
137 Ok(())
138 }
139
140 fn create_address_request(
153 &self,
154 index: u32,
155 address: IpAddr,
156 prefix_len: u8,
157 ) -> Result<AddressMessage> {
158 let mut message = AddressMessage::default();
159 message.header.prefix_len = prefix_len;
160 message.header.index = index;
161 message.header.family = match address {
162 IpAddr::V4(_) => AddressFamily::Inet,
163 IpAddr::V6(_) => AddressFamily::Inet6,
164 };
165
166 if address.is_multicast() {
167 if let IpAddr::V6(a) = address {
168 message.attributes.push(AddressAttribute::Multicast(a));
169 }
170 } else {
171 message.attributes.push(AddressAttribute::Address(address));
172 message.attributes.push(AddressAttribute::Local(address));
173
174 if let IpAddr::V4(a) = address {
175 if prefix_len == 32 {
176 message.attributes.push(AddressAttribute::Broadcast(a));
177 } else {
178 let ip_addr = u32::from(a);
179 let brd =
180 Ipv4Addr::from(((0xffff_ffff_u32) >> u32::from(prefix_len)) | ip_addr);
181 message.attributes.push(AddressAttribute::Broadcast(brd));
182 };
183 }
184 }
185
186 Ok(message)
187 }
188
189 #[cfg(test)]
190 pub fn get_send_calls(
192 &self,
193 ) -> Option<&[netlink_packet_core::NetlinkMessage<RouteNetlinkMessage>]> {
194 if let ClientWrapper::Fake(fake_client) = &self.client {
195 Some(fake_client.get_send_calls())
196 } else {
197 None
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use serial_test::serial;
205
206 use super::*;
207 use crate::network::fake::FakeNetlinkClient;
208 use crate::network::wrapper::create_network_client;
209
210 #[test]
211 #[serial]
212 fn test_address_message_handler_success() {
213 let handler = AddressMessageHandler::new();
214 let mut addr_msg = AddressMessage::default();
215 addr_msg.header.index = 1;
216 addr_msg
217 .attributes
218 .push(AddressAttribute::Address(IpAddr::V4(Ipv4Addr::new(
219 192, 168, 1, 1,
220 ))));
221
222 let payload =
223 NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewAddress(addr_msg.clone()));
224 let result = handler.handle_payload(payload);
225
226 assert!(result.is_ok());
227 match result.unwrap() {
228 NetlinkResponse::Success(response) => {
229 assert_eq!(response.header.index, 1);
230 assert_eq!(response.attributes.len(), 1);
231 }
232 _ => panic!("Expected Success response"),
233 }
234 }
235
236 #[test]
237 #[serial]
238 fn test_address_message_handler_errorcode_zero() {
239 let handler = AddressMessageHandler::new();
240 let mut error_msg = netlink_packet_core::ErrorMessage::default();
241 error_msg.code = std::num::NonZeroI32::new(0);
242 let error_payload = NetlinkPayload::Error(error_msg);
243 let result = handler.handle_payload(error_payload);
244
245 assert!(result.is_ok());
246 match result.unwrap() {
247 NetlinkResponse::Success(_) => {}
248 _ => panic!("Expected Success response"),
249 }
250 }
251
252 #[test]
253 #[serial]
254 fn test_address_message_handler_error() {
255 let handler = AddressMessageHandler::new();
256 let mut error_msg = netlink_packet_core::ErrorMessage::default();
257 error_msg.code = std::num::NonZeroI32::new(1);
258 let error_payload = NetlinkPayload::Error(error_msg);
259 let result = handler.handle_payload(error_payload);
260
261 assert!(result.is_ok());
262 match result.unwrap() {
263 NetlinkResponse::Error(code) => {
264 assert_eq!(code, 1);
265 }
266 _ => panic!("Expected Error response"),
267 }
268 }
269
270 #[test]
271 #[serial]
272 fn test_address_message_handler_done() {
273 let handler = AddressMessageHandler::new();
274 let done_payload = NetlinkPayload::Done(netlink_packet_core::DoneMessage::default());
275 let result = handler.handle_payload(done_payload);
276
277 assert!(result.is_ok());
278 match result.unwrap() {
279 NetlinkResponse::Done => {}
280 _ => panic!("Expected Done response"),
281 }
282 }
283
284 #[test]
285 #[serial]
286 fn test_address_message_handler_unexpected() {
287 let handler = AddressMessageHandler::new();
288 let unexpected_payload = NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewLink(
289 netlink_packet_route::link::LinkMessage::default(),
290 ));
291 let result = handler.handle_payload(unexpected_payload);
292
293 assert!(result.is_err());
294 }
295
296 #[test]
297 #[serial]
298 fn test_address_client_new() {
299 let result = AddressClient::new(create_network_client());
300 assert!(result.is_ok());
301 }
302
303 #[test]
304 #[serial]
305 fn test_address_client_get_by_index_failure() {
306 let mut fake_client = FakeNetlinkClient::new();
307 fake_client.set_failure("Get by index failed".to_string());
308
309 let mut addr_client = AddressClient::new(ClientWrapper::Fake(fake_client)).unwrap();
310 let result = addr_client.get_by_index(1);
311
312 assert!(result.is_err());
313 }
314
315 #[test]
316 #[serial]
317 fn test_address_client_get_by_index_without_response() {
318 let fake_client = FakeNetlinkClient::new();
319 let mut addr_client = AddressClient::new(ClientWrapper::Fake(fake_client)).unwrap();
320 let result = addr_client.get_by_index(1);
321
322 assert!(result.is_ok());
324 }
325
326 #[test]
327 #[serial]
328 fn test_address_client_get_by_index_with_multiple_responses() {
329 let mut fake_client = FakeNetlinkClient::new();
330
331 let mut addr1 = AddressMessage::default();
333 addr1.header.index = 1;
334 addr1
335 .attributes
336 .push(AddressAttribute::Address(IpAddr::V4(Ipv4Addr::new(
337 192, 168, 1, 1,
338 ))));
339
340 let mut addr2 = AddressMessage::default();
341 addr2.header.index = 1;
342 addr2
343 .attributes
344 .push(AddressAttribute::Address(IpAddr::V4(Ipv4Addr::new(
345 192, 168, 1, 2,
346 ))));
347
348 let responses = vec![
349 RouteNetlinkMessage::NewAddress(addr1),
350 RouteNetlinkMessage::NewAddress(addr2),
351 ];
352 fake_client.set_expected_responses(responses);
353
354 let mut addr_client = AddressClient::new(ClientWrapper::Fake(fake_client)).unwrap();
355 let result = addr_client.get_by_index(1);
356
357 assert!(result.is_ok());
359 let responses = result.unwrap();
360 assert_eq!(responses.len(), 2);
361 assert_eq!(responses[0].header.index, 1);
362 assert_eq!(responses[1].header.index, 1);
363 }
364
365 #[test]
366 #[serial]
367 fn test_address_client_get_by_index_success() {
368 let mut fake_client = FakeNetlinkClient::new();
369
370 let responses = vec![RouteNetlinkMessage::NewAddress(AddressMessage::default())];
371 fake_client.set_expected_responses(responses);
372
373 let client_wrapper = ClientWrapper::Fake(fake_client);
374 let mut addr_client = AddressClient::new(client_wrapper).unwrap();
375
376 let result = addr_client.get_by_index(42);
377 assert!(result.is_ok());
378
379 if let Some(send_calls) = addr_client.get_send_calls() {
381 assert_eq!(send_calls.len(), 1);
382
383 if let NetlinkPayload::InnerMessage(RouteNetlinkMessage::GetAddress(addr)) =
385 &send_calls[0].payload
386 {
387 assert_eq!(addr.header.index, 42);
388 } else {
389 panic!("Expected GetAddress message");
390 }
391
392 let expected_flags = NLM_F_REQUEST | NLM_F_DUMP;
394 assert_eq!(send_calls[0].header.flags, expected_flags);
395 } else {
396 panic!("Expected Fake client");
397 }
398 }
399
400 #[test]
401 #[serial]
402 fn test_address_client_add_failure() {
403 let mut fake_client = FakeNetlinkClient::new();
404 fake_client.set_failure("Add address failed".to_string());
405
406 let mut addr_client = AddressClient::new(ClientWrapper::Fake(fake_client)).unwrap();
407 let result = addr_client.add(1, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 24);
408
409 assert!(result.is_err());
410 }
411
412 #[test]
413 #[serial]
414 fn test_address_client_add_success() {
415 let mut fake_client = FakeNetlinkClient::new();
416
417 let responses = vec![RouteNetlinkMessage::NewAddress(AddressMessage::default())];
418 fake_client.set_expected_responses(responses);
419
420 let client_wrapper = ClientWrapper::Fake(fake_client);
421 let mut addr_client = AddressClient::new(client_wrapper).unwrap();
422
423 let result = addr_client.add(42, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 16);
424 assert!(result.is_ok());
425
426 if let Some(send_calls) = addr_client.get_send_calls() {
428 assert_eq!(send_calls.len(), 1);
429
430 if let NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewAddress(addr)) =
432 &send_calls[0].payload
433 {
434 assert_eq!(addr.header.index, 42);
435 assert_eq!(addr.header.prefix_len, 16);
436 assert_eq!(addr.header.family, AddressFamily::Inet);
437 assert_eq!(addr.attributes.len(), 3); let mut found_address = false;
441 for attr in &addr.attributes {
442 if let AddressAttribute::Address(ip) = attr {
443 assert_eq!(*ip, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
444 found_address = true;
445 break;
446 }
447 }
448 assert!(found_address, "Address attribute not found");
449 } else {
450 panic!("Expected NewAddress message");
451 }
452
453 let expected_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
455 assert_eq!(send_calls[0].header.flags, expected_flags);
456 } else {
457 panic!("Expected Fake client");
458 }
459 }
460
461 #[test]
462 #[serial]
463 fn test_address_client_add_with_different_parameters() {
464 let mut fake_client = FakeNetlinkClient::new();
465 let test_cases = vec![
466 (1, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 24),
467 (10, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 16),
468 (
469 100,
470 IpAddr::V6(std::net::Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
471 64,
472 ),
473 ];
474 let responses: Vec<_> = test_cases
475 .iter()
476 .map(|(index, address, prefix_len)| {
477 let mut msg = AddressMessage::default();
478 msg.header.index = *index;
479 msg.header.prefix_len = *prefix_len;
480 msg.header.family = match address {
481 IpAddr::V4(_) => AddressFamily::Inet,
482 IpAddr::V6(_) => AddressFamily::Inet6,
483 };
484 msg.attributes.push(AddressAttribute::Address(*address));
485 msg.attributes.push(AddressAttribute::Local(*address));
486 if let IpAddr::V4(a) = address {
487 msg.attributes.push(AddressAttribute::Broadcast(*a));
488 }
489 RouteNetlinkMessage::NewAddress(msg)
490 })
491 .collect();
492 fake_client.set_expected_responses(responses);
493
494 let client_wrapper = ClientWrapper::Fake(fake_client);
495 let mut addr_client = AddressClient::new(client_wrapper).unwrap();
496
497 let test_cases_clone = test_cases.clone();
499
500 for (index, address, prefix_len) in test_cases {
501 let result = addr_client.add(index, address, prefix_len);
502 assert!(
503 result.is_ok(),
504 "add failed for index {}, address {:?}, prefix_len {}",
505 index,
506 address,
507 prefix_len
508 );
509 }
510
511 if let Some(send_calls) = addr_client.get_send_calls() {
513 assert_eq!(send_calls.len(), test_cases_clone.len());
514
515 for (i, (index, address, prefix_len)) in test_cases_clone.iter().enumerate() {
516 if let NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewAddress(addr)) =
517 &send_calls[i].payload
518 {
519 assert_eq!(addr.header.index, *index);
520 assert_eq!(addr.header.prefix_len, *prefix_len);
521
522 let expected_family = match address {
524 IpAddr::V4(_) => AddressFamily::Inet,
525 IpAddr::V6(_) => AddressFamily::Inet6,
526 };
527 assert_eq!(addr.header.family, expected_family);
528
529 let mut found_address = false;
531 for attr in &addr.attributes {
532 if let AddressAttribute::Address(ip) = attr {
533 assert_eq!(*ip, *address);
534 found_address = true;
535 break;
536 }
537 }
538 assert!(
539 found_address,
540 "Address attribute not found for index {}",
541 index
542 );
543 } else {
544 panic!("Expected NewAddress message for index {}", index);
545 }
546 }
547 } else {
548 panic!("Expected Fake client");
549 }
550 }
551
552 #[test]
553 #[serial]
554 fn test_create_address_request_ipv4() {
555 let addr_client = AddressClient::new(create_network_client()).unwrap();
556 let result =
557 addr_client.create_address_request(1, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 24);
558
559 assert!(result.is_ok());
560 let message = result.unwrap();
561 assert_eq!(message.header.index, 1);
562 assert_eq!(message.header.prefix_len, 24);
563 assert_eq!(message.header.family, AddressFamily::Inet);
564 assert_eq!(message.attributes.len(), 3); }
566
567 #[test]
568 #[serial]
569 fn test_create_address_request_ipv6() {
570 let addr_client = AddressClient::new(create_network_client()).unwrap();
571 let result = addr_client.create_address_request(
572 1,
573 IpAddr::V6(std::net::Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
574 64,
575 );
576
577 assert!(result.is_ok());
578 let message = result.unwrap();
579 assert_eq!(message.header.index, 1);
580 assert_eq!(message.header.prefix_len, 64);
581 assert_eq!(message.header.family, AddressFamily::Inet6);
582 assert_eq!(message.attributes.len(), 2); }
584}