Skip to content
Open
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
]
dependencies = [
"timm>=1.0.17",
"numpy==1.26",
"numpy>=1.26",
"tqdm",
"ftfy==6.1.1",
"regex",
Expand Down Expand Up @@ -59,7 +59,7 @@ notebooks = [
"ipycanvas",
"ipympl",
"pycocotools",
"decord",
"decord2",
"opencv-python",
"einops",
"scikit-image",
Expand Down
20 changes: 14 additions & 6 deletions sam3/eval/postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,13 @@ def _process_masks(self, target_sizes, pred_masks, consistent=True, keep=None):
if pred_masks is None:
return None
if self.always_interpolate_masks_on_gpu:
gpu_device = target_sizes.device
assert gpu_device.type == "cuda"
pred_masks = pred_masks.to(device=gpu_device)
device = target_sizes.device
if device.type == "cpu":
logging.warning(
"always_interpolate_masks_on_gpu=True but data is on CPU; "
"falling back to CPU interpolation"
)
pred_masks = pred_masks.to(device=device)
if consistent:
assert keep is None, "TODO: implement?"
# All masks should have the same shape, expected when processing a batch of size 1
Expand Down Expand Up @@ -454,9 +458,13 @@ def process_results(
] # [P,Q,...] --> [K,...]
meta_td = meta_td[tracked_obj_ids_idx[PROMPT_AXIS].cpu()]
if self.always_interpolate_masks_on_gpu:
gpu_device = meta_td["original_size"].device
assert gpu_device.type == "cuda"
tracked_objs_outs_td = tracked_objs_outs_td.to(device=gpu_device)
device = meta_td["original_size"].device
if device.type == "cpu":
logging.warning(
"always_interpolate_masks_on_gpu=True but data is on CPU; "
"falling back to CPU interpolation"
)
tracked_objs_outs_td = tracked_objs_outs_td.to(device=device)
frame_results_td = self(
tracked_objs_outs_td.unsqueeze(1),
(
Expand Down
9 changes: 8 additions & 1 deletion sam3/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
gen_sineembed_for_position,
get_activation_fn,
get_clones,
get_default_device,
inverse_sigmoid,
MLP,
)
Expand Down Expand Up @@ -277,8 +278,9 @@ def __init__(

if resolution is not None and stride is not None:
feat_size = resolution // stride
device = get_default_device()
coords_h, coords_w = self._get_coords(
feat_size, feat_size, device="cuda"
feat_size, feat_size, device=device
)
self.compilable_cord_cache = (coords_h, coords_w)
self.compilable_stored_size = (feat_size, feat_size)
Expand Down Expand Up @@ -342,6 +344,11 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
):
# good, hitting the cache, will be compilable
coords_h, coords_w = self.compilable_cord_cache
# Ensure cache is on the same device as input (handles model.to(device) calls)
if coords_h.device != reference_boxes.device:
coords_h = coords_h.to(reference_boxes.device)
coords_w = coords_w.to(reference_boxes.device)
self.compilable_cord_cache = (coords_h, coords_w)
else:
# cache miss, will create compilation issue
# In case we're not compiling, we'll still rely on the dict-based cache
Expand Down
15 changes: 10 additions & 5 deletions sam3/model/geometry_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy

from .model_misc import get_clones
from .model_misc import get_clones, tensor_to_device


def is_right_padded(mask):
Expand Down Expand Up @@ -44,8 +44,13 @@ def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False
assert seq1_length == mask1.size(1)
assert seq2_length == mask2.size(1)

torch._assert_async(is_right_padded(mask1))
torch._assert_async(is_right_padded(mask2))
# Use regular assert on non-CUDA devices (torch._assert_async is CUDA-only)
if mask1.is_cuda:
torch._assert_async(is_right_padded(mask1))
torch._assert_async(is_right_padded(mask2))
else:
assert is_right_padded(mask1), "mask1 must be right-padded"
assert is_right_padded(mask2), "mask2 must be right-padded"

actual_seq1_lengths = (~mask1).sum(dim=-1)
actual_seq2_lengths = (~mask2).sum(dim=-1)
Expand Down Expand Up @@ -606,7 +611,7 @@ def _encode_points(self, points, points_mask, points_labels, img_feats):
assert points_embed is None
points_embed = proj

if self.points_pool_project is not None:
if self.points_pool_project is not None and n_points > 0:
# points are [Num_points, bs, 2], normalized in [0, 1]
# the grid needs to be [Bs, H_out, W_out, 2] normalized in [-1,1]
# Will take H_out = num_points, w_out = 1
Expand Down Expand Up @@ -656,7 +661,7 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
# We need to denormalize, and convert to [x, y, x, y]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
scale = tensor_to_device(scale, boxes_xyxy.device)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
sampled = torchvision.ops.roi_align(
Expand Down
Loading