Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
instance_segmentation_classic
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Maciej Wielgosz
instance_segmentation_classic
Commits
781bd59b
Commit
781bd59b
authored
2 years ago
by
Maciej Wielgosz
Browse files
Options
Downloads
Patches
Plain Diff
updated instance segmentation metrics with params
parent
e797bc8d
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
metrics/instance_segmentation_metrics.py
+11
-6
11 additions, 6 deletions
metrics/instance_segmentation_metrics.py
metrics/instance_segmentation_metrics_in_folder.py
+5
-2
5 additions, 2 deletions
metrics/instance_segmentation_metrics_in_folder.py
with
16 additions
and
8 deletions
metrics/instance_segmentation_metrics.py
+
11
−
6
View file @
781bd59b
...
@@ -9,6 +9,8 @@ from sklearn.neighbors import KDTree
...
@@ -9,6 +9,8 @@ from sklearn.neighbors import KDTree
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
class
InstanceSegmentationMetrics
:
class
InstanceSegmentationMetrics
:
GT_LABEL_NAME
=
'
treeID
'
#GT_LABEL_NAME = 'StemID'
TARGET_LABEL_NAME
=
'
instance_nr
'
def
__init__
(
def
__init__
(
self
,
self
,
input_file_path
,
input_file_path
,
...
@@ -27,13 +29,16 @@ class InstanceSegmentationMetrics:
...
@@ -27,13 +29,16 @@ class InstanceSegmentationMetrics:
self
.
input_las
=
laspy
.
read
(
self
.
input_file_path
)
self
.
input_las
=
laspy
.
read
(
self
.
input_file_path
)
self
.
instance_segmented_las
=
laspy
.
read
(
self
.
instance_segmented_file_path
)
self
.
instance_segmented_las
=
laspy
.
read
(
self
.
instance_segmented_file_path
)
self
.
skip_flag
=
self
.
check_if_labels_exist
()
self
.
skip_flag
=
self
.
check_if_labels_exist
(
X_label
=
self
.
GT_LABEL_NAME
,
Y_label
=
self
.
TARGET_LABEL_NAME
)
if
not
self
.
skip_flag
:
if
not
self
.
skip_flag
:
# get labels from input las file
# get labels from input las file
self
.
X_labels
=
self
.
input_las
.
treeID
.
astype
(
int
)
#TODO: generalize this to other labels
self
.
X_labels
=
self
.
input_las
[
self
.
GT_LABEL_NAME
].
astype
(
int
)
# get labels from instance segmented las file
# get labels from instance segmented las file
self
.
Y_labels
=
self
.
instance_segmented_las
.
instance_nr
.
astype
(
int
)
#TODO: generalize this to other labels
self
.
Y_labels
=
self
.
instance_segmented_las
[
self
.
TARGET_LABEL_NAME
].
astype
(
int
)
# if self.remove_ground:
# if self.remove_ground:
# # the labeling starts from 0, so we need to remove the ground
# # the labeling starts from 0, so we need to remove the ground
# self.Y_labels += 1
# self.Y_labels += 1
...
@@ -126,12 +131,12 @@ class InstanceSegmentationMetrics:
...
@@ -126,12 +131,12 @@ class InstanceSegmentationMetrics:
# define a function that finds class in input_file with the most points
# define a function that finds class in input_file with the most points
def
find_dominant_classes_in_gt
(
self
,
input_file
):
def
find_dominant_classes_in_gt
(
self
,
input_file
):
# get the unique labels
# get the unique labels
unique_labels
=
np
.
unique
(
input_file
.
treeID
).
astype
(
int
)
unique_labels
=
np
.
unique
(
input_file
[
self
.
GT_LABEL_NAME
]
).
astype
(
int
)
# create a dictionary
# create a dictionary
tmp_dict
=
{}
tmp_dict
=
{}
for
label
in
unique_labels
:
for
label
in
unique_labels
:
# get the indices of input_file.treeID == label
# get the indices of input_file.treeID == label
ind_label
=
np
.
where
(
input_file
.
treeID
==
label
)[
0
]
ind_label
=
np
.
where
(
input_file
[
self
.
GT_LABEL_NAME
]
==
label
)[
0
]
# put the number of points to the tmp_dict
# put the number of points to the tmp_dict
tmp_dict
[
str
(
label
)]
=
ind_label
.
shape
[
0
]
tmp_dict
[
str
(
label
)]
=
ind_label
.
shape
[
0
]
# sort tmp_dict by the number of points
# sort tmp_dict by the number of points
...
@@ -331,7 +336,7 @@ class InstanceSegmentationMetrics:
...
@@ -331,7 +336,7 @@ class InstanceSegmentationMetrics:
# compute tree level metrics
# compute tree level metrics
if
metric_dict
:
if
metric_dict
:
# get the number of trees in the ground truth
# get the number of trees in the ground truth
gt_trees
=
np
.
unique
(
self
.
input_las
.
treeID
)
gt_trees
=
np
.
unique
(
self
.
input_las
[
self
.
GT_LABEL_NAME
]
)
# remove 0 from gt_trees
# remove 0 from gt_trees
gt_trees
=
gt_trees
[
gt_trees
!=
0
]
gt_trees
=
gt_trees
[
gt_trees
!=
0
]
...
...
This diff is collapsed.
Click to expand it.
metrics/instance_segmentation_metrics_in_folder.py
+
5
−
2
View file @
781bd59b
...
@@ -9,6 +9,9 @@ from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics
...
@@ -9,6 +9,9 @@ from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics
from
nibio_postprocessing.attach_labels_to_las_file
import
AttachLabelsToLasFile
from
nibio_postprocessing.attach_labels_to_las_file
import
AttachLabelsToLasFile
class
InstanceSegmentationMetricsInFolder
():
class
InstanceSegmentationMetricsInFolder
():
GT_LABEL_NAME
=
'
treeID
'
TARGET_LABEL_NAME
=
'
instance_nr
'
def
__init__
(
def
__init__
(
self
,
self
,
gt_las_folder_path
,
gt_las_folder_path
,
...
@@ -161,8 +164,8 @@ class InstanceSegmentationMetricsInFolder():
...
@@ -161,8 +164,8 @@ class InstanceSegmentationMetricsInFolder():
gt_las_file_path
,
gt_las_file_path
,
target_las_file_path
,
target_las_file_path
,
update_las_file_path
=
os
.
path
.
join
(
self
.
output_folder_path
,
gt_las_file_core_name
+
'
.las
'
),
update_las_file_path
=
os
.
path
.
join
(
self
.
output_folder_path
,
gt_las_file_core_name
+
'
.las
'
),
gt_label_name
=
'
treeID
'
,
gt_label_name
=
self
.
GT_LABEL_NAME
,
target_label_name
=
'
treeID
'
,
target_label_name
=
self
.
GT_LABEL_NAME
,
verbose
=
self
.
verbose
verbose
=
self
.
verbose
).
main
()
).
main
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment