1use serde::{Deserialize, Serialize};
52use std::collections::HashSet;
53use std::sync::Arc;
54
55use parking_lot::RwLock;
56
57use super::node_identity::{NodeId, NodeIdentity};
58use super::rejection::{KeyspaceRegion, TargetRegion};
59use crate::Result;
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TargetingConfig {
64 pub max_generation_attempts: u32,
66
67 pub min_distance_from_rejected: u8,
69
70 pub distance_weight: f64,
72
73 pub target_weight: f64,
75
76 pub avoidance_weight: f64,
78
79 pub candidates_per_round: usize,
81}
82
83impl Default for TargetingConfig {
84 fn default() -> Self {
85 Self {
86 max_generation_attempts: 100,
87 min_distance_from_rejected: 4,
88 distance_weight: 0.4,
89 target_weight: 0.3,
90 avoidance_weight: 0.3,
91 candidates_per_round: 10,
92 }
93 }
94}
95
96struct IdentityCandidate {
98 identity: NodeIdentity,
100
101 score: f64,
103}
104
105#[derive(Debug, Clone, Hash, PartialEq, Eq)]
107pub struct RejectedPrefix {
108 pub prefix: Vec<u8>,
110
111 pub prefix_len: u8,
113}
114
115impl RejectedPrefix {
116 #[must_use]
118 pub fn new(prefix: Vec<u8>, prefix_len: u8) -> Self {
119 Self { prefix, prefix_len }
120 }
121
122 #[must_use]
124 pub fn matches(&self, node_id: &NodeId) -> bool {
125 let node_bytes = node_id.to_bytes();
126 let full_bytes = self.prefix_len as usize / 8;
127 let remaining_bits = self.prefix_len as usize % 8;
128
129 let prefix_slice = &self.prefix[..full_bytes.min(self.prefix.len())];
131 let node_slice = &node_bytes[..full_bytes.min(node_bytes.len())];
132 for (p, n) in prefix_slice.iter().zip(node_slice.iter()) {
133 if p != n {
134 return false;
135 }
136 }
137
138 if remaining_bits > 0 && full_bytes < self.prefix.len() && full_bytes < node_bytes.len() {
140 let mask = 0xFF << (8 - remaining_bits);
141 if (self.prefix[full_bytes] & mask) != (node_bytes[full_bytes] & mask) {
142 return false;
143 }
144 }
145
146 true
147 }
148
149 #[must_use]
151 pub fn xor_distance_bits(&self, node_id: &NodeId) -> u32 {
152 let node_bytes = node_id.to_bytes();
153 let mut distance = 0u32;
154
155 let full_bytes = self.prefix_len as usize / 8;
156 let remaining_bits = self.prefix_len as usize % 8;
157
158 let prefix_slice = &self.prefix[..full_bytes.min(self.prefix.len())];
160 let node_slice = &node_bytes[..full_bytes.min(node_bytes.len())];
161 for (p, n) in prefix_slice.iter().zip(node_slice.iter()) {
162 let xor = p ^ n;
163 distance += xor.count_ones();
164 }
165
166 if remaining_bits > 0 && full_bytes < self.prefix.len() && full_bytes < node_bytes.len() {
168 let mask = 0xFF << (8 - remaining_bits);
169 let xor = (self.prefix[full_bytes] ^ node_bytes[full_bytes]) & mask;
170 distance += xor.count_ones();
171 }
172
173 distance
174 }
175}
176
177#[derive(Default)]
179struct TargeterState {
180 rejected_prefixes: HashSet<RejectedPrefix>,
182
183 rejected_node_ids: Vec<NodeId>,
185
186 last_target: Option<TargetRegion>,
188
189 total_attempts: u64,
191
192 successful_generations: u64,
194}
195
196pub struct IdentityTargeter {
198 config: TargetingConfig,
200
201 state: RwLock<TargeterState>,
203}
204
205impl IdentityTargeter {
206 #[must_use]
208 pub fn new(config: TargetingConfig) -> Self {
209 Self {
210 config,
211 state: RwLock::new(TargeterState::default()),
212 }
213 }
214
215 pub fn add_rejected_prefix(&self, prefix: Vec<u8>, prefix_len: u8) {
217 let rejected = RejectedPrefix::new(prefix, prefix_len);
218 self.state.write().rejected_prefixes.insert(rejected);
219 }
220
221 pub fn add_rejected_prefixes(&self, prefixes: &[Vec<u8>], prefix_len: u8) {
223 let mut state = self.state.write();
224 for prefix in prefixes {
225 let rejected = RejectedPrefix::new(prefix.clone(), prefix_len);
226 state.rejected_prefixes.insert(rejected);
227 }
228 }
229
230 pub fn record_rejected_node_id(&self, node_id: NodeId) {
232 let mut state = self.state.write();
233 state.rejected_node_ids.push(node_id);
234
235 if state.rejected_node_ids.len() > 100 {
237 state.rejected_node_ids.remove(0);
238 }
239 }
240
241 pub fn set_target(&self, target: Option<TargetRegion>) {
243 self.state.write().last_target = target;
244 }
245
246 pub fn clear_rejected(&self) {
248 let mut state = self.state.write();
249 state.rejected_prefixes.clear();
250 state.rejected_node_ids.clear();
251 }
252
253 pub fn generate_targeted_identity(
258 &self,
259 suggested_target: Option<&TargetRegion>,
260 ) -> Result<NodeIdentity> {
261 let mut state = self.state.write();
262 state.total_attempts += 1;
263
264 if let Some(target) = suggested_target {
266 state.last_target = Some(target.clone());
267 }
268
269 let target = state.last_target.clone();
270 let rejected_prefixes: Vec<_> = state.rejected_prefixes.iter().cloned().collect();
271 let rejected_ids: Vec<_> = state.rejected_node_ids.clone();
272
273 drop(state); let mut best_candidate: Option<IdentityCandidate> = None;
276 let mut attempts = 0u32;
277
278 while attempts < self.config.max_generation_attempts {
279 let candidates: Vec<_> = (0..self.config.candidates_per_round)
281 .filter_map(|_| {
282 attempts += 1;
283 if attempts > self.config.max_generation_attempts {
284 return None;
285 }
286 NodeIdentity::generate().ok()
287 })
288 .collect();
289
290 for identity in candidates {
292 let node_id = identity.node_id();
293
294 let matches_rejected = rejected_prefixes.iter().any(|p| p.matches(node_id));
296 if matches_rejected {
297 continue;
298 }
299
300 let score = self.score_candidate(
302 node_id,
303 target.as_ref(),
304 &rejected_prefixes,
305 &rejected_ids,
306 );
307
308 match &best_candidate {
310 None => {
311 best_candidate = Some(IdentityCandidate { identity, score });
312 }
313 Some(best) if score > best.score => {
314 best_candidate = Some(IdentityCandidate { identity, score });
315 }
316 _ => {}
317 }
318
319 if score > 0.9 {
321 break;
322 }
323 }
324
325 if let Some(ref best) = best_candidate
327 && best.score > 0.7
328 {
329 break;
330 }
331 }
332
333 let identity = match best_candidate {
335 Some(candidate) => {
336 let mut state = self.state.write();
337 state.successful_generations += 1;
338 candidate.identity
339 }
340 None => {
341 NodeIdentity::generate()?
343 }
344 };
345
346 Ok(identity)
347 }
348
349 fn score_candidate(
351 &self,
352 node_id: &NodeId,
353 target: Option<&TargetRegion>,
354 rejected_prefixes: &[RejectedPrefix],
355 rejected_ids: &[NodeId],
356 ) -> f64 {
357 let mut score = 0.0;
358
359 let avoidance_score = if rejected_prefixes.is_empty() {
361 1.0
362 } else {
363 let min_distance = rejected_prefixes
364 .iter()
365 .map(|p| p.xor_distance_bits(node_id))
366 .min()
367 .unwrap_or(u32::MAX);
368
369 let threshold = u32::from(self.config.min_distance_from_rejected);
371 (min_distance as f64 / threshold as f64).min(1.0)
372 };
373 score += avoidance_score * self.config.avoidance_weight;
374
375 let target_score = if let Some(target) = target {
377 if target.region.contains(node_id) {
378 target.confidence
379 } else {
380 let distance = self.xor_distance_to_region(node_id, &target.region);
382 (1.0 - (distance as f64 / 256.0)).max(0.0) * target.confidence
383 }
384 } else {
385 0.5 };
387 score += target_score * self.config.target_weight;
388
389 let id_distance_score = if rejected_ids.is_empty() {
391 1.0
392 } else {
393 let min_distance = rejected_ids
394 .iter()
395 .map(|id| self.leading_zero_distance(node_id, id))
396 .min()
397 .unwrap_or(256);
398
399 (min_distance as f64 / 32.0).min(1.0)
401 };
402 score += id_distance_score * self.config.distance_weight;
403
404 score
405 }
406
407 fn xor_distance_to_region(&self, node_id: &NodeId, region: &KeyspaceRegion) -> u32 {
409 let node_bytes = node_id.to_bytes();
410 let mut distance = 0u32;
411
412 let full_bytes = region.prefix_len as usize / 8;
413 let remaining_bits = region.prefix_len as usize % 8;
414
415 let prefix_slice = ®ion.prefix[..full_bytes.min(region.prefix.len())];
417 let node_slice = &node_bytes[..full_bytes.min(node_bytes.len())];
418 for (p, n) in prefix_slice.iter().zip(node_slice.iter()) {
419 let xor = p ^ n;
420 distance += xor.count_ones();
421 }
422
423 if remaining_bits > 0 && full_bytes < region.prefix.len() && full_bytes < node_bytes.len() {
425 let mask = 0xFF << (8 - remaining_bits);
426 let xor = (region.prefix[full_bytes] ^ node_bytes[full_bytes]) & mask;
427 distance += xor.count_ones();
428 }
429
430 distance
431 }
432
433 fn leading_zero_distance(&self, a: &NodeId, b: &NodeId) -> u32 {
435 let distance = a.xor_distance(b);
436
437 let mut leading_zeros = 0u32;
438 for byte in &distance {
439 if *byte == 0 {
440 leading_zeros += 8;
441 } else {
442 leading_zeros += byte.leading_zeros();
443 break;
444 }
445 }
446
447 leading_zeros
448 }
449
450 #[must_use]
452 pub fn stats(&self) -> TargetingStats {
453 let state = self.state.read();
454 TargetingStats {
455 total_attempts: state.total_attempts,
456 successful_generations: state.successful_generations,
457 rejected_prefix_count: state.rejected_prefixes.len(),
458 rejected_id_count: state.rejected_node_ids.len(),
459 }
460 }
461
462 #[must_use]
464 pub fn has_rejected_prefixes(&self) -> bool {
465 !self.state.read().rejected_prefixes.is_empty()
466 }
467
468 #[must_use]
470 pub fn rejected_prefix_count(&self) -> usize {
471 self.state.read().rejected_prefixes.len()
472 }
473}
474
475#[derive(Debug, Clone, Serialize, Deserialize)]
477pub struct TargetingStats {
478 pub total_attempts: u64,
480
481 pub successful_generations: u64,
483
484 pub rejected_prefix_count: usize,
486
487 pub rejected_id_count: usize,
489}
490
491impl TargetingStats {
492 #[must_use]
494 pub fn success_rate(&self) -> f64 {
495 if self.total_attempts == 0 {
496 1.0
497 } else {
498 self.successful_generations as f64 / self.total_attempts as f64
499 }
500 }
501}
502
503pub type SharedIdentityTargeter = Arc<IdentityTargeter>;
505
506pub struct IdentityTargeterBuilder {
508 config: TargetingConfig,
509 initial_rejected: Vec<RejectedPrefix>,
510}
511
512impl IdentityTargeterBuilder {
513 #[must_use]
515 pub fn new() -> Self {
516 Self {
517 config: TargetingConfig::default(),
518 initial_rejected: Vec::new(),
519 }
520 }
521
522 #[must_use]
524 pub fn max_attempts(mut self, max: u32) -> Self {
525 self.config.max_generation_attempts = max;
526 self
527 }
528
529 #[must_use]
531 pub fn min_distance_from_rejected(mut self, bits: u8) -> Self {
532 self.config.min_distance_from_rejected = bits;
533 self
534 }
535
536 #[must_use]
538 pub fn weights(mut self, distance: f64, target: f64, avoidance: f64) -> Self {
539 self.config.distance_weight = distance;
540 self.config.target_weight = target;
541 self.config.avoidance_weight = avoidance;
542 self
543 }
544
545 #[must_use]
547 pub fn reject_prefix(mut self, prefix: Vec<u8>, prefix_len: u8) -> Self {
548 self.initial_rejected
549 .push(RejectedPrefix::new(prefix, prefix_len));
550 self
551 }
552
553 #[must_use]
555 pub fn build(self) -> IdentityTargeter {
556 let targeter = IdentityTargeter::new(self.config);
557
558 for rejected in self.initial_rejected {
560 targeter.state.write().rejected_prefixes.insert(rejected);
561 }
562
563 targeter
564 }
565}
566
567impl Default for IdentityTargeterBuilder {
568 fn default() -> Self {
569 Self::new()
570 }
571}
572
573#[cfg(test)]
574#[allow(clippy::field_reassign_with_default)]
575mod tests {
576 use super::*;
577
578 fn test_node_id() -> NodeId {
579 NodeId([0x42; 32])
580 }
581
582 #[test]
583 fn test_rejected_prefix_matches() {
584 let prefix = RejectedPrefix::new(vec![0xAB], 8);
585
586 let matching_id = NodeId([0xAB; 32]);
588 assert!(prefix.matches(&matching_id));
589
590 let non_matching_id = NodeId([0x12; 32]);
592 assert!(!prefix.matches(&non_matching_id));
593 }
594
595 #[test]
596 fn test_rejected_prefix_partial_byte() {
597 let prefix = RejectedPrefix::new(vec![0xF0], 4);
599
600 let matching_id = NodeId([0xFF; 32]);
602 assert!(prefix.matches(&matching_id));
603
604 let non_matching_id = NodeId([0x00; 32]);
606 assert!(!prefix.matches(&non_matching_id));
607 }
608
609 #[test]
610 fn test_rejected_prefix_xor_distance() {
611 let prefix = RejectedPrefix::new(vec![0xFF], 8);
612
613 let same = NodeId([0xFF; 32]);
615 assert_eq!(prefix.xor_distance_bits(&same), 0);
616
617 let opposite = NodeId([0x00; 32]);
619 assert_eq!(prefix.xor_distance_bits(&opposite), 8);
620
621 let half = NodeId([0xF0; 32]);
623 assert_eq!(prefix.xor_distance_bits(&half), 4);
624 }
625
626 #[test]
627 fn test_identity_targeter_creation() {
628 let config = TargetingConfig::default();
629 let targeter = IdentityTargeter::new(config);
630
631 assert!(!targeter.has_rejected_prefixes());
632 assert_eq!(targeter.rejected_prefix_count(), 0);
633 }
634
635 #[test]
636 fn test_add_rejected_prefix() {
637 let config = TargetingConfig::default();
638 let targeter = IdentityTargeter::new(config);
639
640 targeter.add_rejected_prefix(vec![0xAB], 8);
641
642 assert!(targeter.has_rejected_prefixes());
643 assert_eq!(targeter.rejected_prefix_count(), 1);
644 }
645
646 #[test]
647 fn test_generate_targeted_identity() {
648 let config = TargetingConfig::default();
649 let targeter = IdentityTargeter::new(config);
650
651 let identity = targeter.generate_targeted_identity(None);
653 assert!(identity.is_ok());
654 }
655
656 #[test]
657 fn test_generate_targeted_identity_with_rejected() {
658 let mut config = TargetingConfig::default();
659 config.max_generation_attempts = 50;
660 let targeter = IdentityTargeter::new(config);
661
662 targeter.add_rejected_prefix(vec![0x00], 4);
664 targeter.add_rejected_prefix(vec![0x10], 4);
665
666 let identity = targeter.generate_targeted_identity(None);
667 assert!(identity.is_ok());
668
669 let node_id = identity.unwrap();
671 let id_bytes = node_id.node_id().to_bytes();
672
673 let first_nibble = id_bytes[0] >> 4;
675 assert!(first_nibble <= 15);
677 }
678
679 #[test]
680 fn test_targeting_stats() {
681 let config = TargetingConfig::default();
682 let targeter = IdentityTargeter::new(config);
683
684 let stats = targeter.stats();
686 assert_eq!(stats.total_attempts, 0);
687 assert_eq!(stats.successful_generations, 0);
688
689 for _ in 0..3 {
691 let _ = targeter.generate_targeted_identity(None);
692 }
693
694 let stats = targeter.stats();
695 assert_eq!(stats.total_attempts, 3);
696 assert!(stats.successful_generations <= 3);
697 }
698
699 #[test]
700 fn test_record_rejected_node_id() {
701 let config = TargetingConfig::default();
702 let targeter = IdentityTargeter::new(config);
703
704 targeter.record_rejected_node_id(test_node_id());
705
706 let stats = targeter.stats();
707 assert_eq!(stats.rejected_id_count, 1);
708 }
709
710 #[test]
711 fn test_clear_rejected() {
712 let config = TargetingConfig::default();
713 let targeter = IdentityTargeter::new(config);
714
715 targeter.add_rejected_prefix(vec![0xAB], 8);
716 targeter.record_rejected_node_id(test_node_id());
717
718 assert!(targeter.has_rejected_prefixes());
719
720 targeter.clear_rejected();
721
722 assert!(!targeter.has_rejected_prefixes());
723 assert_eq!(targeter.stats().rejected_id_count, 0);
724 }
725
726 #[test]
727 fn test_builder() {
728 let targeter = IdentityTargeterBuilder::new()
729 .max_attempts(50)
730 .min_distance_from_rejected(6)
731 .weights(0.5, 0.3, 0.2)
732 .reject_prefix(vec![0xAB], 8)
733 .build();
734
735 assert!(targeter.has_rejected_prefixes());
736 assert_eq!(targeter.rejected_prefix_count(), 1);
737 }
738
739 #[test]
740 fn test_targeting_stats_success_rate() {
741 let stats = TargetingStats {
742 total_attempts: 10,
743 successful_generations: 8,
744 rejected_prefix_count: 2,
745 rejected_id_count: 5,
746 };
747
748 assert!((stats.success_rate() - 0.8).abs() < f64::EPSILON);
749 }
750
751 #[test]
752 fn test_targeting_stats_zero_attempts() {
753 let stats = TargetingStats {
754 total_attempts: 0,
755 successful_generations: 0,
756 rejected_prefix_count: 0,
757 rejected_id_count: 0,
758 };
759
760 assert!((stats.success_rate() - 1.0).abs() < f64::EPSILON);
761 }
762}