1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::fs;
4use regex::Regex;
5use clang::{Clang, Index};
6use crate::debug_println;
7
8use super::annotations::{FunctionSignature, extract_annotations};
9use super::safety_annotations::{SafetyMode, parse_entity_safety};
10use super::external_annotations::ExternalAnnotations;
11
12#[derive(Debug)]
14pub struct HeaderCache {
15 signatures: HashMap<String, FunctionSignature>,
17 pub safety_annotations: HashMap<String, SafetyMode>,
19 processed_headers: Vec<PathBuf>,
21 include_paths: Vec<PathBuf>,
23 pub external_annotations: ExternalAnnotations,
25}
26
27fn strip_template_params(name: &str) -> String {
29 if let Some(pos) = name.find('<') {
30 name[..pos].to_string()
31 } else {
32 name.to_string()
33 }
34}
35
36impl HeaderCache {
37 pub fn new() -> Self {
38 Self {
39 signatures: HashMap::new(),
40 safety_annotations: HashMap::new(),
41 processed_headers: Vec::new(),
42 include_paths: Vec::new(),
43 external_annotations: ExternalAnnotations::new(),
44 }
45 }
46
47 pub fn set_include_paths(&mut self, paths: Vec<PathBuf>) {
49 self.include_paths = paths;
50 }
51
52 pub fn get_signature(&self, func_name: &str) -> Option<&FunctionSignature> {
54 self.signatures.get(func_name)
55 }
56
57 pub fn get_safety_annotation(&self, func_name: &str) -> Option<SafetyMode> {
59 self.safety_annotations.get(func_name).copied()
60 }
61
62 pub fn parse_header(&mut self, header_path: &Path) -> Result<(), String> {
64 debug_println!("DEBUG HEADER: Parsing header file: {}", header_path.display());
65
66 if self.processed_headers.iter().any(|p| p == header_path) {
68 debug_println!("DEBUG HEADER: Already processed, skipping");
69 return Ok(());
70 }
71
72 let mut unqualified_annotations = HashMap::new();
76 if let Ok(header_safety_context) = super::safety_annotations::parse_safety_annotations(header_path) {
77 for (func_sig, safety_mode) in &header_safety_context.function_overrides {
79 debug_println!("DEBUG HEADER: Found unqualified annotation for '{}': {:?}", func_sig.name, safety_mode);
80 unqualified_annotations.insert(func_sig.name.clone(), *safety_mode);
81 }
82 debug_println!("DEBUG HEADER: Parsed {} unqualified safety annotations from header file", header_safety_context.function_overrides.len());
83 }
84
85 if let Ok(content) = fs::read_to_string(header_path) {
87 if let Err(e) = self.external_annotations.parse_content(&content) {
90 debug_println!("DEBUG HEADER: Failed to parse external annotations: {}", e);
91 } else {
92 debug_println!("DEBUG HEADER: Parsed external annotations from header");
93 }
94 }
95
96 let clang = Clang::new()
98 .map_err(|e| format!("Failed to initialize Clang: {:?}", e))?;
99 let index = Index::new(&clang, false, false);
100
101 let mut args = vec![
103 "-std=c++17".to_string(),
104 "-xc++".to_string(),
105 "-fparse-all-comments".to_string(), ];
107 for include_path in &self.include_paths {
108 args.push(format!("-I{}", include_path.display()));
109 }
110
111 let tu = index
113 .parser(header_path)
114 .arguments(&args.iter().map(|s| s.as_str()).collect::<Vec<_>>())
115 .parse()
116 .map_err(|e| format!("Failed to parse header {}: {:?}", header_path.display(), e))?;
117
118 let root = tu.get_entity();
120 self.visit_entity_for_signatures(&root);
121
122 let mut simple_to_qualified: HashMap<String, Vec<String>> = HashMap::new();
125 for qualified_name in self.safety_annotations.keys() {
126 if let Some(simple_name) = qualified_name.split("::").last() {
128 simple_to_qualified
129 .entry(simple_name.to_string())
130 .or_insert_with(Vec::new)
131 .push(qualified_name.clone());
132 }
133 }
134
135 debug_println!("DEBUG HEADER: Qualifying {} unqualified annotations", unqualified_annotations.len());
137 for (simple_name, safety_mode) in &unqualified_annotations {
138 debug_println!("DEBUG HEADER: Processing unqualified '{}': {:?}", simple_name, safety_mode);
139 if let Some(qualified_names) = simple_to_qualified.get(simple_name) {
141 for qualified in qualified_names {
143 debug_println!("DEBUG HEADER: Qualifying '{}' -> '{}': {:?}",
144 simple_name, qualified, safety_mode);
145 self.safety_annotations.insert(qualified.clone(), *safety_mode);
147 }
148 } else {
149 debug_println!("DEBUG HEADER: Adding plain function annotation for '{}': {:?}",
152 simple_name, safety_mode);
153 self.safety_annotations.insert(simple_name.clone(), *safety_mode);
154 }
155 }
156
157 debug_println!("DEBUG HEADER: Found {} safety annotations in header (after qualification)", self.safety_annotations.len());
158 for (name, mode) in &self.safety_annotations {
159 debug_println!("DEBUG HEADER: - {} : {:?}", name, mode);
160 }
161
162 self.processed_headers.push(header_path.to_path_buf());
164
165 if let Ok(content) = fs::read_to_string(header_path) {
167 let (quoted_includes, angle_includes) = extract_includes(&content);
168
169 for include_path in quoted_includes {
171 if let Some(resolved) = self.resolve_include(&include_path, header_path, true) {
172 let _ = self.parse_header(&resolved);
174 }
175 }
176
177 for include_path in angle_includes {
179 if let Some(resolved) = self.resolve_include(&include_path, header_path, false) {
180 let _ = self.parse_header(&resolved);
182 }
183 }
184 }
185
186 Ok(())
187 }
188
189 pub fn parse_includes_from_source(&mut self, cpp_file: &Path) -> Result<(), String> {
191 let content = fs::read_to_string(cpp_file)
192 .map_err(|e| format!("Failed to read {}: {}", cpp_file.display(), e))?;
193
194 let (quoted_includes, angle_includes) = extract_includes(&content);
195
196 for include_path in quoted_includes {
198 if let Some(resolved) = self.resolve_include(&include_path, cpp_file, true) {
199 self.parse_header(&resolved)?;
200 }
201 }
202
203 for include_path in angle_includes {
205 if let Some(resolved) = self.resolve_include(&include_path, cpp_file, false) {
206 self.parse_header(&resolved)?;
207 }
208 }
209
210 Ok(())
211 }
212
213 fn resolve_include(&self, include_path: &str, source_file: &Path, search_source_dir: bool) -> Option<PathBuf> {
215 if search_source_dir {
217 if let Some(parent) = source_file.parent() {
218 let local_path = parent.join(include_path);
219 if local_path.exists() {
220 return Some(local_path);
221 }
222 }
223 }
224
225 for include_dir in &self.include_paths {
227 let full_path = include_dir.join(include_path);
228 if full_path.exists() {
229 return Some(full_path);
230 }
231 }
232
233 let path = PathBuf::from(include_path);
235 if path.exists() {
236 return Some(path);
237 }
238
239 None
240 }
241
242 fn visit_entity_for_signatures(&mut self, entity: &clang::Entity) {
243 self.visit_entity_with_context(entity, None, None);
244 }
245
246 fn visit_entity_with_context(
249 &mut self,
250 entity: &clang::Entity,
251 namespace_safety: Option<SafetyMode>,
252 class_safety: Option<SafetyMode>,
253 ) {
254 use clang::EntityKind;
255
256 let mut current_namespace_safety = namespace_safety;
258 let mut current_class_safety = class_safety;
259
260 if entity.get_kind() == EntityKind::Namespace {
262 if let Some(safety) = parse_entity_safety(entity) {
263 current_namespace_safety = Some(safety);
264 if let Some(name) = entity.get_name() {
265 debug_println!("DEBUG SAFETY: Found namespace '{}' with {:?} annotation", name, safety);
266 }
267 } else {
268 current_namespace_safety = None;
272 debug_println!("DEBUG SAFETY: Entering namespace {:?} without annotation - resetting namespace safety",
273 entity.get_name());
274 }
275 }
276
277 if entity.get_kind() == EntityKind::ClassDecl || entity.get_kind() == EntityKind::StructDecl {
279 if let Some(safety) = parse_entity_safety(entity) {
280 current_class_safety = Some(safety);
281 if let Some(name) = entity.get_name() {
282 debug_println!("DEBUG SAFETY: Found class '{}' with {:?} annotation in header", name, safety);
283 }
284 } else if current_namespace_safety.is_some() {
285 current_class_safety = None;
288 }
289 }
290
291 match entity.get_kind() {
292 EntityKind::FunctionDecl | EntityKind::Method | EntityKind::Constructor | EntityKind::FunctionTemplate => {
293
294 if let Some(mut sig) = extract_annotations(entity) {
296 let qualified_name = crate::parser::ast_visitor::get_qualified_name(entity);
299
300 sig.name = qualified_name.clone();
302 self.signatures.insert(qualified_name, sig);
303 }
304
305 let mut safety = parse_entity_safety(entity);
307
308 if safety.is_none() {
311 if current_class_safety.is_some() {
312 safety = current_class_safety;
313 debug_println!("DEBUG SAFETY: Method inheriting {:?} from class", safety);
314 } else {
315 safety = current_namespace_safety;
316 if safety.is_some() {
317 debug_println!("DEBUG SAFETY: Function inheriting {:?} from namespace", safety);
318 }
319 }
320 }
321
322 if let Some(safety_mode) = safety {
323 let raw_name = crate::parser::ast_visitor::get_qualified_name(entity);
326
327 let name = strip_template_params(&raw_name);
330
331 self.safety_annotations.insert(name.clone(), safety_mode);
332 debug_println!("DEBUG SAFETY: Found function '{}' with {:?} annotation in header", name, safety_mode);
333 }
334 }
335 _ => {}
336 }
337
338 let child_class_safety = if entity.get_kind() == EntityKind::ClassDecl || entity.get_kind() == EntityKind::StructDecl {
342 current_class_safety
343 } else {
344 class_safety };
346
347 for child in entity.get_children() {
348 self.visit_entity_with_context(&child, current_namespace_safety, child_class_safety);
349 }
350 }
351
352 pub fn has_signatures(&self) -> bool {
354 !self.signatures.is_empty()
355 }
356}
357
358fn extract_includes(content: &str) -> (Vec<String>, Vec<String>) {
360 let mut quoted_includes = Vec::new();
361 let mut angle_includes = Vec::new();
362
363 let quoted_re = Regex::new(r#"#include\s*"([^"]+)""#).unwrap();
365 for cap in quoted_re.captures_iter(content) {
366 if let Some(path) = cap.get(1) {
367 quoted_includes.push(path.as_str().to_string());
368 }
369 }
370
371 let angle_re = Regex::new(r#"#include\s*<([^>]+)>"#).unwrap();
373 for cap in angle_re.captures_iter(content) {
374 if let Some(path) = cap.get(1) {
375 angle_includes.push(path.as_str().to_string());
376 }
377 }
378
379 (quoted_includes, angle_includes)
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_extract_includes() {
388 let content = r#"
389#include "user.h"
390#include "data.h"
391#include <iostream>
392#include <vector>
393#include "utils/helper.h"
394 "#;
395
396 let (quoted, angle) = extract_includes(content);
397 assert_eq!(quoted.len(), 3);
398 assert_eq!(quoted[0], "user.h");
399 assert_eq!(quoted[1], "data.h");
400 assert_eq!(quoted[2], "utils/helper.h");
401
402 assert_eq!(angle.len(), 2);
403 assert_eq!(angle[0], "iostream");
404 assert_eq!(angle[1], "vector");
405 }
406
407 #[test]
408 fn test_strip_template_params_simple() {
409 assert_eq!(strip_template_params("Option<T>"), "Option");
411 assert_eq!(strip_template_params("Vector<int>"), "Vector");
412 assert_eq!(strip_template_params("Map<K, V>"), "Map");
413 }
414
415 #[test]
416 fn test_strip_template_params_nested() {
417 assert_eq!(strip_template_params("Option<Vector<int>>"), "Option");
419 assert_eq!(strip_template_params("Map<string, Vector<int>>"), "Map");
420 }
421
422 #[test]
423 fn test_strip_template_params_qualified() {
424 assert_eq!(strip_template_params("rusty::Option<T>"), "rusty::Option");
426 assert_eq!(strip_template_params("std::vector<int>"), "std::vector");
427 assert_eq!(strip_template_params("ns::inner::Class<T, U>"), "ns::inner::Class");
428 }
429
430 #[test]
431 fn test_strip_template_params_no_template() {
432 assert_eq!(strip_template_params("Option"), "Option");
434 assert_eq!(strip_template_params("rusty::Option"), "rusty::Option");
435 assert_eq!(strip_template_params("some_function"), "some_function");
436 }
437
438 #[test]
439 fn test_strip_template_params_constructor() {
440 assert_eq!(strip_template_params("Option<T>::Option"), "Option");
443 }
444}