xet_data/processing/migration_tool/
migrate.rs1use std::sync::Arc;
2
3use http::header;
4use tracing::{Instrument, Span, info_span, instrument};
5use xet_client::cas_client::auth::TokenRefresher;
6use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
7use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
8use xet_runtime::core::XetRuntime;
9use xet_runtime::core::par_utils::run_constrained;
10
11use super::super::data_client::{clean_file, default_config};
12use super::super::{FileUploadSession, Sha256Policy, XetFileInfo};
13use super::hub_client_token_refresher::HubClientTokenRefresher;
14use crate::error::{DataError, Result};
15
16const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
17
18pub async fn migrate_with_external_runtime(
29 file_paths: Vec<String>,
30 sha256s: Option<Vec<String>>,
31 hub_endpoint: &str,
32 cas_endpoint: Option<String>,
33 hub_token: &str,
34 repo_type: &str,
35 repo_id: &str,
36) -> Result<()> {
37 let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), "");
38 let mut headers = header::HeaderMap::new();
39 headers.insert(header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT));
40 let hub_client = HubClient::new(
41 hub_endpoint,
42 RepoInfo::try_from(repo_type, repo_id)?,
43 Some("main".to_owned()),
44 "",
45 Some(cred_helper),
46 Some(headers),
47 )?;
48
49 migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, false).await?;
50
51 Ok(())
52}
53
54pub type MigrationInfo = (Vec<MDBFileInfo>, Vec<(XetFileInfo, u64)>, u64);
56
57#[instrument(skip_all, name = "migrate_files", fields(session_id = tracing::field::Empty, num_files = file_paths.len()))]
58pub async fn migrate_files_impl(
59 file_paths: Vec<String>,
60 sha256s: Option<Vec<String>>,
61 sequential: bool,
62 hub_client: HubClient,
63 cas_endpoint: Option<String>,
64 dry_run: bool,
65) -> Result<MigrationInfo> {
66 let operation = Operation::Upload;
67 let jwt_info = hub_client.get_cas_jwt(operation).await?;
68 let token_refresher = Arc::new(HubClientTokenRefresher {
69 operation,
70 client: Arc::new(hub_client),
71 }) as Arc<dyn TokenRefresher>;
72 let cas = cas_endpoint.unwrap_or(jwt_info.cas_url);
73
74 let mut headers = http::HeaderMap::new();
76 headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
77
78 let config = default_config(
79 cas,
80 Some((jwt_info.access_token, jwt_info.exp)),
81 Some(token_refresher),
82 Some(Arc::new(headers)),
83 )?;
84 Span::current().record("session_id", &config.session.session_id);
85
86 let num_workers = if sequential {
87 1
88 } else {
89 XetRuntime::current().num_worker_threads()
90 };
91 let processor = if dry_run {
92 FileUploadSession::dry_run(config.into()).await?
93 } else {
94 FileUploadSession::new(config.into()).await?
95 };
96
97 let sha256_policies: Vec<Sha256Policy> = match sha256s {
98 Some(v) => {
99 if v.len() != file_paths.len() {
100 return Err(DataError::ParameterError(
101 "mismatched length of the file list and the sha256 list".to_string(),
102 ));
103 }
104 v.iter().map(|s| Sha256Policy::from_hex(s)).collect()
105 },
106 None => vec![Sha256Policy::Compute; file_paths.len()],
107 };
108
109 let clean_futs = file_paths.into_iter().zip(sha256_policies).map(|(file_path, policy)| {
110 let proc = processor.clone();
111 async move {
112 let (pf, metrics) = clean_file(proc, file_path, policy).await?;
113 Ok::<(XetFileInfo, u64), DataError>((pf, metrics.new_bytes))
114 }
115 .instrument(info_span!("clean_file"))
116 });
117 let clean_ret = run_constrained(clean_futs, num_workers).await?;
118
119 if dry_run {
120 let (metrics, all_file_info) = processor.finalize_with_file_info().await?;
121 Ok((all_file_info, clean_ret, metrics.total_bytes_uploaded))
122 } else {
123 let metrics = processor.finalize().await?;
124 Ok((vec![], clean_ret, metrics.total_bytes_uploaded as u64))
125 }
126}