segment_anything_rs/
lib.rs

1//! # Segment Anything RS
2//! A rust wrapper for [Segment Anything](https://segment-anything.com/)
3//!
4//! ## Usage
5//!
6//! ```rust, no_run
7//! use segment_anything_rs::*;
8//!
9//! let model = SegmentAnything::builder().build().unwrap();
10//! let image = image::open("examples/landscape.jpg").unwrap();
11//! let images = model.segment_everything(image).unwrap();
12//! for (i, img) in images.iter().enumerate() {
13//!     img.save(&format!("{}.png", i)).unwrap();
14//! }
15//! ```
16
17#![warn(missing_docs)]
18#[cfg(feature = "mkl")]
19extern crate intel_mkl_src;
20
21#[cfg(feature = "accelerate")]
22extern crate accelerate_src;
23
24use candle_core::DType;
25use candle_core::{Device, Tensor};
26use candle_nn::VarBuilder;
27use candle_transformers::models::segment_anything::sam::{self, Sam};
28use image::{DynamicImage, GenericImage, GenericImageView, ImageBuffer, Rgba};
29
30/// A builder for [`SegmentAnything`].
31#[derive(Default)]
32pub struct SegmentAnythingBuilder {
33    source: SegmentAnythingSource,
34}
35
36impl SegmentAnythingBuilder {
37    /// Sets the source of the model.
38    pub fn source(mut self, source: SegmentAnythingSource) -> Self {
39        self.source = source;
40        self
41    }
42
43    /// Builds the [`SegmentAnything`] model.
44    pub fn build(self) -> Result<SegmentAnything, LoadSegmentAnythingError> {
45        SegmentAnything::new(self)
46    }
47}
48
49/// The source of the model.
50pub struct SegmentAnythingSource {
51    model: String,
52    filename: String,
53    tiny: bool,
54}
55
56impl SegmentAnythingSource {
57    /// Creates a new [`SegmentAnythingSource`].
58    pub fn new(model: impl Into<String>, filename: impl Into<String>) -> Self {
59        Self {
60            model: model.into(),
61            filename: filename.into(),
62            tiny: false,
63        }
64    }
65
66    /// Create the tiny SAM model source.
67    pub fn tiny() -> Self {
68        let mut self_ = Self::new("lmz/candle-sam", "mobile_sam-tiny-vitt.safetensors");
69        self_.tiny = true;
70        self_
71    }
72
73    /// Create a normal sized model source.
74    pub fn medium() -> Self {
75        Self::new("lmz/candle-sam", "sam_vit_b_01ec64.safetensors")
76    }
77}
78
79impl Default for SegmentAnythingSource {
80    fn default() -> Self {
81        Self::tiny()
82    }
83}
84
85/// Settings for running inference on [`SegmentAnything`].
86pub struct SegmentAnythingInferenceSettings {
87    threshold: f32,
88
89    /// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image).
90    goal_points: Vec<(f64, f64)>,
91
92    /// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image).
93    avoid_points: Vec<(f64, f64)>,
94
95    image: ImageBuffer<image::Rgba<u8>, Vec<u8>>,
96}
97
98impl SegmentAnythingInferenceSettings {
99    /// Creates a new [`SegmentAnythingInferenceSettings`] from an image.
100    pub fn new<I: GenericImageView<Pixel = Rgba<u8>>>(input: I) -> Self {
101        let mut image = ImageBuffer::new(input.width(), input.height());
102        image.copy_from(&input, 0, 0).unwrap();
103        Self {
104            threshold: 0.,
105            goal_points: Vec::new(),
106            avoid_points: Vec::new(),
107            image,
108        }
109    }
110
111    /// Sets the detection threshold for the mask, 0 is the default value.
112    /// - A negative values makes the model return a larger mask.
113    /// - A positive makes the model return a smaller mask.
114    pub fn set_threshold(mut self, threshold: f32) -> Self {
115        self.threshold = threshold;
116        self
117    }
118
119    /// Add a point to the list of points to segment.
120    pub fn add_goal_point(mut self, x: impl Into<f64>, y: impl Into<f64>) -> Self {
121        self.goal_points.push((x.into(), y.into()));
122        self
123    }
124
125    /// Set the list of points to segment.
126    pub fn set_goal_points(mut self, points: Vec<(f64, f64)>) -> Self {
127        self.goal_points = points;
128        self
129    }
130
131    /// Add a point to the list of points to avoid.
132    pub fn add_avoid_points(mut self, x: impl Into<f64>, y: impl Into<f64>) -> Self {
133        self.avoid_points.push((x.into(), y.into()));
134        self
135    }
136
137    /// Set the list of points to avoid.
138    pub fn set_avoid_points(mut self, points: Vec<(f64, f64)>) -> Self {
139        self.avoid_points = points;
140        self
141    }
142
143    /// Set the image to segment.
144    pub fn set_image<I: GenericImageView<Pixel = Rgba<u8>>>(
145        mut self,
146        image: I,
147    ) -> Result<Self, image::ImageError> {
148        self.image = ImageBuffer::new(image.width(), image.height());
149        self.image.copy_from(&image, 0, 0)?;
150        Ok(self)
151    }
152}
153
154/// An error that can occur when loading a [`SegmentAnything`] model.
155#[derive(Debug, thiserror::Error)]
156pub enum LoadSegmentAnythingError {
157    /// An error that can occur when trying to load a [`SegmentAnything`] model into a device.
158    #[error("Failed to load model into device: {0}")]
159    LoadModel(#[from] candle_core::Error),
160    /// An error that can occur when downloading a [`SegmentAnything`] model from Hugging Face.
161    #[error("Failed to download model from Hugging Face: {0}")]
162    DownloadModel(#[from] hf_hub::api::sync::ApiError),
163}
164
165/// An error that can occur when running a [`SegmentAnything`] model.
166#[derive(Debug, thiserror::Error)]
167pub enum SegmentAnythingInferenceError {
168    /// An error that can occur when trying to run a [`SegmentAnything`] model.
169    #[error("Failed to run model: {0}")]
170    RunModel(#[from] candle_core::Error),
171    /// An error that can occur when converting the result of a [`SegmentAnything`] model to an image.
172    #[error("Failed to merge masks")]
173    MergeMasks,
174}
175
176/// The [segment anything](https://segment-anything.com/) model.
177pub struct SegmentAnything {
178    device: Device,
179    sam: Sam,
180}
181
182impl SegmentAnything {
183    /// Creates a new [`SegmentAnythingBuilder`].
184    pub fn builder() -> SegmentAnythingBuilder {
185        SegmentAnythingBuilder::default()
186    }
187
188    fn new(settings: SegmentAnythingBuilder) -> Result<Self, LoadSegmentAnythingError> {
189        let SegmentAnythingBuilder { source } = settings;
190        let model = {
191            let api = hf_hub::api::sync::Api::new()?;
192            let api = api.model(source.model);
193            api.get(&source.filename)?
194        };
195        // Currently, candle doesn't support some operations that are required for segment anything
196        // let device = kalosm_common::accelerated_device_if_available()?;
197        let device = Device::Cpu;
198        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
199        let sam = if source.tiny {
200            sam::Sam::new_tiny(vb)? // tiny vit_t
201        } else {
202            sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
203        };
204        Ok(Self { device, sam })
205    }
206
207    /// Segment an image from a list of points. Returns a [`DynamicImage`] mask.
208    ///
209    /// # Example
210    /// ```rust, no_run
211    /// use segment_anything_rs::*;
212    ///
213    /// let model = SegmentAnything::builder().build().unwrap();
214    /// let image = image::open("examples/landscape.jpg").unwrap();
215    /// let x = image.width() / 2;
216    /// let y = image.height() / 4;
217    /// let images = model
218    ///     .segment_from_points(SegmentAnythingInferenceSettings::new(image).add_goal_point(x, y))
219    ///     .unwrap();
220    ///
221    /// images.save("out.png").unwrap();
222    /// ```
223    pub fn segment_from_points(
224        &self,
225        settings: SegmentAnythingInferenceSettings,
226    ) -> Result<DynamicImage, SegmentAnythingInferenceError> {
227        let SegmentAnythingInferenceSettings {
228            threshold,
229            goal_points,
230            avoid_points,
231            image,
232        } = settings;
233
234        let image = image::DynamicImage::ImageRgba8(image);
235        let image_width = image.width();
236        let image_height = image.height();
237
238        let image_tensor = self.image_to_tensor(image)?;
239
240        let points = {
241            let mut points = Vec::new();
242            for (x, y) in goal_points {
243                points.push((x, y, true));
244            }
245            for (x, y) in avoid_points {
246                points.push((x, y, false));
247            }
248            points
249        };
250
251        let (mask, _iou_predictions) = self.sam.forward(&image_tensor, &points, false)?;
252
253        let mask = (mask.ge(threshold)? * 255.)?;
254        let (_one, h, w) = mask.dims3()?;
255        let mask = mask.expand((3, h, w))?;
256
257        let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
258        let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
259            image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels)
260                .ok_or(SegmentAnythingInferenceError::MergeMasks)?;
261
262        Ok(image::DynamicImage::from(mask_img).resize_to_fill(
263            image_width,
264            image_height,
265            image::imageops::FilterType::CatmullRom,
266        ))
267    }
268
269    fn image_to_tensor(&self, image: DynamicImage) -> candle_core::Result<Tensor> {
270        let image = {
271            let resize_longest = sam::IMAGE_SIZE;
272            let (height, width) = (image.height(), image.width());
273            let resize_longest = resize_longest as u32;
274            let (height, width) = if height < width {
275                let h = (resize_longest * height) / width;
276                (h, resize_longest)
277            } else {
278                let w = (resize_longest * width) / height;
279                (resize_longest, w)
280            };
281            image.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
282        };
283        let (height, width) = (image.height() as usize, image.width() as usize);
284        let img = image.to_rgb8();
285        let data = img.into_raw();
286        let image = Tensor::from_vec(data, (height, width, 3), &self.device)?.permute((2, 0, 1))?;
287
288        let image = image.to_device(&self.device)?;
289
290        Ok(image)
291    }
292
293    /// Segment everything in an image. Returns a list of [`DynamicImage`] masks.
294    ///
295    /// # Example
296    ///
297    /// ```rust, no_run
298    /// use segment_anything_rs::*;
299    ///
300    /// let model = SegmentAnything::builder().build().unwrap();
301    /// let image = image::open("examples/landscape.jpg").unwrap();
302    /// let images = model.segment_everything(image).unwrap();
303    /// for (i, img) in images.iter().enumerate() {
304    ///     img.save(&format!("{}.png", i)).unwrap();
305    /// }
306    /// ```
307    pub fn segment_everything(
308        &self,
309        image: DynamicImage,
310    ) -> Result<Vec<DynamicImage>, SegmentAnythingInferenceError> {
311        let image = self.image_to_tensor(image)?;
312
313        let bboxes = self.sam.generate_masks(&image, 32, 0, 512. / 1500., 1)?;
314        let mut masks = Vec::new();
315        for bbox in bboxes {
316            let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
317            let (h, w) = mask.dims2()?;
318            let mask = mask.broadcast_as((3, h, w))?;
319            let (channel, height, width) = mask.dims3()?;
320            if channel != 3 {
321                return Err(candle_core::Error::Msg(
322                    "save_image expects an input of shape (3, height, width)".to_string(),
323                )
324                .into());
325            }
326            let mask = mask.permute((1, 2, 0))?.flatten_all()?;
327            let pixels = mask.to_vec1::<u8>()?;
328            let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
329                image::ImageBuffer::from_raw(width as u32, height as u32, pixels)
330                    .ok_or(SegmentAnythingInferenceError::MergeMasks)?;
331            let image = image::DynamicImage::from(image);
332            let image =
333                image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
334            masks.push(image);
335        }
336
337        Ok(masks)
338    }
339}