Skip to main content

trace_share_core/
snapshot.rs

1use anyhow::{Context, Result, bail};
2use chrono::Utc;
3use rand::{Rng, thread_rng};
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::{
7    collections::{BTreeMap, HashSet},
8    fs,
9    io::{BufRead, BufReader, Write},
10    path::{Path, PathBuf},
11};
12use tokio::time::{Duration, sleep};
13
14use crate::{
15    config::AppConfig,
16    episode::{EpisodeRecord, derive_sft, derive_tooltrace},
17};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SnapshotBuildResult {
21    pub version: String,
22    pub train_count: usize,
23    pub val_count: usize,
24    pub out_dir: PathBuf,
25    pub manifest_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SnapshotPublishResult {
30    pub version: String,
31    pub snapshot_dir: PathBuf,
32    pub object_prefix: Option<String>,
33    pub indexed: bool,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct WorkerSnapshotPublishResponse {
38    pub version: String,
39    pub object_prefix: Option<String>,
40    pub public_url: Option<String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44struct SnapshotFileEntry {
45    name: String,
46    bytes: u64,
47    sha256: String,
48}
49
50pub fn build_snapshot(
51    version: &str,
52    input: &Path,
53    out_root: &Path,
54    split_seed: &str,
55    revoked_ids: &HashSet<String>,
56) -> Result<SnapshotBuildResult> {
57    let episodes = read_episode_records(input)?;
58    if episodes.is_empty() {
59        bail!("no episode records found in {}", input.display());
60    }
61
62    let mut seen_hash = HashSet::new();
63    let mut filtered = Vec::new();
64    for ep in episodes {
65        if revoked_ids.contains(&ep.id) {
66            continue;
67        }
68        if !matches!(ep.license.as_str(), "CC0-1.0" | "CC-BY-4.0") {
69            continue;
70        }
71        if !(ep.consent.public_searchable && ep.consent.trainable) {
72            continue;
73        }
74        if seen_hash.insert(ep.content_hash.clone()) {
75            filtered.push(ep);
76        }
77    }
78
79    let out_dir = out_root.join(format!("dataset-{version}"));
80    fs::create_dir_all(&out_dir)?;
81
82    let train_path = out_dir.join("train.jsonl.zst");
83    let val_path = out_dir.join("val.jsonl.zst");
84    let mut train_writer = zstd::stream::write::Encoder::new(fs::File::create(&train_path)?, 3)?;
85    let mut val_writer = zstd::stream::write::Encoder::new(fs::File::create(&val_path)?, 3)?;
86
87    let mut train_count = 0usize;
88    let mut val_count = 0usize;
89    let mut license_breakdown: BTreeMap<String, usize> = BTreeMap::new();
90
91    for ep in &filtered {
92        *license_breakdown.entry(ep.license.clone()).or_default() += 1;
93        let bucket = split_bucket(&ep.id, split_seed);
94        let line = serde_json::to_string(ep)? + "\n";
95        if bucket < 98 {
96            train_writer.write_all(line.as_bytes())?;
97            train_count += 1;
98        } else {
99            val_writer.write_all(line.as_bytes())?;
100            val_count += 1;
101        }
102    }
103    train_writer.finish()?;
104    val_writer.finish()?;
105
106    let sft_path = out_dir.join("sft.jsonl.zst");
107    let tooltrace_path = out_dir.join("tooltrace.jsonl.zst");
108    let mut sft_writer = zstd::stream::write::Encoder::new(fs::File::create(&sft_path)?, 3)?;
109    let mut tt_writer = zstd::stream::write::Encoder::new(fs::File::create(&tooltrace_path)?, 3)?;
110    for ep in &filtered {
111        let sft = derive_sft(ep);
112        let tt = derive_tooltrace(ep);
113        sft_writer.write_all((serde_json::to_string(&sft)? + "\n").as_bytes())?;
114        tt_writer.write_all((serde_json::to_string(&tt)? + "\n").as_bytes())?;
115    }
116    sft_writer.finish()?;
117    tt_writer.finish()?;
118
119    let manifest = serde_json::json!({
120        "version": version,
121        "total_records": filtered.len(),
122        "train_count": train_count,
123        "val_count": val_count,
124        "split_rule": "blake3(id|seed)%100 < 98 => train",
125        "license_breakdown": license_breakdown,
126        "files": [
127            "train.jsonl.zst",
128            "val.jsonl.zst",
129            "sft.jsonl.zst",
130            "tooltrace.jsonl.zst"
131        ]
132    });
133    let manifest_path = out_dir.join("manifest.json");
134    fs::write(&manifest_path, serde_json::to_vec_pretty(&manifest)?)?;
135    let manifest_hash = sha256_file(&manifest_path)?;
136
137    let checksums = checksums(&[
138        &train_path,
139        &val_path,
140        &sft_path,
141        &tooltrace_path,
142        &manifest_path,
143    ])?;
144    fs::write(out_dir.join("CHECKSUMS.txt"), checksums)?;
145
146    let datacard = format!(
147        "# DATA_CARD\n\nVersion: {version}\n\nTotal: {}\nTrain: {train_count}\nVal: {val_count}\n\nSanitized traces suitable for SFT and tool-use training.\n",
148        filtered.len()
149    );
150    fs::write(out_dir.join("DATA_CARD.md"), datacard)?;
151
152    Ok(SnapshotBuildResult {
153        version: version.to_string(),
154        train_count,
155        val_count,
156        out_dir,
157        manifest_hash,
158    })
159}
160
161pub async fn publish_snapshot(
162    config: &AppConfig,
163    version: &str,
164    snapshot_path: &Path,
165    dry_run: bool,
166) -> Result<SnapshotPublishResult> {
167    let snapshot_dir = resolve_snapshot_dir(version, snapshot_path)?;
168    let required = required_snapshot_files(&snapshot_dir)?;
169    let manifest_path = snapshot_dir.join("manifest.json");
170    let checksums_path = snapshot_dir.join("CHECKSUMS.txt");
171    let data_card_path = snapshot_dir.join("DATA_CARD.md");
172
173    let manifest: serde_json::Value = serde_json::from_slice(
174        &fs::read(&manifest_path)
175            .with_context(|| format!("failed to read {}", manifest_path.display()))?,
176    )
177    .context("invalid manifest.json")?;
178    let checksums = fs::read_to_string(&checksums_path)
179        .with_context(|| format!("failed to read {}", checksums_path.display()))?;
180    let data_card = fs::read_to_string(&data_card_path)
181        .with_context(|| format!("failed to read {}", data_card_path.display()))?;
182
183    let mut file_entries = Vec::new();
184    for path in required {
185        let name = path
186            .file_name()
187            .and_then(|v| v.to_str())
188            .unwrap_or("unknown")
189            .to_string();
190        let md = fs::metadata(&path)?;
191        file_entries.push(SnapshotFileEntry {
192            name,
193            bytes: md.len(),
194            sha256: sha256_file(&path)?,
195        });
196    }
197    file_entries.sort_by(|a, b| a.name.cmp(&b.name));
198
199    let mut object_prefix = None;
200    if !dry_run {
201        object_prefix = publish_snapshot_to_worker(
202            config,
203            version,
204            &manifest,
205            &checksums,
206            &file_entries,
207            &data_card,
208        )
209        .await?
210        .object_prefix;
211
212        index_snapshot_pointer(config, version, &manifest, object_prefix.as_deref()).await?;
213    }
214
215    Ok(SnapshotPublishResult {
216        version: version.to_string(),
217        snapshot_dir,
218        object_prefix,
219        indexed: !dry_run,
220    })
221}
222
223fn split_bucket(id: &str, seed: &str) -> u8 {
224    let value = format!("{id}|{seed}");
225    let hash = blake3::hash(value.as_bytes());
226    hash.as_bytes()[0] % 100
227}
228
229fn read_episode_records(path: &Path) -> Result<Vec<EpisodeRecord>> {
230    let mut out = Vec::new();
231    if path.is_file() {
232        parse_episode_file(path, &mut out)?;
233        return Ok(out);
234    }
235
236    if path.is_dir() {
237        for entry in ignore::WalkBuilder::new(path)
238            .hidden(false)
239            .git_ignore(false)
240            .build()
241        {
242            let entry = match entry {
243                Ok(v) => v,
244                Err(_) => continue,
245            };
246            if entry.file_type().map(|f| f.is_file()).unwrap_or(false) {
247                parse_episode_file(entry.path(), &mut out)?;
248            }
249        }
250        return Ok(out);
251    }
252
253    Ok(out)
254}
255
256fn parse_episode_file(path: &Path, out: &mut Vec<EpisodeRecord>) -> Result<()> {
257    let file = fs::File::open(path)?;
258    let reader = BufReader::new(file);
259    for line in reader.lines() {
260        let line = line?;
261        if line.trim().is_empty() {
262            continue;
263        }
264        if let Ok(ep) = serde_json::from_str::<EpisodeRecord>(&line) {
265            out.push(ep);
266        }
267    }
268    Ok(())
269}
270
271fn checksums(paths: &[&Path]) -> Result<String> {
272    let mut lines = Vec::new();
273    for path in paths {
274        let bytes = fs::read(path)?;
275        let mut hasher = Sha256::new();
276        hasher.update(bytes);
277        let digest = hasher.finalize();
278        let digest_hex = format!("{:x}", digest);
279        let name = path
280            .file_name()
281            .and_then(|s| s.to_str())
282            .unwrap_or("unknown");
283        lines.push(format!("{}  {}", digest_hex, name));
284    }
285    lines.sort();
286    Ok(lines.join("\n") + "\n")
287}
288
289fn resolve_snapshot_dir(version: &str, snapshot_path: &Path) -> Result<PathBuf> {
290    let expected_name = format!("dataset-{version}");
291    if snapshot_path
292        .file_name()
293        .and_then(|v| v.to_str())
294        .map(|v| v == expected_name)
295        .unwrap_or(false)
296    {
297        if snapshot_path.is_dir() {
298            return Ok(snapshot_path.to_path_buf());
299        }
300        bail!(
301            "snapshot path is not a directory: {}",
302            snapshot_path.display()
303        );
304    }
305
306    let candidate = snapshot_path.join(expected_name);
307    if candidate.is_dir() {
308        return Ok(candidate);
309    }
310
311    bail!(
312        "snapshot directory not found: expected {} or {}",
313        snapshot_path.display(),
314        candidate.display()
315    )
316}
317
318fn required_snapshot_files(snapshot_dir: &Path) -> Result<Vec<PathBuf>> {
319    let names = [
320        "train.jsonl.zst",
321        "val.jsonl.zst",
322        "sft.jsonl.zst",
323        "tooltrace.jsonl.zst",
324        "manifest.json",
325        "CHECKSUMS.txt",
326        "DATA_CARD.md",
327    ];
328    let mut out = Vec::new();
329    for name in names {
330        let path = snapshot_dir.join(name);
331        if !path.is_file() {
332            bail!("missing required snapshot artifact: {}", path.display());
333        }
334        out.push(path);
335    }
336    Ok(out)
337}
338
339async fn publish_snapshot_to_worker(
340    config: &AppConfig,
341    version: &str,
342    manifest: &serde_json::Value,
343    checksums: &str,
344    files: &[SnapshotFileEntry],
345    data_card: &str,
346) -> Result<WorkerSnapshotPublishResponse> {
347    let base_url = config
348        .worker
349        .base_url
350        .as_ref()
351        .context("missing TRACE_SHARE_WORKER_BASE_URL")?;
352    let endpoint = format!("{}/v1/snapshots", base_url.trim_end_matches('/'));
353    let client = reqwest::Client::builder()
354        .timeout(std::time::Duration::from_secs(
355            config.worker.timeout_seconds.max(5),
356        ))
357        .build()?;
358
359    let payload = serde_json::json!({
360        "version": version,
361        "created_at": Utc::now().to_rfc3339(),
362        "manifest": manifest,
363        "checksums": checksums,
364        "files": files,
365        "data_card": data_card,
366    });
367
368    let mut attempt: u32 = 0;
369    loop {
370        let mut req = client.post(&endpoint).json(&payload);
371        if let Some(token) = config.worker.api_token.as_ref() {
372            req = req.bearer_auth(token);
373        }
374
375        let res = req.send().await;
376        match res {
377            Ok(resp) => {
378                let status = resp.status();
379                if status.is_success() {
380                    return Ok(resp.json::<WorkerSnapshotPublishResponse>().await?);
381                }
382                let body = resp.text().await.unwrap_or_default();
383                if !should_retry_status(status) || attempt >= 4 {
384                    anyhow::bail!(
385                        "worker snapshot publish failed: status={} body={}",
386                        status,
387                        body
388                    );
389                }
390            }
391            Err(e) => {
392                let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
393                if !retryable_transport || attempt >= 4 {
394                    return Err(e).context("worker snapshot publish request failed after retries");
395                }
396            }
397        }
398
399        attempt += 1;
400        let jitter: u64 = thread_rng().gen_range(50..300);
401        let wait_ms = (2u64.pow(attempt) * 200) + jitter;
402        sleep(Duration::from_millis(wait_ms)).await;
403    }
404}
405
406async fn index_snapshot_pointer(
407    config: &AppConfig,
408    version: &str,
409    manifest: &serde_json::Value,
410    object_prefix: Option<&str>,
411) -> Result<()> {
412    let rest_url = config
413        .upstash
414        .rest_url
415        .as_ref()
416        .context("missing UPSTASH_VECTOR_REST_URL")?;
417    let token = config
418        .upstash
419        .rest_token
420        .as_ref()
421        .context("missing UPSTASH_VECTOR_REST_TOKEN")?;
422    let endpoint = format!("{}/upsert-data", rest_url.trim_end_matches('/'));
423    let client = reqwest::Client::new();
424
425    let payload = serde_json::json!({
426        "vectors": [
427            {
428                "id": format!("snapshot:{version}"),
429                "data": format!("trace-share dataset snapshot {version}"),
430                "metadata": {
431                    "kind": "dataset_snapshot",
432                    "snapshot_version": version,
433                    "total_records": manifest.get("total_records"),
434                    "train_count": manifest.get("train_count"),
435                    "val_count": manifest.get("val_count"),
436                    "pointer": {
437                        "storage": "r2",
438                        "object_key": object_prefix,
439                        "snapshot_version": version
440                    }
441                }
442            }
443        ]
444    });
445
446    let mut attempt: u32 = 0;
447    loop {
448        let res = client
449            .post(&endpoint)
450            .bearer_auth(token)
451            .json(&payload)
452            .send()
453            .await;
454        match res {
455            Ok(resp) => {
456                let status = resp.status();
457                if status.is_success() {
458                    return Ok(());
459                }
460                let body = resp.text().await.unwrap_or_default();
461                if !should_retry_status(status) || attempt >= 4 {
462                    anyhow::bail!(
463                        "upstash snapshot pointer index failed: status={} body={}",
464                        status,
465                        body
466                    );
467                }
468            }
469            Err(e) => {
470                let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
471                if !retryable_transport || attempt >= 4 {
472                    return Err(e).context("upstash snapshot pointer request failed after retries");
473                }
474            }
475        }
476
477        attempt += 1;
478        let jitter: u64 = thread_rng().gen_range(50..300);
479        let wait_ms = (2u64.pow(attempt) * 200) + jitter;
480        sleep(Duration::from_millis(wait_ms)).await;
481    }
482}
483
484fn should_retry_status(status: reqwest::StatusCode) -> bool {
485    status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
486}
487
488fn sha256_file(path: &Path) -> Result<String> {
489    let bytes = fs::read(path)?;
490    let mut hasher = Sha256::new();
491    hasher.update(bytes);
492    Ok(format!("{:x}", hasher.finalize()))
493}