scirs2_text/information_extraction/
document.rs1use super::entities::{Entity, EntityCluster, EntityType};
4use super::pipeline::AdvancedExtractionPipeline;
5use super::relations::{Event, Relation};
6use crate::error::Result;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct DocumentSummary {
12 pub document_index: usize,
14 pub entity_count: usize,
16 pub relation_count: usize,
18 pub key_phrases: Vec<(String, f64)>,
20 pub confidence_score: f64,
22}
23
24#[derive(Debug, Clone)]
26pub struct Topic {
27 pub name: String,
29 pub key_phrases: Vec<String>,
31 pub document_indices: Vec<usize>,
33 pub confidence: f64,
35}
36
37#[derive(Debug)]
39pub struct StructuredDocumentInformation {
40 pub documents: Vec<DocumentSummary>,
42 pub entity_clusters: Vec<EntityCluster>,
44 pub relations: Vec<Relation>,
46 pub events: Vec<Event>,
48 pub topics: Vec<Topic>,
50 pub total_entities: usize,
52 pub total_relations: usize,
54}
55
56pub struct DocumentInformationExtractor {
58 topic_threshold: f64,
59 similarity_threshold: f64,
60 max_topics: usize,
61}
62
63impl Default for DocumentInformationExtractor {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl DocumentInformationExtractor {
70 pub fn new() -> Self {
72 Self {
73 topic_threshold: 0.3,
74 similarity_threshold: 0.7,
75 max_topics: 10,
76 }
77 }
78
79 pub fn extract_structured_information(
81 &self,
82 documents: &[String],
83 pipeline: &AdvancedExtractionPipeline,
84 ) -> Result<StructuredDocumentInformation> {
85 let mut all_entities = Vec::new();
86 let mut all_relations = Vec::new();
87 let mut document_summaries = Vec::new();
88
89 for (doc_idx, document) in documents.iter().enumerate() {
91 let info = pipeline.extract_advanced(document)?;
92
93 let mut doc_entities = info.entities;
95 for entity in &mut doc_entities {
96 entity.confidence *= 0.9; }
98
99 let doc_summary = DocumentSummary {
100 document_index: doc_idx,
101 entity_count: doc_entities.len(),
102 relation_count: info.relations.len(),
103 key_phrases: info.key_phrases.clone(),
104 confidence_score: self.calculate_document_confidence(&doc_entities),
105 };
106
107 all_entities.extend(doc_entities);
108 all_relations.extend(info.relations);
109 document_summaries.push(doc_summary);
110 }
111
112 let entity_clusters = self.cluster_entities(&all_entities)?;
114
115 let events = self.extract_events(&all_relations, &all_entities)?;
117
118 let topics = self.identify_topics(&document_summaries)?;
120
121 let total_relations = all_relations.len();
122 Ok(StructuredDocumentInformation {
123 documents: document_summaries,
124 entity_clusters,
125 relations: all_relations,
126 events,
127 topics,
128 total_entities: all_entities.len(),
129 total_relations,
130 })
131 }
132
133 fn calculate_document_confidence(&self, entities: &[Entity]) -> f64 {
135 if entities.is_empty() {
136 return 0.0;
137 }
138
139 let sum: f64 = entities.iter().map(|e| e.confidence).sum();
140 sum / entities.len() as f64
141 }
142
143 pub fn cluster_entities(&self, entities: &[Entity]) -> Result<Vec<EntityCluster>> {
145 let mut clusters = Vec::new();
146 let mut used = vec![false; entities.len()];
147
148 for (i, entity) in entities.iter().enumerate() {
149 if used[i] {
150 continue;
151 }
152
153 let mut cluster = EntityCluster {
154 representative: entity.clone(),
155 members: vec![entity.clone()],
156 entity_type: entity.entity_type.clone(),
157 confidence: entity.confidence,
158 };
159
160 used[i] = true;
161
162 for (j, other) in entities.iter().enumerate().skip(i + 1) {
164 if used[j] || other.entity_type != entity.entity_type {
165 continue;
166 }
167
168 let similarity = self.calculate_entity_similarity(entity, other);
169 if similarity > self.similarity_threshold {
170 cluster.members.push(other.clone());
171 cluster.confidence = (cluster.confidence + other.confidence) / 2.0;
172 used[j] = true;
173 }
174 }
175
176 clusters.push(cluster);
177 }
178
179 clusters.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
180 Ok(clusters)
181 }
182
183 pub fn calculate_entity_similarity(&self, entity1: &Entity, entity2: &Entity) -> f64 {
185 if entity1.entity_type != entity2.entity_type {
186 return 0.0;
187 }
188
189 let text1 = entity1.text.to_lowercase();
191 let text2 = entity2.text.to_lowercase();
192
193 if text1 == text2 {
194 return 1.0;
195 }
196
197 let max_len = text1.len().max(text2.len());
199 if max_len == 0 {
200 return 1.0;
201 }
202
203 let distance = self.levenshtein_distance(&text1, &text2);
204 1.0 - (distance as f64 / max_len as f64)
205 }
206
207 pub fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
209 let len1 = s1.len();
210 let len2 = s2.len();
211 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
212
213 #[allow(clippy::needless_range_loop)]
214 for i in 0..=len1 {
215 matrix[i][0] = i;
216 }
217 for j in 0..=len2 {
218 matrix[0][j] = j;
219 }
220
221 let s1_chars: Vec<char> = s1.chars().collect();
222 let s2_chars: Vec<char> = s2.chars().collect();
223
224 for (i, &c1) in s1_chars.iter().enumerate() {
225 for (j, &c2) in s2_chars.iter().enumerate() {
226 let cost = if c1 == c2 { 0 } else { 1 };
227 matrix[i + 1][j + 1] = std::cmp::min(
228 std::cmp::min(matrix[i][j + 1] + 1, matrix[i + 1][j] + 1),
229 matrix[i][j] + cost,
230 );
231 }
232 }
233
234 matrix[len1][len2]
235 }
236
237 pub fn extract_events(
239 &self,
240 relations: &[Relation],
241 entities: &[Entity],
242 ) -> Result<Vec<Event>> {
243 let mut events = Vec::new();
244
245 let mut relation_groups: std::collections::HashMap<String, Vec<&Relation>> =
247 std::collections::HashMap::new();
248
249 for relation in relations {
250 let context_key = format!(
251 "{}_{}",
252 relation.subject.start / 100, relation.object.start / 100
254 );
255 relation_groups
256 .entry(context_key)
257 .or_default()
258 .push(relation);
259 }
260
261 for (_, group_relations) in relation_groups {
263 if group_relations.len() >= 2 {
264 let event = Event {
265 event_type: self.infer_event_type(&group_relations),
266 participants: self.extract_participants(&group_relations),
267 location: self.extract_location(&group_relations, entities),
268 time: self.extract_time(&group_relations, entities),
269 description: self.generate_event_description(&group_relations),
270 confidence: self.calculate_event_confidence(&group_relations),
271 };
272 events.push(event);
273 }
274 }
275
276 Ok(events)
277 }
278
279 fn infer_event_type(&self, relations: &[&Relation]) -> String {
281 let relation_types: std::collections::HashMap<String, usize> =
282 relations
283 .iter()
284 .fold(std::collections::HashMap::new(), |mut acc, rel| {
285 *acc.entry(rel.relation_type.clone()).or_insert(0) += 1;
286 acc
287 });
288
289 relation_types
290 .into_iter()
291 .max_by_key(|(_, count)| *count)
292 .map(|(rel_type_, _)| rel_type_)
293 .unwrap_or_else(|| "unknown".to_string())
294 }
295
296 fn extract_participants(&self, relations: &[&Relation]) -> Vec<Entity> {
298 let mut participants = Vec::new();
299 for relation in relations {
300 participants.push(relation.subject.clone());
301 participants.push(relation.object.clone());
302 }
303
304 participants.sort_by_key(|e| e.text.clone());
306 participants.dedup_by_key(|e| e.text.clone());
307 participants
308 }
309
310 fn extract_location(&self, relations: &[&Relation], entities: &[Entity]) -> Option<Entity> {
312 for relation in relations {
313 for entity in entities {
314 if matches!(entity.entity_type, EntityType::Location) {
315 let relation_span = relation.subject.start..relation.object.end;
316 let entity_span = entity.start..entity.end;
317
318 if relation_span.contains(&entity.start)
320 || entity_span.contains(&relation.subject.start)
321 || (entity.start as i32 - relation.subject.start as i32).abs() < 100
322 {
323 return Some(entity.clone());
324 }
325 }
326 }
327 }
328 None
329 }
330
331 fn extract_time(&self, relations: &[&Relation], entities: &[Entity]) -> Option<Entity> {
333 for relation in relations {
334 for entity in entities {
335 if matches!(entity.entity_type, EntityType::Date | EntityType::Time) {
336 let relation_span = relation.subject.start..relation.object.end;
337 let entity_span = entity.start..entity.end;
338
339 if relation_span.contains(&entity.start)
341 || entity_span.contains(&relation.subject.start)
342 || (entity.start as i32 - relation.subject.start as i32).abs() < 100
343 {
344 return Some(entity.clone());
345 }
346 }
347 }
348 }
349 None
350 }
351
352 fn generate_event_description(&self, relations: &[&Relation]) -> String {
354 if relations.is_empty() {
355 return "Unknown event".to_string();
356 }
357
358 let contexts: Vec<String> = relations.iter().map(|r| r.context.clone()).collect();
359
360 contexts
362 .into_iter()
363 .max_by_key(|s| s.len())
364 .unwrap_or_else(|| "Event description unavailable".to_string())
365 }
366
367 fn calculate_event_confidence(&self, relations: &[&Relation]) -> f64 {
369 if relations.is_empty() {
370 return 0.0;
371 }
372
373 let sum: f64 = relations.iter().map(|r| r.confidence).sum();
374 (sum / relations.len() as f64) * 0.8 }
376
377 pub fn identify_topics(&self, summaries: &[DocumentSummary]) -> Result<Vec<Topic>> {
379 let mut topics = Vec::new();
380 let mut topic_phrases: std::collections::HashMap<String, Vec<usize>> =
381 std::collections::HashMap::new();
382
383 for summary in summaries {
385 for (phrase, score) in &summary.key_phrases {
386 if *score > self.topic_threshold {
387 topic_phrases
388 .entry(phrase.clone())
389 .or_default()
390 .push(summary.document_index);
391 }
392 }
393 }
394
395 for (phrase, doc_indices) in topic_phrases {
397 if doc_indices.len() >= 2 {
398 let topic = Topic {
400 name: phrase.clone(),
401 key_phrases: vec![phrase],
402 document_indices: doc_indices,
403 confidence: 0.8,
404 };
405 topics.push(topic);
406 }
407 }
408
409 topics.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
411 topics.truncate(self.max_topics);
412
413 Ok(topics)
414 }
415}