From fd264b44acbab66866a365e2c29b7c09e28c7a08 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Thu, 21 Apr 2022 17:18:04 +0200
Subject: [PATCH 01/16] REFAC: rename python folder + split otbtf.py in several
 files

---
 {python => otbtf}/__init__.py                 |   0
 {python => otbtf}/ckpt2savedmodel.py          |   0
 python/otbtf.py => otbtf/dataset.py           | 249 +-----------------
 .../create_savedmodel_ienco-m3_patchbased.py  |   0
 .../create_savedmodel_maggiori17_fullyconv.py |   0
 .../create_savedmodel_pxs_fcn.py              |   0
 .../create_savedmodel_simple_cnn.py           |   0
 .../create_savedmodel_simple_fcn.py           |   0
 .../examples/tensorflow_v2x/l2_norm.py        |   0
 .../examples/tensorflow_v2x/scalar_product.py |   0
 otbtf/tfrecords.py                            | 208 +++++++++++++++
 {python => otbtf}/tricks.py                   |   0
 otbtf/utils.py                                |  35 +++
 13 files changed, 245 insertions(+), 247 deletions(-)
 rename {python => otbtf}/__init__.py (100%)
 rename {python => otbtf}/ckpt2savedmodel.py (100%)
 rename python/otbtf.py => otbtf/dataset.py (64%)
 rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py (100%)
 rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py (100%)
 rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py (100%)
 rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py (100%)
 rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py (100%)
 rename {python => otbtf}/examples/tensorflow_v2x/l2_norm.py (100%)
 rename {python => otbtf}/examples/tensorflow_v2x/scalar_product.py (100%)
 create mode 100644 otbtf/tfrecords.py
 rename {python => otbtf}/tricks.py (100%)
 create mode 100644 otbtf/utils.py

diff --git a/python/__init__.py b/otbtf/__init__.py
similarity index 100%
rename from python/__init__.py
rename to otbtf/__init__.py
diff --git a/python/ckpt2savedmodel.py b/otbtf/ckpt2savedmodel.py
similarity index 100%
rename from python/ckpt2savedmodel.py
rename to otbtf/ckpt2savedmodel.py
diff --git a/python/otbtf.py b/otbtf/dataset.py
similarity index 64%
rename from python/otbtf.py
rename to otbtf/dataset.py
index b28a1cc4..4b0f945d 100644
--- a/python/otbtf.py
+++ b/otbtf/dataset.py
@@ -20,60 +20,19 @@
 """
 Contains stuff to help working with TensorFlow and geospatial data in the OTBTF framework.
 """
-import glob
-import json
-import os
 import threading
 import multiprocessing
 import time
 import logging
 from abc import ABC, abstractmethod
-from functools import partial
 import numpy as np
 import tensorflow as tf
-from osgeo import gdal
-from tqdm import tqdm
-
-
-# ----------------------------------------------------- Helpers --------------------------------------------------------
-
-
-def gdal_open(filename):
-    """
-    Open a GDAL raster
-    :param filename: raster file
-    :return: a GDAL dataset instance
-    """
-    gdal_ds = gdal.Open(filename)
-    if gdal_ds is None:
-        raise Exception("Unable to open file {}".format(filename))
-    return gdal_ds
-
-
-def read_as_np_arr(gdal_ds, as_patches=True):
-    """
-    Read a GDAL raster as numpy array
-    :param gdal_ds: a GDAL dataset instance
-    :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If
-        False, the shape is (1, psz_y, psz_x, nb_channels)
-    :return: Numpy array of dim 4
-    """
-    buffer = gdal_ds.ReadAsArray()
-    size_x = gdal_ds.RasterXSize
-    if len(buffer.shape) == 3:
-        buffer = np.transpose(buffer, axes=(1, 2, 0))
-    if not as_patches:
-        n_elems = 1
-        size_y = gdal_ds.RasterYSize
-    else:
-        n_elems = int(gdal_ds.RasterYSize / size_x)
-        size_y = size_x
-    return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
+from otbtf.utils import read_as_np_arr, gdal_open
+from otbtf.tfrecords import TFRecords
 
 
 # -------------------------------------------------- Buffer class ------------------------------------------------------
 
-
 class Buffer:
     """
     Used to store and access list of objects
@@ -106,7 +65,6 @@ class Buffer:
 
 # ---------------------------------------------- PatchesReaderBase class -----------------------------------------------
 
-
 class PatchesReaderBase(ABC):
     """
     Base class for patches delivery
@@ -151,7 +109,6 @@ class PatchesReaderBase(ABC):
 
 # --------------------------------------------- PatchesImagesReader class ----------------------------------------------
 
-
 class PatchesImagesReader(PatchesReaderBase):
     """
     This class provides a read access to a set of patches images.
@@ -327,7 +284,6 @@ class PatchesImagesReader(PatchesReaderBase):
 
 # ----------------------------------------------- IteratorBase class ---------------------------------------------------
 
-
 class IteratorBase(ABC):
     """
     Base class for iterators
