Commit d6087ee4 authored by acpaquette's avatar acpaquette Committed by jlaura
Browse files

PyHAT Updates (#87)

* Small updates to get the scripts working.

* Fixed pradius calculations, and made some small changes.

* Removed coord transforms and made body_fix func more generic.

* Updated reproj doc string.

* General refactor to socet scripts and clean up. Simplfied input for both scripts.

* Updated conf to use conda prefix

* Fixed up __getattr__ func

* Corrected attribute error text

* Uploaded notebooks for remote access

* Updated notebooks with complete footprint function

* More or less final notebook

* Forgot to tab this in under the first except block

* Updated version.

* Removed unnecessary notebooks

* Removed previously removed module

* Removed old import

* Fixes issue indexing on the hcube

* Updated hcube

* Updated io_m3 and io_crism pyhat dependency

* hcube and indexing update to handle clipping and other operations
parent d4f55f50
Loading
Loading
Loading
Loading
+124 −1
Original line number Diff line number Diff line
@@ -2,6 +2,8 @@ import numpy as np
import gdal

from ..utils.indexing import _LocIndexer, _iLocIndexer
from libpyhat.transform.continuum import continuum_correction
from libpyhat.transform.continuum import polynomial, linear, regression


class HCube(object):
@@ -10,6 +12,11 @@ class HCube(object):
    to optionally add support for spectral labels, label
    based indexing, and lazy loading for reads.
    """
    def __init__(self, data = [], wavelengths = []):
        if len(data) != 0:
            self._data = data
        if len(wavelengths) != 0:
            self._wavelengths = wavelengths

    @property
    def wavelengths(self):
@@ -24,6 +31,21 @@ class HCube(object):
                self._wavelengths = []
        return self._wavelengths

    @property
    def data(self):
        if not hasattr(self, '_data'):
            try:
                key = (slice(None, None, None),
                       slice(None, None, None),
                       slice(None, None, None))
                data = self._read(key)
            except Exception as e:
                print(e)
                data = []
            self._data = data

        return self._data

    @property
    def tolerance(self):
        return getattr(self, '_tolerance', 2)
@@ -52,6 +74,104 @@ class HCube(object):
    def iloc(self):
        return _iLocIndexer(self)

    def reduce(self, how = np.mean, axis = (1, 2)):
        """
        Parameters
        ----------
        how : function
              Function to apply across along axises of the hcube

        axis : tuple
               List of axis to apply a given function along

        Returns
        -------
        new_hcube : Object
                    A new hcube object with the reduced data set
        """
        res = how(self.data, axis = axis)

        new_hcube = HCube(res, self.wavelengths)
        return new_hcube

    def continuum_correct(self, nodes, correction_nodes = np.array([]), correction = linear,
                          axis=0, adaptive=False, window=3, **kwargs):
        """
        Parameters
        ----------

        nodes : list
                A list of wavelengths for the continuum to be corrected along

        correction_nodes : list
                           A list of nodes to limit the correction between

        correction : function
                     Function specifying the type of correction to perform
                     along the continuum

        axis : int
               Axis to apply the continuum correction on

        adaptive : boolean
                   ?

        window : int
                 ?

        Returns
        -------

        new_hcube : Object
                    A new hcube object with the corrected dataset
        """

        continuum_data = continuum_correction(self.data, self.wavelengths, nodes = nodes,
                                              correction_nodes = correction_nodes, correction = correction,
                                              axis = axis, adaptive = adaptive,
                                              window = window, **kwargs)

        new_hcube = HCube(continuum_data[0], self.wavelengths)
        return new_hcube


    def clip_roi(self, x, y, band, tolerance=2):
        """
        Parameters
        ----------

        x : tuple
            Lower and upper bound along the x axis for clipping

        y : tuple
            Lower and upper bound along the y axis for clipping

        band : tuple
               Lower and upper band along the z axis for clipping

        tolerance : int
                    Tolerance given for trying to find wavelengths
                    between the upper and lower bound

        Returns
        -------

        new_hcube : Object
                    A new hcube object with the clipped dataset
        """
        wavelength_clip = []
        for wavelength in self.wavelengths:
            wavelength_upper = wavelength + tolerance
            wavelength_lower = wavelength - tolerance
            if wavelength_upper > band[0] and wavelength_lower < band[1]:
                wavelength_clip.append(wavelength)

        key = (wavelength_clip, slice(*x), slice(*y))
        data_clip = _LocIndexer(self)[key]

        new_hcube = HCube(np.copy(data_clip), np.array(wavelength_clip))
        return new_hcube

    def _read(self, key):
        ifnone = lambda a, b: b if a is None else a

@@ -76,7 +196,10 @@ class HCube(object):

        elif isinstance(key[0], slice):
            # Given some slice iterate over the bands and get the bands and pixel space requested
            return [self.read_array(i, pixels = pixels) for i in list(range(1, self.nbands + 1))[key[0]]]
            arrs = []
            for band in list(list(range(1, self.nbands + 1))[key[0]]):
                arrs.append(self.read_array(band, pixels = pixels))
            return np.stack(arrs)

        else:
            arrs = []
+6 −6
Original line number Diff line number Diff line
@@ -4,12 +4,12 @@ from .io_gdal import GeoDataset
from .hcube import HCube

try:
    from libpysat.derived import crism
    from libpysat.derived.utils import get_derived_funcs
    libpysat_enabled = True
    from libpyhat.derived import crism
    from libpyhat.derived.utils import get_derived_funcs
    libpyhat_enabled = True
except:
    print('No libpysat module. Unable to attach derived product functions')
    libpysat_enabled = False
    print('No libpyhat module. Unable to attach derived product functions')
    libpyhat_enabled = False

import gdal

@@ -25,7 +25,7 @@ class Crism(GeoDataset, HCube):

        self.derived_funcs = {}

        if libpysat_enabled:
        if libpyhat_enabled:
            self.derived_funcs = get_derived_funcs(crism)

    def __getattr__(self, name):
+6 −6
Original line number Diff line number Diff line
@@ -4,12 +4,12 @@ from .io_gdal import GeoDataset
from .hcube import HCube

try:
    from libpysat.derived import m3
    from libpysat.derived.utils import get_derived_funcs
    libpysat_enabled = True
    from libpyhat.derived import m3
    from libpyhat.derived.utils import get_derived_funcs
    libpyhat_enabled = True
except:
    print('No libpysat module. Unable to attach derived product functions')
    libpysat_enabled = False
    print('No libpyhat module. Unable to attach derived product functions')
    libpyhat_enabled = False

import gdal

@@ -25,7 +25,7 @@ class M3(GeoDataset, HCube):

        self.derived_funcs = {}

        if libpysat_enabled:
        if libpyhat_enabled:
            self.derived_funcs = get_derived_funcs(m3)

    def __getattr__(self, name):
+19 −9
Original line number Diff line number Diff line
@@ -42,13 +42,19 @@ class _LocIndexer(object):
        sl = key[0]
        ifnone = lambda a, b: b if a is None else a
        if isinstance(sl, slice):
            sl = list(range(ifnone(sl.start, 0), self.data_array.nbands, ifnone(sl.step, 1)))
            sl = list(range(ifnone(sl.start, 0),
                            ifnone(sl.stop, len(self.data_array.wavelengths)),
                            ifnone(sl.step, 1)))

        if isinstance(sl, (int, float)):
            idx = self._get_idx(sl)
        else:
            idx = [self._get_idx(s) for s in sl]
        key = (idx, key[1], key[2])

        if len(self.data_array.data) != 0:
            return self.data_array.data[key]

        return self.data_array._read(key)

    def _get_idx(self, value, tolerance=2):
@@ -69,8 +75,12 @@ class _iLocIndexer(object):
        ifnone = lambda a, b: b if a is None else a
        if isinstance(sl, slice):
            sl = list(range(ifnone(sl.start, 0),
                            ifnone(sl.stop, self.data_array.nbands),
                            ifnone(sl.stop, len(self.data_array.wavelengths)),
                            ifnone(sl.step, 1)))

        key = (key[0], key[1], key[2])

        if len(self.data_array.data) != 0:
            return self.data_array.data[key]

        return self.data_array._read(key)