1mod error;
68pub use error::{Result, YcbError};
69
70#[cfg(feature = "blocking")]
73pub mod blocking;
74
75use futures_util::stream::{self, StreamExt, TryStreamExt};
76use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
77use reqwest::Client;
78use serde::Deserialize;
79use std::fs::{self, File};
80use std::io::{BufWriter, Write};
81use std::path::Path;
82
83pub const BASE_URL: &str = "https://ycb-benchmarks.s3.amazonaws.com/data/";
85
86pub const OBJECTS_URL: &str = "https://ycb-benchmarks.s3.amazonaws.com/data/objects.json";
88
89pub const GOOGLE_16K_MESH_RELATIVE: &str = "google_16k/textured.obj";
94
95pub const GOOGLE_16K_TEXTURE_RELATIVE: &str = "google_16k/texture_map.png";
97
98pub const REPRESENTATIVE_OBJECTS: &[&str] =
100 &["003_cracker_box", "004_sugar_box", "005_tomato_soup_can"];
101
102pub const TBP_STANDARD_OBJECTS: &[&str] = &[
109 "025_mug",
110 "024_bowl",
111 "010_potted_meat_can",
112 "031_spoon",
113 "012_strawberry",
114 "006_mustard_bottle",
115 "062_dice",
116 "058_golf_ball",
117 "073-c_lego_duplo",
118 "011_banana",
119];
120
121pub const TBP_SIMILAR_OBJECTS: &[&str] = &[
128 "003_cracker_box",
129 "004_sugar_box",
130 "009_gelatin_box",
131 "021_bleach_cleanser",
132 "036_wood_block",
133 "039_key",
134 "040_large_marker",
135 "051_large_clamp",
136 "052_extra_large_clamp",
137 "061_foam_brick",
138];
139
140#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default)]
142pub enum Subset {
143 #[default]
145 Representative,
146 TbpStandard,
150 TbpSimilar,
152 All,
154}
155
156#[derive(Clone, Debug)]
162#[non_exhaustive]
163pub struct DownloadOptions {
164 pub overwrite: bool,
166 pub full: bool,
169 pub show_progress: bool,
171 pub delete_archives: bool,
173 pub concurrency: usize,
180 pub verify_integrity: bool,
187}
188
189impl Default for DownloadOptions {
190 fn default() -> Self {
191 Self {
192 overwrite: false,
193 full: false,
194 show_progress: true,
195 delete_archives: true,
196 concurrency: 1,
197 verify_integrity: true,
198 }
199 }
200}
201
202#[derive(Deserialize, Debug)]
204struct ObjectsResponse {
205 objects: Vec<String>,
206}
207
208pub(crate) async fn selected_objects_for_subset(
209 subset: Subset,
210 client: &Client,
211) -> Result<Vec<String>> {
212 match get_subset_objects(subset) {
213 Some(objects) => Ok(objects),
214 None => fetch_objects(client).await,
215 }
216}
217
218fn download_file_types(full: bool) -> &'static [&'static str] {
219 if full {
220 &["berkeley_processed", "google_16k"]
221 } else {
222 &["google_16k"]
223 }
224}
225
226fn local_artifact_exists(output_dir: &Path, object: &str, file_type: &str) -> bool {
227 match file_type {
228 "google_16k" => object_mesh_path(output_dir, object).exists(),
229 _ => false,
230 }
231}
232
233pub async fn fetch_objects(client: &Client) -> Result<Vec<String>> {
250 let response = client.get(OBJECTS_URL).send().await?;
251 let status = response.status();
252 if !status.is_success() {
253 return Err(YcbError::HttpStatus {
254 status: status.as_u16(),
255 url: OBJECTS_URL.to_string(),
256 });
257 }
258 let body = response.text().await?;
259 let objects_response: ObjectsResponse = serde_json::from_str(&body)
260 .map_err(|e| YcbError::InvalidResponse(format!("YCB objects index: {e}")))?;
261 Ok(objects_response.objects)
262}
263
264pub fn get_tgz_url(object: &str, file_type: &str) -> String {
280 if file_type == "berkeley_rgbd" || file_type == "berkeley_rgb_highres" {
281 format!(
282 "{}berkeley/{}/{}_{}.tgz",
283 BASE_URL, object, object, file_type
284 )
285 } else if file_type == "berkeley_processed" {
286 format!(
287 "{}berkeley/{}/{}_berkeley_meshes.tgz",
288 BASE_URL, object, object
289 )
290 } else {
291 format!("{}google/{}_{}.tgz", BASE_URL, object, file_type)
292 }
293}
294
295pub async fn download_file(
320 client: &Client,
321 url: &str,
322 dest_path: &Path,
323 show_progress: bool,
324) -> Result<()> {
325 download_file_inner(client, url, dest_path, show_progress, None).await
326}
327
328async fn download_file_inner(
329 client: &Client,
330 url: &str,
331 dest_path: &Path,
332 show_progress: bool,
333 multi: Option<&MultiProgress>,
334) -> Result<()> {
335 let res = client.get(url).send().await?;
336 let status = res.status();
337 if !status.is_success() {
338 return Err(YcbError::HttpStatus {
339 status: status.as_u16(),
340 url: url.to_string(),
341 });
342 }
343 let total_size = res.content_length().unwrap_or(0);
344 let filename = dest_path
345 .file_name()
346 .map(|n| n.to_string_lossy().to_string())
347 .unwrap_or_else(|| "unknown".to_string());
348
349 let pb = if show_progress {
350 let pb = ProgressBar::new(total_size);
351 pb.set_style(
352 ProgressStyle::default_bar()
353 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta}) {msg}")
354 .expect("Invalid progress bar template - this is a bug")
355 .progress_chars("#>-"),
356 );
357 pb.set_message(format!("Downloading {}", filename));
358 Some(match multi {
359 Some(m) => m.add(pb),
360 None => pb,
361 })
362 } else {
363 None
364 };
365
366 let mut file = BufWriter::new(File::create(dest_path)?);
367 let mut stream = res.bytes_stream();
368
369 while let Some(item) = stream.next().await {
370 let chunk = item?;
371 file.write_all(&chunk)?;
372 if let Some(ref pb) = pb {
373 pb.inc(chunk.len() as u64);
374 }
375 }
376
377 file.flush()?;
378
379 if let Some(pb) = pb {
380 pb.finish_with_message("Done");
381 }
382 Ok(())
383}
384
385pub fn extract_tgz(tgz_path: &Path, output_dir: &Path, delete_archive: bool) -> Result<()> {
411 let tgz_str = tgz_path.display().to_string();
412 fs::create_dir_all(output_dir)?;
413 let canonical_output = output_dir
414 .canonicalize()
415 .unwrap_or_else(|_| output_dir.to_path_buf());
416
417 let tar_gz = File::open(tgz_path)?;
418 let tar = flate2::read::GzDecoder::new(tar_gz);
419 let mut archive = tar::Archive::new(tar);
420
421 let entries = archive
422 .entries()
423 .map_err(|e| YcbError::extraction(&tgz_str, e))?;
424
425 for entry in entries {
426 let mut entry = entry.map_err(|e| YcbError::extraction(&tgz_str, e))?;
427 let path = entry
428 .path()
429 .map_err(|e| YcbError::extraction(&tgz_str, e))?
430 .to_path_buf();
431
432 if path
433 .components()
434 .any(|c| matches!(c, std::path::Component::ParentDir))
435 {
436 return Err(YcbError::UnsafeArchive(format!(
437 "archive entry contains '..': {}",
438 path.display()
439 )));
440 }
441
442 let dest = output_dir.join(&path);
443
444 if let Ok(canonical_dest) = dest.canonicalize() {
445 if !canonical_dest.starts_with(&canonical_output) {
446 return Err(YcbError::UnsafeArchive(format!(
447 "archive entry escapes output dir: {}",
448 dest.display()
449 )));
450 }
451 }
452
453 if let Some(parent) = dest.parent() {
454 fs::create_dir_all(parent)?;
455 }
456
457 entry
458 .unpack(&dest)
459 .map_err(|e| YcbError::extraction(&tgz_str, e))?;
460 }
461
462 if delete_archive {
463 fs::remove_file(tgz_path)?;
464 }
465 Ok(())
466}
467
468pub async fn url_exists(client: &Client, url: &str) -> Result<bool> {
479 let response = client.head(url).send().await?;
480 Ok(response.status().is_success())
481}
482
483async fn fetch_content_length(client: &Client, url: &str) -> Result<Option<u64>> {
489 let response = client.head(url).send().await?;
490 if !response.status().is_success() {
491 return Ok(None);
492 }
493 Ok(response.content_length())
494}
495
496pub async fn download_ycb(
524 subset: Subset,
525 output_dir: &Path,
526 options: DownloadOptions,
527) -> Result<()> {
528 let client = Client::new();
529 let selected_objects = selected_objects_for_subset(subset, &client).await?;
530 let refs: Vec<&str> = selected_objects.iter().map(String::as_str).collect();
531 download_objects(&refs, output_dir, options).await
532}
533
534async fn process_work_item(
535 client: &Client,
536 output_dir: &Path,
537 options: &DownloadOptions,
538 multi: Option<&MultiProgress>,
539 object: &str,
540 file_type: &'static str,
541) -> Result<()> {
542 if !options.overwrite && local_artifact_exists(output_dir, object, file_type) {
547 return Ok(());
548 }
549
550 let filename = format!("{}_{}.tgz", object, file_type);
551 let dest_path = output_dir.join(&filename);
552 let url = get_tgz_url(object, file_type);
553
554 let mut have_valid_archive = false;
555 if !options.overwrite && dest_path.exists() {
556 if options.verify_integrity {
557 match fetch_content_length(client, &url).await? {
558 Some(expected) => {
559 let actual = std::fs::metadata(&dest_path)?.len();
560 if actual == expected {
561 have_valid_archive = true;
562 } else {
563 let _ = std::fs::remove_file(&dest_path);
565 }
566 }
567 None => {
568 have_valid_archive = true;
569 }
570 }
571 } else {
572 have_valid_archive = true;
573 }
574 }
575
576 if !options.overwrite && have_valid_archive {
577 return Ok(());
578 }
579
580 match download_file_inner(client, &url, &dest_path, options.show_progress, multi).await {
581 Ok(()) => {}
582 Err(YcbError::HttpStatus { status: 404, .. }) => return Ok(()),
583 Err(err) => return Err(err),
584 }
585
586 extract_tgz(&dest_path, output_dir, options.delete_archives)?;
587 Ok(())
588}
589
590pub async fn download_objects(
618 objects: &[&str],
619 output_dir: &Path,
620 options: DownloadOptions,
621) -> Result<()> {
622 if objects.is_empty() {
623 return Ok(());
624 }
625
626 let client = Client::new();
627 fs::create_dir_all(output_dir).map_err(YcbError::Io)?;
628
629 let file_types = download_file_types(options.full);
630 let concurrency = options.concurrency.max(1);
631 let multi = if options.show_progress && concurrency > 1 {
632 Some(MultiProgress::new())
633 } else {
634 None
635 };
636
637 let work: Vec<(&str, &'static str)> = objects
638 .iter()
639 .flat_map(|o| file_types.iter().map(move |ft| (*o, *ft)))
640 .collect();
641
642 stream::iter(work)
643 .map(|(object, file_type)| {
644 let client = &client;
645 let multi = multi.as_ref();
646 let options = &options;
647 async move {
648 process_work_item(client, output_dir, options, multi, object, file_type).await
649 }
650 })
651 .buffer_unordered(concurrency)
652 .try_for_each(|_| async { Ok::<(), YcbError>(()) })
653 .await?;
654
655 Ok(())
656}
657
658pub fn get_subset_objects(subset: Subset) -> Option<Vec<String>> {
678 match subset {
679 Subset::Representative => Some(
680 REPRESENTATIVE_OBJECTS
681 .iter()
682 .map(|s| s.to_string())
683 .collect(),
684 ),
685 Subset::TbpStandard => Some(TBP_STANDARD_OBJECTS.iter().map(|s| s.to_string()).collect()),
686 Subset::TbpSimilar => Some(TBP_SIMILAR_OBJECTS.iter().map(|s| s.to_string()).collect()),
687 Subset::All => None,
688 }
689}
690
691pub fn object_mesh_path(ycb_dir: &Path, object: &str) -> std::path::PathBuf {
707 ycb_dir.join(object).join(GOOGLE_16K_MESH_RELATIVE)
708}
709
710pub fn object_texture_path(ycb_dir: &Path, object: &str) -> std::path::PathBuf {
712 ycb_dir.join(object).join(GOOGLE_16K_TEXTURE_RELATIVE)
713}
714
715#[derive(Debug, Clone)]
717pub struct ObjectValidation {
718 pub name: String,
720 pub mesh_present: bool,
722 pub texture_present: bool,
724}
725
726impl ObjectValidation {
727 pub fn is_complete(&self) -> bool {
729 self.mesh_present && self.texture_present
730 }
731}
732
733pub fn validate_objects(ycb_dir: &Path, objects: &[&str]) -> Vec<ObjectValidation> {
747 objects
748 .iter()
749 .map(|name| ObjectValidation {
750 name: name.to_string(),
751 mesh_present: object_mesh_path(ycb_dir, name).exists(),
752 texture_present: object_texture_path(ycb_dir, name).exists(),
753 })
754 .collect()
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760
761 #[test]
762 fn test_get_tgz_url_google_16k() {
763 let url = get_tgz_url("003_cracker_box", "google_16k");
764 assert_eq!(
765 url,
766 "https://ycb-benchmarks.s3.amazonaws.com/data/google/003_cracker_box_google_16k.tgz"
767 );
768 }
769
770 #[test]
771 fn test_get_tgz_url_berkeley_processed() {
772 let url = get_tgz_url("003_cracker_box", "berkeley_processed");
773 assert_eq!(
774 url,
775 "https://ycb-benchmarks.s3.amazonaws.com/data/berkeley/003_cracker_box/003_cracker_box_berkeley_meshes.tgz"
776 );
777 }
778
779 #[test]
780 fn test_get_tgz_url_berkeley_rgbd() {
781 let url = get_tgz_url("003_cracker_box", "berkeley_rgbd");
782 assert_eq!(
783 url,
784 "https://ycb-benchmarks.s3.amazonaws.com/data/berkeley/003_cracker_box/003_cracker_box_berkeley_rgbd.tgz"
785 );
786 }
787
788 #[test]
789 fn test_get_tgz_url_berkeley_rgb_highres() {
790 let url = get_tgz_url("003_cracker_box", "berkeley_rgb_highres");
791 assert_eq!(
792 url,
793 "https://ycb-benchmarks.s3.amazonaws.com/data/berkeley/003_cracker_box/003_cracker_box_berkeley_rgb_highres.tgz"
794 );
795 }
796
797 #[test]
798 fn test_get_tgz_url_different_objects() {
799 let url1 = get_tgz_url("004_sugar_box", "google_16k");
800 assert!(url1.contains("004_sugar_box"));
801
802 let url2 = get_tgz_url("005_tomato_soup_can", "google_16k");
803 assert!(url2.contains("005_tomato_soup_can"));
804 }
805
806 #[test]
807 fn test_subset_default() {
808 let subset = Subset::default();
809 assert_eq!(subset, Subset::Representative);
810 }
811
812 #[test]
813 fn test_download_options_default() {
814 let options = DownloadOptions::default();
815 assert!(!options.overwrite);
816 assert!(!options.full);
817 assert!(options.show_progress);
818 assert!(options.delete_archives);
819 assert_eq!(options.concurrency, 1);
820 assert!(options.verify_integrity);
821 }
822
823 #[test]
824 fn test_get_subset_objects_representative() {
825 let objects = get_subset_objects(Subset::Representative);
826 assert_eq!(objects.unwrap().len(), 3);
827 }
828
829 #[test]
830 fn test_get_subset_objects_tbp_standard() {
831 let objects = get_subset_objects(Subset::TbpStandard);
832 assert_eq!(objects.unwrap().len(), 10);
833 }
834
835 #[test]
836 fn test_get_subset_objects_tbp_similar() {
837 let objects = get_subset_objects(Subset::TbpSimilar);
838 assert_eq!(objects.unwrap().len(), 10);
839 }
840
841 #[test]
842 fn test_get_subset_objects_all() {
843 let objects = get_subset_objects(Subset::All);
844 assert!(objects.is_none());
845 }
846
847 #[test]
848 fn test_local_artifact_exists_for_google_16k_mesh() {
849 let dir = tempfile::tempdir().unwrap();
850 let mesh_path = object_mesh_path(dir.path(), "003_cracker_box");
851 fs::create_dir_all(mesh_path.parent().unwrap()).unwrap();
852 File::create(&mesh_path).unwrap();
853
854 assert!(local_artifact_exists(
855 dir.path(),
856 "003_cracker_box",
857 "google_16k"
858 ));
859 assert!(!local_artifact_exists(
860 dir.path(),
861 "003_cracker_box",
862 "berkeley_processed"
863 ));
864 }
865
866 #[test]
867 fn test_path_consts_compose_with_object_helpers() {
868 let root = Path::new("ycb-root");
869 let object = "006_mustard_bottle";
870
871 assert_eq!(
872 object_mesh_path(root, object),
873 root.join(object).join(GOOGLE_16K_MESH_RELATIVE)
874 );
875 assert_eq!(
876 object_texture_path(root, object),
877 root.join(object).join(GOOGLE_16K_TEXTURE_RELATIVE)
878 );
879 }
880
881 #[test]
882 fn test_path_consts_have_expected_values() {
883 assert_eq!(GOOGLE_16K_MESH_RELATIVE, "google_16k/textured.obj");
884 assert_eq!(GOOGLE_16K_TEXTURE_RELATIVE, "google_16k/texture_map.png");
885 }
886
887 #[tokio::test]
888 async fn test_download_objects_empty_slice_is_noop() {
889 let dir = tempfile::tempdir().unwrap();
890 let result = download_objects(&[], dir.path(), DownloadOptions::default()).await;
891 assert!(result.is_ok());
892 let entries = fs::read_dir(dir.path()).unwrap().count();
894 assert_eq!(entries, 0);
895 }
896
897 #[tokio::test]
898 async fn test_download_objects_skips_when_mesh_present() {
899 let dir = tempfile::tempdir().unwrap();
900 let mesh_path = object_mesh_path(dir.path(), "003_cracker_box");
901 fs::create_dir_all(mesh_path.parent().unwrap()).unwrap();
902 File::create(&mesh_path).unwrap();
903
904 let options = DownloadOptions {
907 show_progress: false,
908 ..DownloadOptions::default()
909 };
910 let result = download_objects(&["003_cracker_box"], dir.path(), options).await;
911 assert!(result.is_ok());
912 }
913
914 #[tokio::test]
915 async fn test_download_objects_mesh_skip_bypasses_head_even_with_archive_present() {
916 let dir = tempfile::tempdir().unwrap();
921 let object = "003_cracker_box";
922
923 let archive_path = dir.path().join(format!("{object}_google_16k.tgz"));
925 let mut f = File::create(&archive_path).unwrap();
926 f.write_all(b"not a real archive").unwrap();
927
928 let mesh_path = object_mesh_path(dir.path(), object);
930 fs::create_dir_all(mesh_path.parent().unwrap()).unwrap();
931 File::create(&mesh_path).unwrap();
932
933 let options = DownloadOptions {
934 show_progress: false,
935 verify_integrity: true,
936 ..DownloadOptions::default()
937 };
938 let result = download_objects(&[object], dir.path(), options).await;
939 assert!(result.is_ok());
940 assert!(archive_path.exists());
943 }
944
945 #[tokio::test]
946 async fn test_download_objects_concurrent_skips_when_all_meshes_present() {
947 let dir = tempfile::tempdir().unwrap();
951 for object in TBP_STANDARD_OBJECTS {
952 let mesh_path = object_mesh_path(dir.path(), object);
953 fs::create_dir_all(mesh_path.parent().unwrap()).unwrap();
954 File::create(&mesh_path).unwrap();
955 }
956
957 let options = DownloadOptions {
958 show_progress: false,
959 concurrency: 4,
960 ..DownloadOptions::default()
961 };
962 let refs: Vec<&str> = TBP_STANDARD_OBJECTS.to_vec();
963 let result = download_objects(&refs, dir.path(), options).await;
964 assert!(result.is_ok());
965 }
966
967 #[test]
968 fn test_ycb_error_converts_to_anyhow() {
969 let y = YcbError::HttpStatus {
970 status: 404,
971 url: "https://example.com".into(),
972 };
973 let a: anyhow::Error = y.into();
974 assert!(a.to_string().contains("404"));
975 }
976
977 #[cfg(feature = "blocking")]
978 #[test]
979 fn test_blocking_download_objects_empty_slice() {
980 let dir = tempfile::tempdir().unwrap();
981 let result =
982 crate::blocking::download_objects_blocking(&[], dir.path(), DownloadOptions::default());
983 assert!(result.is_ok());
984 }
985}