Unverified Commit c3039921 authored by Kelvin Rodriguez's avatar Kelvin Rodriguez Committed by GitHub
Browse files

passable matching functions (#513)

* added passable matchers

* fixed weird geom_match_classic merge issue

* updated doc strings
parent 80da2dcd
Loading
Loading
Loading
Loading
+14 −4
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ from shapely.geometry import Point

from plurmy import Slurm

from autocnet.matcher.subpixel import check_match_func
from autocnet.io.db.model import Images, Points, Measures, JsonEncoder
from autocnet.cg.cg import distribute_points_in_geom, xy_in_polygon
from autocnet.io.db.connection import new_connection
@@ -171,7 +172,8 @@ def propagate_point(Session,
                    samples,
                    size_x=40,
                    size_y=40,
                    template_kwargs={'image_size': (39, 39), 'template_size': (21, 21)},
                    match_func="classic",
                    match_kwargs={'image_size': (39, 39), 'template_size': (21, 21)},
                    verbose=False,
                    cost=lambda x, y: y == np.max(x)):
    """
@@ -245,6 +247,9 @@ def propagate_point(Session,
                   and cartesian) of successfully propagated points

    """

    match_func = check_match_func(match_func)

    session = Session()
    engine = session.get_bind()
    string = f"select * from images where ST_Intersects(geom, ST_SetSRID(ST_Point({lon}, {lat}), {config['spatial']['latitudinal_srid']}))"
@@ -278,7 +283,8 @@ def propagate_point(Session,
                print(f'prop point: dest_image: {dest_image}')
                print(f'prop point: (sx, sy): ({sx}, {sy})')
                x,y, dist, metrics, corrmap = geom_match_simple(base_image, dest_image, sx, sy, 16, 16, \
                        template_kwargs=template_kwargs, \
                        match_func = match_func, \
                        match_kwargs=match_kwargs, \
                        verbose=verbose)
            except Exception as e:
                raise Exception(e)
@@ -350,7 +356,8 @@ def propagate_control_network(Session,
        base_cnet,
        size_x=40,
        size_y=40,
        template_kwargs={'image_size': (39,39), 'template_size': (21,21)},
        match_func="classic",
        match_kwargs={'image_size': (39,39), 'template_size': (21,21)},
        verbose=False,
        cost=lambda x,y: y == np.max(x)):
    """
@@ -402,6 +409,8 @@ def propagate_control_network(Session,
    warnings.warn('This function is not well tested. No tests currently exist \
    in the test suite for this version of the function.')

    match_func = check_match_func(match_func)

    groups = base_cnet.groupby('pointid').groups

    # append CNET info into structured Python list
@@ -427,7 +436,8 @@ def propagate_control_network(Session,
                                      measures["sample"],
                                      size_x,
                                      size_y,
                                      template_kwargs,
                                      match_func,
                                      match_kwargs,
                                      verbose=verbose,
                                      cost=cost)

+83 −117
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@ from math import modf, floor
import numpy as np
import warnings

import numbers

import sys
from skimage.feature import register_translation
from skimage import transform as tf
@@ -33,6 +35,41 @@ isis2np_types = {
        "Real" : "float64"
}


def check_geom_func(func):
    # TODO: Pain. Stick with one of these and delete this function along with
    # everything else
    geom_funcs = {
            "classic": geom_match_classic,
            "new": geom_match,
            "simple" : geom_match_simple,
    }

    if func in geom_funcs.values():
        return func

    if func in geom_funcs.keys():
        return match_funcs[func]

    raise Exception(f"{func} not a valid geometry function.")


def check_match_func(func):
    match_funcs = {
        "classic": subpixel_template_classic,
        "phase": iterative_phase,
        "template": subpixel_template
    }

    if func in match_funcs.values():
        return func

    if func in match_funcs.keys():
        return match_funcs[func]

    raise Exception(f"{func} not a valid matching function.")


# TODO: look into KeyPoint.size and perhaps use to determine an appropriately-sized search/template.
def _prep_subpixel(nmatches, nstrengths=2):
    """
@@ -86,8 +123,9 @@ def check_image_size(imagesize):
    imagesize : tuple
                in the form (size_x, size_y)
    """
    if isinstance(imagesize, int):
        imagesize = (imagesize, imagesize)
    if isinstance(imagesize, numbers.Number):
        imagesize = (int(imagesize), int(imagesize))


    x = imagesize[0]
    y = imagesize[1]
@@ -100,6 +138,7 @@ def check_image_size(imagesize):
    y = floor(y/2)
    return x,y


def clip_roi(img, center_x, center_y, size_x=200, size_y=200, dtype="uint64"):
    """
    Given an input image, clip a square region of interest
@@ -157,6 +196,7 @@ def clip_roi(img, center_x, center_y, size_x=200, size_y=200, dtype="uint64"):
            return None, 0, 0
    return subarray, axr, ayr


def subpixel_phase(sx, sy, dx, dy,
                   s_img, d_img,
                   image_size=(51, 51),
@@ -228,7 +268,8 @@ def subpixel_phase(sx, sy, dx, dy,
    dx = d_roi.x - shift_x
    dy = d_roi.y - shift_y

    return dx, dy, (error, diffphase)
    return dx, dy, error


def subpixel_transformed_template(sx, sy, dx, dy,
                                  s_img, d_img,
@@ -310,8 +351,8 @@ def subpixel_transformed_template(sx, sy, dx, dy,
    image_size = check_image_size(image_size)
    template_size = check_image_size(template_size)

    template_size_x = template_size[0] * transform.scale[0]
    template_size_y = template_size[1] * transform.scale[1]
    template_size_x = int(template_size[0] * transform.scale[0])
    template_size_y = int(template_size[1] * transform.scale[1])

    s_roi = roi.Roi(s_img, sx, sy, size_x=image_size[0], size_y=image_size[1])
    d_roi = roi.Roi(d_img, dx, dy, size_x=template_size_x, size_y=template_size_y)
@@ -402,6 +443,7 @@ def subpixel_transformed_template(sx, sy, dx, dy,

    return dx, dy, metrics, corrmap


def subpixel_template_classic(sx, sy, dx, dy,
                              s_img, d_img,
                              image_size=(251, 251),
@@ -463,7 +505,7 @@ def subpixel_template_classic(sx, sy, dx, dy,
    dx = d_roi.x - shift_x
    dy = d_roi.y - shift_y

    return dx, dy, metrics, corrmap
    return dx, dy, metrics


def subpixel_template(sx, sy, dx, dy,
@@ -580,6 +622,7 @@ def subpixel_template(sx, sy, dx, dy,

    return dx, dy, metrics, corrmap


def subpixel_ciratefi(sx, sy, dx, dy, s_img, d_img, search_size=251, template_size=51, **kwargs):
    """
    Uses a pattern-matcher on subsets of two images determined from the passed-in keypoints and optional sizes to
@@ -631,6 +674,7 @@ def subpixel_ciratefi(sx, sy, dx, dy, s_img, d_img, search_size=251, template_si
    dy += (y_offset + t_roi.ayr)
    return dx, dy, strength


def iterative_phase(sx, sy, dx, dy, s_img, d_img, size=(51, 51), reduction=11, convergence_threshold=1.0, max_dist=50, **kwargs):
    """
    Iteratively apply a subpixel phase matcher to source (s_img) and destination (d_img)
@@ -701,6 +745,7 @@ def iterative_phase(sx, sy, dx, dy, s_img, d_img, size=(51, 51), reduction=11, c
           break
    return dx, dy, metrics


def estimate_affine_transformation(destination_coordinates, source_coordinates):
    """
    Given a set of destination control points compute the affine transformation
@@ -787,8 +832,6 @@ def geom_match_simple(base_cube,
    autocnet.matcher.subpixel.subpixel_template: for list of kwargs that can be passed to the matcher
    autocnet.matcher.subpixel.subpixel_phase: for list of kwargs that can be passed to the matcher
    """
    print("in geommatch")
    print("subpixel kwargs", template_kwargs)

    if not isinstance(input_cube, GeoDataset):
        raise Exception("input cube must be a geodataset obj")
@@ -849,11 +892,9 @@ def geom_match_simple(base_cube,
                    "Real" : "float64"
    }

    #base_pixels = list(map(int, [base_corners[0][0], base_corners[0][1], size_x*2, size_y*2]))
    base_type = isis2np_types[pvl.load(base_cube.file_name)["IsisCube"]["Core"]["Pixels"]["Type"]]
    base_arr = base_cube.read_array(dtype=base_type)

    #dst_pixels = list(map(int, [start_x, start_y, stop_x-start_x, stop_y-start_y]))
    dst_type = isis2np_types[pvl.load(input_cube.file_name)["IsisCube"]["Core"]["Pixels"]["Type"]]
    dst_arr = input_cube.read_array(dtype=dst_type)

@@ -1087,94 +1128,6 @@ def geom_match_classic(base_cube,
    return sample, line, dist, metric, temp_corrmap


def geom_match(base_cube,
               input_cube,
               bcenter_x,
               bcenter_y,
               size_x=60,
               size_y=60,
               template_kwargs={"image_size":(59,59), "template_size":(31,31)},
               phase_kwargs=None,
               verbose=True):
    """
    Propagates a source measure into destination images and then perfroms subpixel registration.
    Measure creation is done by projecting the (lon, lat) associated with the source measure into the
    destination image. The created measure is then matched to the source measure using a quick projection
    of the destination image into source image space (using an affine transformation) and a naive
    template match with optional phase template match.
    Parameters
    ----------
    base_cube:  plio.io.io_gdal.GeoDataset
                source image
    input_cube: plio.io.io_gdal.GeoDataset
                destination image; gets matched to the source image
    bcenter_x:  int
                sample location of source measure in base_cube
    bcenter_y:  int
                line location of source measure in base_cube
    size_x:     int
                half-height of the subimage used in the affine transformation
    size_y:     int
                half-width of the subimage used in affine transformation
    template_kwargs: dict
                    contains keywords necessary for autocnet.matcher.subpixel.subpixel_template
    phase_kwargs:   dict
                    contains kwargs for autocnet.matcher.subpixel.subpixel_phase
    verbose: boolean
             indicates level of print out desired. If True, two subplots are output; the first subplot contains
             the source subimage and projected destination subimage, the second subplot contains the registered
             measure's location in the base subimage and the unprojected destination subimage with the corresponding
             template metric correlation map.

    Returns
    -------
    sample: int
            sample of new measure in destination image space
    line:   int
            line of new measures in destination image space
    dist:   np.float or tuple of np.float
            distance matching algorithm moved measure
            if template matcher only (default): returns dist_template
            if template and phase matcher:      returns (dist_template, dist_phase)
    metric: np.float or tuple of np.float
            matching metric output by the matcher
            if template matcher only (default): returns maxcorr
            if template and phase matcher:      returns (maxcorr, perror, pdiff)
    temp_corrmap: np.ndarray
                  correlation map of the naive template matcher
    See Also
    --------
    autocnet.matcher.subpixel.subpixel_template: for list of kwargs that can be passed to the matcher
    autocnet.matcher.subpixel.subpixel_phase: for list of kwargs that can be passed to the matcher
    """

    if not isinstance(input_cube, GeoDataset):
        raise Exception("input cube must be a geodataset obj")
    if not isinstance(base_cube, GeoDataset):
        raise Exception("match cube must be a geodataset obj")

    base_startx = int(bcenter_x - size_x)
    base_starty = int(bcenter_y - size_y)
    base_stopx = int(bcenter_x + size_x)
    base_stopy = int(bcenter_y + size_y)

    image_size = input_cube.raster_size
    match_size = base_cube.raster_size

    # for now, require the entire window resides inside both cubes.
    if base_stopx > match_size[0]:
        raise Exception(f"Window: {base_stopx} > {match_size[0]}, center: {bcenter_x},{bcenter_y}")
    if base_startx < 0:
        raise Exception(f"Window: {base_startx} < 0, center: {bcenter_x},{bcenter_y}")
    if base_stopy > match_size[1]:
        raise Exception(f"Window: {base_stopy} > {match_size[1]}, center: {bcenter_x},{bcenter_y} ")
    if base_starty < 0:
        raise Exception(f"Window: {base_starty} < 0, center: {bcenter_x},{bcenter_y}")

    # specifically not putting this in a try/except, this should never fail
    mlat, mlon = spatial.isis.image_to_ground(base_cube.file_name, bcenter_x, bcenter_y)
    center_x, center_y = spatial.isis.ground_to_image(input_cube.file_name, mlon, mlat)[::-1]

def geom_match(destination_cube,
               source_cube,
               bcenter_x,
@@ -1436,11 +1389,12 @@ def subpixel_register_measure(measureid,


def subpixel_register_point(pointid,
                            subpixel_template_kwargs={},
                            cost_func=lambda x,y: 1/x**2 * y,
                            threshold=0.005,
                            ncg=None,
                            version='new',
                            geom_func='simple',
                            match_func='classic',
                            match_kwargs={},
                            **kwargs):

    """
@@ -1452,10 +1406,7 @@ def subpixel_register_point(pointid,
    pointid : int or obj
              The identifier of the point in the DB or a Points object

    subpixel_template_kwargs : dict
                               Ay keyword arguments passed to the template matcher

    cost : func
    cost_func : func
                A generic cost function accepting two arguments (x,y), where x is the
                distance that a point has shifted from the original, sensor identified
                intersection, and y is the correlation coefficient coming out of the
@@ -1467,18 +1418,25 @@ def subpixel_register_point(pointid,
    ncg : obj
          the network candidate graph that the point is associated with; used for
          the DB session that is able to access the point.
    
    geom_func : callable
                function used to tranform the source and/or destination image before 
                running a matcher. 
    
    match_func : callable
                 subpixel matching function to use registering measures      
    """

    geom_func=geom_func.lower()
    match_func=match_func.lower()

    print(f"Using {geom_func} with the {match_func} matcher.")

    if not ncg.Session:
        raise BrokenPipeError('This func requires a database session from a NetworkCandidateGraph.')

    version = version.lower()
    geom_funcs = {"classic": geom_match_classic,
                "new": geom_match
                }
    if version not in geom_funcs.keys():
        raise Exception(f"{version} not a valid geom_match function version.")
    geom_func = geom_funcs[version]
    match_func = check_match_func(match_func)
    geom_func = check_geom_func(geom_func)

    if isinstance(pointid, Points):
        pointid = pointid.id
@@ -1513,9 +1471,16 @@ def subpixel_register_point(pointid,

            print('geom_match image:', res.path)
            try:
                # new geom_match has a incompatible API, until we devide on one, put in if.
                if (geom_func == geom_match):
                   new_x, new_y, dist, metric,  _ = geom_func(source_node.geodata, destination_node.geodata,
                                                        source.apriorisample, source.aprioriline,
                                                        template_kwargs=subpixel_template_kwargs)
                                                        template_kwargs=match_kwargs)
                else:
                    new_x, new_y, dist, metric,  _ = geom_func(source_node.geodata, destination_node.geodata,
                                                        source.apriorisample, source.aprioriline,
                                                        match_func=match_func,
                                                        match_kwargs=match_kwargs)
            except Exception as e:
                print(f'geom_match failed on measure {measure.id} with exception -> {e}')
                currentlog['status'] = f"geom_match failed on measure {measure.id}"
@@ -1568,6 +1533,7 @@ def subpixel_register_point(pointid,

    return resultlog


def subpixel_register_points(subpixel_template_kwargs={'image_size':(251,251)},
                             cost_kwargs={},
                             threshold=0.005,
+8 −6
Original line number Diff line number Diff line
@@ -103,9 +103,9 @@ def test_subpixel_transformed_template(apollo_subsets):
                                                b.shape[1]/2, b.shape[0]/2,
                                                a, b, transform, upsampling=16)

    assert strength >= 0.84
    assert nx == pytest.approx(51.18894)
    assert ny == pytest.approx(54.36261)
    assert strength >= 0.83
    assert nx == pytest.approx(50.576284)
    assert ny == pytest.approx(54.0081)


@pytest.mark.parametrize("loc, failure", [((0,4), True),
@@ -124,6 +124,8 @@ def test_subpixel_transformed_template_at_edge(apollo_subsets, loc, failure):
    with patch('autocnet.matcher.subpixel.clip_roi', side_effect=clip_side_effect):
        if failure:
            with pytest.warns(UserWarning, match=r'Maximum correlation \S+'):
                print(a.shape[1]/2, a.shape[0]/2,b.shape[1]/2, b.shape[0]/2,
                                                        a, b)
                nx, ny, strength, _ = sp.subpixel_transformed_template(a.shape[1]/2, a.shape[0]/2,
                                                        b.shape[1]/2, b.shape[0]/2,
                                                        a, b, transform, upsampling=16,
@@ -135,7 +137,7 @@ def test_subpixel_transformed_template_at_edge(apollo_subsets, loc, failure):
                                                        func=func)
            assert nx == 50.5

@pytest.mark.parametrize("convergence_threshold, expected", [(2.0, (50.49, 52.08, (0.039507, -9.5e-20)))])
@pytest.mark.parametrize("convergence_threshold, expected", [(2.0, (50.49, 52.08, -9.5e-20))])
def test_iterative_phase(apollo_subsets, convergence_threshold, expected):
    a = apollo_subsets[0]
    b = apollo_subsets[1]
@@ -148,8 +150,8 @@ def test_iterative_phase(apollo_subsets, convergence_threshold, expected):
    assert dx == expected[0]
    assert dy == expected[1]
    if expected[2] is not None:
        for i in range(len(strength)):
            assert pytest.approx(strength[i],6) == expected[2][i]
        # for i in range(len(strength)):
        assert pytest.approx(strength,6) == expected[2]

@pytest.mark.parametrize("data, expected", [
    ((21,21), (10, 10)),