Commit dc04d3df authored by jlaura's avatar jlaura
Browse files

Merge pull request #8 from jcwbacker/master

Updated io_gdal. Fixes JIRA AUTOCONG-58
parents e6569a4b c12eadbd
Loading
Loading
Loading
Loading
+222 −134
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import numpy as np
from osgeo import gdal
from osgeo import osr

from autocnet.fileio import extract_metadata as em
from autocnet.fileio import extract_metadata

NP2GDAL_CONVERSION = {
  "uint8": 1,
@@ -26,91 +26,139 @@ for k, v in iter(NP2GDAL_CONVERSION.items()):

GDAL2NP_CONVERSION[1] = 'int8'

class GeoDataSet(object):
class GeoDataset(object):
    """
    Geospatial dataset object
    Geospatial dataset object that represents.

    Parameters
    ----------
    filename : str
               The path to the file
    file_name : str
                The name of the input image, including its full path.

    Attributes
    ----------

    basename : str
               The basename extracted from the full path
    base_name : str
                The base name of the input image, extracted from the full path.

    geotransform : object
                   OGR geotransformation object
                   Geotransform reference OGR object as an array of size 6 containing the affine 
                   transformation coefficients for transforming from raw sample/line to projected x/y.
                   xproj = geotransform[0] + sample * geotransform[1] + line * geotransform[2]
                   yproj = geotransform[3] + sample * geotransform[4] + line * geotransform[5]

    standardparallels : list
                        of the standard parallels
    geospatial_coordinate_system : object
                                   Geospatial coordinate system OSR object.

    unittype : str
               Name of the unit, e.g. 'm' or 'ft' used by the raster
    latlon_extent : list
                    of two tuples containing the latitide/longitude boundaries. 
                    This list is in the form [(lowerlat, lowerlon), (upperlat, upperlon)].

    spatialreference : object
                       OSR spatial reference object
    pixel_width : float
                  The width of the image pixels (i.e. displacement in the x-direction).
                  Note: This is the second value geotransform array.

    geospatial_coordinate_system : object
                                   OSR geospatial coordinate reference object
    pixel_height : float
                   The height of the image pixels (i.e. displacement in the y-direction).
                   Note: This is the sixth (last) value geotransform array.

    latlon_extent : list
                    of tuples in the form (llat, llon), (ulat, ulon)
    spatial_reference : object
                        Spatial reference system OSR object.

    standard_parallels : list
                         of the standard parallels used by the map projection found in the metadata
                         using the spatial reference for this GeoDataset.

    unit_type : str
                Name of the unit used by the raster, e.g. 'm' or 'ft'.

    x_rotation : float
                The geotransform coefficient that represents the rotation about the x-axis.
                Note: This is the third value geotransform array.

    xy_extent : list
                of two tuples containing the sample/line boundaries. 
                The first value is the upper left corner of the upper left pixel and 
                the second value is the lower right corner of the lower right pixel. 
                This list is in the form [(minx, miny), (maxx, maxy)].

    y_rotation : float
                 The geotransform coefficient that represents the rotation about the y-axis.
                 Note: This is the fifth value geotransform array.

    coordinate_transformation : object
                                The coordinate transformation from the spatial reference system to 
                                the geospatial coordinate system.
        
    inverse_coordinate_transformation : object
                                        The coordinate transformation from the geospatial 
                                        coordinate system to the spatial reference system.
        
    scale : tuple
            The name and value of the linear projection units of the spatial reference system. 
            This tuple is of type string/float of the form (unit name, value).
            To transform a linear distance to meters, multiply by this value.
            If no units are available ("Meters", 1) will be returned.
                 
    extent : list
             of tuples in the form (minx, miny), (maxx, maxy)
    spheroid : tuple
               The spheroid found in the metadata using the spatial reference system. 
               This is of the form (semi-major, semi-minor, inverse flattening).

    xpixelsize : float
                 Size of the x-pixel
    raster_size : tuple
                  The dimensions of the raster, i.e. (number of samples, number of lines).
        
    ypixelsize : float
                 Size of the y-pixel
    central_meridian : float
                       The central meridian of the map projection from the metadata.

    xrotation : float
                Rotation of the x-axis
    no_data_value : float
                    Special value used to indicate pixels that are not valid.

    yrotation : float
                Rotation of the y-axis
    """
    def __init__(self, file_name):
        """
        Initialization method to set the file name and open the file using GDAL.

        Parameters
        ----------
        file_name : str
                   The file name to set and open.

        """
    def __init__(self, filename):
        self.filename = filename
        self.ds = gdal.Open(filename)
        self.file_name = file_name
        self.dataset = gdal.Open(file_name)
    
    def __repr__(self):
        return os.path.basename(self.filename)
        return os.path.basename(self.file_name)

    @property
    def basename(self):
        if not getattr(self, '_basename', None):
            self._basename = os.path.splitext(os.path.basename(self.filename))[0]
        return self._basename
    def base_name(self):
        if not getattr(self, '_base_name', None):
            self._base_name = os.path.splitext(os.path.basename(self.file_name))[0]
        return self._base_name

    @property
    def geotransform(self):
        if not getattr(self, '_geotransform', None):
            self._geotransform = self.ds.GetGeoTransform()
            self._geotransform = self.dataset.GetGeoTransform()
        return self._geotransform

    @property
    def standardparallels(self):
        if not getattr(self, '_standardparallels', None):
            self._standardparallels = em.get_standard_parallels(self.spatialreference)
        return self._standardparallels
    def standard_parallels(self):
        if not getattr(self, '_standard_parallels', None):
            self._standard_parallels = extract_metadata.get_standard_parallels(self.spatial_reference)
        return self._standard_parallels

    @property
    def unittype(self):
        if not getattr(self, '_unittype', None):
            self._unittype = self.ds.GetRasterBand(1).GetUnitType()
        return self._unittype
    def unit_type(self):
        if not getattr(self, '_unit_type', None):
            self._unit_type = self.dataset.GetRasterBand(1).GetUnitType()
        return self._unit_type

    @property
    def spatialreference(self):
    def spatial_reference(self):
        if not getattr(self, '_srs', None):
            self._srs = osr.SpatialReference()
            self._srs.ImportFromWkt(self.ds.GetProjection())
            self._srs.ImportFromWkt(self.dataset.GetProjection())
            try:
                self._srs.MorphToESRI()
                self._srs.MorphFromESRI()
@@ -123,67 +171,60 @@ class GeoDataSet(object):
    @property
    def geospatial_coordinate_system(self):
        if not getattr(self, '_gcs', None):
            self._gcs = self.spatialreference.CloneGeogCS()
            self._gcs = self.spatial_reference.CloneGeogCS()
        return self._gcs

    @property
    def latlon_extent(self):
        if not getattr(self, '_latlonextent', None):
            ext = self.extent
            llat, llon = self.pixel_to_latlon(ext[0][0], ext[0][1])
            ulat, ulon = self.pixel_to_latlon(ext[1][0], ext[1][1])
            self._latlonextent = [(llat, llon), (ulat, ulon)]
        return self._latlonextent
        if not getattr(self, '_latlon_extent', None):
            xy_extent = self.xy_extent
            lowerlat, lowerlon = self.pixel_to_latlon(xy_extent[0][0], xy_extent[0][1])
            upperlat, upperlon = self.pixel_to_latlon(xy_extent[1][0], xy_extent[1][1])
            self._latlon_extent = [(lowerlat, lowerlon), (upperlat, upperlon)]
        return self._latlon_extent

    @property
    def extent(self):
        if not getattr(self, '_extent', None):
            gt = self.geotransform
            minx = gt[0]
            maxy = gt[3]
    def xy_extent(self):
        if not getattr(self, '_xy_extent', None):
            geotransform = self.geotransform
            minx = geotransform[0]
            maxy = geotransform[3]

            maxx = minx + gt[1] * self.ds.RasterXSize
            miny = maxy + gt[5] * self.ds.RasterYSize
            maxx = minx + geotransform[1] * self.dataset.RasterXSize
            miny = maxy + geotransform[5] * self.dataset.RasterYSize

            self._extent = [(minx, miny), (maxx, maxy)]
            self._xy_extent = [(minx, miny), (maxx, maxy)]

        return self._extent
        return self._xy_extent

    @property
    def xpixelsize(self):
        """
        Get the pixel size of the input data
        """
        if not getattr(self, '_xpixelsize', None):
            self._xpixelsize = self.geotransform[1]
        return self._xpixelsize
    def pixel_width(self):
        if not getattr(self, '_pixel_width', None):
            self._pixel_width = self.geotransform[1]
        return self._pixel_width

    @property
    def ypixelsize(self):
        """
        The y-pixel size of the input data
        """

        if not getattr(self, '_ypixelsize', None):
            self._ypixelsize = self.geotransform[5]
        return self._ypixelsize
    def pixel_height(self):
        if not getattr(self, '_pixel_height', None):
            self._pixel_height = self.geotransform[5]
        return self._pixel_height

    @property
    def xrotation(self):
        if not getattr(self, '_xrotation', None):
            self._xrotation = self.geotransform[2]
        return self._xrotation
    def x_rotation(self):
        if not getattr(self, '_x_rotation', None):
            self._x_rotation = self.geotransform[2]
        return self._x_rotation

    @property
    def yrotation(self):
        if not getattr(self, '_yrotation', None):
            self._yrotation = self.geotransform[4]
        return self._yrotation
    def y_rotation(self):
        if not getattr(self, '_y_rotation', None):
            self._y_rotation = self.geotransform[4]
        return self._y_rotation

    @property
    def coordinate_transformation(self):
        if not getattr(self, '_ct', None):
            self._ct = osr.CoordinateTransformation(self.spatialreference,
            self._ct = osr.CoordinateTransformation(self.spatial_reference,
                                                    self.geospatial_coordinate_system)
        return self._ct

@@ -191,85 +232,109 @@ class GeoDataSet(object):
    def inverse_coordinate_transformation(self):
        if not getattr(self, '_ict', None):
                       self._ict = osr.CoordinateTransformation(self.geospatial_coordinate_system,
                                                                self.spatialreference)
                                                                self.spatial_reference)
        return self._ict

    @property
    def ndv(self, band=1):
        if not getattr(self, '_ndv', None):
            self._ndv = self.ds.GetRasterBand(band).GetNoDataValue()
        return self._ndv
    def no_data_value(self):
        if not getattr(self, '_no_data_value', None):
            self._no_data_value = self.dataset.GetRasterBand(1).GetNoDataValue()
        return self._no_data_value

    @property
    def scale(self):
        if not getattr(self, '_scale', None):
            unitname = self.spatialreference.GetLinearUnitsName()
            value = self.spatialreference.GetLinearUnits()
            unitname = self.spatial_reference.GetLinearUnitsName()
            value = self.spatial_reference.GetLinearUnits()
            self._scale = (unitname, value)
        return self._scale

    @property
    def spheroid(self):
        if not getattr(self, '_spheroid', None):
            self._spheroid = em.get_spheroid(self.spatialreference)
            self._spheroid = extract_metadata.get_spheroid(self.spatial_reference)
        return self._spheroid

    @property
    def rastersize(self):
        if not getattr(self, '_rastersize', None):
            self._rastersize = (self.ds.RasterXSize, self.ds.RasterYSize)
        return self._rastersize
    def raster_size(self):
        if not getattr(self, '_raster_size', None):
            self._raster_size = (self.dataset.RasterXSize, self.dataset.RasterYSize)
        return self._raster_size

    @property
    def central_meridian(self):
        if not getattr(self, '_central_meridian', None):
            self._central_meridian = em.get_central_meridian(self.spatialreference)
            self._central_meridian = extract_metadata.get_central_meridian(self.spatial_reference)
        return self._central_meridian

    def pixel_to_latlon(self, x, y):
        """
        Convert from pixel space to lat/lon space
        Convert from pixel space (i.e. sample/line) to lat/lon space.

        Parameters
        ----------
        x : float
            x-coordinate
            x-coordinate to be transformed.
        y : float
            y-coordinates
            y-coordinate to be transformed.

        Returns
        -------
        lat, lon : tuple
                   Latitude, Longitude
                   (Latitude, Longitude) corresponding to the given (x,y).
        
        """
        gt = self.geotransform
        x = gt[0] + (x * gt[1]) + (y * gt[2])
        y = gt[3] + (x * gt[4]) + (y * gt[5])
        geotransform = self.geotransform
        x = geotransform[0] + (x * geotransform[1]) + (y * geotransform[2])
        y = geotransform[3] + (x * geotransform[4]) + (y * geotransform[5])
        lon, lat, _ = self.coordinate_transformation.TransformPoint(x, y)

        return lat, lon

    def latlon_to_pixel(self, lat, lon):
        gt = self.geotransform
        ulat, ulon, _ = self.inverse_coordinate_transformation.TransformPoint(lon, lat)
        x = (ulat - gt[0]) / gt[1]
        y = (ulon - gt[3]) / gt[5]
        """
        Convert from lat/lon space to pixel space (i.e. sample/line).

        Parameters
        ----------
        lat: float
             Latitude to be transformed.
        lon : float
              Longitude to be transformed.
        Returns
        -------
        x, y : tuple
               (Sample, line) position corresponding to the given (latitude, longitude).
        
        """
        geotransform = self.geotransform
        upperlat, upperlon, _ = self.inverse_coordinate_transformation.TransformPoint(lon, lat)
        x = (upperlat - geotransform[0]) / geotransform[1]
        y = (upperlon - geotransform[3]) / geotransform[5]
        return x, y

    def readarray(self, band=1, pixels=None, dtype='float32'):
    def read_array(self, band=1, pixels=None, dtype='float32'):
        """
        Extract the required data as a numpy array
        Extract the required data as a NumPy array

        Parameters
        ----------
        band : int
               The image band number to be extracted as a NumPy array. Default band=1.

        pixels : list
		  [start, ystart, xstop, ystop]
                 [start, ystart, xstop, ystop]. Default pixels=None.

        dtype : str
                numpy dtype, e.g. float32
                The NumPy dtype for the output array. Default dtype='float32'.

        Returns
        -------
        array : ndarray
                The dataset for the specified band.

        """
        band = self.ds.GetRasterBand(band)
        band = self.dataset.GetRasterBand(band)

        dtype = getattr(np, dtype)

@@ -284,9 +349,32 @@ class GeoDataSet(object):
                                          xextent, yextent).astype(dtype)
        return array

