Skip to main content

stprobe/
lib.rs

1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::fmt::Write as _;
4use std::fs::File;
5use std::io;
6use std::io::Read;
7use std::path::{Path, PathBuf};
8
9use memmap2::MmapOptions;
10use reqwest::blocking::{Client, Response};
11use reqwest::header::{ACCEPT_ENCODING, AUTHORIZATION, CONTENT_RANGE, RANGE};
12use reqwest::{StatusCode, Url};
13use safetensors::{tensor::Metadata, SafeTensorError, SafeTensors};
14
15const HEADER_PREFIX_LEN: u64 = 8;
16const MAX_HEADER_SIZE: usize = 100_000_000;
17
18#[derive(Debug)]
19pub enum InspectError {
20    FileNotFound(PathBuf),
21    CannotRead {
22        path: PathBuf,
23        source: io::Error,
24    },
25    InvalidSafetensors {
26        path: String,
27        source: SafeTensorError,
28    },
29    Overflow {
30        tensor: String,
31    },
32    MissingTensorInfo {
33        tensor: String,
34    },
35    UnsupportedUrlScheme(String),
36    HttpClient(reqwest::Error),
37    HttpRequest {
38        url: String,
39        source: reqwest::Error,
40    },
41    RangeUnsupported(String),
42    InvalidRemoteResponse {
43        url: String,
44        reason: String,
45    },
46}
47
48impl std::fmt::Display for InspectError {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            Self::FileNotFound(path) => write!(f, "file not found: {}", path.display()),
52            Self::CannotRead { path, source } => {
53                write!(f, "failed to read file: {} ({source})", path.display())
54            }
55            Self::InvalidSafetensors { path, source } => {
56                write!(f, "invalid safetensors file: {path} ({source})")
57            }
58            Self::Overflow { tensor } => {
59                write!(f, "tensor is too large to summarize safely: {tensor}")
60            }
61            Self::MissingTensorInfo { tensor } => {
62                write!(f, "missing tensor metadata for: {tensor}")
63            }
64            Self::UnsupportedUrlScheme(scheme) => {
65                write!(f, "unsupported URL scheme: {scheme}")
66            }
67            Self::HttpClient(source) => write!(f, "failed to initialize HTTP client ({source})"),
68            Self::HttpRequest { url, source } => {
69                write!(f, "failed to fetch remote file: {url} ({source})")
70            }
71            Self::RangeUnsupported(url) => {
72                write!(
73                    f,
74                    "remote server does not support byte range requests: {url}"
75                )
76            }
77            Self::InvalidRemoteResponse { url, reason } => {
78                write!(f, "invalid remote response for {url} ({reason})")
79            }
80        }
81    }
82}
83
84impl std::error::Error for InspectError {}
85
86#[derive(Debug)]
87pub struct Report {
88    file_path: String,
89    file_size: u64,
90    tensor_count: usize,
91    total_parameters: u128,
92    total_tensor_bytes: u128,
93    metadata: Vec<(String, String)>,
94    dtype_breakdown: Vec<(String, u128)>,
95    tensors: Vec<TensorSummary>,
96}
97
98#[derive(Debug)]
99pub struct TensorSummary {
100    name: String,
101    dtype: String,
102    shape: Vec<usize>,
103    numel: u128,
104    bytes: u128,
105}
106
107pub fn inspect_input(input: &str) -> Result<Report, InspectError> {
108    match classify_input(input)? {
109        Input::LocalPath(path) => inspect_local_file(input, path),
110        Input::HttpUrl(url) => inspect_remote_file(input, &url),
111    }
112}
113
114pub fn render_report(report: &Report) -> String {
115    let mut output = String::new();
116
117    writeln!(&mut output, "File: {}", report.file_path).unwrap();
118    writeln!(&mut output, "Size: {} bytes", report.file_size).unwrap();
119    writeln!(&mut output, "Tensors: {}", report.tensor_count).unwrap();
120    writeln!(&mut output, "Parameters: {}", report.total_parameters).unwrap();
121    writeln!(&mut output, "Tensor-Bytes: {}", report.total_tensor_bytes).unwrap();
122    writeln!(&mut output).unwrap();
123
124    writeln!(&mut output, "Metadata:").unwrap();
125    if report.metadata.is_empty() {
126        writeln!(&mut output, "  (none)").unwrap();
127    } else {
128        for (key, value) in &report.metadata {
129            writeln!(&mut output, "  {key} = {value}").unwrap();
130        }
131    }
132    writeln!(&mut output).unwrap();
133
134    writeln!(&mut output, "DType Breakdown:").unwrap();
135    if report.dtype_breakdown.is_empty() {
136        writeln!(&mut output, "  (none)").unwrap();
137    } else {
138        for (dtype, bytes) in &report.dtype_breakdown {
139            writeln!(&mut output, "  {dtype}: {bytes} bytes").unwrap();
140        }
141    }
142    writeln!(&mut output).unwrap();
143
144    writeln!(&mut output, "Tensors:").unwrap();
145    if report.tensors.is_empty() {
146        writeln!(&mut output, "  (none)").unwrap();
147        return output;
148    }
149
150    for (index, tensor) in report.tensors.iter().enumerate() {
151        if index > 0 {
152            writeln!(&mut output).unwrap();
153        }
154
155        writeln!(&mut output, "  {}", tensor.name).unwrap();
156        writeln!(&mut output, "    dtype: {}", tensor.dtype).unwrap();
157        writeln!(&mut output, "    shape: {}", format_shape(&tensor.shape)).unwrap();
158        writeln!(&mut output, "    numel: {}", tensor.numel).unwrap();
159        writeln!(&mut output, "    bytes: {}", tensor.bytes).unwrap();
160    }
161
162    output
163}
164
165#[derive(Debug)]
166enum Input<'a> {
167    LocalPath(&'a Path),
168    HttpUrl(Url),
169}
170
171fn classify_input(input: &str) -> Result<Input<'_>, InspectError> {
172    if !input.contains("://") {
173        return Ok(Input::LocalPath(Path::new(input)));
174    }
175
176    match Url::parse(input) {
177        Ok(url) => match url.scheme() {
178            "http" | "https" => Ok(Input::HttpUrl(url)),
179            scheme => Err(InspectError::UnsupportedUrlScheme(scheme.to_owned())),
180        },
181        Err(_) => Err(InspectError::InvalidRemoteResponse {
182            url: input.to_owned(),
183            reason: "malformed URL".to_owned(),
184        }),
185    }
186}
187
188fn inspect_local_file(input: &str, path: &Path) -> Result<Report, InspectError> {
189    let file = File::open(path).map_err(|source| match source.kind() {
190        io::ErrorKind::NotFound => InspectError::FileNotFound(path.to_path_buf()),
191        _ => InspectError::CannotRead {
192            path: path.to_path_buf(),
193            source,
194        },
195    })?;
196
197    let file_size = file
198        .metadata()
199        .map_err(|source| InspectError::CannotRead {
200            path: path.to_path_buf(),
201            source,
202        })?
203        .len();
204
205    let mmap =
206        unsafe { MmapOptions::new().map(&file) }.map_err(|source| InspectError::CannotRead {
207            path: path.to_path_buf(),
208            source,
209        })?;
210
211    let (_, metadata) =
212        SafeTensors::read_metadata(&mmap).map_err(|source| InspectError::InvalidSafetensors {
213            path: input.to_owned(),
214            source,
215        })?;
216
217    build_report(input, file_size, &metadata)
218}
219
220fn inspect_remote_file(input: &str, url: &Url) -> Result<Report, InspectError> {
221    let client = build_http_client()?;
222    let (file_size, header_len) = fetch_header_prefix(&client, url)?;
223    if header_len > MAX_HEADER_SIZE {
224        return Err(InspectError::InvalidSafetensors {
225            path: input.to_owned(),
226            source: SafeTensorError::HeaderTooLarge,
227        });
228    }
229
230    let header_bytes = fetch_header_bytes(&client, url, header_len)?;
231    let metadata: Metadata = serde_json::from_slice(&header_bytes).map_err(|source| {
232        InspectError::InvalidSafetensors {
233            path: input.to_owned(),
234            source: SafeTensorError::InvalidHeaderDeserialization(source),
235        }
236    })?;
237
238    let expected_size = HEADER_PREFIX_LEN
239        .checked_add(header_len as u64)
240        .and_then(|value| value.checked_add(metadata.data_len() as u64))
241        .ok_or_else(|| InspectError::InvalidSafetensors {
242            path: input.to_owned(),
243            source: SafeTensorError::ValidationOverflow,
244        })?;
245
246    if expected_size != file_size {
247        return Err(InspectError::InvalidSafetensors {
248            path: input.to_owned(),
249            source: SafeTensorError::MetadataIncompleteBuffer,
250        });
251    }
252
253    build_report(input, file_size, &metadata)
254}
255
256fn build_http_client() -> Result<Client, InspectError> {
257    Client::builder()
258        .redirect(reqwest::redirect::Policy::limited(10))
259        .user_agent(format!("stprobe/{}", env!("CARGO_PKG_VERSION")))
260        .build()
261        .map_err(InspectError::HttpClient)
262}
263
264fn fetch_header_prefix(client: &Client, url: &Url) -> Result<(u64, usize), InspectError> {
265    let response = ranged_get(client, url, 0, HEADER_PREFIX_LEN - 1)?;
266    if response.status() != StatusCode::PARTIAL_CONTENT {
267        return Err(InspectError::RangeUnsupported(url.to_string()));
268    }
269
270    let file_size = parse_total_size(&response, url)?;
271    let bytes = read_response_bytes(response, HEADER_PREFIX_LEN as usize, url)?;
272    let header_len = u64::from_le_bytes(
273        bytes[..HEADER_PREFIX_LEN as usize]
274            .try_into()
275            .expect("slice length is checked by read_response_bytes"),
276    );
277
278    let header_len = header_len
279        .try_into()
280        .map_err(|_| InspectError::InvalidSafetensors {
281            path: url.to_string(),
282            source: SafeTensorError::HeaderTooLarge,
283        })?;
284
285    Ok((file_size, header_len))
286}
287
288fn fetch_header_bytes(
289    client: &Client,
290    url: &Url,
291    header_len: usize,
292) -> Result<Vec<u8>, InspectError> {
293    let start = HEADER_PREFIX_LEN;
294    let end = start
295        .checked_add(header_len as u64)
296        .and_then(|value| value.checked_sub(1))
297        .ok_or_else(|| InspectError::InvalidRemoteResponse {
298            url: url.to_string(),
299            reason: "invalid header range".to_owned(),
300        })?;
301
302    let response = ranged_get(client, url, start, end)?;
303    if response.status() != StatusCode::PARTIAL_CONTENT {
304        return Err(InspectError::RangeUnsupported(url.to_string()));
305    }
306
307    read_response_bytes(response, header_len, url)
308}
309
310fn ranged_get(client: &Client, url: &Url, start: u64, end: u64) -> Result<Response, InspectError> {
311    let mut request = client
312        .get(url.clone())
313        .header(RANGE, format!("bytes={start}-{end}"))
314        .header(ACCEPT_ENCODING, "identity");
315
316    if is_hugging_face_url(url) {
317        if let Ok(token) = std::env::var("HF_TOKEN") {
318            if !token.is_empty() {
319                request = request.header(AUTHORIZATION, format!("Bearer {token}"));
320            }
321        }
322    }
323
324    request.send().map_err(|source| InspectError::HttpRequest {
325        url: url.to_string(),
326        source,
327    })
328}
329
330fn is_hugging_face_url(url: &Url) -> bool {
331    matches!(
332        url.host_str(),
333        Some("huggingface.co") | Some("www.huggingface.co")
334    )
335}
336
337fn parse_total_size(response: &Response, url: &Url) -> Result<u64, InspectError> {
338    let content_range = response
339        .headers()
340        .get(CONTENT_RANGE)
341        .ok_or_else(|| InspectError::InvalidRemoteResponse {
342            url: url.to_string(),
343            reason: "missing Content-Range header".to_owned(),
344        })?
345        .to_str()
346        .map_err(|_| InspectError::InvalidRemoteResponse {
347            url: url.to_string(),
348            reason: "invalid Content-Range header".to_owned(),
349        })?;
350
351    parse_total_size_from_content_range(content_range).map_err(|reason| {
352        InspectError::InvalidRemoteResponse {
353            url: url.to_string(),
354            reason,
355        }
356    })
357}
358
359fn parse_total_size_from_content_range(content_range: &str) -> Result<u64, String> {
360    let total = content_range
361        .rsplit('/')
362        .next()
363        .ok_or_else(|| "malformed Content-Range header".to_owned())?;
364
365    total
366        .parse::<u64>()
367        .map_err(|_| "invalid total size in Content-Range header".to_owned())
368}
369
370fn read_response_bytes(
371    mut response: Response,
372    expected_len: usize,
373    url: &Url,
374) -> Result<Vec<u8>, InspectError> {
375    let mut bytes = Vec::with_capacity(expected_len);
376    response
377        .read_to_end(&mut bytes)
378        .map_err(|source| InspectError::InvalidRemoteResponse {
379            url: url.to_string(),
380            reason: format!("failed reading response body ({source})"),
381        })?;
382
383    if bytes.len() != expected_len {
384        return Err(InspectError::InvalidRemoteResponse {
385            url: url.to_string(),
386            reason: format!("expected {expected_len} bytes, got {}", bytes.len()),
387        });
388    }
389
390    Ok(bytes)
391}
392
393fn build_report(input: &str, file_size: u64, metadata: &Metadata) -> Result<Report, InspectError> {
394    let mut total_parameters = 0_u128;
395    let mut total_tensor_bytes = 0_u128;
396    let mut tensors = Vec::new();
397    let mut dtype_breakdown = BTreeMap::<String, u128>::new();
398
399    for name in metadata.offset_keys() {
400        let info = metadata
401            .info(&name)
402            .ok_or_else(|| InspectError::MissingTensorInfo {
403                tensor: name.clone(),
404            })?;
405
406        let numel = numel(&info.shape, &name)?;
407        let bytes = (info.data_offsets.1 - info.data_offsets.0) as u128;
408        let dtype = info.dtype.to_string();
409
410        total_parameters =
411            total_parameters
412                .checked_add(numel)
413                .ok_or_else(|| InspectError::Overflow {
414                    tensor: name.clone(),
415                })?;
416        total_tensor_bytes =
417            total_tensor_bytes
418                .checked_add(bytes)
419                .ok_or_else(|| InspectError::Overflow {
420                    tensor: name.clone(),
421                })?;
422        *dtype_breakdown.entry(dtype.clone()).or_insert(0) += bytes;
423
424        tensors.push(TensorSummary {
425            name,
426            dtype,
427            shape: info.shape.clone(),
428            numel,
429            bytes,
430        });
431    }
432
433    tensors.sort_by(|left, right| compare_tensor_names(&left.name, &right.name));
434
435    let mut metadata_entries = metadata
436        .metadata()
437        .as_ref()
438        .map(|entries| {
439            entries
440                .iter()
441                .map(|(key, value)| (key.clone(), value.clone()))
442                .collect::<Vec<_>>()
443        })
444        .unwrap_or_default();
445    metadata_entries.sort_by(|left, right| left.0.cmp(&right.0));
446
447    Ok(Report {
448        file_path: input.to_owned(),
449        file_size,
450        tensor_count: tensors.len(),
451        total_parameters,
452        total_tensor_bytes,
453        metadata: metadata_entries,
454        dtype_breakdown: dtype_breakdown.into_iter().collect(),
455        tensors,
456    })
457}
458
459fn numel(shape: &[usize], tensor_name: &str) -> Result<u128, InspectError> {
460    shape.iter().try_fold(1_u128, |acc, &dim| {
461        acc.checked_mul(dim as u128)
462            .ok_or_else(|| InspectError::Overflow {
463                tensor: tensor_name.to_owned(),
464            })
465    })
466}
467
468fn compare_tensor_names(left: &str, right: &str) -> Ordering {
469    let mut left_parts = left.split('.');
470    let mut right_parts = right.split('.');
471
472    loop {
473        match (left_parts.next(), right_parts.next()) {
474            (Some(left_part), Some(right_part)) => {
475                let ordering = compare_natural_str(left_part, right_part);
476                if ordering != Ordering::Equal {
477                    return ordering;
478                }
479            }
480            (Some(_), None) => return Ordering::Greater,
481            (None, Some(_)) => return Ordering::Less,
482            (None, None) => return Ordering::Equal,
483        }
484    }
485}
486
487fn compare_natural_str(left: &str, right: &str) -> Ordering {
488    let mut left_chunks = ChunkIter::new(left);
489    let mut right_chunks = ChunkIter::new(right);
490
491    loop {
492        match (left_chunks.next(), right_chunks.next()) {
493            (Some(left_chunk), Some(right_chunk)) => {
494                let ordering = match (left_chunk, right_chunk) {
495                    (Chunk::Digits(left_digits), Chunk::Digits(right_digits)) => {
496                        compare_digit_chunks(left_digits, right_digits)
497                    }
498                    (Chunk::Text(left_text), Chunk::Text(right_text)) => left_text.cmp(right_text),
499                    (Chunk::Digits(_), Chunk::Text(_)) => Ordering::Less,
500                    (Chunk::Text(_), Chunk::Digits(_)) => Ordering::Greater,
501                };
502
503                if ordering != Ordering::Equal {
504                    return ordering;
505                }
506            }
507            (Some(_), None) => return Ordering::Greater,
508            (None, Some(_)) => return Ordering::Less,
509            (None, None) => return Ordering::Equal,
510        }
511    }
512}
513
514fn compare_digit_chunks(left: &str, right: &str) -> Ordering {
515    let left_trimmed = left.trim_start_matches('0');
516    let right_trimmed = right.trim_start_matches('0');
517    let left_normalized = if left_trimmed.is_empty() {
518        "0"
519    } else {
520        left_trimmed
521    };
522    let right_normalized = if right_trimmed.is_empty() {
523        "0"
524    } else {
525        right_trimmed
526    };
527
528    left_normalized
529        .len()
530        .cmp(&right_normalized.len())
531        .then_with(|| left_normalized.cmp(right_normalized))
532        .then_with(|| left.len().cmp(&right.len()))
533}
534
535#[derive(Clone, Copy)]
536enum Chunk<'a> {
537    Digits(&'a str),
538    Text(&'a str),
539}
540
541struct ChunkIter<'a> {
542    input: &'a str,
543    index: usize,
544}
545
546impl<'a> ChunkIter<'a> {
547    fn new(input: &'a str) -> Self {
548        Self { input, index: 0 }
549    }
550}
551
552impl<'a> Iterator for ChunkIter<'a> {
553    type Item = Chunk<'a>;
554
555    fn next(&mut self) -> Option<Self::Item> {
556        if self.index >= self.input.len() {
557            return None;
558        }
559
560        let rest = &self.input[self.index..];
561        let mut chars = rest.char_indices();
562        let (_, first) = chars.next()?;
563        let is_digit = first.is_ascii_digit();
564        let mut end = rest.len();
565
566        for (offset, ch) in chars {
567            if ch.is_ascii_digit() != is_digit {
568                end = offset;
569                break;
570            }
571        }
572
573        let chunk = &rest[..end];
574        self.index += end;
575
576        Some(if is_digit {
577            Chunk::Digits(chunk)
578        } else {
579            Chunk::Text(chunk)
580        })
581    }
582}
583
584fn format_shape(shape: &[usize]) -> String {
585    let dims = shape
586        .iter()
587        .map(|dim| dim.to_string())
588        .collect::<Vec<_>>()
589        .join(", ");
590    format!("[{dims}]")
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn classifies_local_paths() {
599        match classify_input("model.safetensors").unwrap() {
600            Input::LocalPath(path) => assert_eq!(path, Path::new("model.safetensors")),
601            Input::HttpUrl(_) => panic!("expected local path"),
602        }
603    }
604
605    #[test]
606    fn classifies_https_urls() {
607        match classify_input("https://example.com/model.safetensors").unwrap() {
608            Input::HttpUrl(url) => {
609                assert_eq!(url.scheme(), "https");
610                assert_eq!(url.host_str(), Some("example.com"));
611            }
612            Input::LocalPath(_) => panic!("expected URL"),
613        }
614    }
615
616    #[test]
617    fn rejects_unsupported_schemes() {
618        match classify_input("hf://org/repo/file.safetensors").unwrap_err() {
619            InspectError::UnsupportedUrlScheme(scheme) => assert_eq!(scheme, "hf"),
620            other => panic!("unexpected error: {other}"),
621        }
622    }
623
624    #[test]
625    fn classifies_windows_drive_paths_as_local() {
626        match classify_input(r"C:\models\sample.safetensors").unwrap() {
627            Input::LocalPath(path) => {
628                assert_eq!(path, Path::new(r"C:\models\sample.safetensors"));
629            }
630            Input::HttpUrl(_) => panic!("expected local path"),
631        }
632    }
633
634    #[test]
635    fn parses_total_size_from_content_range() {
636        let total = parse_total_size_from_content_range("bytes 0-7/17246524772").unwrap();
637        assert_eq!(total, 17_246_524_772);
638    }
639
640    #[test]
641    fn rejects_malformed_content_range() {
642        let error = parse_total_size_from_content_range("bytes 0-7/*").unwrap_err();
643        assert_eq!(error, "invalid total size in Content-Range header");
644    }
645
646    #[test]
647    fn computes_numel() {
648        assert_eq!(numel(&[2, 3, 4], "tensor").unwrap(), 24);
649    }
650
651    #[test]
652    fn formats_shapes() {
653        assert_eq!(format_shape(&[2, 3, 4]), "[2, 3, 4]");
654        assert_eq!(format_shape(&[]), "[]");
655    }
656
657    #[test]
658    fn sorts_tensor_names_naturally() {
659        let mut names = vec![
660            "encoder.layer.10.output.weight",
661            "encoder.layer.2.output.weight",
662            "encoder.layer.1.output.weight",
663            "encoder.layer.2.output.bias",
664        ];
665
666        names.sort_by(|left, right| compare_tensor_names(left, right));
667
668        assert_eq!(
669            names,
670            vec![
671                "encoder.layer.1.output.weight",
672                "encoder.layer.2.output.bias",
673                "encoder.layer.2.output.weight",
674                "encoder.layer.10.output.weight",
675            ]
676        );
677    }
678}