@@ -340,7 +296,6 @@ class IteratorBase(ABC):
 
 # ---------------------------------------------- RandomIterator class --------------------------------------------------
 
-
 class RandomIterator(IteratorBase):
     """
     Pick a random number in the [0, handler.size) range.
@@ -370,7 +325,6 @@ class RandomIterator(IteratorBase):
 
 # ------------------------------------------------- Dataset class ------------------------------------------------------
 
-
 class Dataset:
     """
     Handles the "mining" of patches.
@@ -532,7 +486,6 @@ class Dataset:
 
 # ----------------------------------------- DatasetFromPatchesImages class ---------------------------------------------
 
-
 class DatasetFromPatchesImages(Dataset):
     """
     Handles the "mining" of a set of patches images.
@@ -559,202 +512,4 @@ class DatasetFromPatchesImages(Dataset):
         super().__init__(patches_reader=patches_reader, buffer_length=buffer_length, Iterator=Iterator)
 
 
-class TFRecords:
-    """
-    This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format.
-    """
-
-    def __init__(self, path):
-        """
-        :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path
-        """
-        if os.path.isdir(path) or not os.path.exists(path):
-            self.dirpath = path
-            os.makedirs(self.dirpath, exist_ok=True)
-            self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records")
-        else:
-            self.dirpath = os.path.dirname(path)
-            self.tfrecords_pattern_path = path
-        self.output_types_file = os.path.join(self.dirpath, "output_types.json")
-        self.output_shape_file = os.path.join(self.dirpath, "output_shape.json")
-        self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None
-        self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None
-
-    @staticmethod
-    def _bytes_feature(value):
-        """
-        Convert a value to a type compatible with tf.train.Example.
-        :param value: value
-        :return a bytes_list from a string / byte.
-        """
-        if isinstance(value, type(tf.constant(0))):
-            value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
-        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-
-    def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True):
-        """
-        Convert and save samples from dataset object to tfrecord files.
-        :param dataset: Dataset object to convert into a set of tfrecords
-        :param n_samples_per_shard: Number of samples per shard
-        :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training.
-                               If True, all TFRecords will have `n_samples_per_shard` samples
-        """
-        logging.info("%s samples", dataset.size)
-
-        nb_shards = (dataset.size // n_samples_per_shard)
-        if not drop_remainder and dataset.size % n_samples_per_shard > 0:
-            nb_shards += 1
-
-        self.convert_dataset_output_shapes(dataset)
-
-        def _convert_data(data):
-            """
-            Convert data
-            """
-            data_converted = {}
-
-            for k, d in data.items():
-                data_converted[k] = d.name
 
-            return data_converted
-
-        self.save(_convert_data(dataset.output_types), self.output_types_file)
-
-        for i in tqdm(range(nb_shards)):
-
-            if (i + 1) * n_samples_per_shard <= dataset.size:
-                nb_sample = n_samples_per_shard
-            else:
-                nb_sample = dataset.size - i * n_samples_per_shard
-
-            filepath = os.path.join(self.dirpath, f"{i}.records")
-            with tf.io.TFRecordWriter(filepath) as writer:
-                for s in range(nb_sample):
-                    sample = dataset.read_one_sample()
-                    serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()}
-                    features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in
-                                serialized_sample.items()}
-                    tf_features = tf.train.Features(feature=features)
-                    example = tf.train.Example(features=tf_features)
-                    writer.write(example.SerializeToString())
-
-    @staticmethod
-    def save(data, filepath):
-        """
-        Save data to pickle format.
-        :param data: Data to save json format
-        :param filepath: Output file name
-        """
-
-        with open(filepath, 'w') as f:
-            json.dump(data, f, indent=4)
-
-    @staticmethod
-    def load(filepath):
-        """
-        Return data from pickle format.
-        :param filepath: Input file name
-        """
-        with open(filepath, 'r') as f:
-            return json.load(f)
-
-    def convert_dataset_output_shapes(self, dataset):
-        """
-        Convert and save numpy shape to tensorflow shape.
-        :param dataset: Dataset object containing output shapes
-        """
-        output_shapes = {}
-
-        for key in dataset.output_shapes.keys():
-            output_shapes[key] = (None,) + dataset.output_shapes[key]
-
-        self.save(output_shapes, self.output_shape_file)
-
-    @staticmethod
-    def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs):
-        """
-        Parse example object to sample dict.
-        :param example: Example object to parse
-        :param features_types: List of types for each feature
-        :param target_keys: list of keys of the targets
-        :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns
-                                           a tuple (input_preprocessed, target_preprocessed)
-        :param kwargs: some keywords arguments for preprocessing_fn
-        """
-        read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types}
-        example_parsed = tf.io.parse_single_example(example, read_features)
-
-        for key in read_features.keys():
-            example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key])
-
-        # Differentiating inputs and outputs
-        input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys}
-        target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys}
-
-        if preprocessing_fn:
-            input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs)
-
-        return input_parsed, target_parsed
-
-    def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None,
-             preprocessing_fn=None, **kwargs):
-        """
-        Read all tfrecord files matching with pattern and convert data to tensorflow dataset.
-        :param batch_size: Size of tensorflow batch
-        :param target_keys: Keys of the target, e.g. ['s2_out']
-        :param n_workers: number of workers, e.g. 4 if using 4 GPUs
-                                             e.g. 12 if using 3 nodes of 4 GPUs
-        :param drop_remainder: whether the last batch should be dropped in the case it has fewer than
-                               `batch_size` elements. True is advisable when training on multiworkers.
-                               False is advisable when evaluating metrics so that all samples are used
-        :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size
-                                    elements are shuffled using uniform random.
-        :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns
-                                   a tuple (input_preprocessed, target_preprocessed)
-        :param kwargs: some keywords arguments for preprocessing_fn
-        """
-        options = tf.data.Options()
-        if shuffle_buffer_size:
-            options.experimental_deterministic = False  # disable order, increase speed
-        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO  # for multiworker
-        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys,
-                        preprocessing_fn=preprocessing_fn, **kwargs)
-
-        # TODO: to be investigated :
-        # 1/ num_parallel_reads useful ? I/O bottleneck of not ?
-        # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ?
-        # 3/ shuffle or not shuffle ?
-        matching_files = glob.glob(self.tfrecords_pattern_path)
-        logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path)
-        logging.info('Number of matching TFRecords: %s', len(matching_files))
-        matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)]  # files multiple of workers
-        nb_matching_files = len(matching_files)
-        if nb_matching_files == 0:
-            raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord "
-                            "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path))
-        logging.info('Reducing number of records to : %s', nb_matching_files)
-        dataset = tf.data.TFRecordDataset(matching_files)  # , num_parallel_reads=2)  # interleaves reads from xxx files
-        dataset = dataset.with_options(options)  # uses data as soon as it streams in, rather than in its original order
-        dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)
-        if shuffle_buffer_size:
-            dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
-        dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
-        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
-        # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/
-
-        return dataset
-
-    def read_one_sample(self, target_keys):
-        """
-        Read one tfrecord file matching with pattern and convert data to tensorflow dataset.
-        :param target_key: Key of the target, e.g. 's2_out'
-        """
-        matching_files = glob.glob(self.tfrecords_pattern_path)
-        one_file = matching_files[0]
-        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys)
-        dataset = tf.data.TFRecordDataset(one_file)
-        dataset = dataset.map(parse)
-        dataset = dataset.batch(1)
-
-        sample = iter(dataset).get_next()
-        return sample
diff --git a/python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py
similarity index 100%
rename from python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py
rename to otbtf/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py
diff --git a/python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py
similarity index 100%
rename from python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py
rename to otbtf/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py
diff --git a/python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py
similarity index 100%
rename from python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py
rename to otbtf/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py
diff --git a/python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py
similarity index 100%
rename from python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py
rename to otbtf/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py
diff --git a/python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py
similarity index 100%
rename from python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py
rename to otbtf/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py
diff --git a/python/examples/tensorflow_v2x/l2_norm.py b/otbtf/examples/tensorflow_v2x/l2_norm.py
similarity index 100%
rename from python/examples/tensorflow_v2x/l2_norm.py
rename to otbtf/examples/tensorflow_v2x/l2_norm.py
diff --git a/python/examples/tensorflow_v2x/scalar_product.py b/otbtf/examples/tensorflow_v2x/scalar_product.py
similarity index 100%
rename from python/examples/tensorflow_v2x/scalar_product.py
rename to otbtf/examples/tensorflow_v2x/scalar_product.py
diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py
new file mode 100644
index 00000000..17fa51e9
--- /dev/null
+++ b/otbtf/tfrecords.py
@@ -0,0 +1,208 @@
+import glob
+import json
+import os
+import logging
+from functools import partial
+import tensorflow as tf
+from tqdm import tqdm
+
+
+class TFRecords:
+    """
+    This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format.
+    """
+
+    def __init__(self, path):
+        """
+        :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path
+        """
+        if os.path.isdir(path) or not os.path.exists(path):
+            self.dirpath = path
+            os.makedirs(self.dirpath, exist_ok=True)
+            self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records")
+        else:
+            self.dirpath = os.path.dirname(path)
+            self.tfrecords_pattern_path = path
+        self.output_types_file = os.path.join(self.dirpath, "output_types.json")
+        self.output_shape_file = os.path.join(self.dirpath, "output_shape.json")
+        self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None
+        self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None
+
+    @staticmethod
+    def _bytes_feature(value):
+        """
+        Convert a value to a type compatible with tf.train.Example.
+        :param value: value
+        :return a bytes_list from a string / byte.
+        """
+        if isinstance(value, type(tf.constant(0))):
+            value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
+        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+    def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True):
+        """
+        Convert and save samples from dataset object to tfrecord files.
+        :param dataset: Dataset object to convert into a set of tfrecords
+        :param n_samples_per_shard: Number of samples per shard
+        :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training.
+                               If True, all TFRecords will have `n_samples_per_shard` samples
+        """
+        logging.info("%s samples", dataset.size)
+
+        nb_shards = (dataset.size // n_samples_per_shard)
+        if not drop_remainder and dataset.size % n_samples_per_shard > 0:
+            nb_shards += 1
+
+        self.convert_dataset_output_shapes(dataset)
+
+        def _convert_data(data):
+            """
+            Convert data
+            """
+            data_converted = {}
+
+            for k, d in data.items():
+                data_converted[k] = d.name
+
+            return data_converted
+
+        self.save(_convert_data(dataset.output_types), self.output_types_file)
+
+        for i in tqdm(range(nb_shards)):
+
+            if (i + 1) * n_samples_per_shard <= dataset.size:
+                nb_sample = n_samples_per_shard
+            else:
+                nb_sample = dataset.size - i * n_samples_per_shard
+
+            filepath = os.path.join(self.dirpath, f"{i}.records")
+            with tf.io.TFRecordWriter(filepath) as writer:
+                for s in range(nb_sample):
+                    sample = dataset.read_one_sample()
+                    serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()}
+                    features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in
+                                serialized_sample.items()}
+                    tf_features = tf.train.Features(feature=features)
+                    example = tf.train.Example(features=tf_features)
+                    writer.write(example.SerializeToString())
+
+    @staticmethod
+    def save(data, filepath):
+        """
+        Save data to pickle format.
+        :param data: Data to save json format
+        :param filepath: Output file name
+        """
+
+        with open(filepath, 'w') as f:
+            json.dump(data, f, indent=4)
+
+    @staticmethod
+    def load(filepath):
+        """
+        Return data from pickle format.
+        :param filepath: Input file name
+        """
+        with open(filepath, 'r') as f:
+            return json.load(f)
+
+    def convert_dataset_output_shapes(self, dataset):
+        """
+        Convert and save numpy shape to tensorflow shape.
+        :param dataset: Dataset object containing output shapes
+        """
+        output_shapes = {}
+
+        for key in dataset.output_shapes.keys():
+            output_shapes[key] = (None,) + dataset.output_shapes[key]
+
+        self.save(output_shapes, self.output_shape_file)
+
+    @staticmethod
+    def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs):
+        """
+        Parse example object to sample dict.
+        :param example: Example object to parse
+        :param features_types: List of types for each feature
+        :param target_keys: list of keys of the targets
+        :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns
+                                           a tuple (input_preprocessed, target_preprocessed)
+        :param kwargs: some keywords arguments for preprocessing_fn
+        """
+        read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types}
+        example_parsed = tf.io.parse_single_example(example, read_features)
+
+        for key in read_features.keys():
+            example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key])
+
+        # Differentiating inputs and outputs
+        input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys}
+        target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys}
+
+        if preprocessing_fn:
+            input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs)
+
+        return input_parsed, target_parsed
+
+    def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None,
+             preprocessing_fn=None, **kwargs):
+        """
+        Read all tfrecord files matching with pattern and convert data to tensorflow dataset.
+        :param batch_size: Size of tensorflow batch
+        :param target_keys: Keys of the target, e.g. ['s2_out']
+        :param n_workers: number of workers, e.g. 4 if using 4 GPUs
+                                             e.g. 12 if using 3 nodes of 4 GPUs
+        :param drop_remainder: whether the last batch should be dropped in the case it has fewer than
+                               `batch_size` elements. True is advisable when training on multiworkers.
+                               False is advisable when evaluating metrics so that all samples are used
+        :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size
+                                    elements are shuffled using uniform random.
+        :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns
+                                   a tuple (input_preprocessed, target_preprocessed)
+        :param kwargs: some keywords arguments for preprocessing_fn
+        """
+        options = tf.data.Options()
+        if shuffle_buffer_size:
+            options.experimental_deterministic = False  # disable order, increase speed
+        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO  # for multiworker
+        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys,
+                        preprocessing_fn=preprocessing_fn, **kwargs)
+
+        # TODO: to be investigated :
+        # 1/ num_parallel_reads useful ? I/O bottleneck of not ?
+        # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ?
+        # 3/ shuffle or not shuffle ?
+        matching_files = glob.glob(self.tfrecords_pattern_path)
+        logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path)
+        logging.info('Number of matching TFRecords: %s', len(matching_files))
+        matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)]  # files multiple of workers
+        nb_matching_files = len(matching_files)
+        if nb_matching_files == 0:
+            raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord "
+                            "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path))
+        logging.info('Reducing number of records to : %s', nb_matching_files)
+        dataset = tf.data.TFRecordDataset(matching_files)  # , num_parallel_reads=2)  # interleaves reads from xxx files
+        dataset = dataset.with_options(options)  # uses data as soon as it streams in, rather than in its original order
+        dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+        if shuffle_buffer_size:
+            dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
+        dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
+        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+        # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/
+
+        return dataset
+
+    def read_one_sample(self, target_keys):
+        """
+        Read one tfrecord file matching with pattern and convert data to tensorflow dataset.
+        :param target_key: Key of the target, e.g. 's2_out'
+        """
+        matching_files = glob.glob(self.tfrecords_pattern_path)
+        one_file = matching_files[0]
+        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys)
+        dataset = tf.data.TFRecordDataset(one_file)
+        dataset = dataset.map(parse)
+        dataset = dataset.batch(1)
+
+        sample = iter(dataset).get_next()
+        return sample
\ No newline at end of file
diff --git a/python/tricks.py b/otbtf/tricks.py
similarity index 100%
rename from python/tricks.py
rename to otbtf/tricks.py
diff --git a/otbtf/utils.py b/otbtf/utils.py
new file mode 100644
index 00000000..920b0dc6
--- /dev/null
+++ b/otbtf/utils.py
@@ -0,0 +1,35 @@
+from osgeo import gdal
+
+# ----------------------------------------------------- Helpers --------------------------------------------------------
+
+def gdal_open(filename):
+    """
+    Open a GDAL raster
+    :param filename: raster file
+    :return: a GDAL dataset instance
+    """
+    gdal_ds = gdal.Open(filename)
+    if gdal_ds is None:
+        raise Exception("Unable to open file {}".format(filename))
+    return gdal_ds
+
+
+def read_as_np_arr(gdal_ds, as_patches=True):
+    """
+    Read a GDAL raster as numpy array
+    :param gdal_ds: a GDAL dataset instance
+    :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If
+        False, the shape is (1, psz_y, psz_x, nb_channels)
+    :return: Numpy array of dim 4
+    """
+    buffer = gdal_ds.ReadAsArray()
+    size_x = gdal_ds.RasterXSize
+    if len(buffer.shape) == 3:
+        buffer = np.transpose(buffer, axes=(1, 2, 0))
+    if not as_patches:
+        n_elems = 1
+        size_y = gdal_ds.RasterYSize
+    else:
+        n_elems = int(gdal_ds.RasterYSize / size_x)
+        size_y = size_x
+    return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
\ No newline at end of file
-- 
GitLab


From e4f2f7ac31e61693741c03e25712e28053631ff0 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Thu, 21 Apr 2022 17:21:38 +0200
Subject: [PATCH 02/16] ENH: add setup.py

---
 setup.py | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)
 create mode 100644 setup.py

diff --git a/setup.py b/setup.py
new file mode 100644
index 00000000..3a95ac4a
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+import setuptools
+
+with open("README.md", "r", encoding="utf-8") as fh:
+    long_description = fh.read()
+
+setuptools.setup(
+    name="otbtf",
+    version="0.1",
+    author="Remi Cresson",
+    author_email="remi.cresson@inrae.fr",
+    description="OTBTF: Orfeo ToolBox meets TensorFlow",
+    long_description=long_description,
+    long_description_content_type="text/markdown",
+    url="https://gitlab.irstea.fr/remi.cresson/otbtf",
+    classifiers=[
+        "Programming Language :: Python :: 3",
+        "Programming Language :: Python :: 3.6",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
+        "Topic :: Scientific/Engineering :: GIS",
+        "Topic :: Scientific/Engineering :: Image Processing",
+        "License :: OSI Approved :: Apache Software License",
+        "Operating System :: OS Independent",
+    ],
+    packages=setuptools.find_packages(),
+    python_requires=">=3.6",
+    keywords="remote sensing, otb, orfeotoolbox, orfeo toolbox, tensorflow, tf, deep learning, machine learning",
+)
\ No newline at end of file
-- 
GitLab


From 9277f4eddd83a350364f0e68168371304bd36f9d Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 18:18:12 +0200
Subject: [PATCH 03/16] CI: change remote cache server

---
 .gitlab-ci.yml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 798ece96..a855105f 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -52,6 +52,7 @@ test docker image:
       --build-arg KEEP_SRC_OTB="true"
       --build-arg BZL_CONFIGS=""
       --build-arg BASE_IMG="ubuntu:20.04"
+      --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=http://172.17.0.1:9090"
       .
     - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
 
-- 
GitLab


From 846f370086884406ba37beb55796953bc0dd05d9 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 20:11:00 +0200
Subject: [PATCH 04/16] CI: remove dummy scripts

---
 .gitlab-ci.yml | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index a855105f..c4c1568c 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -31,9 +31,6 @@ test docker image:
     - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY
   timeout: 10 hours
   script:
-    - ls -ll /
-    - ls -ll /bzl_cache/
-    - touch /bzl_cache/toto
     - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME ||
     - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || 
     - >
-- 
GitLab


From 061fcb65e4036d3210fccb0dbf9e4f6a3959dad4 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 20:11:20 +0200
Subject: [PATCH 05/16] ENH: remove useless astype() conversion

---
 otbtf/dataset.py | 11 ++++-------
 otbtf/utils.py   | 21 ++++-----------------
 2 files changed, 8 insertions(+), 24 deletions(-)

diff --git a/otbtf/dataset.py b/otbtf/dataset.py
index fc108531..2350ed80 100644
--- a/otbtf/dataset.py
+++ b/otbtf/dataset.py
@@ -27,7 +27,7 @@ import logging
 from abc import ABC, abstractmethod
 import numpy as np
 import tensorflow as tf
-from otbtf.utils import read_as_np_arr, gdal_open, GDAL_TO_NP_TYPES
+from otbtf.utils import read_as_np_arr, gdal_open
 from otbtf.tfrecords import TFRecords
 
 
@@ -203,16 +203,13 @@ class PatchesImagesReader(PatchesReaderBase):
     def _read_extract_as_np_arr(gdal_ds, offset):
         assert gdal_ds is not None
         psz = gdal_ds.RasterXSize
-        gdal_type = gdal_ds.GetRasterBand(1).DataType
         yoff = int(offset * psz)
         assert yoff + psz <= gdal_ds.RasterYSize
         buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz)
         if len(buffer.shape) == 3:
-            buffer = np.transpose(buffer, axes=(1, 2, 0))
-        else:  # single-band raster
-            buffer = np.expand_dims(buffer, axis=2)
-
-        return buffer.astype(GDAL_TO_NP_TYPES[gdal_type])
+            # multi-band raster
+            return np.transpose(buffer, axes=(1, 2, 0))
+        return np.expand_dims(buffer, axis=2)
 
     def get_sample(self, index):
         """
