Source code for dgenerate.imageprocessors.u_sam

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import contextlib
import os
import re

import PIL.Image
import PIL.ImageDraw
import PIL.ImageFont
import PIL.ImageOps
import PIL.ImageStat
import cv2
import numpy
import torch
import dgenerate.image as _image
from ultralytics import SAM as _SAM
from ultralytics.models.sam.build import sam_model_map as _sam_model_map_u

import dgenerate.messages as _messages
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
import dgenerate.webcache as _webcache
from dgenerate.imageprocessors import imageprocessor as _imageprocessor

# sam_h.pt is not actually an available ultralytics asset for whatever reason.
_sam_model_names = [k for k in _sam_model_map_u.keys() if not k == 'sam_h.pt']

_sam_assets_url = "https://github.com/ultralytics/assets/releases/download/v8.3.0/"


@contextlib.contextmanager
def _ultralytics_download_patch(local_files_only: bool):
    import ultralytics.models.sam.build

    og = ultralytics.models.sam.build.attempt_download_asset

    def attempt_download_asset(file) -> str:
        _, path = _webcache.create_web_cache_file(
            f'{_sam_assets_url}{file}',
            local_files_only=local_files_only)
        return path

    ultralytics.models.sam.build.attempt_download_asset = attempt_download_asset

    try:
        yield
    finally:
        ultralytics.models.sam.build.attempt_download_asset = og


