1#![warn(missing_docs)]
18#[cfg(feature = "mkl")]
19extern crate intel_mkl_src;
20
21#[cfg(feature = "accelerate")]
22extern crate accelerate_src;
23
24use candle_core::DType;
25use candle_core::{Device, Tensor};
26use candle_nn::VarBuilder;
27use candle_transformers::models::segment_anything::sam::{self, Sam};
28use image::{DynamicImage, GenericImage, GenericImageView, ImageBuffer, Rgba};
29
30#[derive(Default)]
32pub struct SegmentAnythingBuilder {
33 source: SegmentAnythingSource,
34}
35
36impl SegmentAnythingBuilder {
37 pub fn source(mut self, source: SegmentAnythingSource) -> Self {
39 self.source = source;
40 self
41 }
42
43 pub fn build(self) -> Result<SegmentAnything, LoadSegmentAnythingError> {
45 SegmentAnything::new(self)
46 }
47}
48
49pub struct SegmentAnythingSource {
51 model: String,
52 filename: String,
53 tiny: bool,
54}
55
56impl SegmentAnythingSource {
57 pub fn new(model: impl Into<String>, filename: impl Into<String>) -> Self {
59 Self {
60 model: model.into(),
61 filename: filename.into(),
62 tiny: false,
63 }
64 }
65
66 pub fn tiny() -> Self {
68 let mut self_ = Self::new("lmz/candle-sam", "mobile_sam-tiny-vitt.safetensors");
69 self_.tiny = true;
70 self_
71 }
72
73 pub fn medium() -> Self {
75 Self::new("lmz/candle-sam", "sam_vit_b_01ec64.safetensors")
76 }
77}
78
79impl Default for SegmentAnythingSource {
80 fn default() -> Self {
81 Self::tiny()
82 }
83}
84
85pub struct SegmentAnythingInferenceSettings {
87 threshold: f32,
88
89 goal_points: Vec<(f64, f64)>,
91
92 avoid_points: Vec<(f64, f64)>,
94
95 image: ImageBuffer<image::Rgba<u8>, Vec<u8>>,
96}
97
98impl SegmentAnythingInferenceSettings {
99 pub fn new<I: GenericImageView<Pixel = Rgba<u8>>>(input: I) -> Self {
101 let mut image = ImageBuffer::new(input.width(), input.height());
102 image.copy_from(&input, 0, 0).unwrap();
103 Self {
104 threshold: 0.,
105 goal_points: Vec::new(),
106 avoid_points: Vec::new(),
107 image,
108 }
109 }
110
111 pub fn set_threshold(mut self, threshold: f32) -> Self {
115 self.threshold = threshold;
116 self
117 }
118
119 pub fn add_goal_point(mut self, x: impl Into<f64>, y: impl Into<f64>) -> Self {
121 self.goal_points.push((x.into(), y.into()));
122 self
123 }
124
125 pub fn set_goal_points(mut self, points: Vec<(f64, f64)>) -> Self {
127 self.goal_points = points;
128 self
129 }
130
131 pub fn add_avoid_points(mut self, x: impl Into<f64>, y: impl Into<f64>) -> Self {
133 self.avoid_points.push((x.into(), y.into()));
134 self
135 }
136
137 pub fn set_avoid_points(mut self, points: Vec<(f64, f64)>) -> Self {
139 self.avoid_points = points;
140 self
141 }
142
143 pub fn set_image<I: GenericImageView<Pixel = Rgba<u8>>>(
145 mut self,
146 image: I,
147 ) -> Result<Self, image::ImageError> {
148 self.image = ImageBuffer::new(image.width(), image.height());
149 self.image.copy_from(&image, 0, 0)?;
150 Ok(self)
151 }
152}
153
154#[derive(Debug, thiserror::Error)]
156pub enum LoadSegmentAnythingError {
157 #[error("Failed to load model into device: {0}")]
159 LoadModel(#[from] candle_core::Error),
160 #[error("Failed to download model from Hugging Face: {0}")]
162 DownloadModel(#[from] hf_hub::api::sync::ApiError),
163}
164
165#[derive(Debug, thiserror::Error)]
167pub enum SegmentAnythingInferenceError {
168 #[error("Failed to run model: {0}")]
170 RunModel(#[from] candle_core::Error),
171 #[error("Failed to merge masks")]
173 MergeMasks,
174}
175
176pub struct SegmentAnything {
178 device: Device,
179 sam: Sam,
180}
181
182impl SegmentAnything {
183 pub fn builder() -> SegmentAnythingBuilder {
185 SegmentAnythingBuilder::default()
186 }
187
188 fn new(settings: SegmentAnythingBuilder) -> Result<Self, LoadSegmentAnythingError> {
189 let SegmentAnythingBuilder { source } = settings;
190 let model = {
191 let api = hf_hub::api::sync::Api::new()?;
192 let api = api.model(source.model);
193 api.get(&source.filename)?
194 };
195 let device = Device::Cpu;
198 let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
199 let sam = if source.tiny {
200 sam::Sam::new_tiny(vb)? } else {
202 sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? };
204 Ok(Self { device, sam })
205 }
206
207 pub fn segment_from_points(
224 &self,
225 settings: SegmentAnythingInferenceSettings,
226 ) -> Result<DynamicImage, SegmentAnythingInferenceError> {
227 let SegmentAnythingInferenceSettings {
228 threshold,
229 goal_points,
230 avoid_points,
231 image,
232 } = settings;
233
234 let image = image::DynamicImage::ImageRgba8(image);
235 let image_width = image.width();
236 let image_height = image.height();
237
238 let image_tensor = self.image_to_tensor(image)?;
239
240 let points = {
241 let mut points = Vec::new();
242 for (x, y) in goal_points {
243 points.push((x, y, true));
244 }
245 for (x, y) in avoid_points {
246 points.push((x, y, false));
247 }
248 points
249 };
250
251 let (mask, _iou_predictions) = self.sam.forward(&image_tensor, &points, false)?;
252
253 let mask = (mask.ge(threshold)? * 255.)?;
254 let (_one, h, w) = mask.dims3()?;
255 let mask = mask.expand((3, h, w))?;
256
257 let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
258 let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
259 image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels)
260 .ok_or(SegmentAnythingInferenceError::MergeMasks)?;
261
262 Ok(image::DynamicImage::from(mask_img).resize_to_fill(
263 image_width,
264 image_height,
265 image::imageops::FilterType::CatmullRom,
266 ))
267 }
268
269 fn image_to_tensor(&self, image: DynamicImage) -> candle_core::Result<Tensor> {
270 let image = {
271 let resize_longest = sam::IMAGE_SIZE;
272 let (height, width) = (image.height(), image.width());
273 let resize_longest = resize_longest as u32;
274 let (height, width) = if height < width {
275 let h = (resize_longest * height) / width;
276 (h, resize_longest)
277 } else {
278 let w = (resize_longest * width) / height;
279 (resize_longest, w)
280 };
281 image.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
282 };
283 let (height, width) = (image.height() as usize, image.width() as usize);
284 let img = image.to_rgb8();
285 let data = img.into_raw();
286 let image = Tensor::from_vec(data, (height, width, 3), &self.device)?.permute((2, 0, 1))?;
287
288 let image = image.to_device(&self.device)?;
289
290 Ok(image)
291 }
292
293 pub fn segment_everything(
308 &self,
309 image: DynamicImage,
310 ) -> Result<Vec<DynamicImage>, SegmentAnythingInferenceError> {
311 let image = self.image_to_tensor(image)?;
312
313 let bboxes = self.sam.generate_masks(&image, 32, 0, 512. / 1500., 1)?;
314 let mut masks = Vec::new();
315 for bbox in bboxes {
316 let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
317 let (h, w) = mask.dims2()?;
318 let mask = mask.broadcast_as((3, h, w))?;
319 let (channel, height, width) = mask.dims3()?;
320 if channel != 3 {
321 return Err(candle_core::Error::Msg(
322 "save_image expects an input of shape (3, height, width)".to_string(),
323 )
324 .into());
325 }
326 let mask = mask.permute((1, 2, 0))?.flatten_all()?;
327 let pixels = mask.to_vec1::<u8>()?;
328 let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
329 image::ImageBuffer::from_raw(width as u32, height as u32, pixels)
330 .ok_or(SegmentAnythingInferenceError::MergeMasks)?;
331 let image = image::DynamicImage::from(image);
332 let image =
333 image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
334 masks.push(image);
335 }
336
337 Ok(masks)
338 }
339}