Skip to content
Snippets Groups Projects
instance_segmentation_metrics_in_folder.py 9.73 KiB
import csv
import glob
import os
import argparse
from joblib import Parallel, delayed
import laspy

from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics
from nibio_postprocessing.attach_labels_to_las_file import AttachLabelsToLasFile

class InstanceSegmentationMetricsInFolder():
    GT_LABEL_NAME = 'treeID'
    TARGET_LABEL_NAME = 'instance_nr'

    def __init__(
        self,
        gt_las_folder_path,
        target_las_folder_path,
        output_folder_path=None, # if None, output will be saved in this folder
        remove_ground=False,
        verbose=False
    ):
        self.gt_las_folder_path = gt_las_folder_path
        self.target_las_folder_path = target_las_folder_path
        self.output_folder_path = output_folder_path
        # create output folder if not exists
        if self.output_folder_path is not None:
            if not os.path.exists(self.output_folder_path):
                os.makedirs(self.output_folder_path)
        self.remove_ground = remove_ground
        self.verbose = verbose

    def main(self):
        # get all las files in the gt_las_folder_path
        gt_las_file_paths = glob.glob(self.gt_las_folder_path + '/*.las', recursive=False)
        gt_las_file_paths.sort()

        # get all las files in the target_las_folder_path
        target_las_file_paths = glob.glob(self.target_las_folder_path + '/*.las', recursive=False)
        target_las_file_paths.sort()

        # check that the number of las files in the gt_las_folder_path and target_las_folder_path are the same
        if len(gt_las_file_paths) != len(target_las_file_paths):
            # print names of the folders
            print('gt_las_folder_path: ' + self.gt_las_folder_path)
            print('target_las_folder_path: ' + self.target_las_folder_path)
            print('Number of files in gt_las_folder_path: ' + str(len(gt_las_file_paths)))
            print('Number of files in target_las_folder_path: ' + str(len(target_las_file_paths)))
            raise Exception('The number of las files in the gt_las_folder_path and target_las_folder_path are not the same')

        # iterate over the las files
        for gt_las_file_path, target_las_file_path in zip(gt_las_file_paths, target_las_file_paths):

            # read the las file check if las file is not empty
            gt_las_file = laspy.read(gt_las_file_path)
            if len(gt_las_file.points) == 0:
                # remove the las file and the corresponding target_las_file_path from the list of paths
                gt_las_file_paths.remove(gt_las_file_path)
                target_las_file_paths.remove(target_las_file_path)
                if self.verbose:
                    print('Removed empty las file from the list: ' + gt_las_file_path)

            # check if las file is not empty
            target_las_file = laspy.read(target_las_file_path)
            if len(target_las_file.points) == 0:
                # remove the las file and the corresponding target_las_file_path from the list of paths
                gt_las_file_paths.remove(gt_las_file_path)
                target_las_file_paths.remove(target_las_file_path)
                if self.verbose:
                    print('Removed empty las file from the list: ' + target_las_file_path)

        # match the core name in the gt_las_file_path and target_las_file_path and make tuples of the matched paths
        matched_paths = []
        for gt_las_file_path, target_las_file_path in zip(gt_las_file_paths, target_las_file_paths):

            # print what files are being matched and processed
            if self.verbose:
                print('Matching: ' + gt_las_file_path + ' and ' + target_las_file_path)

            # get the core name of the gt_las_file_path
            gt_las_file_core_name = os.path.basename(gt_las_file_path).split('.')[0]
            # get the core name of the target_las_file_path
            target_las_file_core_name = os.path.basename(target_las_file_path).split('.')[0]

            # check that the core name of the gt_las_file_path and target_las_file_path are the same
            if gt_las_file_core_name == target_las_file_core_name:
                # make a tuple of the matched paths
                matched_paths.append((gt_las_file_path, target_las_file_path)) 

        # check if all are matched if not raise an exception
        if len(matched_paths) != len(gt_las_file_paths):
            raise Exception('Not all las files in the gt_las_folder_path and target_las_folder_path are matched')

        # run the instance segmentation metrics for each matched las file
        metric_dict_list = []
        f1_scores_weighted_list = []

        paralle_output = Parallel(n_jobs=-1, verbose=0)(
            delayed(self.compute_metrics)(gt_las_file_path, target_las_file_path) for gt_las_file_path, target_las_file_path in matched_paths
        )

        # extract the metric_dict_list and f1_scores_weighted_list from the paralle_output
        for metric_dict_mean, f1_score_weighted in paralle_output:
            metric_dict_list.append(metric_dict_mean) 
            f1_scores_weighted_list.append(f1_score_weighted)

        # this is serial version of the above code
        # for gt_las_file_path, target_las_file_path in matched_paths:
        #     metric_dict, f1_score_weighted = self.compute_metrics(gt_las_file_path, target_las_file_path)
        #     metric_dict_list.append(metric_dict)
        #     f1_scores_weighted_list.append(f1_score_weighted)

        # calculate the mean f1 score of weighted f1 scores
        mean_f1_score = sum(f1_scores_weighted_list) / len(f1_scores_weighted_list)
        # calculate the mean metrics for all the elements in the metric_dict_list
        # create a mean_metrics dictionary and initialize it with zeros
        mean_metrics = {}
        for metric_dict in metric_dict_list:
            for key, value in metric_dict.items():
                mean_metrics[key] = 0

        for metric_dict in metric_dict_list:
            for key, value in metric_dict.items():
                mean_metrics[key] += value 

        # devide the mean_metrics by the number of metric_dict_list
        for key, value in mean_metrics.items():
            mean_metrics[key] = value / len(metric_dict_list)
   
        if self.output_folder_path is not None:
            # create the output folder path
            save_to_csv_path = os.path.join(self.output_folder_path, 'summary_metrics_all_plots.csv')
            # save the mean metrics to a csv file
            with open(save_to_csv_path, 'w') as csv_file:
                writer = csv.writer(csv_file)
                for key, value in mean_metrics.items():
                    # round the value to 3 decimal places
                    value = round(value, 3)
                    writer.writerow([key, value])
                    
        if self.verbose:
            print('Mean F1 Score: {}'.format(mean_f1_score))
            # print the mean metrics
            print('Mean Metrics: {}'.format(mean_metrics))

        return mean_f1_score

    def compute_metrics(self, gt_las_file_path, target_las_file_path):
        # get the core name of the gt_las_file_path
        gt_las_file_core_name = os.path.basename(gt_las_file_path).split('.')[0]
        # get the core name of the target_las_file_path
        target_las_file_core_name = os.path.basename(target_las_file_path).split('.')[0]

        # check that the core name of the gt_las_file_path and target_las_file_path are the same
        if gt_las_file_core_name == target_las_file_core_name:
            if self.verbose:
                print('Processing: ' + gt_las_file_path + ' and ' + target_las_file_path)
    
            if self.output_folder_path is not None:
                # create the output folder path
                save_to_csv_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.csv')
                # attach labels to the las file
                AttachLabelsToLasFile(
                    gt_las_file_path,
                    target_las_file_path,
                    update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.las'),
                    gt_label_name=self.GT_LABEL_NAME,
                    target_label_name=self.GT_LABEL_NAME,
                    verbose=self.verbose
                ).main()

            else:
                save_to_csv_path = None

            # run the instance segmentation metrics
            instance_segmentation_metrics = InstanceSegmentationMetrics(
                gt_las_file_path,
                target_las_file_path,
                remove_ground=self.remove_ground,
                csv_file_name=save_to_csv_path,
                verbose=self.verbose
            )
            metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean = instance_segmentation_metrics.main()
            f1_score_weighted = metric_dict_mean['f1_score']
        return metric_dict_mean, f1_score_weighted

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gt_las_folder_path', type=str, required=True)
    parser.add_argument('--target_las_folder_path', type=str, required=True)
    parser.add_argument('--output_folder_path', type=str, required=False, default=None)
    parser.add_argument('--remove_ground', action='store_true', help="Do not take into account the ground (class 0).", default=False)
    parser.add_argument('--verbose', action='store_true', help="Print information about the process")
    args = parser.parse_args()

    # run the instance segmentation metrics in folder
    instance_segmentation_metrics_in_folder = InstanceSegmentationMetricsInFolder(
        args.gt_las_folder_path,
        args.target_las_folder_path,
        args.output_folder_path,
        args.remove_ground,
        verbose=args.verbose
    )

    mean_f1_score = instance_segmentation_metrics_in_folder.main()