# -*- coding: utf-8 -*-
import numpy as np
from abc import abstractmethod

from .base import ImageAugmentor
from .transform import TransformFactory

__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',

[docs]class BackgroundFiller(object): """ Base class for all BackgroundFiller"""
[docs] def fill(self, background_shape, img): """ Return a proper background image of background_shape, given img. Args: background_shape (tuple): a shape (h, w) img: an image Returns: a background image """ background_shape = tuple(background_shape) return self._fill(background_shape, img)
@abstractmethod def _fill(self, background_shape, img): pass
[docs]class ConstantBackgroundFiller(BackgroundFiller): """ Fill the background by a constant """
[docs] def __init__(self, value): """ Args: value (float): the value to fill the background. """ self.value = value
def _fill(self, background_shape, img): assert img.ndim in [3, 2] if img.ndim == 3: return_shape = background_shape + (img.shape[2],) else: return_shape = background_shape return np.zeros(return_shape, dtype=img.dtype) + self.value
# NOTE: # apply_coords should be implemeted in paste transform, but not yet done
[docs]class CenterPaste(ImageAugmentor): """ Paste the image onto the center of a background canvas. """
[docs] def __init__(self, background_shape, background_filler=None): """ Args: background_shape (tuple): shape of the background canvas. background_filler (BackgroundFiller): How to fill the background. Defaults to zero-filler. """ if background_filler is None: background_filler = ConstantBackgroundFiller(0) self._init(locals())
[docs] def get_transform(self, _): return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img))
def _impl(self, img): img_shape = img.shape[:2] assert self.background_shape[0] >= img_shape[0] and self.background_shape[1] >= img_shape[1] background = self.background_filler.fill( self.background_shape, img) y0 = int((self.background_shape[0] - img_shape[0]) * 0.5) x0 = int((self.background_shape[1] - img_shape[1]) * 0.5) background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img return background
[docs]class RandomPaste(CenterPaste): """ Randomly paste the image onto a background canvas. """
[docs] def get_transform(self, img): img_shape = img.shape[:2] assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1] y0 = self._rand_range(self.background_shape[0] - img_shape[0]) x0 = self._rand_range(self.background_shape[1] - img_shape[1]) l = int(x0), int(y0) return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img, l))
def _impl(self, img, loc): x0, y0 = loc img_shape = img.shape[:2] background = self.background_filler.fill( self.background_shape, img) background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img return background