1use std::path::{Path, PathBuf};
9
10use image::DynamicImage;
11use ndarray::Array3;
12
13use crate::error::{InferenceError, Result};
14
15#[derive(Debug, Clone)]
17pub enum Source {
18 Image(PathBuf),
20 ImageBuffer(DynamicImage),
22 Array(Array3<u8>),
24 ImageUrl(String),
26 ImageList(Vec<PathBuf>),
28 Video(PathBuf),
30 Webcam(u32),
32 Stream(String),
34 Directory(PathBuf),
36 Glob(String),
38}
39
40impl Source {
41 #[must_use]
47 pub const fn is_image(&self) -> bool {
48 matches!(
49 self,
50 Self::Image(_) | Self::ImageBuffer(_) | Self::Array(_) | Self::ImageUrl(_)
51 )
52 }
53
54 #[must_use]
60 pub const fn is_video(&self) -> bool {
61 matches!(self, Self::Video(_) | Self::Webcam(_) | Self::Stream(_))
62 }
63
64 #[must_use]
70 pub const fn is_batch(&self) -> bool {
71 matches!(
72 self,
73 Self::Directory(_) | Self::Glob(_) | Self::ImageList(_)
74 )
75 }
76
77 #[must_use]
83 pub fn path(&self) -> Option<&Path> {
84 match self {
85 Self::Image(p) | Self::Video(p) | Self::Directory(p) => Some(p),
86 _ => None,
87 }
88 }
89
90 fn is_image_url(url: &str) -> bool {
92 let url_lower = url.to_lowercase();
93 let path_part = url_lower.split('?').next().unwrap_or(&url_lower);
95
96 std::path::Path::new(path_part)
97 .extension()
98 .is_some_and(|ext| {
99 let s = ext.to_string_lossy();
100 s.eq_ignore_ascii_case("jpg")
101 || s.eq_ignore_ascii_case("jpeg")
102 || s.eq_ignore_ascii_case("png")
103 || s.eq_ignore_ascii_case("bmp")
104 || s.eq_ignore_ascii_case("gif")
105 || s.eq_ignore_ascii_case("webp")
106 || s.eq_ignore_ascii_case("tiff")
107 || s.eq_ignore_ascii_case("tif")
108 })
109 }
110}
111
112impl From<&str> for Source {
114 fn from(s: &str) -> Self {
115 if let Ok(idx) = s.parse::<u32>() {
117 return Self::Webcam(idx);
118 }
119
120 if s.starts_with("http://") || s.starts_with("https://") {
122 if Self::is_image_url(s) {
124 return Self::ImageUrl(s.to_string());
125 }
126 return Self::Stream(s.to_string());
128 }
129
130 if s.starts_with("rtsp://") || s.starts_with("rtmp://") {
132 return Self::Stream(s.to_string());
133 }
134
135 if s.contains('*') {
137 return Self::Glob(s.to_string());
138 }
139
140 let path = PathBuf::from(s)
141 .canonicalize()
142 .unwrap_or_else(|_| PathBuf::from(s));
143
144 if path.is_dir() {
146 return Self::Directory(path);
147 }
148
149 if let Some(ext) = path.extension() {
151 let ext = ext.to_string_lossy().to_lowercase();
152 if matches!(
153 ext.as_str(),
154 "mp4" | "avi" | "mov" | "mkv" | "wmv" | "flv" | "webm" | "m4v" | "mpeg" | "mpg"
155 ) {
156 return Self::Video(path);
157 }
158 }
159
160 Self::Image(path)
162 }
163}
164
165impl From<String> for Source {
166 fn from(s: String) -> Self {
167 Self::from(s.as_str())
168 }
169}
170
171impl From<PathBuf> for Source {
172 fn from(path: PathBuf) -> Self {
173 Self::from(path.to_string_lossy().as_ref())
174 }
175}
176
177impl From<&Path> for Source {
178 fn from(path: &Path) -> Self {
179 Self::from(path.to_string_lossy().as_ref())
180 }
181}
182
183impl From<DynamicImage> for Source {
184 fn from(img: DynamicImage) -> Self {
185 Self::ImageBuffer(img)
186 }
187}
188
189impl From<Array3<u8>> for Source {
190 fn from(arr: Array3<u8>) -> Self {
191 Self::Array(arr)
192 }
193}
194
195impl From<u32> for Source {
196 fn from(idx: u32) -> Self {
197 Self::Webcam(idx)
198 }
199}
200
201impl From<i32> for Source {
202 fn from(idx: i32) -> Self {
203 #[allow(clippy::cast_sign_loss)]
204 Self::Webcam(idx as u32)
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct SourceMeta {
211 pub frame_idx: usize,
213 pub total_frames: Option<usize>,
215 pub path: String,
217 pub fps: Option<f32>,
219}
220
221impl Default for SourceMeta {
222 fn default() -> Self {
223 Self {
224 frame_idx: 0,
225 total_frames: Some(1),
226 path: String::new(),
227 fps: None,
228 }
229 }
230}
231
232#[cfg(feature = "video")]
233use video_rs::ffmpeg;
234
235#[cfg(feature = "video")]
242struct BilinearVideoDecoder {
243 input_ctx: ffmpeg::format::context::Input,
244 decoder: ffmpeg::decoder::Video,
245 scaler: Option<ffmpeg::software::scaling::context::Context>,
246 stream_index: usize,
247 total_frames: Option<usize>,
249 fps: f32,
251}
252
253#[cfg(feature = "video")]
254impl BilinearVideoDecoder {
255 fn new(path: &Path) -> Result<Self> {
256 ffmpeg::init().map_err(|e| InferenceError::VideoError(format!("FFmpeg init: {e}")))?;
257
258 let input_ctx = ffmpeg::format::input(path).map_err(|e| {
259 InferenceError::VideoError(format!("Cannot open {}: {e}", path.display()))
260 })?;
261
262 let stream = input_ctx
263 .streams()
264 .best(ffmpeg::media::Type::Video)
265 .ok_or_else(|| InferenceError::VideoError("No video stream found".into()))?;
266
267 let stream_index = stream.index();
268
269 #[allow(clippy::cast_possible_truncation)]
271 let fps = f64::from(stream.avg_frame_rate()) as f32;
272 #[allow(clippy::cast_precision_loss)]
273 let duration_secs = input_ctx.duration() as f64 / f64::from(ffmpeg::ffi::AV_TIME_BASE);
274 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
275 let total_frames = if duration_secs > 0.0 && fps > 0.0 {
276 Some((duration_secs * f64::from(fps)) as usize)
277 } else {
278 None
279 };
280
281 let context_decoder = ffmpeg::codec::context::Context::from_parameters(stream.parameters())
282 .map_err(|e| InferenceError::VideoError(format!("Codec context: {e}")))?;
283 let decoder = context_decoder
284 .decoder()
285 .video()
286 .map_err(|e| InferenceError::VideoError(format!("Video decoder: {e}")))?;
287
288 Ok(Self {
289 input_ctx,
290 decoder,
291 scaler: None,
292 stream_index,
293 total_frames,
294 fps,
295 })
296 }
297
298 fn decode_next(&mut self) -> Option<Result<DynamicImage>> {
300 let mut decoded = ffmpeg::util::frame::video::Video::empty();
301
302 loop {
303 if self.decoder.receive_frame(&mut decoded).is_ok() {
305 return Some(self.frame_to_image(&decoded));
306 }
307
308 let mut found_packet = false;
310 for (stream, packet) in self.input_ctx.packets() {
311 if stream.index() == self.stream_index {
312 if self.decoder.send_packet(&packet).is_err() {
313 continue;
314 }
315 found_packet = true;
316 break;
317 }
318 }
319
320 if !found_packet {
321 let _ = self.decoder.send_eof();
323 return if self.decoder.receive_frame(&mut decoded).is_ok() {
324 Some(self.frame_to_image(&decoded))
325 } else {
326 None
327 };
328 }
329
330 if self.decoder.receive_frame(&mut decoded).is_ok() {
332 return Some(self.frame_to_image(&decoded));
333 }
334 }
335 }
336
337 fn frame_to_image(
339 &mut self,
340 decoded: &ffmpeg::util::frame::video::Video,
341 ) -> Result<DynamicImage> {
342 if self.scaler.is_none() {
344 self.scaler = Some(
345 ffmpeg::software::scaling::context::Context::get(
346 decoded.format(),
347 decoded.width(),
348 decoded.height(),
349 ffmpeg::format::Pixel::RGB24,
350 decoded.width(),
351 decoded.height(),
352 ffmpeg::software::scaling::flag::Flags::BILINEAR,
353 )
354 .map_err(|e| InferenceError::VideoError(format!("Scaler init: {e}")))?,
355 );
356 }
357
358 let mut rgb_frame = ffmpeg::util::frame::video::Video::empty();
359 self.scaler
360 .as_mut()
361 .unwrap()
362 .run(decoded, &mut rgb_frame)
363 .map_err(|e| InferenceError::VideoError(format!("Scale: {e}")))?;
364
365 let width = rgb_frame.width();
366 let height = rgb_frame.height();
367 let data = rgb_frame.data(0);
368 let stride = rgb_frame.stride(0);
369
370 let mut rgb_data = Vec::with_capacity((width * height * 3) as usize);
372 for y in 0..height as usize {
373 let row = &data[y * stride..y * stride + (width as usize) * 3];
374 rgb_data.extend_from_slice(row);
375 }
376
377 let img_buffer = image::RgbImage::from_raw(width, height, rgb_data).ok_or_else(|| {
378 InferenceError::ImageError("Failed to create image from video frame".into())
379 })?;
380 Ok(DynamicImage::ImageRgb8(img_buffer))
381 }
382}
383
384pub struct SourceIterator {
386 source: Source,
387 current_frame: usize,
388 image_paths: Vec<PathBuf>,
389 #[cfg(feature = "video")]
390 decoder: Option<BilinearVideoDecoder>,
391 #[cfg(feature = "video")]
392 webcam_decoder: Option<(ffmpeg::format::context::Input, ffmpeg::decoder::Video)>,
393 #[cfg(feature = "video")]
394 webcam_stream_index: usize,
395 #[cfg(feature = "video")]
396 total_frames: Option<usize>,
397 #[cfg(feature = "video")]
398 webcam_init_failed: bool,
399 #[cfg(feature = "video")]
400 video_init_failed: bool,
401}
402
403impl SourceIterator {
404 pub fn new(source: Source) -> Result<Self> {
418 let image_paths = match &source {
419 Source::Directory(path) => Self::collect_images_from_dir(path)?,
420 Source::Glob(pattern) => Self::collect_images_from_glob(pattern)?,
421 Source::Image(path) => vec![path.clone()],
422 Source::ImageList(paths) => paths.clone(),
424 _ => vec![],
425 };
426
427 Ok(Self {
428 source,
429 current_frame: 0,
430 image_paths,
431 #[cfg(feature = "video")]
432 decoder: None,
433 #[cfg(feature = "video")]
434 webcam_decoder: None,
435 #[cfg(feature = "video")]
436 webcam_stream_index: 0,
437 #[cfg(feature = "video")]
438 total_frames: None,
439 #[cfg(feature = "video")]
440 webcam_init_failed: false,
441 #[cfg(feature = "video")]
442 video_init_failed: false,
443 })
444 }
445
446 fn collect_images_from_dir(dir: &Path) -> Result<Vec<PathBuf>> {
448 if !dir.is_dir() {
449 return Err(InferenceError::ImageError(format!(
450 "Not a directory: {}",
451 dir.display()
452 )));
453 }
454
455 let mut paths: Vec<PathBuf> = std::fs::read_dir(dir)?
456 .filter_map(std::result::Result::ok)
457 .map(|entry| entry.path())
458 .filter(|path| Self::is_image_file(path))
459 .collect();
460
461 paths.sort();
462 Ok(paths)
463 }
464
465 fn collect_images_from_glob(pattern: &str) -> Result<Vec<PathBuf>> {
470 if let Some(star_pos) = pattern.find('*') {
473 let dir_part = &pattern[..star_pos];
474 let dir = if dir_part.is_empty() {
475 Path::new(".")
476 } else {
477 Path::new(dir_part.trim_end_matches('/').trim_end_matches('\\'))
478 };
479
480 let ext_filter: Option<String> = pattern[star_pos..]
482 .strip_prefix("*.")
483 .map(str::to_lowercase);
484
485 if !dir.is_dir() {
486 return Err(InferenceError::ImageError(format!(
487 "Directory not found: {}",
488 dir.display()
489 )));
490 }
491
492 let mut paths: Vec<PathBuf> = std::fs::read_dir(dir)?
493 .filter_map(std::result::Result::ok)
494 .map(|entry| entry.path())
495 .filter(|path| {
496 ext_filter.as_ref().map_or_else(
497 || Self::is_image_file(path),
498 |ext| {
499 path.extension()
500 .is_some_and(|e| e.to_string_lossy().to_lowercase() == *ext)
501 },
502 )
503 })
504 .collect();
505
506 paths.sort();
507 Ok(paths)
508 } else {
509 Ok(vec![PathBuf::from(pattern)])
511 }
512 }
513
514 fn is_image_file(path: &Path) -> bool {
516 path.extension().is_some_and(|ext| {
517 let ext = ext.to_string_lossy().to_lowercase();
518 matches!(
519 ext.as_str(),
520 "jpg" | "jpeg" | "png" | "bmp" | "gif" | "webp" | "tiff" | "tif"
521 )
522 })
523 }
524
525 fn download_image(url: &str) -> Result<DynamicImage> {
527 let mut response = ureq::get(url)
528 .call()
529 .map_err(|e| InferenceError::ImageError(format!("Failed to download {url}: {e}")))?
530 .into_body();
531
532 let bytes = response.read_to_vec().map_err(|e| {
533 InferenceError::ImageError(format!("Failed to read response from {url}: {e}"))
534 })?;
535
536 image::load_from_memory(&bytes).map_err(|e| {
537 InferenceError::ImageError(format!("Failed to decode image from {url}: {e}"))
538 })
539 }
540
541 fn next_image_url(&mut self, url: &str) -> Option<Result<(DynamicImage, SourceMeta)>> {
543 if self.current_frame > 0 {
544 return None;
545 }
546
547 self.current_frame = 1;
548 let meta = SourceMeta {
549 frame_idx: 0,
550 total_frames: Some(1),
551 path: url.to_string(),
552 fps: None,
553 };
554
555 match Self::download_image(url) {
556 Ok(img) => Some(Ok((img, meta))),
557 Err(e) => Some(Err(e)),
558 }
559 }
560
561 fn next_image(&mut self) -> Option<Result<(DynamicImage, SourceMeta)>> {
563 if self.current_frame >= self.image_paths.len() {
564 return None;
565 }
566
567 let path = &self.image_paths[self.current_frame];
568 let meta = SourceMeta {
569 frame_idx: self.current_frame,
570 total_frames: Some(self.image_paths.len()),
571 path: path.to_string_lossy().to_string(),
572 fps: None,
573 };
574
575 self.current_frame += 1;
576
577 match image::open(path) {
578 Ok(img) => Some(Ok((img, meta))),
579 Err(e) => Some(Err(InferenceError::ImageError(format!(
580 "Failed to load {}: {e}",
581 path.display()
582 )))),
583 }
584 }
585
586 #[cfg(feature = "video")]
588 #[allow(unsafe_code, clippy::too_many_lines)]
589 fn next_video_frame(&mut self) -> Option<Result<(DynamicImage, SourceMeta)>> {
590 if let Source::Webcam(idx) = &self.source {
592 if self.webcam_init_failed {
593 return None;
594 }
595
596 if self.webcam_decoder.is_none() {
597 ffmpeg::init().ok();
599
600 let input_format_name = if cfg!(target_os = "macos") {
602 "avfoundation"
603 } else if cfg!(target_os = "linux") {
604 "video4linux2"
605 } else if cfg!(target_os = "windows") {
606 "dshow"
607 } else {
608 self.webcam_init_failed = true;
609 return Some(Err(InferenceError::VideoError(
610 "Unsupported OS for webcam".to_string(),
611 )));
612 };
613
614 let c_name = std::ffi::CString::new(input_format_name).unwrap();
616 #[allow(unsafe_code)]
617 let ptr = unsafe { video_rs::ffmpeg::ffi::av_find_input_format(c_name.as_ptr()) };
618
619 let input_format = if ptr.is_null() {
620 self.webcam_init_failed = true;
621 return Some(Err(InferenceError::VideoError(format!(
622 "Input format '{input_format_name}' not found"
623 ))));
624 } else {
625 #[allow(unsafe_code, clippy::ptr_cast_constness)]
626 unsafe {
627 ffmpeg::format::Input::wrap(ptr.cast_mut())
628 }
629 };
630
631 let device_name = if cfg!(target_os = "macos") {
633 idx.to_string() } else if cfg!(target_os = "linux") {
635 format!("/dev/video{idx}")
636 } else if cfg!(target_os = "windows") {
637 format!("video={idx}")
638 } else {
639 self.webcam_init_failed = true;
640 return Some(Err(InferenceError::VideoError(
641 "Unsupported OS for webcam device name".to_string(),
642 )));
643 };
644
645 let mut options = ffmpeg::Dictionary::new();
647 options.set("framerate", "30");
648
649 match ffmpeg::format::open_with(
650 &PathBuf::from(&device_name),
651 &ffmpeg::Format::Input(input_format),
652 options,
653 ) {
654 #[allow(clippy::single_match_else)]
655 Ok(ctx) => match ctx {
656 ffmpeg::format::context::Context::Input(ictx) => {
657 let input =
658 ictx.streams()
659 .best(ffmpeg::media::Type::Video)
660 .ok_or_else(|| {
661 InferenceError::VideoError(
662 "No video stream found in webcam".to_string(),
663 )
664 });
665
666 match input {
667 Ok(stream) => {
668 let stream_index = stream.index();
669 self.webcam_stream_index = stream_index;
670 let context_decoder =
671 ffmpeg::codec::context::Context::from_parameters(
672 stream.parameters(),
673 )
674 .unwrap();
675 match context_decoder.decoder().video() {
676 Ok(decoder) => {
677 self.webcam_decoder = Some((ictx, decoder));
678 }
679 Err(e) => {
680 self.webcam_init_failed = true;
681 return Some(Err(InferenceError::VideoError(format!(
682 "Failed to create webcam decoder: {e}"
683 ))));
684 }
685 }
686 }
687 Err(e) => {
688 self.webcam_init_failed = true;
689 return Some(Err(e));
690 }
691 }
692 }
693 ffmpeg::format::context::Context::Output(_) => {
694 self.webcam_init_failed = true;
695 return Some(Err(InferenceError::VideoError(
696 "Opened context is not an input context".to_string(),
697 )));
698 }
699 },
700 Err(e) => {
701 self.webcam_init_failed = true;
702 return Some(Err(InferenceError::VideoError(format!(
703 "Failed to open webcam: {e}"
704 ))));
705 }
706 }
707 }
708
709 if let Some((ictx, decoder)) = &mut self.webcam_decoder {
710 let mut decoded = ffmpeg::util::frame::video::Video::empty();
711
712 for (stream, packet) in ictx.packets() {
714 if stream.index() == self.webcam_stream_index
715 && decoder.send_packet(&packet).is_ok()
716 && decoder.receive_frame(&mut decoded).is_ok()
717 {
718 let mut rgb_frame = ffmpeg::util::frame::video::Video::empty();
724 let mut scaler = ffmpeg::software::scaling::context::Context::get(
725 decoded.format(),
726 decoded.width(),
727 decoded.height(),
728 ffmpeg::format::Pixel::RGB24,
729 decoded.width(),
730 decoded.height(),
731 ffmpeg::software::scaling::flag::Flags::BILINEAR,
732 )
733 .unwrap();
734
735 scaler.run(&decoded, &mut rgb_frame).ok();
736
737 let width = rgb_frame.width();
738 let height = rgb_frame.height();
739 let data = rgb_frame.data(0);
740 let stride = rgb_frame.stride(0);
741
742 let mut rgb_data = Vec::with_capacity((width * height * 3) as usize);
744 for y in 0..height as usize {
745 let row = &data[y * stride..y * stride + (width as usize) * 3];
746 rgb_data.extend_from_slice(row);
747 }
748
749 let img_buffer =
750 image::RgbImage::from_raw(width, height, rgb_data).unwrap();
751 let img = DynamicImage::ImageRgb8(img_buffer);
752
753 let meta = SourceMeta {
754 frame_idx: self.current_frame,
755 total_frames: None,
756 path: format!("Webcam {idx}"),
757 fps: None,
758 };
759 self.current_frame += 1;
760 return Some(Ok((img, meta)));
761 }
762 }
763 return None; }
765 return None;
766 }
767
768 if self.decoder.is_none() {
770 if self.video_init_failed {
771 return None;
772 }
773
774 let path_str = match &self.source {
775 Source::Video(p) => Some(p.to_string_lossy().to_string()),
776 Source::Stream(s) => Some(s.clone()),
777 _ => None,
778 };
779
780 if let Some(path_str) = path_str {
781 match BilinearVideoDecoder::new(Path::new(&path_str)) {
782 Ok(d) => {
783 self.total_frames = d.total_frames;
784 self.decoder = Some(d);
785 }
786 Err(e) => {
787 self.video_init_failed = true;
788 return Some(Err(InferenceError::VideoError(format!(
789 "Failed to create decoder: {e}"
790 ))));
791 }
792 }
793 }
794 }
795
796 if let Some(decoder) = &mut self.decoder {
797 match decoder.decode_next() {
798 Some(Ok(img)) => {
799 let meta = SourceMeta {
800 frame_idx: self.current_frame,
801 total_frames: self.total_frames,
802 path: self
803 .source
804 .path()
805 .map(|p| p.to_string_lossy().to_string())
806 .unwrap_or_default(),
807 fps: Some(decoder.fps),
808 };
809 self.current_frame += 1;
810 Some(Ok((img, meta)))
811 }
812 Some(Err(e)) => Some(Err(e)),
813 None => None,
814 }
815 } else {
816 None
817 }
818 }
819
820 #[cfg(not(feature = "video"))]
821 #[allow(
822 clippy::unused_self,
823 clippy::unnecessary_wraps,
824 clippy::needless_pass_by_ref_mut
825 )]
826 fn next_video_frame(&mut self) -> Option<Result<(DynamicImage, SourceMeta)>> {
827 Some(Err(InferenceError::FeatureNotEnabled(
828 "Video support requires '--features video'".to_string(),
829 )))
830 }
831}
832
833impl Iterator for SourceIterator {
834 type Item = Result<(DynamicImage, SourceMeta)>;
835
836 fn next(&mut self) -> Option<Self::Item> {
837 match &self.source {
838 Source::Image(_) | Source::Directory(_) | Source::Glob(_) | Source::ImageList(_) => {
839 self.next_image()
840 }
841 Source::ImageUrl(url) => {
842 let url = url.clone();
843 self.next_image_url(&url)
844 }
845 Source::ImageBuffer(img) => {
846 if self.current_frame == 0 {
847 self.current_frame = 1;
848 let meta = SourceMeta::default();
849 Some(Ok((img.clone(), meta)))
850 } else {
851 None
852 }
853 }
854 Source::Array(arr) => {
855 if self.current_frame == 0 {
856 self.current_frame = 1;
857 let meta = SourceMeta::default();
858 match crate::utils::array_to_image(arr) {
860 Ok(img) => Some(Ok((img, meta))),
861 Err(e) => Some(Err(e)),
862 }
863 } else {
864 None
865 }
866 }
867 Source::Video(_) | Source::Webcam(_) | Source::Stream(_) => self.next_video_frame(),
868 }
869 }
870}
871
872#[cfg(test)]
873mod tests {
874 use super::*;
875
876 #[test]
877 fn test_source_from_string() {
878 assert!(matches!(Source::from("image.jpg"), Source::Image(_)));
879 assert!(matches!(Source::from("video.mp4"), Source::Video(_)));
880 assert!(matches!(
881 Source::from("rtsp://example.com"),
882 Source::Stream(_)
883 ));
884 assert!(matches!(Source::from("0"), Source::Webcam(0)));
885 assert!(matches!(Source::from("*.jpg"), Source::Glob(_)));
886 }
887
888 #[test]
889 fn test_source_checks() {
890 let img = Source::Image(PathBuf::from("test.jpg"));
891 assert!(img.is_image());
892 assert!(!img.is_video());
893
894 let vid = Source::Video(PathBuf::from("test.mp4"));
895 assert!(!vid.is_image());
896 assert!(vid.is_video());
897
898 let dir = Source::Directory(PathBuf::from("./images"));
899 assert!(dir.is_batch());
900 }
901}