1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::fmt::Write as _;
4use std::fs::File;
5use std::io;
6use std::io::Read;
7use std::path::{Path, PathBuf};
8
9use memmap2::MmapOptions;
10use reqwest::blocking::{Client, Response};
11use reqwest::header::{ACCEPT_ENCODING, AUTHORIZATION, CONTENT_RANGE, RANGE};
12use reqwest::{StatusCode, Url};
13use safetensors::{tensor::Metadata, SafeTensorError, SafeTensors};
14
15const HEADER_PREFIX_LEN: u64 = 8;
16const MAX_HEADER_SIZE: usize = 100_000_000;
17
18#[derive(Debug)]
19pub enum InspectError {
20 FileNotFound(PathBuf),
21 CannotRead {
22 path: PathBuf,
23 source: io::Error,
24 },
25 InvalidSafetensors {
26 path: String,
27 source: SafeTensorError,
28 },
29 Overflow {
30 tensor: String,
31 },
32 MissingTensorInfo {
33 tensor: String,
34 },
35 UnsupportedUrlScheme(String),
36 HttpClient(reqwest::Error),
37 HttpRequest {
38 url: String,
39 source: reqwest::Error,
40 },
41 RangeUnsupported(String),
42 InvalidRemoteResponse {
43 url: String,
44 reason: String,
45 },
46}
47
48impl std::fmt::Display for InspectError {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 Self::FileNotFound(path) => write!(f, "file not found: {}", path.display()),
52 Self::CannotRead { path, source } => {
53 write!(f, "failed to read file: {} ({source})", path.display())
54 }
55 Self::InvalidSafetensors { path, source } => {
56 write!(f, "invalid safetensors file: {path} ({source})")
57 }
58 Self::Overflow { tensor } => {
59 write!(f, "tensor is too large to summarize safely: {tensor}")
60 }
61 Self::MissingTensorInfo { tensor } => {
62 write!(f, "missing tensor metadata for: {tensor}")
63 }
64 Self::UnsupportedUrlScheme(scheme) => {
65 write!(f, "unsupported URL scheme: {scheme}")
66 }
67 Self::HttpClient(source) => write!(f, "failed to initialize HTTP client ({source})"),
68 Self::HttpRequest { url, source } => {
69 write!(f, "failed to fetch remote file: {url} ({source})")
70 }
71 Self::RangeUnsupported(url) => {
72 write!(
73 f,
74 "remote server does not support byte range requests: {url}"
75 )
76 }
77 Self::InvalidRemoteResponse { url, reason } => {
78 write!(f, "invalid remote response for {url} ({reason})")
79 }
80 }
81 }
82}
83
84impl std::error::Error for InspectError {}
85
86#[derive(Debug)]
87pub struct Report {
88 file_path: String,
89 file_size: u64,
90 tensor_count: usize,
91 total_parameters: u128,
92 total_tensor_bytes: u128,
93 metadata: Vec<(String, String)>,
94 dtype_breakdown: Vec<(String, u128)>,
95 tensors: Vec<TensorSummary>,
96}
97
98#[derive(Debug)]
99pub struct TensorSummary {
100 name: String,
101 dtype: String,
102 shape: Vec<usize>,
103 numel: u128,
104 bytes: u128,
105}
106
107pub fn inspect_input(input: &str) -> Result<Report, InspectError> {
108 match classify_input(input)? {
109 Input::LocalPath(path) => inspect_local_file(input, path),
110 Input::HttpUrl(url) => inspect_remote_file(input, &url),
111 }
112}
113
114pub fn render_report(report: &Report) -> String {
115 let mut output = String::new();
116
117 writeln!(&mut output, "File: {}", report.file_path).unwrap();
118 writeln!(&mut output, "Size: {} bytes", report.file_size).unwrap();
119 writeln!(&mut output, "Tensors: {}", report.tensor_count).unwrap();
120 writeln!(&mut output, "Parameters: {}", report.total_parameters).unwrap();
121 writeln!(&mut output, "Tensor-Bytes: {}", report.total_tensor_bytes).unwrap();
122 writeln!(&mut output).unwrap();
123
124 writeln!(&mut output, "Metadata:").unwrap();
125 if report.metadata.is_empty() {
126 writeln!(&mut output, " (none)").unwrap();
127 } else {
128 for (key, value) in &report.metadata {
129 writeln!(&mut output, " {key} = {value}").unwrap();
130 }
131 }
132 writeln!(&mut output).unwrap();
133
134 writeln!(&mut output, "DType Breakdown:").unwrap();
135 if report.dtype_breakdown.is_empty() {
136 writeln!(&mut output, " (none)").unwrap();
137 } else {
138 for (dtype, bytes) in &report.dtype_breakdown {
139 writeln!(&mut output, " {dtype}: {bytes} bytes").unwrap();
140 }
141 }
142 writeln!(&mut output).unwrap();
143
144 writeln!(&mut output, "Tensors:").unwrap();
145 if report.tensors.is_empty() {
146 writeln!(&mut output, " (none)").unwrap();
147 return output;
148 }
149
150 for (index, tensor) in report.tensors.iter().enumerate() {
151 if index > 0 {
152 writeln!(&mut output).unwrap();
153 }
154
155 writeln!(&mut output, " {}", tensor.name).unwrap();
156 writeln!(&mut output, " dtype: {}", tensor.dtype).unwrap();
157 writeln!(&mut output, " shape: {}", format_shape(&tensor.shape)).unwrap();
158 writeln!(&mut output, " numel: {}", tensor.numel).unwrap();
159 writeln!(&mut output, " bytes: {}", tensor.bytes).unwrap();
160 }
161
162 output
163}
164
165#[derive(Debug)]
166enum Input<'a> {
167 LocalPath(&'a Path),
168 HttpUrl(Url),
169}
170
171fn classify_input(input: &str) -> Result<Input<'_>, InspectError> {
172 if !input.contains("://") {
173 return Ok(Input::LocalPath(Path::new(input)));
174 }
175
176 match Url::parse(input) {
177 Ok(url) => match url.scheme() {
178 "http" | "https" => Ok(Input::HttpUrl(url)),
179 scheme => Err(InspectError::UnsupportedUrlScheme(scheme.to_owned())),
180 },
181 Err(_) => Err(InspectError::InvalidRemoteResponse {
182 url: input.to_owned(),
183 reason: "malformed URL".to_owned(),
184 }),
185 }
186}
187
188fn inspect_local_file(input: &str, path: &Path) -> Result<Report, InspectError> {
189 let file = File::open(path).map_err(|source| match source.kind() {
190 io::ErrorKind::NotFound => InspectError::FileNotFound(path.to_path_buf()),
191 _ => InspectError::CannotRead {
192 path: path.to_path_buf(),
193 source,
194 },
195 })?;
196
197 let file_size = file
198 .metadata()
199 .map_err(|source| InspectError::CannotRead {
200 path: path.to_path_buf(),
201 source,
202 })?
203 .len();
204
205 let mmap =
206 unsafe { MmapOptions::new().map(&file) }.map_err(|source| InspectError::CannotRead {
207 path: path.to_path_buf(),
208 source,
209 })?;
210
211 let (_, metadata) =
212 SafeTensors::read_metadata(&mmap).map_err(|source| InspectError::InvalidSafetensors {
213 path: input.to_owned(),
214 source,
215 })?;
216
217 build_report(input, file_size, &metadata)
218}
219
220fn inspect_remote_file(input: &str, url: &Url) -> Result<Report, InspectError> {
221 let client = build_http_client()?;
222 let (file_size, header_len) = fetch_header_prefix(&client, url)?;
223 if header_len > MAX_HEADER_SIZE {
224 return Err(InspectError::InvalidSafetensors {
225 path: input.to_owned(),
226 source: SafeTensorError::HeaderTooLarge,
227 });
228 }
229
230 let header_bytes = fetch_header_bytes(&client, url, header_len)?;
231 let metadata: Metadata = serde_json::from_slice(&header_bytes).map_err(|source| {
232 InspectError::InvalidSafetensors {
233 path: input.to_owned(),
234 source: SafeTensorError::InvalidHeaderDeserialization(source),
235 }
236 })?;
237
238 let expected_size = HEADER_PREFIX_LEN
239 .checked_add(header_len as u64)
240 .and_then(|value| value.checked_add(metadata.data_len() as u64))
241 .ok_or_else(|| InspectError::InvalidSafetensors {
242 path: input.to_owned(),
243 source: SafeTensorError::ValidationOverflow,
244 })?;
245
246 if expected_size != file_size {
247 return Err(InspectError::InvalidSafetensors {
248 path: input.to_owned(),
249 source: SafeTensorError::MetadataIncompleteBuffer,
250 });
251 }
252
253 build_report(input, file_size, &metadata)
254}
255
256fn build_http_client() -> Result<Client, InspectError> {
257 Client::builder()
258 .redirect(reqwest::redirect::Policy::limited(10))
259 .user_agent(format!("stprobe/{}", env!("CARGO_PKG_VERSION")))
260 .build()
261 .map_err(InspectError::HttpClient)
262}
263
264fn fetch_header_prefix(client: &Client, url: &Url) -> Result<(u64, usize), InspectError> {
265 let response = ranged_get(client, url, 0, HEADER_PREFIX_LEN - 1)?;
266 if response.status() != StatusCode::PARTIAL_CONTENT {
267 return Err(InspectError::RangeUnsupported(url.to_string()));
268 }
269
270 let file_size = parse_total_size(&response, url)?;
271 let bytes = read_response_bytes(response, HEADER_PREFIX_LEN as usize, url)?;
272 let header_len = u64::from_le_bytes(
273 bytes[..HEADER_PREFIX_LEN as usize]
274 .try_into()
275 .expect("slice length is checked by read_response_bytes"),
276 );
277
278 let header_len = header_len
279 .try_into()
280 .map_err(|_| InspectError::InvalidSafetensors {
281 path: url.to_string(),
282 source: SafeTensorError::HeaderTooLarge,
283 })?;
284
285 Ok((file_size, header_len))
286}
287
288fn fetch_header_bytes(
289 client: &Client,
290 url: &Url,
291 header_len: usize,
292) -> Result<Vec<u8>, InspectError> {
293 let start = HEADER_PREFIX_LEN;
294 let end = start
295 .checked_add(header_len as u64)
296 .and_then(|value| value.checked_sub(1))
297 .ok_or_else(|| InspectError::InvalidRemoteResponse {
298 url: url.to_string(),
299 reason: "invalid header range".to_owned(),
300 })?;
301
302 let response = ranged_get(client, url, start, end)?;
303 if response.status() != StatusCode::PARTIAL_CONTENT {
304 return Err(InspectError::RangeUnsupported(url.to_string()));
305 }
306
307 read_response_bytes(response, header_len, url)
308}
309
310fn ranged_get(client: &Client, url: &Url, start: u64, end: u64) -> Result<Response, InspectError> {
311 let mut request = client
312 .get(url.clone())
313 .header(RANGE, format!("bytes={start}-{end}"))
314 .header(ACCEPT_ENCODING, "identity");
315
316 if is_hugging_face_url(url) {
317 if let Ok(token) = std::env::var("HF_TOKEN") {
318 if !token.is_empty() {
319 request = request.header(AUTHORIZATION, format!("Bearer {token}"));
320 }
321 }
322 }
323
324 request.send().map_err(|source| InspectError::HttpRequest {
325 url: url.to_string(),
326 source,
327 })
328}
329
330fn is_hugging_face_url(url: &Url) -> bool {
331 matches!(
332 url.host_str(),
333 Some("huggingface.co") | Some("www.huggingface.co")
334 )
335}
336
337fn parse_total_size(response: &Response, url: &Url) -> Result<u64, InspectError> {
338 let content_range = response
339 .headers()
340 .get(CONTENT_RANGE)
341 .ok_or_else(|| InspectError::InvalidRemoteResponse {
342 url: url.to_string(),
343 reason: "missing Content-Range header".to_owned(),
344 })?
345 .to_str()
346 .map_err(|_| InspectError::InvalidRemoteResponse {
347 url: url.to_string(),
348 reason: "invalid Content-Range header".to_owned(),
349 })?;
350
351 parse_total_size_from_content_range(content_range).map_err(|reason| {
352 InspectError::InvalidRemoteResponse {
353 url: url.to_string(),
354 reason,
355 }
356 })
357}
358
359fn parse_total_size_from_content_range(content_range: &str) -> Result<u64, String> {
360 let total = content_range
361 .rsplit('/')
362 .next()
363 .ok_or_else(|| "malformed Content-Range header".to_owned())?;
364
365 total
366 .parse::<u64>()
367 .map_err(|_| "invalid total size in Content-Range header".to_owned())
368}
369
370fn read_response_bytes(
371 mut response: Response,
372 expected_len: usize,
373 url: &Url,
374) -> Result<Vec<u8>, InspectError> {
375 let mut bytes = Vec::with_capacity(expected_len);
376 response
377 .read_to_end(&mut bytes)
378 .map_err(|source| InspectError::InvalidRemoteResponse {
379 url: url.to_string(),
380 reason: format!("failed reading response body ({source})"),
381 })?;
382
383 if bytes.len() != expected_len {
384 return Err(InspectError::InvalidRemoteResponse {
385 url: url.to_string(),
386 reason: format!("expected {expected_len} bytes, got {}", bytes.len()),
387 });
388 }
389
390 Ok(bytes)
391}
392
393fn build_report(input: &str, file_size: u64, metadata: &Metadata) -> Result<Report, InspectError> {
394 let mut total_parameters = 0_u128;
395 let mut total_tensor_bytes = 0_u128;
396 let mut tensors = Vec::new();
397 let mut dtype_breakdown = BTreeMap::<String, u128>::new();
398
399 for name in metadata.offset_keys() {
400 let info = metadata
401 .info(&name)
402 .ok_or_else(|| InspectError::MissingTensorInfo {
403 tensor: name.clone(),
404 })?;
405
406 let numel = numel(&info.shape, &name)?;
407 let bytes = (info.data_offsets.1 - info.data_offsets.0) as u128;
408 let dtype = info.dtype.to_string();
409
410 total_parameters =
411 total_parameters
412 .checked_add(numel)
413 .ok_or_else(|| InspectError::Overflow {
414 tensor: name.clone(),
415 })?;
416 total_tensor_bytes =
417 total_tensor_bytes
418 .checked_add(bytes)
419 .ok_or_else(|| InspectError::Overflow {
420 tensor: name.clone(),
421 })?;
422 *dtype_breakdown.entry(dtype.clone()).or_insert(0) += bytes;
423
424 tensors.push(TensorSummary {
425 name,
426 dtype,
427 shape: info.shape.clone(),
428 numel,
429 bytes,
430 });
431 }
432
433 tensors.sort_by(|left, right| compare_tensor_names(&left.name, &right.name));
434
435 let mut metadata_entries = metadata
436 .metadata()
437 .as_ref()
438 .map(|entries| {
439 entries
440 .iter()
441 .map(|(key, value)| (key.clone(), value.clone()))
442 .collect::<Vec<_>>()
443 })
444 .unwrap_or_default();
445 metadata_entries.sort_by(|left, right| left.0.cmp(&right.0));
446
447 Ok(Report {
448 file_path: input.to_owned(),
449 file_size,
450 tensor_count: tensors.len(),
451 total_parameters,
452 total_tensor_bytes,
453 metadata: metadata_entries,
454 dtype_breakdown: dtype_breakdown.into_iter().collect(),
455 tensors,
456 })
457}
458
459fn numel(shape: &[usize], tensor_name: &str) -> Result<u128, InspectError> {
460 shape.iter().try_fold(1_u128, |acc, &dim| {
461 acc.checked_mul(dim as u128)
462 .ok_or_else(|| InspectError::Overflow {
463 tensor: tensor_name.to_owned(),
464 })
465 })
466}
467
468fn compare_tensor_names(left: &str, right: &str) -> Ordering {
469 let mut left_parts = left.split('.');
470 let mut right_parts = right.split('.');
471
472 loop {
473 match (left_parts.next(), right_parts.next()) {
474 (Some(left_part), Some(right_part)) => {
475 let ordering = compare_natural_str(left_part, right_part);
476 if ordering != Ordering::Equal {
477 return ordering;
478 }
479 }
480 (Some(_), None) => return Ordering::Greater,
481 (None, Some(_)) => return Ordering::Less,
482 (None, None) => return Ordering::Equal,
483 }
484 }
485}
486
487fn compare_natural_str(left: &str, right: &str) -> Ordering {
488 let mut left_chunks = ChunkIter::new(left);
489 let mut right_chunks = ChunkIter::new(right);
490
491 loop {
492 match (left_chunks.next(), right_chunks.next()) {
493 (Some(left_chunk), Some(right_chunk)) => {
494 let ordering = match (left_chunk, right_chunk) {
495 (Chunk::Digits(left_digits), Chunk::Digits(right_digits)) => {
496 compare_digit_chunks(left_digits, right_digits)
497 }
498 (Chunk::Text(left_text), Chunk::Text(right_text)) => left_text.cmp(right_text),
499 (Chunk::Digits(_), Chunk::Text(_)) => Ordering::Less,
500 (Chunk::Text(_), Chunk::Digits(_)) => Ordering::Greater,
501 };
502
503 if ordering != Ordering::Equal {
504 return ordering;
505 }
506 }
507 (Some(_), None) => return Ordering::Greater,
508 (None, Some(_)) => return Ordering::Less,
509 (None, None) => return Ordering::Equal,
510 }
511 }
512}
513
514fn compare_digit_chunks(left: &str, right: &str) -> Ordering {
515 let left_trimmed = left.trim_start_matches('0');
516 let right_trimmed = right.trim_start_matches('0');
517 let left_normalized = if left_trimmed.is_empty() {
518 "0"
519 } else {
520 left_trimmed
521 };
522 let right_normalized = if right_trimmed.is_empty() {
523 "0"
524 } else {
525 right_trimmed
526 };
527
528 left_normalized
529 .len()
530 .cmp(&right_normalized.len())
531 .then_with(|| left_normalized.cmp(right_normalized))
532 .then_with(|| left.len().cmp(&right.len()))
533}
534
535#[derive(Clone, Copy)]
536enum Chunk<'a> {
537 Digits(&'a str),
538 Text(&'a str),
539}
540
541struct ChunkIter<'a> {
542 input: &'a str,
543 index: usize,
544}
545
546impl<'a> ChunkIter<'a> {
547 fn new(input: &'a str) -> Self {
548 Self { input, index: 0 }
549 }
550}
551
552impl<'a> Iterator for ChunkIter<'a> {
553 type Item = Chunk<'a>;
554
555 fn next(&mut self) -> Option<Self::Item> {
556 if self.index >= self.input.len() {
557 return None;
558 }
559
560 let rest = &self.input[self.index..];
561 let mut chars = rest.char_indices();
562 let (_, first) = chars.next()?;
563 let is_digit = first.is_ascii_digit();
564 let mut end = rest.len();
565
566 for (offset, ch) in chars {
567 if ch.is_ascii_digit() != is_digit {
568 end = offset;
569 break;
570 }
571 }
572
573 let chunk = &rest[..end];
574 self.index += end;
575
576 Some(if is_digit {
577 Chunk::Digits(chunk)
578 } else {
579 Chunk::Text(chunk)
580 })
581 }
582}
583
584fn format_shape(shape: &[usize]) -> String {
585 let dims = shape
586 .iter()
587 .map(|dim| dim.to_string())
588 .collect::<Vec<_>>()
589 .join(", ");
590 format!("[{dims}]")
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
598 fn classifies_local_paths() {
599 match classify_input("model.safetensors").unwrap() {
600 Input::LocalPath(path) => assert_eq!(path, Path::new("model.safetensors")),
601 Input::HttpUrl(_) => panic!("expected local path"),
602 }
603 }
604
605 #[test]
606 fn classifies_https_urls() {
607 match classify_input("https://example.com/model.safetensors").unwrap() {
608 Input::HttpUrl(url) => {
609 assert_eq!(url.scheme(), "https");
610 assert_eq!(url.host_str(), Some("example.com"));
611 }
612 Input::LocalPath(_) => panic!("expected URL"),
613 }
614 }
615
616 #[test]
617 fn rejects_unsupported_schemes() {
618 match classify_input("hf://org/repo/file.safetensors").unwrap_err() {
619 InspectError::UnsupportedUrlScheme(scheme) => assert_eq!(scheme, "hf"),
620 other => panic!("unexpected error: {other}"),
621 }
622 }
623
624 #[test]
625 fn classifies_windows_drive_paths_as_local() {
626 match classify_input(r"C:\models\sample.safetensors").unwrap() {
627 Input::LocalPath(path) => {
628 assert_eq!(path, Path::new(r"C:\models\sample.safetensors"));
629 }
630 Input::HttpUrl(_) => panic!("expected local path"),
631 }
632 }
633
634 #[test]
635 fn parses_total_size_from_content_range() {
636 let total = parse_total_size_from_content_range("bytes 0-7/17246524772").unwrap();
637 assert_eq!(total, 17_246_524_772);
638 }
639
640 #[test]
641 fn rejects_malformed_content_range() {
642 let error = parse_total_size_from_content_range("bytes 0-7/*").unwrap_err();
643 assert_eq!(error, "invalid total size in Content-Range header");
644 }
645
646 #[test]
647 fn computes_numel() {
648 assert_eq!(numel(&[2, 3, 4], "tensor").unwrap(), 24);
649 }
650
651 #[test]
652 fn formats_shapes() {
653 assert_eq!(format_shape(&[2, 3, 4]), "[2, 3, 4]");
654 assert_eq!(format_shape(&[]), "[]");
655 }
656
657 #[test]
658 fn sorts_tensor_names_naturally() {
659 let mut names = vec![
660 "encoder.layer.10.output.weight",
661 "encoder.layer.2.output.weight",
662 "encoder.layer.1.output.weight",
663 "encoder.layer.2.output.bias",
664 ];
665
666 names.sort_by(|left, right| compare_tensor_names(left, right));
667
668 assert_eq!(
669 names,
670 vec![
671 "encoder.layer.1.output.weight",
672 "encoder.layer.2.output.bias",
673 "encoder.layer.2.output.weight",
674 "encoder.layer.10.output.weight",
675 ]
676 );
677 }
678}