[docs] class USAMProcessor(_imageprocessor.ImageProcessor):
[docs] @staticmethod def help(loaded_by_name: str): models = ('\n'+' '*12+'* ').join(_sam_model_names) # the indentation level of this here string is important # to the template, it is at level 8, plus 4 extra (doc indent), star, space return \ f""" Process the input image with Ultralytics SAM (Segment Anything Model) using point or bounding box prompts. This processor operates in two distinct modes: Preview Mode (default, masks=False): Returns the original image with generated masks outlined and labeled with prompt indices. The colors of the outlines and text are automatically chosen to contrast with the background for optimal visibility. Mask Mode (masks=True): Returns a single composite mask image containing all generated masks combined together. This is useful for inpainting, outpainting, or other mask-based image processing operations. ----- The "asset" argument specifies which SAM model asset to use. This should be the name of an Ultralytics SAM model asset, loading arbitrary checkpoints is not supported. This argument may be one of: NOWRAP! * {models} You may exclude the `.pt` suffix if desired. The "local-files-only" argument specifies that dgenerate should not attempt to download any model files, and to only look for them locally in the cache or otherwise. The "points" argument specifies point prompts as a list of coordinates. Each point can be specified as either: NOWRAP! - Single point: [x,y] or x,y or "x,y" or 50x50 or "50x50" - Single point: [x,y,label] or x,y,label or "x,y,label" or 50x50xLabel or "50x50xLabel" - Nested list/tuple literal: [[x,y], ...] or [[x,y,label], ...] - String format: ["x,y", ...] or ["x,y,label", ...] or "x,y","x,y,label" - Token list format: 25x25,50x50xLabel Where label is 1 for foreground, 0 for background. If no label is provided, it defaults to 1 (foreground). Note that for string format, comma is interchangeable and mixable with the character "x", as the quotes delimit the bounds of the point or box value. lists / tuple literals may not contain space. NOWRAP! Examples: points=[100,100] # Single point points=100,100 # Single point points=100x100 # Single point points=[100,100,1] # Single point (label) points=100,100,1 # Single point (label) points=100x100x1 # Single point (label) points=[[100,100],[200,200,0]] # Nested list format points=["100,100","200,200,0"] # String format points="100,100","200,200,0" # String format points=["100x100","200x200x0"] # String format points="100x100","200x200x0" # String format points=100x100,200x200x0 # Token format The "boxes" argument specifies bounding box prompts as a list of coordinates. Each box can be specified as either: NOWRAP! - Single box: [x1,y1,x2,y2] or x1,y1,x2,y2 or "x1,y1,x2,y2" - Nested list/tuple: [[x1,y1,x2,y2], ...] - String format: ["x1,y1,x2,y2", ...] - Token list format: 50x50x100x100,200x200x400x400 NOWRAP! Examples: boxes=[50,50,150,150] # Single box boxes=50,50,150,150 # Single box boxes=50x50x150x150 # Single box boxes=[[50,50,150,150],[200,200,300,300]] # Nested list format boxes=["50,50,150,150","200,200,300,300"] # String format boxes="50,50,150,150","200,200,300,300" # String format boxes="50x50x150x150","200x200x300x300" # String format boxes=50x50x150x150,200x200x300x300 # Token format The "boxes-mask" argument specifies a black and white mask image where white areas will be automatically converted to bounding box prompts. This is useful for integrating with YOLO detection results or other object detection masks. The mask will be resized to match the input image dimensions before processing. The "boxes-mask-processors" argument allows you to pre-process the boxes mask with an image processor chain before extracting bounding boxes. This is useful for applying filters, transforms, or other modifications to the mask. Note: You may use python tuple syntax as well as list syntax, additionally something such as: (100,100),(100,100) will be interpreted as a tuple of of tuples, and: [100,100],[100,100] a tuple of lists. The "font-size" argument determines the size of the label text. If not specified, it will be automatically calculated based on the image dimensions. The "line-width" argument controls the thickness of the mask outline lines. If not specified, it will be automatically calculated based on the image dimensions. The "line-color" argument overrides the color for mask outlines and text label backgrounds. This should be specified as a HEX color code, e.g. "#FFFFFF" or "#FFF". If not specified, colors are automatically chosen to contrast with the background. The text color will always be automatically chosen to contrast with the background for optimal readability. The "masks" argument enables mask generation mode. When True, the processor returns a composite mask image instead of the annotated preview image. This defaults to False. The "outpaint" argument inverts the generated masks, creating inverted masks suitable for outpainting operations. This only has an effect when "masks" is True. This defaults to False. The "pre-resize" argument determines if the processing occurs before or after dgenerate resizes the image. This defaults to False, meaning the image is processed after dgenerate is done resizing it. """
NAMES = ['u-sam'] OPTION_ARGS = { 'asset': list(_sam_model_names), } FILE_ARGS = { 'boxes-mask': {'mode': 'in', 'filetypes': [('Images', _imageprocessor.ImageProcessor.image_in_filetypes())]} } @staticmethod def _match_hex_color(color): pattern = r'^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$' return bool(re.match(pattern, color)) @staticmethod def _hex_to_rgb(hex_color): """Convert hex color to RGB tuple.""" hex_color = hex_color.lstrip('#') if len(hex_color) == 3: hex_color = ''.join([c * 2 for c in hex_color]) return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4)) @staticmethod def _parse_points(points_input): """Parse point coordinates from nested lists/tuples or string format.""" if not points_input: return [] if isinstance(points_input, str) or any(isinstance(x, int) for x in points_input): # singular point points_input = [points_input] points = [] for point in points_input: if isinstance(point, (list, tuple)): # Already parsed nested structure: [x, y] or [x, y, label] if len(point) < 2: raise ValueError(f"Point must have at least x,y coordinates: {point}") elif len(point) == 2: x, y = map(float, point) points.append([x, y, 1]) # Default to foreground elif len(point) == 3: x, y, label = map(float, point) points.append([x, y, int(label)]) else: raise ValueError(f"Point should have 2 or 3 coordinates: {point}") elif isinstance(point, str): # String format for backward compatibility: "x,y" or "x,y,label" # And: "0x0" # And: 0x0,0x0x0 coords = re.split(r'[x,]', point) if len(coords) < 2: raise ValueError(f"Point must have at least x,y coordinates: {point}") elif len(coords) == 2: x, y = map(float, coords) points.append([x, y, 1]) # Default to foreground elif len(coords) == 3: x, y, label = map(float, coords) points.append([x, y, int(label)]) else: # Try splitting by comma for multiple points comma_split = point.split(',') if len(comma_split) == 1: # No commas found, this is a malformed single point raise ValueError(f"Invalid point format: {point}") # Multiple points separated by commas for c in comma_split: p = USAMProcessor._parse_points(c) if p: points.extend(p) else: raise ValueError(f'Missing point definition in: "{point}", stray comma?') else: raise ValueError(f"Point must be a list/tuple or string, got: {type(point).__name__}") return points @staticmethod def _parse_boxes(boxes_input): """Parse bounding box coordinates from nested lists/tuples or string format.""" if not boxes_input: return [] if isinstance(boxes_input, str) or any(isinstance(x, int) for x in boxes_input): # singular box boxes_input = [boxes_input] boxes = [] for box in boxes_input: if isinstance(box, (list, tuple)): # Already parsed nested structure: [x1, y1, x2, y2] if len(box) != 4: raise ValueError(f"Box must have x1,y1,x2,y2 coordinates: {box}") x1, y1, x2, y2 = map(float, box) boxes.append([x1, y1, x2, y2]) elif isinstance(box, str): # String format for backward compatibility: "x1,y1,x2,y2" # or: 0x0x0x0,0x0x0x0 coords = re.split(r'[x,]', box) if len(coords) > 4: # Try splitting by comma for multiple boxes comma_split = box.split(',') if len(comma_split) == 1: # No commas found, this is a malformed single box raise ValueError(f"Invalid box format - too many coordinates: {box}") # Multiple boxes separated by commas for c in comma_split: b = USAMProcessor._parse_boxes(c) if b: boxes.extend(b) else: raise ValueError(f'Missing box definition in: "{box}", stray comma?') elif len(coords) < 4: raise ValueError(f'Box must have x1,y1,x2,y2 coordinates: {box}') else: x1, y1, x2, y2 = map(float, coords) boxes.append([x1, y1, x2, y2]) else: raise ValueError(f"Box must be a list/tuple or string, got: {type(box).__name__}") return boxes
[docs] def __init__(self, asset: str, points: str | list | tuple | None = None, boxes: str | list | tuple | None = None, boxes_mask: str | None = None, boxes_mask_processors: str | None = None, font_size: int | None = None, line_width: int | None = None, line_color: str | None = None, masks: bool = False, outpaint: bool = False, pre_resize: bool = False, **kwargs): """ :param asset: SAM model asset to use, an Ultralytics asset name :param points: list of point prompts - can be nested lists [[x,y], [x,y,label]] or strings ["x,y", "x,y,label"] :param boxes: list of bounding box prompts - can be nested lists [[x1,y1,x2,y2]] or strings ["x1,y1,x2,y2"] :param boxes_mask: path or URL to a black and white mask image where white areas will be converted to bounding boxes :param boxes_mask_processors: image processor chain to apply to the boxes mask before extracting bounding boxes :param font_size: size of label text, if None will be calculated based on image dimensions :param line_width: thickness of mask outline lines, if None will be calculated based on image dimensions :param line_color: override color for mask outlines and text label backgrounds as hex color code (e.g. "#FF0000" or "#F00") :param masks: generate mask images instead of preview, default is ``False`` :param outpaint: invert generated masks for outpainting, only effective when masks is ``True``, default is ``False`` :param pre_resize: process the image before it is resized, or after? default is ``False`` (after). :param kwargs: forwarded to base class """ super().__init__(**kwargs) if line_width is not None and line_width < 1: raise self.argument_error('Argument "line-width" must be at least 1.') if font_size is not None and font_size < 8: raise self.argument_error('Argument "font-size" must be at least 8.') # Validate color arguments if line_color is not None and not self._match_hex_color(line_color): raise self.argument_error('line-color must be a HEX color code, e.g. #FFFFFF or #FFF') # Validate boxes-mask arguments if boxes_mask_processors and not boxes_mask: raise self.argument_error( 'Cannot use "boxes-mask-processors" without specifying "boxes-mask"' ) if not asset.endswith('.pt'): asset += '.pt' # get model path on disk self._model_path = self._get_model_path(asset) self._line_width = line_width self._font_size = font_size self._line_color = line_color self._masks = masks self._outpaint = outpaint self._pre_resize = pre_resize self._boxes_mask = boxes_mask self._boxes_mask_processors = boxes_mask_processors # Parse prompts try: self._points = self._parse_points(points or []) self._boxes = self._parse_boxes(boxes or []) except ValueError as e: raise self.argument_error(f'Error parsing prompts: {e}') from e if not self._points and not self._boxes and not self._boxes_mask: raise self.argument_error('At least one point, box, or boxes-mask prompt must be specified.') model_size = os.path.getsize(self._model_path) self.set_size_estimate(model_size) # Load the SAM model with _ultralytics_download_patch(self.local_files_only): try: self._model = self.load_object_cached( tag=self._model_path, estimated_size=self.size_estimate, method=lambda: _SAM(asset), ) self.register_module(self._model.model) except Exception as e: raise self.argument_error(f'Failed to load SAM model: {e}') from e
def _get_model_path(self, asset_name: str) -> str: if asset_name not in _sam_model_names: raise self.argument_error( f'Unknown SAM model: {asset_name}, must be one of: ' f'{_textprocessing.oxford_comma(_sam_model_names, "or")}') try: _, file = _webcache.create_web_cache_file( f'{_sam_assets_url}{asset_name}', local_files_only=self.local_files_only ) except Exception as e: raise self.argument_error(f'Error downloading ultralytics asset "model": {e}') return file def _run_image_processor( self, uri_chain_string, image, resize_resolution, aspect_correct, align, ): """Run an image processor from a URI chain string.""" import dgenerate.imageprocessors as _imgp # Convert image to RGB mode for consistent processing if image.mode != 'RGB': image = image.convert('RGB') processor = _imgp.create_image_processor( _textprocessing.shell_parse( uri_chain_string, expand_home=False, expand_glob=False, expand_vars=False ), device=self.device, model_offload=self.model_offload, ) try: return processor.process( image, resize_resolution=resize_resolution, aspect_correct=aspect_correct, align=align ) finally: processor.to('cpu') def _extract_boxes_from_mask(self, target_size: _types.Size) -> list: """ Extract bounding boxes from a black and white mask image. :param target_size: Size to resize mask to match input image :return: List of bounding boxes in format [[x1,y1,x2,y2], ...] """ if not self._boxes_mask: return [] try: # Handle URL downloads using webcache if _webcache.is_downloadable_url(self._boxes_mask): # Download and cache the URL _, mask_file_path = _webcache.create_web_cache_file( self._boxes_mask, mime_acceptable_desc='image files', mimetype_is_supported=lambda m: m.startswith('image/'), local_files_only=self.local_files_only ) mask_path = mask_file_path else: # Use local file path directly mask_path = self._boxes_mask # Load mask image and convert to grayscale mask_image = PIL.Image.open(mask_path) # Apply processors if specified if self._boxes_mask_processors is not None: mask_image = self._run_image_processor( self._boxes_mask_processors, mask_image, aspect_correct=False, resize_resolution=None, align=1 ) # Convert to grayscale if needed if mask_image.mode != 'L': mask_image = mask_image.convert('L') # Resize mask to match target image size if mask_image.size != target_size: old_size = mask_image.size mask_image = mask_image.resize( target_size, _image.best_pil_resampling(mask_image.size, target_size) ) _messages.debug_log(f"Boxes mask resized from {old_size} to {target_size}") # Convert to numpy array for OpenCV processing mask_array = numpy.array(mask_image) # Threshold to ensure we have a proper binary mask # Values > 128 are considered white (areas of interest) _, binary_mask = cv2.threshold(mask_array, 128, 255, cv2.THRESH_BINARY) # Find contours of white areas contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Extract bounding boxes from contours boxes = [] for contour in contours: # Get bounding rectangle for each contour x, y, w, h = cv2.boundingRect(contour) # Skip very small contours (likely noise) if w < 3 or h < 3: continue # Convert to [x1, y1, x2, y2] format x1, y1, x2, y2 = x, y, x + w, y + h boxes.append([float(x1), float(y1), float(x2), float(y2)]) _messages.debug_log(f"Extracted {len(boxes)} bounding boxes from boxes-mask") return boxes except Exception as e: raise self.argument_error(f'Failed to process argument "boxes-mask" "{self._boxes_mask}": {e}') def _get_contrasting_color(self, background_color): """ Calculate the best contrasting color for text based on background color. Uses HSV color space to find a high-contrast complementary color. :param background_color: RGB tuple of the background color :return: RGB tuple of the contrasting color """ import colorsys # Normalize RGB values to 0-1 range r, g, b = [c / 255.0 for c in background_color[:3]] # Convert to HSV h, s, v = colorsys.rgb_to_hsv(r, g, b) # Calculate complementary hue (opposite on color wheel) complementary_h = (h + 0.5) % 1.0 # For high contrast, we want high saturation and appropriate value # If background is dark, use bright contrasting color # If background is bright, use darker contrasting color if v < 0.5: # Dark background contrast_s = min(1.0, s + 0.3) # Increase saturation contrast_v = min(1.0, v + 0.6) # Increase brightness else: # Bright background contrast_s = min(1.0, s + 0.2) # Slightly increase saturation contrast_v = max(0.2, v - 0.5) # Decrease brightness # Convert back to RGB contrast_r, contrast_g, contrast_b = colorsys.hsv_to_rgb(complementary_h, contrast_s, contrast_v) # Convert back to 0-255 range and return as integers return int(contrast_r * 255), int(contrast_g * 255), int(contrast_b * 255) def _sample_line_area_background_color(self, image, contours, line_width, extra_thickness=3): """ Sample background color from the area where the outline will be drawn, including pixels both inside and outside the line area for better contrast. :param image: PIL Image to sample from :param contours: list of contours from cv2.findContours :param line_width: width of the line that will be drawn :param extra_thickness: additional pixels to sample beyond the line width :return: RGB tuple of the average background color around the line area """ # Create a mask for the line area line_mask = numpy.zeros((image.size[1], image.size[0]), dtype=numpy.uint8) # Draw the contours with the actual line width plus extra thickness # This gives us the area where the line will be plus some surrounding pixels sample_width = line_width + extra_thickness * 2 cv2.drawContours(line_mask, contours, -1, 255, thickness=sample_width) # Convert image to numpy array image_array = numpy.array(image) # Sample colors from the line area line_pixels = image_array[line_mask > 0] if len(line_pixels) > 0: # Calculate mean color from line area pixels bg_color = numpy.mean(line_pixels.reshape(-1, 3), axis=0) else: # Fallback to sampling from center of image if line area is empty center_x, center_y = image.size[0] // 2, image.size[1] // 2 bg_sample_area = image.crop((center_x - 25, center_y - 25, center_x + 25, center_y + 25)) bg_color = PIL.ImageStat.Stat(bg_sample_area).mean return bg_color def _calculate_line_width_font_size(self, image_size): """ Calculate appropriate line width and font size based on image dimensions. :param image_size: tuple of (width, height) :return: tuple of (line_width, font_size, text_padding) """ # Use the larger dimension to calculate sizes max_dim = max(image_size) # Calculate line width as 0.3% of max dimension, with min of 1 if self._line_width is None: line_width = max(1, int(0.003 * max_dim)) else: line_width = self._line_width # Calculate font size as 1.5% of max dimension, with min of 10 if self._font_size is None: font_size = max(10, int(0.015 * max_dim)) else: font_size = self._font_size # Calculate text padding as 0.3% of max dimension, with min of 2 text_padding = max(2, int(0.003 * max_dim)) return line_width, font_size, text_padding @torch.no_grad() def _process(self, image): # Convert PIL image to numpy array for SAM input_image = numpy.array(image) # Calculate dynamic sizes based on image dimensions line_width, font_size, text_padding = self._calculate_line_width_font_size(image.size) # Extract boxes from mask if provided and combine with existing boxes extracted_boxes = self._extract_boxes_from_mask(image.size) all_boxes = list(self._boxes) + extracted_boxes # Prepare prompts for batching batch_points = [] batch_labels = [] # Collect all points for batch processing if self._points: for point in self._points: # point is [x, y, label] batch_points.append([point[0], point[1]]) batch_labels.append(int(point[2])) # Process based on what prompts we have if not self._points and not all_boxes: _messages.debug_log("SAM mask: No prompts were specified.") # Return empty result based on mode if self._masks: empty_color = 0 if not self._outpaint else 255 empty_mask = PIL.Image.new("RGB", image.size, (empty_color, empty_color, empty_color)) return empty_mask else: return image.copy() # Run SAM with prompts - each call returns a single Results object with multiple masks results = [] try: if batch_points and all_boxes: # Process points first if batch_points: sam_result = self._model(input_image, points=batch_points, labels=batch_labels)[0] if sam_result.masks is not None and len(sam_result.masks) > 0: # Extract each mask individually for i in range(len(sam_result.masks)): results.append((sam_result, 'point', i, i)) # (result, type, prompt_idx, mask_idx) # Process boxes if all_boxes: sam_result = self._model(input_image, bboxes=all_boxes)[0] if sam_result.masks is not None and len(sam_result.masks) > 0: # Extract each mask individually for i in range(len(sam_result.masks)): prompt_idx = len(self._points) + i if self._points else i # Determine box type (original vs mask-extracted) if i < len(self._boxes): box_type = 'box' else: box_type = 'mask-box' results.append((sam_result, box_type, prompt_idx, i)) # (result, type, prompt_idx, mask_idx) elif batch_points: # Only points sam_result = self._model(input_image, points=batch_points, labels=batch_labels)[0] if sam_result.masks is not None and len(sam_result.masks) > 0: # Extract each mask individually for i in range(len(sam_result.masks)): results.append((sam_result, 'point', i, i)) # (result, type, prompt_idx, mask_idx) else: _messages.debug_log(f"SAM mask: No masks generated for point prompts") elif all_boxes: # Only boxes sam_result = self._model(input_image, bboxes=all_boxes)[0] if sam_result.masks is not None and len(sam_result.masks) > 0: # Extract each mask individually for i in range(len(sam_result.masks)): # Determine box type (original vs mask-extracted) if i < len(self._boxes): box_type = 'box' else: box_type = 'mask-box' results.append((sam_result, box_type, i, i)) # (result, type, prompt_idx, mask_idx) else: _messages.debug_log(f"SAM mask: No masks generated for box prompts") except Exception as e: _messages.debug_log(f"SAM mask: Error processing prompts: {e}") results = [] if not results: _messages.debug_log("SAM mask: No masks were generated from prompts.") # Return empty result based on mode if self._masks: empty_color = 0 if not self._outpaint else 255 empty_mask = PIL.Image.new("RGB", image.size, (empty_color, empty_color, empty_color)) return empty_mask else: return image.copy() # If masks mode is enabled, return composite mask if self._masks: composite_mask = PIL.Image.new("L", image.size, 0) for result, prompt_type, prompt_idx, mask_idx in results: if result.masks is not None and mask_idx < len(result.masks.data): # Get the specific mask data mask_data = result.masks.data[mask_idx] # Convert to PIL Image mask_np = mask_data.cpu().numpy() mask_img = PIL.Image.fromarray((mask_np * 255).astype(numpy.uint8), mode="L") # Resize to match original image size mask_img = mask_img.resize(image.size, PIL.Image.LANCZOS) # Combine with composite mask (logical OR) composite_array = numpy.array(composite_mask) mask_array = numpy.array(mask_img) combined_array = numpy.maximum(composite_array, mask_array) composite_mask = PIL.Image.fromarray(combined_array, mode="L") _messages.debug_log(f"SAM mask: Generated composite mask from {len(results)} prompts.") if self._outpaint: # Invert the composite mask for outpainting composite_mask = PIL.ImageOps.invert(composite_mask) _messages.debug_log("SAM mask: Inverted composite mask for outpainting.") return composite_mask.convert('RGB') # Preview mode - return annotated image output_image = image.copy() draw = PIL.ImageDraw.Draw(output_image) # Try to load a font, fall back to default if not available try: font = PIL.ImageFont.truetype("arial.ttf", font_size) except IOError: try: font = PIL.ImageFont.truetype(PIL.ImageFont.load_default().path, font_size) except: font = PIL.ImageFont.load_default() # Draw mask outlines and labels for result, prompt_type, prompt_idx, mask_idx in results: if result.masks is not None and mask_idx < len(result.masks.data): # Get the specific mask data mask_data = result.masks.data[mask_idx] # Convert to PIL Image mask_np = mask_data.cpu().numpy() mask_img = PIL.Image.fromarray((mask_np * 255).astype(numpy.uint8), mode="L") # Resize to match original image size mask_img = mask_img.resize(image.size, PIL.Image.LANCZOS) mask_array = numpy.array(mask_img) # Find mask contours first contours, _ = cv2.findContours(mask_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Sample background color from the area where the line will be drawn bg_color = self._sample_line_area_background_color(image, contours, line_width) # Determine colors if self._line_color is not None: line_color = self._hex_to_rgb(self._line_color) else: line_color = self._get_contrasting_color(bg_color) text_bg_color = line_color text_color = self._get_contrasting_color(text_bg_color) # Draw mask contours for contour in contours: # Convert contour to the format PIL expects points = [] for point in contour: points.extend([int(point[0][0]), int(point[0][1])]) if len(points) >= 6: # Need at least 3 points (6 coordinates) for a polygon draw.polygon(points, outline=line_color, width=line_width) # Draw label label = f"{prompt_idx}: {prompt_type}" # Get proper text bounding box # textbbox returns (left, top, right, bottom) including ascent/descent bbox = draw.textbbox((0, 0), label, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] # The text baseline offset (negative value of top coordinate) text_offset_y = -bbox[1] # Find a good position for the label (top-left of the mask) mask_coords = numpy.where(mask_array > 128) if len(mask_coords[0]) > 0: min_y = numpy.min(mask_coords[0]) min_x = numpy.min(mask_coords[1]) # Calculate text background box position box_x = min_x box_y = max(0, min_y - text_height - text_padding * 2) # If box would go above the image, place it below if box_y < 0: box_y = min_y + text_padding # Ensure box doesn't go off the right edge if box_x + text_width + text_padding * 2 > image.size[0]: box_x = max(0, image.size[0] - text_width - text_padding * 2) # Draw text background box box_right = box_x + text_width + text_padding * 2 box_bottom = box_y + text_height + text_padding * 2 draw.rectangle([box_x, box_y, box_right, box_bottom], fill=text_bg_color) # Draw text centered in the box with proper baseline adjustment text_x = box_x + text_padding text_y = box_y + text_padding + text_offset_y draw.text((text_x, text_y), label, fill=text_color, font=font) _messages.debug_log(f"SAM mask: Drew mask outlines for {len(results)} prompts.") return output_image
[docs] def impl_pre_resize(self, image: PIL.Image.Image, resize_resolution: _types.OptionalSize): """ Pre resize, SAM mask processing may or may not occur here depending on the boolean value of the processor argument "pre-resize" :param image: image to process :param resize_resolution: purely informational, is unused by this processor :return: possibly a SAM mask processed image, or the input image """ if self._pre_resize: return self._process(image) return image
[docs] def impl_post_resize(self, image: PIL.Image.Image): """ Post resize, SAM mask processing may or may not occur here depending on the boolean value of the processor argument "pre-resize" :param image: image to process :return: possibly a SAM mask processed image, or the input image """ if not self._pre_resize: return self._process(image) return image
__all__ = _types.module_all()