diff --git a/otbtf/utils.py b/otbtf/utils.py
index 28677ce7..f1e803d9 100644
--- a/otbtf/utils.py
+++ b/otbtf/utils.py
@@ -1,19 +1,6 @@
 from osgeo import gdal
 import numpy as np
 
-# --------------------------------------------- GDAL to numpy types ----------------------------------------------------
-
-
-GDAL_TO_NP_TYPES = {1: 'uint8',
-                    2: 'uint16',
-                    3: 'int16',
-                    4: 'uint32',
-                    5: 'int32',
-                    6: 'float32',
-                    7: 'float64',
-                    10: 'complex64',
-                    11: 'complex128'}
-
 
 # ----------------------------------------------------- Helpers --------------------------------------------------------
 
@@ -41,10 +28,10 @@ def read_as_np_arr(gdal_ds, as_patches=True):
     size_x = gdal_ds.RasterXSize
     if len(buffer.shape) == 3:
         buffer = np.transpose(buffer, axes=(1, 2, 0))
-    if not as_patches:
-        n_elems = 1
-        size_y = gdal_ds.RasterYSize
-    else:
+    if as_patches:
         n_elems = int(gdal_ds.RasterYSize / size_x)
         size_y = size_x
+    else:
+        n_elems = 1
+        size_y = gdal_ds.RasterYSize
     return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
