Commit 2aedc6d5 authored by Jay's avatar Jay
Browse files

Modularizes and adds retries to overlap placement and subpixel registration

parent 267e4229
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ class Parent:
        self.session = Session()
        self.session.begin()

def retry(max_retries=3, wait_time=300):
def retry(max_retries=5, wait_time=300):
    def decorator(func):
        def wrapper(*args, **kwargs):
            retries = 0
@@ -29,7 +29,7 @@ def retry(max_retries=3, wait_time=300):
        return wrapper
    return decorator

@retry(max_retries=5)
@retry
def new_connection(dbconfig):
    """
    Using the user supplied config create a NullPool database connection.
+2 −0
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ from shapely.geometry import Point

from autocnet.transformation.spatial import reproject, og2oc
from autocnet.utils.serializers import JsonEncoder
from autocnet.io.db.connection import retry

log = logging.getLogger(__name__)

@@ -388,6 +389,7 @@ class Overlay(BaseMixin, Base):
    def geom(self, geom):
        self._geom = from_shape(geom, srid=self.latitudinal_srid)

    @retry(max_retries=5)
    @classmethod
    def overlapping_larger_than(cls, size_threshold, session):
        """
+91 −0
Original line number Diff line number Diff line
from contextlib import nullcontext
import logging
import warnings

from sqlalchemy.sql.expression import bindparam

from autocnet.io.db.connection import retry
from autocnet.io.db.model import Images, Overlay, Points, Measures
from autocnet.graph.node import NetworkNode

@retry
def update_measures(ncg, session, measures_iterable_to_update):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        stmt = Measures.__table__.update().\
                        where(Measures.__table__.c.id == bindparam('_id')).\
                        values({'weight':bindparam('weight'),
                                'measureIgnore':bindparam('ignore'),
                                'templateMetric':bindparam('template_metric'),
                                'templateShift':bindparam('template_shift'),
                                        'line': bindparam('line'),
                                        'sample':bindparam('sample'),
                                        'ChooserName':bindparam('choosername')})
        session.execute(stmt, measures_iterable_to_update)
        return

@retry
def ignore_measures(ncg, session, measures_iterable_to_ignore, chooser):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        measures_to_set_false = [{'_id':i} for i in measures_to_set_false]
        # Set ignore=True measures that failed
        stmt = Measures.__table__.update().\
                                where(Measures.__table__.c.id == bindparam('_id')).\
                                values({'measureIgnore':True,
                                        'ChooserName':chooser})
        session.execute(stmt, measures_to_set_false)

@retry
def get_nodes_for_overlap(ncg, session, overlap):
    # If an NCG is passed, instantiate a session off the NCG, else just pass the session through
    nodes = []
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        for id in overlap.intersections:
            try:
                res = session.query(Images).filter(Images.id == id).one()
            except Exception as e:
                warnings.warn(f'Unable to instantiate image with id: {id} with error: {e}')
                continue
            nn = NetworkNode(node_id=id, 
                             image_path=res.path, 
                             cam_type=res.cam_type,
                             dem=res.dem,
                             dem_type=res.dem_type)
            nodes.append(nn)

@retry
def get_nodes_for_measures(ncg, session, measures):
        nodes = {}
        with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
            for measure in measures:
                res = session.query(Images).filter(Images.id == measure.imageid).one()
                logging.debug(f'Node instantiation image query result: {res.path, res.cam_type, res.dem, res.dem_type}')
                nn = NetworkNode(node_id=measure.imageid, 
                                image_path=res.path,
                                cam_type=res.cam_type,
                                dem=res.dem,
                                dem_type=res.dem_type)
                nodes[measure.imageid] = nn
            session.expunge_all()  
        return nodes  

@retry
def get_overlap(ncg, session, overlapid):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        overlap = session.query(Overlay).filter(Overlay.id == overlapid).one()
        session.expunge_all()
    return overlap

@retry
def get_point(ncg, session, pointid):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        point = session.query(Points).filter(Points.id == pointid).one()
        session.expunge_all()
    return point



