Skip to main content

xet_data/processing/migration_tool/
migrate.rs

1use std::sync::Arc;
2
3use http::header;
4use tracing::{Instrument, Span, info_span, instrument};
5use xet_client::cas_client::auth::TokenRefresher;
6use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
7use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
8use xet_runtime::core::XetRuntime;
9use xet_runtime::core::par_utils::run_constrained;
10
11use super::super::data_client::{clean_file, default_config};
12use super::super::{FileUploadSession, Sha256Policy, XetFileInfo};
13use super::hub_client_token_refresher::HubClientTokenRefresher;
14use crate::error::{DataError, Result};
15
16const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
17
18/// Migrate files to the Hub with external async runtime.
19/// How to use:
20/// ```no_run
21/// let file_paths = vec!["/path/to/file1".to_string(), "/path/to/file2".to_string()];
22/// let hub_endpoint = "https://huggingface.co";
23/// let hub_token = "your_token";
24/// let repo_type = "model";
25/// let repo_id = "your_repo_id";
26/// migrate_with_external_runtime(file_paths, hub_endpoint, hub_token, repo_type, repo_id).await?;
27/// ```
28pub async fn migrate_with_external_runtime(
29    file_paths: Vec<String>,
30    sha256s: Option<Vec<String>>,
31    hub_endpoint: &str,
32    cas_endpoint: Option<String>,
33    hub_token: &str,
34    repo_type: &str,
35    repo_id: &str,
36) -> Result<()> {
37    let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), "");
38    let mut headers = header::HeaderMap::new();
39    headers.insert(header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT));
40    let hub_client = HubClient::new(
41        hub_endpoint,
42        RepoInfo::try_from(repo_type, repo_id)?,
43        Some("main".to_owned()),
44        "",
45        Some(cred_helper),
46        Some(headers),
47    )?;
48
49    migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, false).await?;
50
51    Ok(())
52}
53
54/// mdb file info (if dryrun), cleaned file info, total bytes uploaded
55pub type MigrationInfo = (Vec<MDBFileInfo>, Vec<(XetFileInfo, u64)>, u64);
56
57#[instrument(skip_all, name = "migrate_files", fields(session_id = tracing::field::Empty, num_files = file_paths.len()))]
58pub async fn migrate_files_impl(
59    file_paths: Vec<String>,
60    sha256s: Option<Vec<String>>,
61    sequential: bool,
62    hub_client: HubClient,
63    cas_endpoint: Option<String>,
64    dry_run: bool,
65) -> Result<MigrationInfo> {
66    let operation = Operation::Upload;
67    let jwt_info = hub_client.get_cas_jwt(operation).await?;
68    let token_refresher = Arc::new(HubClientTokenRefresher {
69        operation,
70        client: Arc::new(hub_client),
71    }) as Arc<dyn TokenRefresher>;
72    let cas = cas_endpoint.unwrap_or(jwt_info.cas_url);
73
74    // Create headers with USER_AGENT
75    let mut headers = http::HeaderMap::new();
76    headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
77
78    let config = default_config(
79        cas,
80        Some((jwt_info.access_token, jwt_info.exp)),
81        Some(token_refresher),
82        Some(Arc::new(headers)),
83    )?;
84    Span::current().record("session_id", &config.session.session_id);
85
86    let num_workers = if sequential {
87        1
88    } else {
89        XetRuntime::current().num_worker_threads()
90    };
91    let processor = if dry_run {
92        FileUploadSession::dry_run(config.into()).await?
93    } else {
94        FileUploadSession::new(config.into()).await?
95    };
96
97    let sha256_policies: Vec<Sha256Policy> = match sha256s {
98        Some(v) => {
99            if v.len() != file_paths.len() {
100                return Err(DataError::ParameterError(
101                    "mismatched length of the file list and the sha256 list".to_string(),
102                ));
103            }
104            v.iter().map(|s| Sha256Policy::from_hex(s)).collect()
105        },
106        None => vec![Sha256Policy::Compute; file_paths.len()],
107    };
108
109    let clean_futs = file_paths.into_iter().zip(sha256_policies).map(|(file_path, policy)| {
110        let proc = processor.clone();
111        async move {
112            let (pf, metrics) = clean_file(proc, file_path, policy).await?;
113            Ok::<(XetFileInfo, u64), DataError>((pf, metrics.new_bytes))
114        }
115        .instrument(info_span!("clean_file"))
116    });
117    let clean_ret = run_constrained(clean_futs, num_workers).await?;
118
119    if dry_run {
120        let (metrics, all_file_info) = processor.finalize_with_file_info().await?;
121        Ok((all_file_info, clean_ret, metrics.total_bytes_uploaded))
122    } else {
123        let metrics = processor.finalize().await?;
124        Ok((vec![], clean_ret, metrics.total_bytes_uploaded as u64))
125    }
126}