-- 
GitLab


From 92fc8d04745c4cb30511ed85594eae85624a326c Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 20:14:10 +0200
Subject: [PATCH 06/16] CI: job name

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index c4c1568c..89d8ed57 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -18,7 +18,7 @@ stages:
   - Applications Test
   - Ship
 
-test docker image:
+docker image:
   stage: Build
   allow_failure: false
   tags: [godzilla]
-- 
GitLab


From cd4d8fecef283e76b6ead1c6fb3620aad323a102 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 20:23:23 +0200
Subject: [PATCH 07/16] REFAC: TFRecords

---
 otbtf/tfrecords.py | 29 ++++-------------------------
 1 file changed, 4 insertions(+), 25 deletions(-)

diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py
index 17fa51e9..889fd898 100644
--- a/otbtf/tfrecords.py
+++ b/otbtf/tfrecords.py
@@ -53,20 +53,11 @@ class TFRecords:
         if not drop_remainder and dataset.size % n_samples_per_shard > 0:
             nb_shards += 1
 
-        self.convert_dataset_output_shapes(dataset)
-
-        def _convert_data(data):
-            """
-            Convert data
-            """
-            data_converted = {}
-
-            for k, d in data.items():
-                data_converted[k] = d.name
-
-            return data_converted
+        output_shapes = {key: (None,) + output_shape for key, output_shape in dataset.output_shapes.items()}
+        self.save(output_shapes, self.output_shape_file)
 