@retry
def bulk_commit(ncg, session, iterable_of_objs_to_commit):
    with ncg.session_scope() if ncg else nullcontext(session) as session:
        session.add_all(iterable_of_objs_to_commit)
        session.commit()
+40 −101
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@ from skimage import registration
from skimage import filters
from scipy import fftpack
from scipy.spatial import distance_matrix
from sqlalchemy.sql.expression import bindparam
from sqlalchemy import inspect
from matplotlib import pyplot as plt

@@ -23,7 +22,7 @@ from autocnet.matcher.naive_template import pattern_match
from autocnet.matcher.mutual_information import mutual_information
from autocnet.io.geodataset import AGeoDataset
from autocnet.io.db.model import Measures, Points, Images, JsonEncoder
from autocnet.io.db.connection import retry
from autocnet.io.db.utils import get_point, get_nodes_for_measures, update_measures, ignore_measures
from autocnet.graph.node import NetworkNode
from autocnet.transformation import roi
from autocnet.transformation.affine import estimate_local_affine
@@ -799,7 +798,6 @@ def fourier_mellen(ref_image, moving_image, affine=tf.AffineTransform(), verbose

    return subpixel_affine, error, diffphase

@retry(max_retries=5)
def subpixel_register_point_smart(point,
                                  session=None,
                                  cost_func=lambda x,y: 1/x**2 * y,
@@ -837,16 +835,14 @@ def subpixel_register_point_smart(point,
                 {'match_kwargs': {'image_size':(151,151), 'template_size':(67,67)}},
                 {'match_kwargs': {'image_size':(181,181), 'template_size':(73,73)}}]
    """
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    if not isinstance(point, Points):
            point = session.query(Points).filter(Points.id == point).one()
        point = get_point(ncg, session, point)
    pointid = point.id

        # Order by is important here because the measures get ids in sequential order when initially placed
        # and the reference_index is positionally linked to the ordered vector of measures.
        measures = session.query(Measures).filter(Measures.pointid == pointid).order_by(Measures.id).all()
    measures = point.measures
    reference_index = point.reference_index
    
    # measures = session.query(Measures).filter(Measures.pointid == pointid).order_by(Measures.id).all()

    # Get the reference measure to instantiate the source node. All other measures will
    # match to the source node.
    source = measures[reference_index]
@@ -856,21 +852,11 @@ def subpixel_register_point_smart(point,

    # Build a node cache so that this is an encapsulated database call. Then nodes
    # can be pulled from the lookup sans database.
        nodes = {}
        for measure in measures:
            res = session.query(Images).filter(Images.id == measure.imageid).one()
            logging.debug(f'Node instantiation image query result: {res.path, res.cam_type, res.dem, res.dem_type}')
            nn = NetworkNode(node_id=measure.imageid, 
                            image_path=res.path,
                            cam_type=res.cam_type,
                            dem=res.dem,
                            dem_type=res.dem_type)
            nodes[measure.imageid] = nn
            session.expunge_all()
    nodes_cache = get_nodes_for_measures(ncg, session, measures)

    log.info(f'Attempting to subpixel register {len(measures)-1} measures for point {pointid}')
    # Set the reference image
    source_node = nodes[reference_index_id]
    source_node = nodes_cache[reference_index_id]

    log.info(f'Source: sample: {source.sample} | line: {source.line}')
    updated_measures = []
@@ -880,7 +866,7 @@ def subpixel_register_point_smart(point,
            continue

        cost = None
        destination_node = nodes[measure.imageid]
        destination_node = nodes_cache[measure.imageid]
        log.info(f'Registering measure {measure.id} (image: {measure.imageid}, serial: {measure.serial})')

        reference_roi = roi.Roi(source_node.geodata, 
@@ -993,12 +979,13 @@ def subpixel_register_point_smart(point,
                'template_metric':maxcorr,
                'template_shift':dist,
                'mi_metric': 0,
                'status': True}
                'status': True,
                'imageid':measure.imageid}
            log.info(f'METRIC: {maxcorr}| SAMPLE: {new_x} | LINE: {new_y} | MI: 0')
            
            updated_measures.append([baseline_mi, baseline_corr, m])
    # Baseline MI, Baseline Correlation, updated measures to select from
    return updated_measures
    return updated_measures, nodes_cache

def check_for_shift_consensus(shifts, tol=0.1):
    """
@@ -1116,9 +1103,9 @@ def decider(measures, tol=0.6):

    return measures_to_update, measures_to_set_false

def validate_candidate_measure(measure_to_register,
                               session=None,
                               ncg=None,
def validate_candidate_measure(point,
                               measure_to_register,
                               node_cache,
                               parameters=[],
                               func=pattern_match,
                               **kwargs):
@@ -1158,45 +1145,18 @@ def validate_candidate_measure(measure_to_register,
            Of reprojection distances for each parameter set.
    """

    measure_to_register_id = measure_to_register['id']
    
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        # Get the measure to be registered
        measure = session.query(Measures).filter(Measures.id == measure_to_register_id).order_by(Measures.id).one()

        # Get the references measure
        point = measure.point
    # Get the source (being validation) and destination (current reference measure)
    # from the node_cache
    source_node = node_cache[measure_to_register['imageid']]
    reference_index = point.reference_index
    reference_measure = point.measures[reference_index]


        # Match the reference measure to the measure_to_register - this is the inverse of the first match attempt
        # Source is the image that we are seeking to validate, destination is the reference measure.
        # This is the inverse of other functions as this is a validator.

        source_imageid = measure.imageid
        source_image = session.query(Images).filter(Images.id == source_imageid).one()
        source_node = NetworkNode(node_id=source_imageid, 
                                  image_path=source_image.path, 
                                  cam_type=source_image.cam_type,
                                  dem=source_image.dem,
                                  dem_type=source_image.dem_type)
        source_node.parent = ncg

        destination_imageid = reference_measure.imageid
        destination_image = session.query(Images).filter(Images.id == destination_imageid).one()
        destination_node = NetworkNode(node_id=destination_imageid, 
                                       image_path=destination_image.path,
                                       cam_type=source_image.cam_type,
                                       dem=destination_image.dem,
                                       dem_type=destination_image.dem_type)
        destination_node.parent = ncg
        session.expunge_all()
    destination_node = node_cache[reference_measure.imageid]

    sample = measure_to_register['sample']
    line = measure_to_register['line']

    log.info(f'Validating measure: {measure_to_register_id} on image: {source_imageid}')
    log.info(f'Validating measure: {measure_to_register['id']} on image: {source_imageid}')

    reference_roi = roi.Roi(source_node.geodata, 
                            sample, 
@@ -1240,7 +1200,6 @@ def validate_candidate_measure(measure_to_register,
        dists.append(dist)
    return dists

@retry(max_retries=5)
def smart_register_point(point, 
                         session=None,
                         parameters=[], 
@@ -1292,17 +1251,15 @@ def smart_register_point(point,
                            building approach

    """
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    if not isinstance(point, Points):
            point = session.query(Points).filter(Points.id == point).one()
        session.expunge_all()
        point = get_point(ncg, session, point)
        
    measure_results = subpixel_register_point_smart(point, session, parameters=parameters, **shared_kwargs)
    measure_results, node_cache = subpixel_register_point_smart(point, session, parameters=parameters, **shared_kwargs)
    measures_to_update, measures_to_set_false = decider(measure_results)
    log.info(f'Found {len(measures_to_update)} measures that found subpixel registration consensus. Running validation now...')
    # Validate that the new position has consensus
    for measure in measures_to_update:
        reprojection_distances = validate_candidate_measure(measure, session, parameters=parameters, ncg=ncg, **shared_kwargs)
        reprojection_distances = validate_candidate_measure(point, measure, node_cache, parameters=parameters, **shared_kwargs)
        log.info(f'Validation Distance Boolean: {np.array(reprojection_distances) < valid_reprojection_distance}')
        if np.sum(np.array(reprojection_distances) < valid_reprojection_distance) < 2:
            log.info(f"Measure {measure['id']} failed validation. Setting ignore=True for this measure.")
@@ -1312,27 +1269,9 @@ def smart_register_point(point,
        measure['_id'] = measure.pop('id', None)


    # Update the measures that passed registration
    if measures_to_update:
        stmt = Measures.__table__.update().\
                                where(Measures.__table__.c.id == bindparam('_id')).\
                                values({'weight':bindparam('weight'),
                                        'measureIgnore':bindparam('ignore'),
                                        'templateMetric':bindparam('template_metric'),
                                        'templateShift':bindparam('template_shift'),
                                        'line': bindparam('line'),
                                        'sample':bindparam('sample'),
                                        'ChooserName':bindparam('choosername')})
        session.execute(stmt, measures_to_update)

    if measures_to_set_false:
        measures_to_set_false = [{'_id':i} for i in measures_to_set_false]
        # Set ignore=True measures that failed
        stmt = Measures.__table__.update().\
                                where(Measures.__table__.c.id == bindparam('_id')).\
                                values({'measureIgnore':True,
                                        'ChooserName':shared_kwargs['chooser']})
        session.execute(stmt, measures_to_set_false)
    # Update the measures that passed and failed registration
    update_measures(ncg, session, measures_to_update)
    ignore_measures(ncg, session, measures_to_set_false, shared_kwargs['chooser'])

    log.info(f'Updated measures: {json.dumps(measures_to_update, indent=2, cls=JsonEncoder)}')
    log.info(f'Ignoring measures: {measures_to_set_false}')
+8 −23
Original line number Diff line number Diff line
@@ -10,11 +10,13 @@ from subprocess import CalledProcessError
from autocnet.cg import cg as compgeom
from autocnet.graph.node import NetworkNode
from autocnet.io.db.model import Images, Measures, Overlay, Points, JsonEncoder
from autocnet.io.db.connection import retry
from autocnet.io.db.utils import get_nodes_for_overlap, get_overlap, bulk_commit
from autocnet.transformation import roi
from autocnet.matcher.cpu_extractor import extract_most_interesting
from autocnet.matcher.validation import is_valid_lroc_image



# set up the logger file
log = logging.getLogger(__name__)

@@ -144,7 +146,6 @@ def find_interesting_point(nodes, lon, lat, size=71, **kwargs):
    log.debug(f'Current reference index: {reference_index}.')
    return reference_index, shapely.geometry.Point(sample, line)

@retry(max_retries=5)
def place_points_in_overlap(overlap,
                            identifier="place_points_in_overlaps",
                            interesting_func=find_interesting_point,
@@ -215,9 +216,7 @@ def place_points_in_overlap(overlap,
    """
    t1 = time.time()
    if not isinstance(overlap, Overlay):
        with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
            overlap = session.query(Overlay).filter(Overlay.id == overlap).one()
            session.expunge_all()
        overlap = get_overlap(ncg, session, overlap)
    
    # Determine the point distribution in the overlap geom
    geom = overlap.geom
@@ -228,21 +227,8 @@ def place_points_in_overlap(overlap,
        return []
    log.info(f'Have {len(candidate_points)} potential points to place in overlap {overlap.id}.')
    
    # If an NCG is passed, instantiate a session off the NCG, else just pass the session through
    nodes = []
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        for id in overlap.intersections:
            try:
                res = session.query(Images).filter(Images.id == id).one()
            except Exception as e:
                warnings.warn(f'Unable to instantiate image with id: {id} with error: {e}')
                continue
            nn = NetworkNode(node_id=id, 
                             image_path=res.path, 
                             cam_type=res.cam_type,
                             dem=res.dem,
                             dem_type=res.dem_type)
            nodes.append(nn)
    nodes = get_nodes_for_overlap(ncg, session, overlap)

    points_to_commit = []
    for valid in candidate_points:
        log.debug(f'Valid point: {valid}')
@@ -281,11 +267,10 @@ def place_points_in_overlap(overlap,
        else:
            if len(point.measures) >= 2:
                points_to_commit.append(point)
    
    log.debug(f'Committing: {points_to_commit}')
    if points_to_commit:
        with ncg.session_scope() if ncg else nullcontext(session) as session:
            session.add_all(points_to_commit)
            session.commit()
        bulk_commit(ncg, session, points_to_commit)
    t2 = time.time()
    log.info(f'Placed {len(candidate_points)} in {t2-t1} seconds.')