1use anyhow::{Context, Result, bail};
2use chrono::Utc;
3use rand::{Rng, thread_rng};
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::{
7 collections::{BTreeMap, HashSet},
8 fs,
9 io::{BufRead, BufReader, Write},
10 path::{Path, PathBuf},
11};
12use tokio::time::{Duration, sleep};
13
14use crate::{
15 config::AppConfig,
16 episode::{EpisodeRecord, derive_sft, derive_tooltrace},
17};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SnapshotBuildResult {
21 pub version: String,
22 pub train_count: usize,
23 pub val_count: usize,
24 pub out_dir: PathBuf,
25 pub manifest_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SnapshotPublishResult {
30 pub version: String,
31 pub snapshot_dir: PathBuf,
32 pub object_prefix: Option<String>,
33 pub indexed: bool,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct WorkerSnapshotPublishResponse {
38 pub version: String,
39 pub object_prefix: Option<String>,
40 pub public_url: Option<String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44struct SnapshotFileEntry {
45 name: String,
46 bytes: u64,
47 sha256: String,
48}
49
50pub fn build_snapshot(
51 version: &str,
52 input: &Path,
53 out_root: &Path,
54 split_seed: &str,
55 revoked_ids: &HashSet<String>,
56) -> Result<SnapshotBuildResult> {
57 let episodes = read_episode_records(input)?;
58 if episodes.is_empty() {
59 bail!("no episode records found in {}", input.display());
60 }
61
62 let mut seen_hash = HashSet::new();
63 let mut filtered = Vec::new();
64 for ep in episodes {
65 if revoked_ids.contains(&ep.id) {
66 continue;
67 }
68 if !matches!(ep.license.as_str(), "CC0-1.0" | "CC-BY-4.0") {
69 continue;
70 }
71 if !(ep.consent.public_searchable && ep.consent.trainable) {
72 continue;
73 }
74 if seen_hash.insert(ep.content_hash.clone()) {
75 filtered.push(ep);
76 }
77 }
78
79 let out_dir = out_root.join(format!("dataset-{version}"));
80 fs::create_dir_all(&out_dir)?;
81
82 let train_path = out_dir.join("train.jsonl.zst");
83 let val_path = out_dir.join("val.jsonl.zst");
84 let mut train_writer = zstd::stream::write::Encoder::new(fs::File::create(&train_path)?, 3)?;
85 let mut val_writer = zstd::stream::write::Encoder::new(fs::File::create(&val_path)?, 3)?;
86
87 let mut train_count = 0usize;
88 let mut val_count = 0usize;
89 let mut license_breakdown: BTreeMap<String, usize> = BTreeMap::new();
90
91 for ep in &filtered {
92 *license_breakdown.entry(ep.license.clone()).or_default() += 1;
93 let bucket = split_bucket(&ep.id, split_seed);
94 let line = serde_json::to_string(ep)? + "\n";
95 if bucket < 98 {
96 train_writer.write_all(line.as_bytes())?;
97 train_count += 1;
98 } else {
99 val_writer.write_all(line.as_bytes())?;
100 val_count += 1;
101 }
102 }
103 train_writer.finish()?;
104 val_writer.finish()?;
105
106 let sft_path = out_dir.join("sft.jsonl.zst");
107 let tooltrace_path = out_dir.join("tooltrace.jsonl.zst");
108 let mut sft_writer = zstd::stream::write::Encoder::new(fs::File::create(&sft_path)?, 3)?;
109 let mut tt_writer = zstd::stream::write::Encoder::new(fs::File::create(&tooltrace_path)?, 3)?;
110 for ep in &filtered {
111 let sft = derive_sft(ep);
112 let tt = derive_tooltrace(ep);
113 sft_writer.write_all((serde_json::to_string(&sft)? + "\n").as_bytes())?;
114 tt_writer.write_all((serde_json::to_string(&tt)? + "\n").as_bytes())?;
115 }
116 sft_writer.finish()?;
117 tt_writer.finish()?;
118
119 let manifest = serde_json::json!({
120 "version": version,
121 "total_records": filtered.len(),
122 "train_count": train_count,
123 "val_count": val_count,
124 "split_rule": "blake3(id|seed)%100 < 98 => train",
125 "license_breakdown": license_breakdown,
126 "files": [
127 "train.jsonl.zst",
128 "val.jsonl.zst",
129 "sft.jsonl.zst",
130 "tooltrace.jsonl.zst"
131 ]
132 });
133 let manifest_path = out_dir.join("manifest.json");
134 fs::write(&manifest_path, serde_json::to_vec_pretty(&manifest)?)?;
135 let manifest_hash = sha256_file(&manifest_path)?;
136
137 let checksums = checksums(&[
138 &train_path,
139 &val_path,
140 &sft_path,
141 &tooltrace_path,
142 &manifest_path,
143 ])?;
144 fs::write(out_dir.join("CHECKSUMS.txt"), checksums)?;
145
146 let datacard = format!(
147 "# DATA_CARD\n\nVersion: {version}\n\nTotal: {}\nTrain: {train_count}\nVal: {val_count}\n\nSanitized traces suitable for SFT and tool-use training.\n",
148 filtered.len()
149 );
150 fs::write(out_dir.join("DATA_CARD.md"), datacard)?;
151
152 Ok(SnapshotBuildResult {
153 version: version.to_string(),
154 train_count,
155 val_count,
156 out_dir,
157 manifest_hash,
158 })
159}
160
161pub async fn publish_snapshot(
162 config: &AppConfig,
163 version: &str,
164 snapshot_path: &Path,
165 dry_run: bool,
166) -> Result<SnapshotPublishResult> {
167 let snapshot_dir = resolve_snapshot_dir(version, snapshot_path)?;
168 let required = required_snapshot_files(&snapshot_dir)?;
169 let manifest_path = snapshot_dir.join("manifest.json");
170 let checksums_path = snapshot_dir.join("CHECKSUMS.txt");
171 let data_card_path = snapshot_dir.join("DATA_CARD.md");
172
173 let manifest: serde_json::Value = serde_json::from_slice(
174 &fs::read(&manifest_path)
175 .with_context(|| format!("failed to read {}", manifest_path.display()))?,
176 )
177 .context("invalid manifest.json")?;
178 let checksums = fs::read_to_string(&checksums_path)
179 .with_context(|| format!("failed to read {}", checksums_path.display()))?;
180 let data_card = fs::read_to_string(&data_card_path)
181 .with_context(|| format!("failed to read {}", data_card_path.display()))?;
182
183 let mut file_entries = Vec::new();
184 for path in required {
185 let name = path
186 .file_name()
187 .and_then(|v| v.to_str())
188 .unwrap_or("unknown")
189 .to_string();
190 let md = fs::metadata(&path)?;
191 file_entries.push(SnapshotFileEntry {
192 name,
193 bytes: md.len(),
194 sha256: sha256_file(&path)?,
195 });
196 }
197 file_entries.sort_by(|a, b| a.name.cmp(&b.name));
198
199 let mut object_prefix = None;
200 if !dry_run {
201 object_prefix = publish_snapshot_to_worker(
202 config,
203 version,
204 &manifest,
205 &checksums,
206 &file_entries,
207 &data_card,
208 )
209 .await?
210 .object_prefix;
211
212 index_snapshot_pointer(config, version, &manifest, object_prefix.as_deref()).await?;
213 }
214
215 Ok(SnapshotPublishResult {
216 version: version.to_string(),
217 snapshot_dir,
218 object_prefix,
219 indexed: !dry_run,
220 })
221}
222
223fn split_bucket(id: &str, seed: &str) -> u8 {
224 let value = format!("{id}|{seed}");
225 let hash = blake3::hash(value.as_bytes());
226 hash.as_bytes()[0] % 100
227}
228
229fn read_episode_records(path: &Path) -> Result<Vec<EpisodeRecord>> {
230 let mut out = Vec::new();
231 if path.is_file() {
232 parse_episode_file(path, &mut out)?;
233 return Ok(out);
234 }
235
236 if path.is_dir() {
237 for entry in ignore::WalkBuilder::new(path)
238 .hidden(false)
239 .git_ignore(false)
240 .build()
241 {
242 let entry = match entry {
243 Ok(v) => v,
244 Err(_) => continue,
245 };
246 if entry.file_type().map(|f| f.is_file()).unwrap_or(false) {
247 parse_episode_file(entry.path(), &mut out)?;
248 }
249 }
250 return Ok(out);
251 }
252
253 Ok(out)
254}
255
256fn parse_episode_file(path: &Path, out: &mut Vec<EpisodeRecord>) -> Result<()> {
257 let file = fs::File::open(path)?;
258 let reader = BufReader::new(file);
259 for line in reader.lines() {
260 let line = line?;
261 if line.trim().is_empty() {
262 continue;
263 }
264 if let Ok(ep) = serde_json::from_str::<EpisodeRecord>(&line) {
265 out.push(ep);
266 }
267 }
268 Ok(())
269}
270
271fn checksums(paths: &[&Path]) -> Result<String> {
272 let mut lines = Vec::new();
273 for path in paths {
274 let bytes = fs::read(path)?;
275 let mut hasher = Sha256::new();
276 hasher.update(bytes);
277 let digest = hasher.finalize();
278 let digest_hex = format!("{:x}", digest);
279 let name = path
280 .file_name()
281 .and_then(|s| s.to_str())
282 .unwrap_or("unknown");
283 lines.push(format!("{} {}", digest_hex, name));
284 }
285 lines.sort();
286 Ok(lines.join("\n") + "\n")
287}
288
289fn resolve_snapshot_dir(version: &str, snapshot_path: &Path) -> Result<PathBuf> {
290 let expected_name = format!("dataset-{version}");
291 if snapshot_path
292 .file_name()
293 .and_then(|v| v.to_str())
294 .map(|v| v == expected_name)
295 .unwrap_or(false)
296 {
297 if snapshot_path.is_dir() {
298 return Ok(snapshot_path.to_path_buf());
299 }
300 bail!(
301 "snapshot path is not a directory: {}",
302 snapshot_path.display()
303 );
304 }
305
306 let candidate = snapshot_path.join(expected_name);
307 if candidate.is_dir() {
308 return Ok(candidate);
309 }
310
311 bail!(
312 "snapshot directory not found: expected {} or {}",
313 snapshot_path.display(),
314 candidate.display()
315 )
316}
317
318fn required_snapshot_files(snapshot_dir: &Path) -> Result<Vec<PathBuf>> {
319 let names = [
320 "train.jsonl.zst",
321 "val.jsonl.zst",
322 "sft.jsonl.zst",
323 "tooltrace.jsonl.zst",
324 "manifest.json",
325 "CHECKSUMS.txt",
326 "DATA_CARD.md",
327 ];
328 let mut out = Vec::new();
329 for name in names {
330 let path = snapshot_dir.join(name);
331 if !path.is_file() {
332 bail!("missing required snapshot artifact: {}", path.display());
333 }
334 out.push(path);
335 }
336 Ok(out)
337}
338
339async fn publish_snapshot_to_worker(
340 config: &AppConfig,
341 version: &str,
342 manifest: &serde_json::Value,
343 checksums: &str,
344 files: &[SnapshotFileEntry],
345 data_card: &str,
346) -> Result<WorkerSnapshotPublishResponse> {
347 let base_url = config
348 .worker
349 .base_url
350 .as_ref()
351 .context("missing TRACE_SHARE_WORKER_BASE_URL")?;
352 let endpoint = format!("{}/v1/snapshots", base_url.trim_end_matches('/'));
353 let client = reqwest::Client::builder()
354 .timeout(std::time::Duration::from_secs(
355 config.worker.timeout_seconds.max(5),
356 ))
357 .build()?;
358
359 let payload = serde_json::json!({
360 "version": version,
361 "created_at": Utc::now().to_rfc3339(),
362 "manifest": manifest,
363 "checksums": checksums,
364 "files": files,
365 "data_card": data_card,
366 });
367
368 let mut attempt: u32 = 0;
369 loop {
370 let mut req = client.post(&endpoint).json(&payload);
371 if let Some(token) = config.worker.api_token.as_ref() {
372 req = req.bearer_auth(token);
373 }
374
375 let res = req.send().await;
376 match res {
377 Ok(resp) => {
378 let status = resp.status();
379 if status.is_success() {
380 return Ok(resp.json::<WorkerSnapshotPublishResponse>().await?);
381 }
382 let body = resp.text().await.unwrap_or_default();
383 if !should_retry_status(status) || attempt >= 4 {
384 anyhow::bail!(
385 "worker snapshot publish failed: status={} body={}",
386 status,
387 body
388 );
389 }
390 }
391 Err(e) => {
392 let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
393 if !retryable_transport || attempt >= 4 {
394 return Err(e).context("worker snapshot publish request failed after retries");
395 }
396 }
397 }
398
399 attempt += 1;
400 let jitter: u64 = thread_rng().gen_range(50..300);
401 let wait_ms = (2u64.pow(attempt) * 200) + jitter;
402 sleep(Duration::from_millis(wait_ms)).await;
403 }
404}
405
406async fn index_snapshot_pointer(
407 config: &AppConfig,
408 version: &str,
409 manifest: &serde_json::Value,
410 object_prefix: Option<&str>,
411) -> Result<()> {
412 let rest_url = config
413 .upstash
414 .rest_url
415 .as_ref()
416 .context("missing UPSTASH_VECTOR_REST_URL")?;
417 let token = config
418 .upstash
419 .rest_token
420 .as_ref()
421 .context("missing UPSTASH_VECTOR_REST_TOKEN")?;
422 let endpoint = format!("{}/upsert-data", rest_url.trim_end_matches('/'));
423 let client = reqwest::Client::new();
424
425 let payload = serde_json::json!({
426 "vectors": [
427 {
428 "id": format!("snapshot:{version}"),
429 "data": format!("trace-share dataset snapshot {version}"),
430 "metadata": {
431 "kind": "dataset_snapshot",
432 "snapshot_version": version,
433 "total_records": manifest.get("total_records"),
434 "train_count": manifest.get("train_count"),
435 "val_count": manifest.get("val_count"),
436 "pointer": {
437 "storage": "r2",
438 "object_key": object_prefix,
439 "snapshot_version": version
440 }
441 }
442 }
443 ]
444 });
445
446 let mut attempt: u32 = 0;
447 loop {
448 let res = client
449 .post(&endpoint)
450 .bearer_auth(token)
451 .json(&payload)
452 .send()
453 .await;
454 match res {
455 Ok(resp) => {
456 let status = resp.status();
457 if status.is_success() {
458 return Ok(());
459 }
460 let body = resp.text().await.unwrap_or_default();
461 if !should_retry_status(status) || attempt >= 4 {
462 anyhow::bail!(
463 "upstash snapshot pointer index failed: status={} body={}",
464 status,
465 body
466 );
467 }
468 }
469 Err(e) => {
470 let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
471 if !retryable_transport || attempt >= 4 {
472 return Err(e).context("upstash snapshot pointer request failed after retries");
473 }
474 }
475 }
476
477 attempt += 1;
478 let jitter: u64 = thread_rng().gen_range(50..300);
479 let wait_ms = (2u64.pow(attempt) * 200) + jitter;
480 sleep(Duration::from_millis(wait_ms)).await;
481 }
482}
483
484fn should_retry_status(status: reqwest::StatusCode) -> bool {
485 status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
486}
487
488fn sha256_file(path: &Path) -> Result<String> {
489 let bytes = fs::read(path)?;
490 let mut hasher = Sha256::new();
491 hasher.update(bytes);
492 Ok(format!("{:x}", hasher.finalize()))
493}