scirs2_text/information_extraction/
document.rs1use super::entities::{Entity, EntityCluster, EntityType};
4use super::pipeline::AdvancedExtractionPipeline;
5use super::relations::{Event, Relation};
6use crate::error::Result;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct DocumentSummary {
12 pub document_index: usize,
14 pub entity_count: usize,
16 pub relation_count: usize,
18 pub key_phrases: Vec<(String, f64)>,
20 pub confidence_score: f64,
22}
23
24#[derive(Debug, Clone)]
26pub struct Topic {
27 pub name: String,
29 pub key_phrases: Vec<String>,
31 pub document_indices: Vec<usize>,
33 pub confidence: f64,
35}
36
37#[derive(Debug)]
39pub struct StructuredDocumentInformation {
40 pub documents: Vec<DocumentSummary>,
42 pub entity_clusters: Vec<EntityCluster>,
44 pub relations: Vec<Relation>,
46 pub events: Vec<Event>,
48 pub topics: Vec<Topic>,
50 pub total_entities: usize,
52 pub total_relations: usize,
54}
55
56pub struct DocumentInformationExtractor {
58 topic_threshold: f64,
59 similarity_threshold: f64,
60 max_topics: usize,
61}
62
63impl Default for DocumentInformationExtractor {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl DocumentInformationExtractor {
70 pub fn new() -> Self {
72 Self {
73 topic_threshold: 0.3,
74 similarity_threshold: 0.7,
75 max_topics: 10,
76 }
77 }
78
79 pub fn extract_structured_information(
81 &self,
82 documents: &[String],
83 pipeline: &AdvancedExtractionPipeline,
84 ) -> Result<StructuredDocumentInformation> {
85 let mut all_entities = Vec::new();
86 let mut all_relations = Vec::new();
87 let mut document_summaries = Vec::new();
88
89 for (doc_idx, document) in documents.iter().enumerate() {
91 let info = pipeline.extract_advanced(document)?;
92
93 let mut doc_entities = info.entities;
95 for entity in &mut doc_entities {
96 entity.confidence *= 0.9; }
98
99 let doc_summary = DocumentSummary {
100 document_index: doc_idx,
101 entity_count: doc_entities.len(),
102 relation_count: info.relations.len(),
103 key_phrases: info.key_phrases.clone(),
104 confidence_score: self.calculate_document_confidence(&doc_entities),
105 };
106
107 all_entities.extend(doc_entities);
108 all_relations.extend(info.relations);
109 document_summaries.push(doc_summary);
110 }
111
112 let entity_clusters = self.cluster_entities(&all_entities)?;
114
115 let events = self.extract_events(&all_relations, &all_entities)?;
117
118 let topics = self.identify_topics(&document_summaries)?;
120
121 let total_relations = all_relations.len();
122 Ok(StructuredDocumentInformation {
123 documents: document_summaries,
124 entity_clusters,
125 relations: all_relations,
126 events,
127 topics,
128 total_entities: all_entities.len(),
129 total_relations,
130 })
131 }
132
133 fn calculate_document_confidence(&self, entities: &[Entity]) -> f64 {
135 if entities.is_empty() {
136 return 0.0;
137 }
138
139 let sum: f64 = entities.iter().map(|e| e.confidence).sum();
140 sum / entities.len() as f64
141 }
142
143 pub fn cluster_entities(&self, entities: &[Entity]) -> Result<Vec<EntityCluster>> {
145 let mut clusters = Vec::new();
146 let mut used = vec![false; entities.len()];
147
148 for (i, entity) in entities.iter().enumerate() {
149 if used[i] {
150 continue;
151 }
152
153 let mut cluster = EntityCluster {
154 representative: entity.clone(),
155 members: vec![entity.clone()],
156 entity_type: entity.entity_type.clone(),
157 confidence: entity.confidence,
158 };
159
160 used[i] = true;
161
162 for (j, other) in entities.iter().enumerate().skip(i + 1) {
164 if used[j] || other.entity_type != entity.entity_type {
165 continue;
166 }
167
168 let similarity = self.calculate_entity_similarity(entity, other);
169 if similarity > self.similarity_threshold {
170 cluster.members.push(other.clone());
171 cluster.confidence = (cluster.confidence + other.confidence) / 2.0;
172 used[j] = true;
173 }
174 }
175
176 clusters.push(cluster);
177 }
178
179 clusters.sort_by(|a, b| {
180 b.confidence
181 .partial_cmp(&a.confidence)
182 .expect("Operation failed")
183 });
184 Ok(clusters)
185 }
186
187 pub fn calculate_entity_similarity(&self, entity1: &Entity, entity2: &Entity) -> f64 {
189 if entity1.entity_type != entity2.entity_type {
190 return 0.0;
191 }
192
193 let text1 = entity1.text.to_lowercase();
195 let text2 = entity2.text.to_lowercase();
196
197 if text1 == text2 {
198 return 1.0;
199 }
200
201 let max_len = text1.len().max(text2.len());
203 if max_len == 0 {
204 return 1.0;
205 }
206
207 let distance = self.levenshtein_distance(&text1, &text2);
208 1.0 - (distance as f64 / max_len as f64)
209 }
210
211 pub fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
213 let len1 = s1.len();
214 let len2 = s2.len();
215 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
216
217 #[allow(clippy::needless_range_loop)]
218 for i in 0..=len1 {
219 matrix[i][0] = i;
220 }
221 for j in 0..=len2 {
222 matrix[0][j] = j;
223 }
224
225 let s1_chars: Vec<char> = s1.chars().collect();
226 let s2_chars: Vec<char> = s2.chars().collect();
227
228 for (i, &c1) in s1_chars.iter().enumerate() {
229 for (j, &c2) in s2_chars.iter().enumerate() {
230 let cost = if c1 == c2 { 0 } else { 1 };
231 matrix[i + 1][j + 1] = std::cmp::min(
232 std::cmp::min(matrix[i][j + 1] + 1, matrix[i + 1][j] + 1),
233 matrix[i][j] + cost,
234 );
235 }
236 }
237
238 matrix[len1][len2]
239 }
240
241 pub fn extract_events(
243 &self,
244 relations: &[Relation],
245 entities: &[Entity],
246 ) -> Result<Vec<Event>> {
247 let mut events = Vec::new();
248
249 let mut relation_groups: std::collections::HashMap<String, Vec<&Relation>> =
251 std::collections::HashMap::new();
252
253 for relation in relations {
254 let context_key = format!(
255 "{}_{}",
256 relation.subject.start / 100, relation.object.start / 100
258 );
259 relation_groups
260 .entry(context_key)
261 .or_default()
262 .push(relation);
263 }
264
265 for (_, group_relations) in relation_groups {
267 if group_relations.len() >= 2 {
268 let event = Event {
269 event_type: self.infer_event_type(&group_relations),
270 participants: self.extract_participants(&group_relations),
271 location: self.extract_location(&group_relations, entities),
272 time: self.extract_time(&group_relations, entities),
273 description: self.generate_event_description(&group_relations),
274 confidence: self.calculate_event_confidence(&group_relations),
275 };
276 events.push(event);
277 }
278 }
279
280 Ok(events)
281 }
282
283 fn infer_event_type(&self, relations: &[&Relation]) -> String {
285 let relation_types: std::collections::HashMap<String, usize> =
286 relations
287 .iter()
288 .fold(std::collections::HashMap::new(), |mut acc, rel| {
289 *acc.entry(rel.relation_type.clone()).or_insert(0) += 1;
290 acc
291 });
292
293 relation_types
294 .into_iter()
295 .max_by_key(|(_, count)| *count)
296 .map(|(rel_type_, _)| rel_type_)
297 .unwrap_or_else(|| "unknown".to_string())
298 }
299
300 fn extract_participants(&self, relations: &[&Relation]) -> Vec<Entity> {
302 let mut participants = Vec::new();
303 for relation in relations {
304 participants.push(relation.subject.clone());
305 participants.push(relation.object.clone());
306 }
307
308 participants.sort_by_key(|e| e.text.clone());
310 participants.dedup_by_key(|e| e.text.clone());
311 participants
312 }
313
314 fn extract_location(&self, relations: &[&Relation], entities: &[Entity]) -> Option<Entity> {
316 for relation in relations {
317 for entity in entities {
318 if matches!(entity.entity_type, EntityType::Location) {
319 let relation_span = relation.subject.start..relation.object.end;
320 let entity_span = entity.start..entity.end;
321
322 if relation_span.contains(&entity.start)
324 || entity_span.contains(&relation.subject.start)
325 || (entity.start as i32 - relation.subject.start as i32).abs() < 100
326 {
327 return Some(entity.clone());
328 }
329 }
330 }
331 }
332 None
333 }
334
335 fn extract_time(&self, relations: &[&Relation], entities: &[Entity]) -> Option<Entity> {
337 for relation in relations {
338 for entity in entities {
339 if matches!(entity.entity_type, EntityType::Date | EntityType::Time) {
340 let relation_span = relation.subject.start..relation.object.end;
341 let entity_span = entity.start..entity.end;
342
343 if relation_span.contains(&entity.start)
345 || entity_span.contains(&relation.subject.start)
346 || (entity.start as i32 - relation.subject.start as i32).abs() < 100
347 {
348 return Some(entity.clone());
349 }
350 }
351 }
352 }
353 None
354 }
355
356 fn generate_event_description(&self, relations: &[&Relation]) -> String {
358 if relations.is_empty() {
359 return "Unknown event".to_string();
360 }
361
362 let contexts: Vec<String> = relations.iter().map(|r| r.context.clone()).collect();
363
364 contexts
366 .into_iter()
367 .max_by_key(|s| s.len())
368 .unwrap_or_else(|| "Event description unavailable".to_string())
369 }
370
371 fn calculate_event_confidence(&self, relations: &[&Relation]) -> f64 {
373 if relations.is_empty() {
374 return 0.0;
375 }
376
377 let sum: f64 = relations.iter().map(|r| r.confidence).sum();
378 (sum / relations.len() as f64) * 0.8 }
380
381 pub fn identify_topics(&self, summaries: &[DocumentSummary]) -> Result<Vec<Topic>> {
383 let mut topics = Vec::new();
384 let mut topic_phrases: std::collections::HashMap<String, Vec<usize>> =
385 std::collections::HashMap::new();
386
387 for summary in summaries {
389 for (phrase, score) in &summary.key_phrases {
390 if *score > self.topic_threshold {
391 topic_phrases
392 .entry(phrase.clone())
393 .or_default()
394 .push(summary.document_index);
395 }
396 }
397 }
398
399 for (phrase, doc_indices) in topic_phrases {
401 if doc_indices.len() >= 2 {
402 let topic = Topic {
404 name: phrase.clone(),
405 key_phrases: vec![phrase],
406 document_indices: doc_indices,
407 confidence: 0.8,
408 };
409 topics.push(topic);
410 }
411 }
412
413 topics.sort_by(|a, b| {
415 b.confidence
416 .partial_cmp(&a.confidence)
417 .expect("Operation failed")
418 });
419 topics.truncate(self.max_topics);
420
421 Ok(topics)
422 }
423}