-        self.save(_convert_data(dataset.output_types), self.output_types_file)
+        output_types = {key: output_type.name for key, output_type in dataset.output_types.items()}
+        self.save(output_types, self.output_types_file)
 
         for i in tqdm(range(nb_shards)):
 
@@ -106,18 +97,6 @@ class TFRecords:
         with open(filepath, 'r') as f:
             return json.load(f)
 
-    def convert_dataset_output_shapes(self, dataset):
-        """
-        Convert and save numpy shape to tensorflow shape.
-        :param dataset: Dataset object containing output shapes
-        """
-        output_shapes = {}
-
-        for key in dataset.output_shapes.keys():
-            output_shapes[key] = (None,) + dataset.output_shapes[key]
-
-        self.save(output_shapes, self.output_shape_file)
-
     @staticmethod
     def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs):
         """
-- 
GitLab


From 4faa71a41120058d919f038761f0b1dd0f018343 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 21:53:56 +0200
Subject: [PATCH 08/16] REFAC: TFRecords

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 89d8ed57..d3daf916 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -49,7 +49,7 @@ docker image:
       --build-arg KEEP_SRC_OTB="true"
       --build-arg BZL_CONFIGS=""
       --build-arg BASE_IMG="ubuntu:20.04"
-      --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=http://172.17.0.1:9090"
+      --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=172.17.0.1:9090"
       .
     - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
 
-- 
GitLab


From 6809cd62f1136327af4e55d07f6965f53647d673 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 21:55:25 +0200
Subject: [PATCH 09/16] CI: remove labels

---
 .gitlab-ci.yml | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index d3daf916..3cb141be 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -39,11 +39,6 @@ docker image:
       --network="host"
       --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test
       --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
-      --label "org.opencontainers.image.title=$CI_PROJECT_TITLE"
-      --label "org.opencontainers.image.url=$CI_PROJECT_URL"
-      --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT"
-      --label "org.opencontainers.image.revision=$CI_COMMIT_SHA"
-      --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME"
       --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
       --build-arg OTBTESTS="true"
       --build-arg KEEP_SRC_OTB="true"
-- 
GitLab


From 6b4e114bd3d945b854b2a8abc1b0ba75177bbfcb Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 22:04:51 +0200
Subject: [PATCH 10/16] CI: patch dockerfile

---
 .gitlab-ci.yml | 2 --
 Dockerfile     | 8 ++++----
 2 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 3cb141be..6c738a64 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -35,7 +35,6 @@ docker image:
     - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || 
     - >
       docker build
-      --pull
       --network="host"
       --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test
       --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
@@ -44,7 +43,6 @@ docker image:
       --build-arg KEEP_SRC_OTB="true"
       --build-arg BZL_CONFIGS=""
       --build-arg BASE_IMG="ubuntu:20.04"
-      --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=172.17.0.1:9090"
       .
     - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
 
diff --git a/Dockerfile b/Dockerfile
index 223a14c6..e1c48f6f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -148,7 +148,7 @@ COPY --from=builder /src /src
 # System-wide ENV
 ENV PATH="/opt/otbtf/bin:$PATH"
 ENV LD_LIBRARY_PATH="/opt/otbtf/lib:$LD_LIBRARY_PATH"
-ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf/python"
+ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf/otbtf"
 ENV OTB_APPLICATION_PATH="/opt/otbtf/lib/otb/applications"
 
 # Default user, directory and command (bash is the entrypoint when using 'docker create')
@@ -169,6 +169,6 @@ USER otbuser
 # User-only ENV
 
 # Test python imports
-RUN python -c "import tensorflow"
-RUN python -c "import otbtf, tricks"
-RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')"
+#RUN python -c "import tensorflow"
+#RUN python -c "import otbtf, tricks"
+#RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')"
-- 
GitLab


From a4a78ed154e56e67d30a7aed23b4118fe4563fa9 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 22:19:54 +0200
Subject: [PATCH 11/16] FIX: update PYTHONPATH

---
 Dockerfile                           | 8 ++++----
 {otbtf => tricks}/ckpt2savedmodel.py | 0
 {otbtf => tricks}/tricks.py          | 0
 3 files changed, 4 insertions(+), 4 deletions(-)
 rename {otbtf => tricks}/ckpt2savedmodel.py (100%)
 rename {otbtf => tricks}/tricks.py (100%)

diff --git a/Dockerfile b/Dockerfile
index e1c48f6f..d5a644f7 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -148,7 +148,7 @@ COPY --from=builder /src /src
 # System-wide ENV
 ENV PATH="/opt/otbtf/bin:$PATH"
 ENV LD_LIBRARY_PATH="/opt/otbtf/lib:$LD_LIBRARY_PATH"
-ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf/otbtf"
+ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf"
 ENV OTB_APPLICATION_PATH="/opt/otbtf/lib/otb/applications"
 
 # Default user, directory and command (bash is the entrypoint when using 'docker create')
@@ -169,6 +169,6 @@ USER otbuser
 # User-only ENV
 
 # Test python imports
-#RUN python -c "import tensorflow"
-#RUN python -c "import otbtf, tricks"
-#RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')"
+RUN python -c "import tensorflow"
+RUN python -c "import otbtf, tricks"
+RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')"
diff --git a/otbtf/ckpt2savedmodel.py b/tricks/ckpt2savedmodel.py
similarity index 100%
rename from otbtf/ckpt2savedmodel.py
rename to tricks/ckpt2savedmodel.py
diff --git a/otbtf/tricks.py b/tricks/tricks.py
similarity index 100%
rename from otbtf/tricks.py
rename to tricks/tricks.py
-- 
GitLab


From d81dc370e6957b2b430021dbbe3eb8989ac03313 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 22:20:13 +0200
Subject: [PATCH 12/16] ADD: update __init__

---
 otbtf/__init__.py | 26 ++++++++++++++++++++++++++
 1 file changed, 26 insertions(+)

diff --git a/otbtf/__init__.py b/otbtf/__init__.py
index e69de29b..8a6951a6 100644
--- a/otbtf/__init__.py
+++ b/otbtf/__init__.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# ==========================================================================
+#
+#   Copyright 2018-2019 IRSTEA
+#   Copyright 2020-2022 INRAE
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#          http://www.apache.org/licenses/LICENSE-2.0.txt
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+#
+# ==========================================================================*/
+"""
+OTBTF python module
+"""
+from utils import read_as_np_arr, gdal_open
+from dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \
+    DatasetFromPatchesImages
+from tfrecords import TFRecords
\ No newline at end of file
-- 
GitLab


From 07e436997330fc7bd8c52f312ae657a34bacca44 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 22:20:28 +0200
Subject: [PATCH 13/16] ADD: trick for deprecated stuff

---
 tricks/tricks.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tricks/tricks.py b/tricks/tricks.py
index b31b14c3..d22e7e96 100644
--- a/tricks/tricks.py
+++ b/tricks/tricks.py
@@ -25,7 +25,7 @@ for TF 1.X versions.
 """
 import tensorflow.compat.v1 as tf
 from deprecated import deprecated
