diff --git a/python/otbtf.py b/python/otbtf.py
index b28a1cc406e72d2a56e134ee490b2a20f6f0ecb4..d7e1a0b037d927ef4e09dd79e367a0edf660d534 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -34,6 +34,19 @@ import tensorflow as tf
 from osgeo import gdal
 from tqdm import tqdm
 
+# --------------------------------------------- GDAL to numpy types ----------------------------------------------------
+
+
+GDAL_TO_NP_TYPES = {1: 'uint8',
+                    2: 'uint16',
+                    3: 'int16',
+                    4: 'uint32',
+                    5: 'int32',
+                    6: 'float32',
+                    7: 'float64',
+                    10: 'complex64',
+                    11: 'complex128'}
+
 
 # ----------------------------------------------------- Helpers --------------------------------------------------------
 
@@ -58,8 +71,9 @@ def read_as_np_arr(gdal_ds, as_patches=True):
         False, the shape is (1, psz_y, psz_x, nb_channels)
     :return: Numpy array of dim 4
     """
-    buffer = gdal_ds.ReadAsArray()
+    gdal_type = gdal_ds.GetRasterBand(1).DataType
     size_x = gdal_ds.RasterXSize
+    buffer = gdal_ds.ReadAsArray().astype(GDAL_TO_NP_TYPES[gdal_type])
     if len(buffer.shape) == 3:
         buffer = np.transpose(buffer, axes=(1, 2, 0))
     if not as_patches:
@@ -68,7 +82,7 @@ def read_as_np_arr(gdal_ds, as_patches=True):
     else:
         n_elems = int(gdal_ds.RasterYSize / size_x)
         size_y = size_x
-    return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
+    return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))
 
 
 # -------------------------------------------------- Buffer class ------------------------------------------------------
@@ -246,6 +260,7 @@ class PatchesImagesReader(PatchesReaderBase):
     def _read_extract_as_np_arr(gdal_ds, offset):
         assert gdal_ds is not None
         psz = gdal_ds.RasterXSize
+        gdal_type = gdal_ds.GetRasterBand(1).DataType
         yoff = int(offset * psz)
         assert yoff + psz <= gdal_ds.RasterYSize
         buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz)
@@ -254,7 +269,7 @@ class PatchesImagesReader(PatchesReaderBase):
         else:  # single-band raster
             buffer = np.expand_dims(buffer, axis=2)
 
-        return np.float32(buffer)
+        return buffer.astype(GDAL_TO_NP_TYPES[gdal_type])
 
     def get_sample(self, index):
         """
@@ -613,8 +628,8 @@ class TFRecords:
             """
             data_converted = {}
 
-            for k, d in data.items():
-                data_converted[k] = d.name
+            for key, value in data.items():
+                data_converted[key] = value.name
 
             return data_converted
 
@@ -629,7 +644,7 @@ class TFRecords:
 
             filepath = os.path.join(self.dirpath, f"{i}.records")
             with tf.io.TFRecordWriter(filepath) as writer:
-                for s in range(nb_sample):
+                for _ in range(nb_sample):
                     sample = dataset.read_one_sample()
                     serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()}
                     features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in
@@ -646,8 +661,8 @@ class TFRecords:
         :param filepath: Output file name
         """
 
-        with open(filepath, 'w') as f:
-            json.dump(data, f, indent=4)
+        with open(filepath, 'w') as file:
+            json.dump(data, file, indent=4)
 
     @staticmethod
     def load(filepath):
@@ -655,8 +670,8 @@ class TFRecords:
         Return data from pickle format.
         :param filepath: Input file name
         """
-        with open(filepath, 'r') as f:
-            return json.load(f)
+        with open(filepath, 'r') as file:
+            return json.load(file)
 
     def convert_dataset_output_shapes(self, dataset):
         """
@@ -665,8 +680,8 @@ class TFRecords:
         """
         output_shapes = {}
 
-        for key in dataset.output_shapes.keys():
-            output_shapes[key] = (None,) + dataset.output_shapes[key]
+        for key, value in dataset.output_shapes.keys():
+            output_shapes[key] = (None,) + value
 
         self.save(output_shapes, self.output_shape_file)