"""
This module implements the co-registration transformers.
Copyright (c) 2017- Sinergise and contributors
For the full list of contributors, see the CREDITS file in the root directory of this source tree.
This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
import logging
import warnings
import cv2
import numpy as np
from eolearn.core import EOPatch, EOTask, FeatureType
from eolearn.core.exceptions import EORuntimeWarning
from eolearn.core.types import Feature, FeaturesSpecification
LOGGER = logging.getLogger(__name__)
[docs]class ECCRegistrationTask(EOTask):
"""Multi-temporal image co-registration using OpenCV Enhanced Cross-Correlation method
The task uses a temporal stack of images of the same location (i.e. a temporal-spatial feature in `EOPatch`)
and a reference timeless feature to estimate a transformation that aligns each frame of the temporal stack
to the reference feature.
Each transformation is calculated using only a single channel of the images. If feature which contains masks of
valid pixels is specified it is used during the estimation of the transformation. The estimated transformations
are applied to each of the specified features.
"""
def __init__(
self,
registration_feature: Feature,
reference_feature: Feature,
channel: int,
valid_mask_feature: Feature | None = None,
apply_to_features: FeaturesSpecification = ...,
interpolation_mode: int = cv2.INTER_LINEAR,
warp_mode: int = cv2.MOTION_TRANSLATION,
max_iter: int = 100,
gauss_kernel_size: int = 1,
border_mode: int = cv2.BORDER_REPLICATE,
border_value: float = 0,
num_threads: int = 1,
max_translation: float = 5.0,
):
"""
:param registration_feature: Feature in EOPatch holding the multi-temporal stack to register to the
reference. Needs to be of FeatureType.DATA.
:param reference_feature: Feature in EOPatch used as reference frame for registration.
:param channel: Defines the index of the stack and reference feature to use during registration.
:param valid_mask_feature: Optional feature in EOPatch that defines which pixels should be used for
registration.
:param apply_to_features: List of temporal features in EOPatch to which applied the estimated
transformation.
:param interpolation_mode: Interpolation type used when transforming the stack of images.
:param warp_mode: Defines the transformation model used to match the stack and the reference.
Examples include TRANSLATION, RIGID_MOTION, AFFINE.
:param max_iter: Maximum number of iterations used during optimization of algorithm.
:param gauss_kernel_size: Size of Gaussian kernel used to smooth images prior to registration.
:param border_mode: Defines the padding strategy when transforming the images with estimated
transformation.
:param border_value: Value used for padding when border mode is set to CONSTANT.
:param num_threads: Number of threads used for optimization of the algorithm.
:param max_translation: Estimated transformations are considered incorrect when the norm of the
translation component is larger than this parameter.
"""
self.registration_feature = self.parse_feature(registration_feature, allowed_feature_types=[FeatureType.DATA])
self.reference_feature = self.parse_feature(
reference_feature, allowed_feature_types=[FeatureType.DATA_TIMELESS]
)
self.channel = channel
self.valid_mask_feature = (
None
if valid_mask_feature is None
else self.parse_feature(valid_mask_feature, allowed_feature_types=[FeatureType.MASK])
)
self.apply_features_parser = self.get_feature_parser(
apply_to_features, allowed_feature_types=[FeatureType.DATA, FeatureType.MASK]
)
self.warp_mode = warp_mode
self.interpolation_mode = interpolation_mode
self.max_iter = max_iter
self.gauss_kernel_size = gauss_kernel_size
self.border_mode = border_mode
self.border_value = border_value
self.num_threads = num_threads
self.max_translation = max_translation
[docs] def register(
self,
src: np.ndarray,
trg: np.ndarray,
valid_mask: np.ndarray | None = None,
warp_mode: int = cv2.MOTION_TRANSLATION,
) -> np.ndarray:
"""Method that estimates the transformation between source and target image"""
criteria = (cv2.TERM_CRITERIA_COUNT, self.max_iter, 0)
warp_matrix_size = (3, 3) if warp_mode == cv2.MOTION_HOMOGRAPHY else (2, 3)
warp_matrix: np.ndarray = np.eye(*warp_matrix_size, dtype=np.float32)
try:
cv2.setNumThreads(self.num_threads)
_, warp_matrix = cv2.findTransformECC(
src.astype(np.float32),
trg.astype(np.float32),
warp_matrix,
warp_mode,
criteria,
valid_mask,
self.gauss_kernel_size,
)
except cv2.error as cv2err: # pylint: disable=catching-non-exception
warnings.warn(f"Could not calculate the warp matrix: {cv2err}", EORuntimeWarning)
return warp_matrix
[docs] def execute(self, eopatch: EOPatch) -> EOPatch:
"""Method that estimates registrations and warps EOPatch objects"""
multi_temp_stack = eopatch[self.registration_feature][..., self.channel]
time_frames = multi_temp_stack.shape[0]
valid_mask = None
if self.valid_mask_feature is not None:
valid_mask = eopatch[self.valid_mask_feature].squeeze(axis=-1)
valid_mask = valid_mask.astype(np.uint8)
reference_image = eopatch[self.reference_feature][..., self.channel]
new_eopatch = EOPatch(bbox=eopatch.bbox, timestamps=eopatch.timestamps)
for feature_type, feature_name in self.apply_features_parser.get_features(eopatch):
new_eopatch[feature_type][feature_name] = np.zeros_like(eopatch[feature_type][feature_name])
warp_matrices = {}
for idx in range(time_frames):
valid_mask_ = None if valid_mask is None else valid_mask[idx]
warp_matrix = self.register(
reference_image, multi_temp_stack[idx], valid_mask=valid_mask_, warp_mode=self.warp_mode
)
if self.is_translation_large(warp_matrix):
warp_matrix = np.eye(2, 3)
warp_matrices[idx] = warp_matrix.tolist()
# Apply transformation to every given feature
for feature_type, feature_name in self.apply_features_parser.get_features(eopatch):
new_eopatch[feature_type][feature_name][idx] = self.warp_feature(
eopatch[feature_type][feature_name][idx], warp_matrix
)
new_eopatch[FeatureType.META_INFO, "warp_matrices"] = warp_matrices
return new_eopatch
[docs] def warp(self, img: np.ndarray, warp_matrix: np.ndarray, shape: tuple[int, int], flags: int) -> np.ndarray:
"""Transform the target image with the estimated transformation matrix"""
if warp_matrix.shape == (3, 3):
return cv2.warpPerspective( # type: ignore[call-overload]
img.astype(np.float32),
warp_matrix,
shape,
flags=flags,
borderMode=self.border_mode,
borderValue=self.border_value,
)
return cv2.warpAffine( # type: ignore[call-overload]
img.astype(np.float32),
warp_matrix,
shape,
flags=flags,
borderMode=self.border_mode,
borderValue=self.border_value,
)
[docs] def warp_feature(self, img: np.ndarray, warp_matrix: np.ndarray) -> np.ndarray:
"""Function to warp input image given an estimated 2D linear transformation"""
height, width = img.shape[:2]
warped_img = np.zeros_like(img, dtype=np.float32)
flags = self.interpolation_mode + cv2.WARP_INVERSE_MAP
# Check if image to warp is 2D or 3D. If 3D need to loop over channels
if img.ndim == 2:
warped_img = self.warp(img, warp_matrix, (width, height), flags=flags)
elif img.ndim == 3:
for idx in range(img.shape[-1]):
warped_img[..., idx] = self.warp(img[..., idx], warp_matrix, (width, height), flags=flags)
else:
raise ValueError(f"Image has incorrect number of dimensions: {img.ndim}. Correct number is either 2 or 3.")
return warped_img.astype(img.dtype)
[docs] def is_translation_large(self, warp_matrix: np.ndarray) -> bool:
"""Method that checks if estimated linear translation could be implausible.
This function checks whether the norm of the estimated translation in pixels exceeds a predefined value.
"""
return np.linalg.norm(warp_matrix[:, 2]).astype(float) > self.max_translation
[docs]def get_gradient(src: np.ndarray) -> np.ndarray:
"""Method which calculates and returns the gradients for the input image, which are
better suited for co-registration
"""
# Calculate the x and y gradients using Sobel operator
src = src.astype(np.float32)
grad_x = cv2.Sobel(src, cv2.CV_32F, 1, 0, ksize=3)
grad_y = cv2.Sobel(src, cv2.CV_32F, 0, 1, ksize=3)
# Combine and return the two gradients
return cv2.addWeighted(np.absolute(grad_x), 0.5, np.absolute(grad_y), 0.5, 0)