-from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds
+from otbtf.utils import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds
 tf.disable_v2_behavior()
 
 
-- 
GitLab


From ca5c11caa9edc50847d124112fc07f6cbee9dd6b Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 22:20:53 +0200
Subject: [PATCH 14/16] ADD: .idea

---
 .gitignore | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.gitignore b/.gitignore
index a29689f2..1ef65aa1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
 # Compiled python source #
 *.pyc
+.idea
-- 
GitLab


From 95658a6a85653ff0ec66eb10502e1d55f791e140 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 22:27:47 +0200
Subject: [PATCH 15/16] FIX: imports

---
 otbtf/__init__.py         | 6 +++---
 tricks/ckpt2savedmodel.py | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/otbtf/__init__.py b/otbtf/__init__.py
index 8a6951a6..77f806b8 100644
--- a/otbtf/__init__.py
+++ b/otbtf/__init__.py
@@ -20,7 +20,7 @@
 """
 OTBTF python module
 """
-from utils import read_as_np_arr, gdal_open
-from dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \
+from otbtf.utils import read_as_np_arr, gdal_open
+from otbtf.dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \
     DatasetFromPatchesImages
-from tfrecords import TFRecords
\ No newline at end of file
+from otbtf.tfrecords import TFRecords
\ No newline at end of file
diff --git a/tricks/ckpt2savedmodel.py b/tricks/ckpt2savedmodel.py
index 117203ba..ff22965f 100755
--- a/tricks/ckpt2savedmodel.py
+++ b/tricks/ckpt2savedmodel.py
@@ -26,7 +26,7 @@ can be more conveniently exported as SavedModel (see how to build a model with
 keras in Tensorflow 2).
 """
 import argparse
-from tricks import ckpt_to_savedmodel
+from tricks.tricks import ckpt_to_savedmodel
 
 
 def main():
-- 
GitLab


From 1a22e0dc14ced66a0fb1ebc04bb3f107de77dc62 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 22:30:19 +0200
Subject: [PATCH 16/16] FIX: imports

---
 tricks/{tricks.py => __init__.py} | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 rename tricks/{tricks.py => __init__.py} (100%)

diff --git a/tricks/tricks.py b/tricks/__init__.py
similarity index 100%
rename from tricks/tricks.py
rename to tricks/__init__.py
-- 
GitLab