def array_to_raster(array, filename, projection=None,
def array_to_raster(array, file_name, projection=None,
                    geotransform=None, outformat='GTiff',
                    ndv=None):
    """
    Converts the given NumPy array to a raster format using the GeoDataset class.

    Parameters
    ----------
    array : ndarray
            

    file_name : str 

    projection : 
                 Default projection=None.

    geotransform : object 
                   Default geotransform=None.

    outformat : str
                Default outformat='GTiff'.

    ndv : float
          The no data value for the given band. See no_data_value(). Default ndv=None.

    """
    driver = gdal.GetDriverByName(outformat)
    try:
        y, x, bands = array.shape
@@ -297,27 +385,27 @@ def array_to_raster(array, filename, projection=None,
        single = True

    #This is a crappy hard code to 32bit.
    ds = driver.Create(filename, x, y, bands, gdal.GDT_Float64)
    dataset = driver.Create(file_name, x, y, bands, gdal.GDT_Float64)

    if geotransform:
        ds.SetGeoTransform(geotransform)
        dataset.SetGeoTransform(geotransform)

    if projection:
        if isinstance(projection, str):
            ds.SetProjection(projection)
            dataset.SetProjection(projection)
        else:
            ds.SetProjection(projection.ExportToWkt())
            dataset.SetProjection(projection.ExportToWkt())

    if single == True:
        bnd = ds.GetRasterBand(1)
        bnd = dataset.GetRasterBand(1)
        if ndv != None:
            bnd.SetNoDataValue(ndv)
        bnd.WriteArray(array)
        ds.FlushCache()
        dataset.FlushCache()
    else:
        for i in range(1, bands + 1):
            bnd = ds.GetRasterBand(i)
            bnd = dataset.GetRasterBand(i)
            if ndv != None:
                bnd.SetNoDataValue(ndv)
            bnd.WriteArray(array[:,:,i - 1])
            ds.FlushCache()
            dataset.FlushCache()
+84 −81

File changed.

Preview size limit exceeded, changes collapsed.