Segment Anything is a new model from Facebook that does arbitrary object recognition in images.
The model is runnable in two steps:
- The “embedding” step, inspired by the language embeddings that were foundational to modern LLMs, that creates image embeddings from the image. This step can take some time, but can be accelerated with GPUs.
- The “prompting” step, again inspired by LLMs that respond to user-input “prompts.” In the case of Segment Anything, the prompt can take multiple forms, such as a single point on the image or a bounding box. When prompted, the model will return one or more image masks of detected objects in the image.
In their paper, Facebook shows off impressive image recognition performance, beating even some models that know what type of image they’re looking for.
Facebook has made Segment Anything available in Python, building on their work on PyTorch. In this model, we’ll run Segment Anything in a Jupyter Notebook and then deploy it to a REST endpoint.
To follow along, check out our Colab notebook with all this code in a runnable, interactive form!
Some Building Blocks For Working With Images
The code in this blog post works with images. For convenience, we’ve got a couple helper methods defined here that will work later on. The bounding box code is from yours truly here at Modelbit. The rest is courtesy of Facebook Research.
{%CODE python%}
# Given the pixels of an image mask, return the mask's bounding box
def mask2boundingbox(mask):
x_min = None
x_max = None
y_min = None
y_max = None
for y, row in enumerate(mask):
for x, val in enumerate(row):
if val:
if x_min is None or x_min > x:
x_min = x
if y_min is None or y_min > y:
y_min = y
if x_max is None or x_max < x:
x_max = x
if y_max is None or y_max < y:
y_max = y
return x_min, y_min, x_max, y_max
# Render a mask in matplotlib
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
# Render a point as a star in matplotlib
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*',
s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*',
s=marker_size, edgecolor='white', linewidth=1.25)
# Render a box in matplotlib
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=2))
# Render an image in matplotlib
def show_image(img, points = None, mask = np.ndarray([]), box = ()):
plt.figure(figsize=(10,10))
plt.imshow(img)
if points:
show_points(np.array([[points[0], points[1]]]), np.array([1]), plt.gca())
if mask.any():
show_mask(mask, plt.gca())
if box:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
{%/CODE%}
Getting Segment Anything Running
First, we’ll need to install the Segment Anything Python package. Like other specific Facebook models (e.g. Detectron2, Prophet), Segment Anything is available from the Facebook Research git repo. We’ll install it like so:
{%CODE bash%}
!pip install git+https://github.com/facebookresearch/segment-anything.git
{%/CODE%}
Segment Anything then needs to be configured with a checkpoint file. Facebook supplies three options, each of which having a size (and therefore speed) vs. accuracy tradeoff. For this example, we’ll use their default ViT-H model, the largest option with 636M parameters, which clocks in at 2.4GB. Let’s get the file:
{%CODE bash%}
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
{%/CODE%}
From here, we can initialize our model:
{%CODE python%}
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry["default"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
{%/CODE%}
To test our model, we just need an image! For fun, we found this image of big cats from the Monterey Zoo:
You can find it here: https://montereyzoo.org/wp-content/uploads/2017/04/big-cats-6.jpg
Let’s download the image into our Notebook:
{%CODE python%}
import cv2
import urllib
import numpy as np
import requests
resp = requests.get("https://montereyzoo.org/wp-content/uploads/2017/04/big-cats-6.jpg")
img = np.asarray(bytearray(resp.content))
img = cv2.imdecode(img, -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
{%/CODE%}
From here, we’re ready to take that first “embedding” step. Let’s get the model loaded up with embeddings from this image. This will take a few moments if you’re on CPU, but is quite quick on a GPU. In either case, it’s one simple line of Python:
{%CODE python%}
predictor.set_image(img)
{%/CODE%}
Now that the model embeddings are ready, it just needs to be prompted! Since we’re rolling with the big cats, let’s make a cat detector. We’ll supply the coordinates, and the model will tell us if there’s a cat, and if so, the cat’s bounding box.
Here’s the code:
{%CODE python%}
def find_cat(x_coord, y_coord):
masks, scores, logits = predictor.predict(
point_coords=np.array([[x_coord, y_coord]]),
point_labels=np.array([1]),
multimask_output=True,
)
top_score = 0
best_mask = None
for i, score in enumerate(scores):
if score > top_score:
top_score = score
best_mask = masks[i]
bbox = mask2boundingbox(best_mask)
show_image(img, (x_coord, y_coord), best_mask, bbox)
return bbox
{%/CODE%}
It’s all pretty simple! predictor.predict will return an array of masks. We iterate through to find the highest-scored image mask. We then convert it to a bounding box and return it.
Along the way, we use our handy image functions to render out the image, the point of the prompt, and the object mask and bounding box. Here it is all put together:
{%CODE python%}
bounding_box = find_cat(225, 150)
bounding_box
{%/CODE%}
Pretty slick! By specifying (225, 150) as our coordinates, we find the lion and we find the lion’s bounding box. Not bad at all.
Deploying Segment Anything to a REST API
With all this in hand, we’re ready for production. As the Modelbit founding team, we’d be remiss if we didn’t use Modelbit to make it happen! We’ll start by importing Modelbit’s Python package and logging in:
{%CODE bash%}
!pip install --upgrade modelbit
{%/CODE%}
{%CODE python%}
import modelbit
mb = modelbit.login(branch="dev")
{%/CODE%}
Now, we can go ahead and use Modelbit to deploy our model to production:
{%CODE python%}
mb.deploy(find_cat,
python_packages=[
"git+https://github.com/facebookresearch/segment-anything.git",
"opencv-python==4.7.0.72",
"pycocotools==2.0.6",
"matplotlib==3.7.1",
"numpy==1.22.4",
"torch==2.0.1+cu118",
"torchvision==0.15.2+cu118",
],
system_packages=["python3-opencv", "build-essential"]
)
{%/CODE%}
Since we already have our find_cat function, we can just pass it right into Modelbit’s deploy function. Modelbit will handle finding all of its dependencies, including the other Python functions and variables that find_cat depends on.
Modelbit will typically detect the necessary Python and system packages as well, but since we’re depending on a package from a specific git URL, as well as very specific versions of torch and torchvision, we went ahead and specified them all. This is mostly user preference.
From here, Modelbit will provision a REST API! It takes in our two coordinates and returns our bounding box:
{%CODE bash%}
% curl -s -XPOST "https://harrys-house.app.modelbit.com/v1/find_cat/latest" -d '{"data": [225, 150]}' | json_pp
{
"data" : [
182,
11,
559,
327
]
}
{%/CODE%}
Now our REST API is returning the same bounding box coordinates for our lion that we saw in our notebook! We can hand this API off to our product or web application so it can be integrated into production.
Want more tutorials for deploying ML models to production?
- Tutorial for Deploying Segment Anything Model to Production
- Tutorial for Deploying OpenAI's Whisper Model to Production
- Tutorial for Deploying Llama-2 to a REST API Endpoint
- Deploying DINOv2 for Image Classification with Modelbit
- Tutorial for Deploying a BERT Model to Production
- Tutorial for Deploying ResNet-50 to a REST API
- Tutorial for Deploying OWL-ViT to Production
- Tutorial for Deploying a Grounding DINO Model to Production