diff --git a/python/otbtf.py b/python/otbtf.py index b28a1cc406e72d2a56e134ee490b2a20f6f0ecb4..d7e1a0b037d927ef4e09dd79e367a0edf660d534 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -34,6 +34,19 @@ import tensorflow as tf from osgeo import gdal from tqdm import tqdm +# --------------------------------------------- GDAL to numpy types ---------------------------------------------------- + + +GDAL_TO_NP_TYPES = {1: 'uint8', + 2: 'uint16', + 3: 'int16', + 4: 'uint32', + 5: 'int32', + 6: 'float32', + 7: 'float64', + 10: 'complex64', + 11: 'complex128'} + # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -58,8 +71,9 @@ def read_as_np_arr(gdal_ds, as_patches=True): False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - buffer = gdal_ds.ReadAsArray() + gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize + buffer = gdal_ds.ReadAsArray().astype(GDAL_TO_NP_TYPES[gdal_type]) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: @@ -68,7 +82,7 @@ def read_as_np_arr(gdal_ds, as_patches=True): else: n_elems = int(gdal_ds.RasterYSize / size_x) size_y = size_x - return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) + return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)) # -------------------------------------------------- Buffer class ------------------------------------------------------ @@ -246,6 +260,7 @@ class PatchesImagesReader(PatchesReaderBase): def _read_extract_as_np_arr(gdal_ds, offset): assert gdal_ds is not None psz = gdal_ds.RasterXSize + gdal_type = gdal_ds.GetRasterBand(1).DataType yoff = int(offset * psz) assert yoff + psz <= gdal_ds.RasterYSize buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) @@ -254,7 +269,7 @@ class PatchesImagesReader(PatchesReaderBase): else: # single-band raster buffer = np.expand_dims(buffer, axis=2) - return np.float32(buffer) + return buffer.astype(GDAL_TO_NP_TYPES[gdal_type]) def get_sample(self, index): """ @@ -613,8 +628,8 @@ class TFRecords: """ data_converted = {} - for k, d in data.items(): - data_converted[k] = d.name + for key, value in data.items(): + data_converted[key] = value.name return data_converted @@ -629,7 +644,7 @@ class TFRecords: filepath = os.path.join(self.dirpath, f"{i}.records") with tf.io.TFRecordWriter(filepath) as writer: - for s in range(nb_sample): + for _ in range(nb_sample): sample = dataset.read_one_sample() serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in @@ -646,8 +661,8 @@ class TFRecords: :param filepath: Output file name """ - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) + with open(filepath, 'w') as file: + json.dump(data, file, indent=4) @staticmethod def load(filepath): @@ -655,8 +670,8 @@ class TFRecords: Return data from pickle format. :param filepath: Input file name """ - with open(filepath, 'r') as f: - return json.load(f) + with open(filepath, 'r') as file: + return json.load(file) def convert_dataset_output_shapes(self, dataset): """ @@ -665,8 +680,8 @@ class TFRecords: """ output_shapes = {} - for key in dataset.output_shapes.keys(): - output_shapes[key] = (None,) + dataset.output_shapes[key] + for key, value in dataset.output_shapes.keys(): + output_shapes[key] = (None,) + value self.save(output_shapes, self.output_shape_file)