Skip to main content

spark_bert/
dataset.rs

1use anyhow::{Context, Result};
2use serde::Deserialize;
3use std::{
4    collections::HashMap,
5    fs::File,
6    io::{BufRead, BufReader},
7    path::{Path, PathBuf},
8};
9
10#[derive(Debug, Deserialize)]
11pub struct CorpusDoc {
12    pub title: Option<String>,
13    pub text: String,
14    #[serde(flatten)]
15    pub meta: HashMap<String, serde_json::Value>,
16}
17
18impl CorpusDoc {
19    pub fn as_text(&self) -> String {
20        let sep = " ";
21        let title = self.title.as_deref().unwrap_or("");
22        let mut combined = String::with_capacity(title.len() + sep.len() + self.text.len());
23        combined.push_str(title);
24        combined.push_str(sep);
25        combined.push_str(&self.text);
26        let trimmed = combined.trim();
27        if trimmed.len() == combined.len() {
28            combined
29        } else {
30            trimmed.to_owned()
31        }
32    }
33}
34
35#[derive(Debug, Deserialize)]
36pub struct Query {
37    pub text: String,
38}
39
40pub type Corpus = HashMap<String, CorpusDoc>;
41pub type Queries = HashMap<String, Query>;
42pub type Qrels = HashMap<String, HashMap<String, i32>>;
43
44const DATA_DIR: &str = "datasets/scifact";
45
46fn read_jsonl<T: for<'de> Deserialize<'de>>(path: impl AsRef<Path>) -> Result<HashMap<String, T>> {
47    let file = BufReader::new(File::open(&path)?);
48    let mut map = HashMap::new();
49
50    for line in file.lines() {
51        let line = line?;
52        let mut obj: serde_json::Map<String, serde_json::Value> = serde_json::from_str(&line)?;
53        let id = obj
54            .remove("_id")
55            .context("_id field missing")?
56            .as_str()
57            .unwrap()
58            .to_owned();
59
60        let val: T = serde_json::from_value(serde_json::Value::Object(obj))?;
61        map.insert(id, val);
62    }
63    Ok(map)
64}
65
66fn read_qrels(path: impl AsRef<Path>) -> Result<Qrels> {
67    let file = BufReader::new(File::open(&path)?);
68    let mut map: Qrels = HashMap::new();
69
70    for line in file.lines() {
71        let line = line?;
72        let parts: Vec<_> = line.split_whitespace().collect();
73        if parts[0] == "query-id" {
74            // Skip first line
75            continue;
76        } else if parts.len() != 3 {
77            println!("Wrong len: {}", parts.len());
78            continue;
79        }
80        let (q, d, rel) = (parts[0], parts[1], parts[2].parse::<i32>()?);
81        map.entry(q.to_owned())
82            .or_default()
83            .insert(d.to_owned(), rel);
84    }
85    Ok(map)
86}
87
88pub fn load_scifact(split: &str) -> Result<(Corpus, Queries, Qrels)> {
89    let base = PathBuf::from(DATA_DIR);
90
91    let corpus = read_jsonl(base.join("corpus.jsonl")).with_context(|| "reading corpus.jsonl")?;
92
93    let mut queries =
94        read_jsonl(base.join("queries.jsonl")).with_context(|| "reading queries.jsonl")?;
95
96    let qrels = read_qrels(base.join(format!("qrels/{split}.tsv")))
97        .with_context(|| format!("reading qrels/{split}.tsv"))?;
98
99    queries.retain(|qid, _| qrels.contains_key(qid));
100
101    Ok((corpus, queries, qrels))
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn should_return_as_text_with_title() {
110        let doc = CorpusDoc {
111            title: Some(" Title".to_string()),
112            text: "Body ".to_string(),
113            meta: HashMap::new(),
114        };
115        assert_eq!(doc.as_text(), "Title Body");
116    }
117
118    #[test]
119    fn should_return_as_text_without_title() {
120        let doc = CorpusDoc {
121            title: None,
122            text: " Body ".to_string(),
123            meta: HashMap::new(),
124        };
125        assert_eq!(doc.as_text(), "Body");
126    }
127}