From 4fe33e0cf38527f1247a9d5ccfdb09410eb6ca66 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Thu, 26 Jan 2023 11:09:45 +0100
Subject: [PATCH] oracle wrapper - reading input and outputs from env vars

---
 run_bash_scripts/tls.sh |  2 +-
 run_oracle_wrapper.py   | 37 ++++++++++++++++++++++++++++---------
 2 files changed, 29 insertions(+), 10 deletions(-)

diff --git a/run_bash_scripts/tls.sh b/run_bash_scripts/tls.sh
index 9f76038..2e3a19a 100755
--- a/run_bash_scripts/tls.sh
+++ b/run_bash_scripts/tls.sh
@@ -3,7 +3,7 @@
 ############################ parameters #################################################
 # General parameters
 CLEAR_INPUT_FOLDER=1  # 1: clear input folder, 0: not clear input folder
-CONDA_ENV="pdal-env-1" # conda environment for running the pipeline
+CONDA_ENV="pdal-env" # conda environment for running the pipeline
 
 # Tiling parameters
 data_folder="" # path to the folder containing the data
diff --git a/run_oracle_wrapper.py b/run_oracle_wrapper.py
index 3964a08..418709a 100644
--- a/run_oracle_wrapper.py
+++ b/run_oracle_wrapper.py
@@ -27,30 +27,49 @@ def run_oracle_wrapper(path_to_config_file):
 
     # read system environment variables
     input_location = os.environ['OBJ_INPUT_LOCATION']
+    output_location = os.environ['OBJ_OUTPUT_LOCATION']
 
+    # doing for the input
     if input_location is not None:
         print('Taking the input from the location ' + input_location)
         parsed_url = urlparse(input_location)
         input_folder_in_bucket = parsed_url.path[1:]
-        bucket_name = parsed_url.netloc.split('@')[0]
-        namespace = parsed_url.netloc.split('@')[1]
+        input_bucket_name = parsed_url.netloc.split('@')[0]
+        input_namespace = parsed_url.netloc.split('@')[1]
 
     else:
         print('Taking the input from the default location')
-        # get the namespace
-        namespace = client.get_namespace().data
+        # get the input_namespace
+        input_namespace = client.get_input_namespace().data
         # get the bucket name
-        bucket_name = 'bucket_lidar_data'
+        input_bucket_name = 'bucket_lidar_data'
         # folder name inside the bucket
         input_folder_in_bucket = 'geoslam'
 
+    # doing for the output
+    if output_location is not None:
+        print('Saving the output to the location ' + output_location)
+        parsed_url = urlparse(output_location)
+        output_folder_in_bucket = parsed_url.path[1:]
+        output_bucket_name = parsed_url.netloc.split('@')[0]
+        output_namespace = parsed_url.netloc.split('@')[1]
+
+    else:
+        print('Saving the output to the default location')
+        # get the output_namespace
+        output_namespace = client.get_input_namespace().data
+        # get the bucket name
+        output_bucket_name = 'bucket_lidar_data'
+        # folder name inside the bucket
+        output_folder_in_bucket = 'output'
+
     # read the config file from config folder
     with open(path_to_config_file) as f:
         config_flow_params = yaml.load(f, Loader=yaml.FullLoader)
 
     # copy all files from the bucket to the input folder
     # get the list of objects in the bucket
-    objects = client.list_objects(namespace, bucket_name).data.objects
+    objects = client.list_objects(input_namespace, input_bucket_name).data.objects
 
     # create the input folder if it does not exist
     if not os.path.exists(config_flow_params['general']['input_folder']):
@@ -62,10 +81,10 @@ def run_oracle_wrapper(path_to_config_file):
             if not (item.name.split('/')[1] == ''):
                 object_name = item.name.split('/')[1]
 
-                print('Downloading the file ' + object_name + ' from the bucket ' + bucket_name)
+                print('Downloading the file ' + object_name + ' from the bucket ' + input_bucket_name)
                 path_to_object = os.path.join(input_folder_in_bucket, object_name)
                 # get the object
-                file = client.get_object(namespace, bucket_name, path_to_object)
+                file = client.get_object(input_namespace, input_bucket_name, path_to_object)
 
                 # write the object to a file
                 with open(object_name, 'wb') as f:
@@ -95,7 +114,7 @@ def run_oracle_wrapper(path_to_config_file):
         file_name = file
 
         # upload the file to the bucket
-        client.put_object(namespace, bucket_name, 'output/' + file_name, io.open(path_to_file, 'rb'))
+        client.put_object(output_namespace, output_bucket_name, os.path.join(output_folder_in_bucket, file_name), io.open(path_to_file, 'rb'))
 
 if __name__ == '__main__':
     # use argparse to get the path to the config file
-- 
GitLab