Commit e14d7674 authored by Jay's avatar Jay Committed by jay
Browse files

Fixes suppression for all cases.

parent 325db1f7
Loading
Loading
Loading
Loading
+22 −8
Original line number Diff line number Diff line
@@ -284,10 +284,9 @@ class Edge(dict, MutableMapping):
            # Get the template and search window
            s_template = sp.clip_roi(s_img, s_keypoint, template_size)
            d_search = sp.clip_roi(d_img, d_keypoint, search_size)

            try:
                x_off, y_off, strength = sp.subpixel_offset(s_template, d_search, upsampling=upsampling)
                self.matches.loc[idx, ('x_offset', 'y_offset', 'correlation')] = [x_off, y_off, strength]
                x_offset, y_offset, strength = sp.subpixel_offset(s_template, d_search, upsampling=upsampling)
                self.matches.loc[idx, ('x_offset', 'y_offset', 'correlation')] = [x_offset, y_offset, strength]
            except:
                warnings.warn('Template-Search size mismatch, failing for this correspondence point.')
                continue
@@ -309,6 +308,22 @@ class Edge(dict, MutableMapping):
        self.masks = ('subpixel', mask)

    def suppress(self, func=spf.correlation, clean_keys=[], **kwargs):
        """
        Apply a disc based suppression algorithm to get a good spatial
        distribution of high quality points, where the user defines some
        function to be used as the quality metric.

        Parameters
        ----------
        func : object
               A function that returns a scalar value to be used
               as the strength of a given row in the matches data
               frame.

        clean_keys : list
                     of mask keys to be used to reduce the total size
                     of the matches dataframe.
        """
        if not hasattr(self, 'matches'):
            raise AttributeError('This edge does not yet have any matches computed.')

@@ -317,17 +332,16 @@ class Edge(dict, MutableMapping):
            matches, mask = self._clean(clean_keys)
        else:
            matches = self.matches

        domain = self.source.handle.raster_size

        # Massage the dataframe into the correct structure
        coords = self.source.keypoints.loc[matches['source_idx']][['x', 'y']]
        matches = matches.merge(coords, left_on=['source_idx'], right_index=True)
        matches['strength'] = self.matches.apply(func, axis=1)
        coords = self.source.keypoints[['x', 'y']]
        merged = matches.merge(coords, left_on=['source_idx'], right_index=True)
        merged['strength'] = merged.apply(func, axis=1)

        if not hasattr(self, 'suppression'):
            # Instantiate the suppression object and suppress matches
            self.suppression = od.SpatialSuppression(matches, domain, **kwargs)
            self.suppression = od.SpatialSuppression(merged, domain, **kwargs)
            self.suppression.suppress()
        else:
            for k, v in kwargs.items():
+54 −39
Original line number Diff line number Diff line
@@ -128,13 +128,13 @@ class SpatialSuppression(Observable):
             The (x,y) extent of the input domain
    """

    def __init__(self, df, domain, min_radius=1, k=250, error_k=0.05):
    def __init__(self, df, domain, min_radius=2, k=250, error_k=0.1):
        columns = df.columns
        for i in ['x', 'y', 'strength']:
            if i not in columns:
                raise ValueError('The dataframe is missing a {} column.'.format(i))
        self._df = df.sort_values(by=['strength'], ascending=False).copy()
        self.max_radius = min(domain)
        self.max_radius = max(domain)
        self.min_radius = min_radius
        self.domain = domain
        self.mask = None
@@ -186,60 +186,75 @@ class SpatialSuppression(Observable):
        error_k : float
                  [0,1) The acceptable epsilon
        """

        df = self.df
        if self.k > len(df):
           raise ValueError('Only {} valid points, but {} points requested'.format(len(df), self.k))
        min_radius = self.min_radius
        max_radius = self.max_radius
        if self.k > len(self.df):
           raise ValueError('Only {} valid points, but {} points requested'.format(len(self.df), self.k))
        search_space = np.linspace(self.min_radius, self.max_radius / 16, 250)
        cell_sizes = (search_space / math.sqrt(2)).astype(np.int)
        min_idx = 0
        max_idx = len(search_space) - 1
        while True:
            r = (min_radius + max_radius) / 2
            cell_size = int(r / math.sqrt(2))
            mid_idx = int((min_idx + max_idx) / 2)
            r = search_space[mid_idx]
            cell_size = cell_sizes[mid_idx]
            n_x_cells = int(self.domain[0] / cell_size)
            n_y_cells = int(self.domain[1] / cell_size)
            grid = np.zeros((n_x_cells, n_y_cells), dtype=np.bool)

            # Setup to store results
            result = []

            # Compute the bin edges and assign points to the appropriate bins
            x_edges = np.arange(0,self.domain[0],
                                self.domain[0] / cell_size)
            y_edges = np.arange(0,self.domain[1],
                                self.domain[1] / cell_size)
            grid = np.zeros((len(y_edges), len(x_edges)), dtype=np.bool)

            # Bin assignment
            xbins = np.digitize(df['x'], bins=x_edges)
            ybins = np.digitize(df['y'], bins=y_edges)
            bounds = True
            for i, (idx, p) in enumerate(df.iterrows()):
            # Assign all points to bins
            x_edges = np.linspace(0, self.domain[0], n_x_cells)
            y_edges = np.linspace(0, self.domain[1], n_y_cells)
            xbins = np.digitize(self.df['x'], bins=x_edges)
            ybins = np.digitize(self.df['y'], bins=y_edges)

            # Convert bins to cells
            xbins -= 1
            ybins -= 1
            pts = []
            for i, (idx, p) in enumerate(self.df.iterrows()):
                x_center = xbins[i]
                y_center = ybins[i]
                cell = grid[y_center - 1 , x_center - 1]
                cell = grid[y_center, x_center]

                if cell == False:
                    result.append(idx)
                    if len(result) > self.k - self.k * self.error_k:
                        # Search the lower half, the radius is too small
                        max_radius = r
                        bounds = False
                        continue
                    pts.append((p[['x', 'y']]))
                    if len(result) > self.k + self.k * self.error_k:
                        # Too many points, break
                        min_idx = mid_idx
                        break

                    y_min = y_center - 5
                    if y_min < 0:
                        y_min = 0

                    x_min = x_center - 5
                    if x_min < 0:
                        x_min = 0

                    y_max = y_center + 5
                    if y_max > grid.shape[0]:
                        y_max = grid.shape[0]

                    x_max = x_center + 5
                    if x_max > grid.shape[1]:
                        x_max = grid.shape[1]

                    # Cover the necessary cells
                    grid[y_center - 3: y_center + 3,
                         x_center - 3: x_center + 3] = True
            if bounds is False:
                continue
                    grid[y_min: y_max,
                         x_min: x_max] = True

            #  Check break conditions
            if self.k - self.k * self.error_k < len(result) < self.k + self.k * self.error_k:
                break
            elif abs(max_radius - min_radius) < 5:
            if self.k - self.k * self.error_k <= len(result) <= self.k + self.k * self.error_k:
                break
            elif len(result) < self.k:
                # Search the upper half, the radius is too small
                min_radius = r
                # The radius is too large
                max_idx = mid_idx

        self.mask = pd.Series(False, self.df.index)
        self.mask.loc[np.array(result)] = True

        self.mask.loc[list(result)] = True
        state_package = {'mask':self.mask,
                         'k': self.k,
                         'error_k': self.error_k}