1use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13use crate::core::errors::{Result, ValknutError};
14
15#[cfg(test)]
16#[path = "featureset_tests.rs"]
17mod tests;
18
19pub type EntityId = String;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
24pub struct FeatureDefinition {
25 pub name: String,
27
28 pub description: String,
30
31 pub data_type: String,
33
34 pub min_value: Option<f64>,
36
37 pub max_value: Option<f64>,
39
40 pub default_value: f64,
42
43 pub higher_is_worse: bool,
45}
46
47impl FeatureDefinition {
49 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
51 Self {
52 name: name.into(),
53 description: description.into(),
54 data_type: "f64".to_string(),
55 min_value: None,
56 max_value: None,
57 default_value: 0.0,
58 higher_is_worse: true,
59 }
60 }
61
62 pub fn with_range(mut self, min_value: f64, max_value: f64) -> Self {
64 self.min_value = Some(min_value);
65 self.max_value = Some(max_value);
66 self
67 }
68
69 pub fn with_default(mut self, default_value: f64) -> Self {
71 self.default_value = default_value;
72 self
73 }
74
75 pub fn with_polarity(mut self, higher_is_worse: bool) -> Self {
77 self.higher_is_worse = higher_is_worse;
78 self
79 }
80
81 pub fn is_valid_value(&self, value: f64) -> bool {
83 if value.is_nan() || value.is_infinite() {
84 return false;
85 }
86
87 if let Some(min) = self.min_value {
88 if value < min {
89 return false;
90 }
91 }
92
93 if let Some(max) = self.max_value {
94 if value > max {
95 return false;
96 }
97 }
98
99 true
100 }
101
102 pub fn clamp_value(&self, value: f64) -> f64 {
104 if value.is_nan() || value.is_infinite() {
105 return self.default_value;
106 }
107
108 let mut clamped = value;
109
110 if let Some(min) = self.min_value {
111 if clamped < min {
112 clamped = min;
113 }
114 }
115
116 if let Some(max) = self.max_value {
117 if clamped > max {
118 clamped = max;
119 }
120 }
121
122 clamped
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct FeatureVector {
129 pub entity_id: EntityId,
131
132 pub features: HashMap<String, f64>,
134
135 pub normalized_features: HashMap<String, f64>,
137
138 pub metadata: HashMap<String, serde_json::Value>,
140
141 pub refactoring_suggestions: Vec<RefactoringSuggestion>,
143}
144
145impl FeatureVector {
147 pub fn new(entity_id: impl Into<EntityId>) -> Self {
149 Self {
150 entity_id: entity_id.into(),
151 features: HashMap::new(),
152 normalized_features: HashMap::new(),
153 metadata: HashMap::new(),
154 refactoring_suggestions: Vec::new(),
155 }
156 }
157
158 pub fn add_feature(&mut self, name: impl Into<String>, value: f64) -> &mut Self {
160 self.features.insert(name.into(), value);
161 self
162 }
163
164 pub fn get_feature(&self, name: &str) -> Option<f64> {
166 self.features.get(name).copied()
167 }
168
169 pub fn get_normalized_feature(&self, name: &str) -> Option<f64> {
171 self.normalized_features.get(name).copied()
172 }
173
174 pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) -> &mut Self {
176 self.metadata.insert(key.into(), value);
177 self
178 }
179
180 pub fn add_suggestion(&mut self, suggestion: RefactoringSuggestion) -> &mut Self {
182 self.refactoring_suggestions.push(suggestion);
183 self
184 }
185
186 pub fn feature_count(&self) -> usize {
188 self.features.len()
189 }
190
191 pub fn has_feature(&self, name: &str) -> bool {
193 self.features.contains_key(name)
194 }
195
196 pub fn feature_names(&self) -> impl Iterator<Item = &String> {
198 self.features.keys()
199 }
200
201 pub fn l2_norm(&self) -> f64 {
203 self.features.values().map(|v| v * v).sum::<f64>().sqrt()
204 }
205
206 pub fn cosine_similarity(&self, other: &Self) -> f64 {
208 let mut dot_product = 0.0;
209 let mut norm_self_squared = 0.0;
210 let mut norm_other_squared = 0.0;
211
212 for (name, &value_a) in &self.features {
214 norm_self_squared += value_a * value_a;
215
216 if let Some(&value_b) = other.features.get(name) {
217 dot_product += value_a * value_b;
218 }
219 }
220
221 for &value_b in other.features.values() {
222 norm_other_squared += value_b * value_b;
223 }
224
225 let denominator = (norm_self_squared * norm_other_squared).sqrt();
226 if denominator == 0.0 {
227 0.0
228 } else {
229 dot_product / denominator
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct RefactoringSuggestion {
237 pub refactoring_type: String,
239
240 pub description: String,
242
243 pub priority: f64,
245
246 pub confidence: f64,
248
249 pub location: Option<serde_json::Value>,
251
252 pub context: Option<String>,
254}
255
256impl RefactoringSuggestion {
258 pub fn new(
260 refactoring_type: impl Into<String>,
261 description: impl Into<String>,
262 priority: f64,
263 confidence: f64,
264 ) -> Self {
265 Self {
266 refactoring_type: refactoring_type.into(),
267 description: description.into(),
268 priority: priority.clamp(0.0, 1.0),
269 confidence: confidence.clamp(0.0, 1.0),
270 location: None,
271 context: None,
272 }
273 }
274
275 pub fn with_location(mut self, location: serde_json::Value) -> Self {
277 self.location = Some(location);
278 self
279 }
280
281 pub fn with_context(mut self, context: impl Into<String>) -> Self {
283 self.context = Some(context.into());
284 self
285 }
286
287 pub fn is_high_priority(&self) -> bool {
289 self.priority >= 0.7
290 }
291
292 pub fn is_high_confidence(&self) -> bool {
294 self.confidence >= 0.8
295 }
296}
297
298#[async_trait]
303pub trait FeatureExtractor: Send + Sync {
304 fn name(&self) -> &str;
306
307 fn features(&self) -> &[FeatureDefinition];
309
310 async fn extract(
312 &self,
313 entity: &CodeEntity,
314 context: &ExtractionContext,
315 ) -> Result<HashMap<String, f64>>;
316
317 fn supports_entity(&self, entity: &CodeEntity) -> bool {
319 true
321 }
322
323 fn get_feature_definition(&self, name: &str) -> Option<&FeatureDefinition> {
325 self.features().iter().find(|f| f.name == name)
326 }
327
328 fn validate_features(&self, features: &HashMap<String, f64>) -> Result<()> {
330 for (name, &value) in features {
331 if let Some(definition) = self.get_feature_definition(name) {
332 if !definition.is_valid_value(value) {
333 return Err(ValknutError::validation(format!(
334 "Feature '{}' value {} is out of range",
335 name, value
336 )));
337 }
338 }
339 }
340 Ok(())
341 }
342}
343
344#[derive(Debug, Clone, PartialEq)]
347pub struct CodeEntity {
348 pub id: EntityId,
350
351 pub entity_type: String,
353
354 pub name: String,
356
357 pub file_path: String,
359
360 pub line_range: Option<(usize, usize)>,
362
363 pub source_code: String,
365
366 pub properties: HashMap<String, serde_json::Value>,
368}
369
370impl CodeEntity {
372 pub fn new(
374 id: impl Into<EntityId>,
375 entity_type: impl Into<String>,
376 name: impl Into<String>,
377 file_path: impl Into<String>,
378 ) -> Self {
379 Self {
380 id: id.into(),
381 entity_type: entity_type.into(),
382 name: name.into(),
383 file_path: file_path.into(),
384 line_range: None,
385 source_code: String::new(),
386 properties: HashMap::new(),
387 }
388 }
389
390 pub fn with_line_range(mut self, start: usize, end: usize) -> Self {
392 self.line_range = Some((start, end));
393 self
394 }
395
396 pub fn with_source_code(mut self, source_code: impl Into<String>) -> Self {
398 self.source_code = source_code.into();
399 self
400 }
401
402 pub fn add_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
404 self.properties.insert(key.into(), value);
405 }
406
407 pub fn line_count(&self) -> usize {
409 if let Some((start, end)) = self.line_range {
410 (end - start).max(1)
411 } else {
412 self.source_code.lines().count()
413 }
414 }
415}
416
417#[derive(Debug)]
419pub struct ExtractionContext {
420 pub config: Arc<crate::core::config::ValknutConfig>,
422
423 pub entity_index: HashMap<EntityId, CodeEntity>,
425
426 pub language: String,
428
429 pub context_data: HashMap<String, serde_json::Value>,
431
432 pub candidate_partitions: Option<Arc<HashMap<EntityId, Vec<EntityId>>>>,
434}
435
436impl ExtractionContext {
438 pub fn new(
440 config: Arc<crate::core::config::ValknutConfig>,
441 language: impl Into<String>,
442 ) -> Self {
443 Self {
444 config,
445 entity_index: HashMap::new(),
446 language: language.into(),
447 context_data: HashMap::new(),
448 candidate_partitions: None,
449 }
450 }
451
452 pub fn add_entity(&mut self, entity: CodeEntity) {
454 self.entity_index.insert(entity.id.clone(), entity);
455 }
456
457 pub fn get_entity(&self, id: &str) -> Option<&CodeEntity> {
459 self.entity_index.get(id)
460 }
461
462 pub fn add_context_data(&mut self, key: impl Into<String>, value: serde_json::Value) {
464 self.context_data.insert(key.into(), value);
465 }
466
467 pub fn with_candidate_partitions(
469 mut self,
470 partitions: Arc<HashMap<EntityId, Vec<EntityId>>>,
471 ) -> Self {
472 self.candidate_partitions = Some(partitions);
473 self
474 }
475}
476
477pub struct BaseFeatureExtractor {
479 name: String,
481
482 feature_definitions: Vec<FeatureDefinition>,
484}
485
486impl BaseFeatureExtractor {
488 pub fn new(name: impl Into<String>) -> Self {
490 Self {
491 name: name.into(),
492 feature_definitions: Vec::new(),
493 }
494 }
495
496 pub fn add_feature(&mut self, definition: FeatureDefinition) {
498 self.feature_definitions.push(definition);
499 }
500
501 pub fn safe_extract<F>(&self, feature_name: &str, extraction_func: F) -> f64
503 where
504 F: FnOnce() -> Result<f64>,
505 {
506 match extraction_func() {
507 Ok(value) => {
508 if let Some(definition) = self.get_feature_definition(feature_name) {
510 definition.clamp_value(value)
511 } else {
512 value
513 }
514 }
515 Err(_) => {
516 self.get_feature_definition(feature_name)
518 .map(|def| def.default_value)
519 .unwrap_or(0.0)
520 }
521 }
522 }
523}
524
525#[async_trait]
527impl FeatureExtractor for BaseFeatureExtractor {
528 fn name(&self) -> &str {
530 &self.name
531 }
532
533 fn features(&self) -> &[FeatureDefinition] {
535 &self.feature_definitions
536 }
537
538 async fn extract(
540 &self,
541 _entity: &CodeEntity,
542 _context: &ExtractionContext,
543 ) -> Result<HashMap<String, f64>> {
544 Ok(HashMap::new())
546 }
547}
548
549#[derive(Default)]
551pub struct FeatureExtractorRegistry {
552 extractors: HashMap<String, Arc<dyn FeatureExtractor>>,
554
555 feature_definitions: HashMap<String, FeatureDefinition>,
557}
558
559impl FeatureExtractorRegistry {
561 pub fn new() -> Self {
563 Self::default()
564 }
565
566 pub fn register(&mut self, extractor: Arc<dyn FeatureExtractor>) {
568 let name = extractor.name().to_string();
569
570 for feature_def in extractor.features() {
572 self.feature_definitions
573 .insert(feature_def.name.clone(), feature_def.clone());
574 }
575
576 self.extractors.insert(name, extractor);
577 }
578
579 pub fn get_extractor(&self, name: &str) -> Option<Arc<dyn FeatureExtractor>> {
581 self.extractors.get(name).cloned()
582 }
583
584 pub fn get_all_extractors(&self) -> impl Iterator<Item = &Arc<dyn FeatureExtractor>> {
586 self.extractors.values()
587 }
588
589 pub fn get_compatible_extractors(&self, entity: &CodeEntity) -> Vec<Arc<dyn FeatureExtractor>> {
591 self.extractors
592 .values()
593 .filter(|extractor| extractor.supports_entity(entity))
594 .cloned()
595 .collect()
596 }
597
598 pub fn get_feature_definition(&self, name: &str) -> Option<&FeatureDefinition> {
600 self.feature_definitions.get(name)
601 }
602
603 pub fn get_all_feature_definitions(&self) -> impl Iterator<Item = &FeatureDefinition> {
605 self.feature_definitions.values()
606 }
607
608 pub async fn extract_all_features(
610 &self,
611 entity: &CodeEntity,
612 context: &ExtractionContext,
613 ) -> Result<FeatureVector> {
614 let mut feature_vector = FeatureVector::new(entity.id.clone());
615
616 let extractors = self.get_compatible_extractors(entity);
618
619 for extractor in extractors {
621 match extractor.extract(entity, context).await {
622 Ok(features) => {
623 for (name, value) in features {
624 feature_vector.add_feature(name, value);
625 }
626 }
627 Err(e) => {
628 tracing::warn!(
630 "Feature extraction failed for extractor '{}' on entity '{}': {}",
631 extractor.name(),
632 entity.id,
633 e
634 );
635 }
636 }
637 }
638
639 Ok(feature_vector)
640 }
641}