Skip to content
Snippets Groups Projects
Commit 7a0233aa authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

bayes_opt pipeline has been implemented instead of wandb

parent d3aa8935
No related branches found
No related tags found
No related merge requests found
from bayes_opt import BayesianOptimization
from bayes_opt.logger import JSONLogger
from bayes_opt.event import Events
from pipeline_test_command_params_just_tls import RunCommand
from pipeline_test_command_params_just_tls import main as pipeline_main
def bayes_opt_main(
n_tiles,
slice_thickness,
find_stems_height,
find_stems_thickness,
graph_maximum_cumulative_gap,
add_leaves_voxel_length,
find_stems_min_points,
graph_edge_length,
add_leaves_edge_length
):
n_tiles = int(n_tiles)
find_stems_min_points = int(find_stems_min_points)
return pipeline_main(
n_tiles,
slice_thickness,
find_stems_height,
find_stems_thickness,
graph_maximum_cumulative_gap,
add_leaves_voxel_length,
find_stems_min_points,
graph_edge_length,
add_leaves_edge_length
)
pbounds = {
'n_tiles': (3, 3),
'slice_thickness': (0.25, 0.75),
'find_stems_height': (0.5, 2.0),
'find_stems_thickness': (0.1, 1.0),
'graph_maximum_cumulative_gap': (5, 20),
'add_leaves_voxel_length': (0.1, 0.5),
'find_stems_min_points': (50, 500),
'graph_edge_length': (0.1, 2.0),
'add_leaves_edge_length': (0.2, 1.5)
}
optimizer = BayesianOptimization(
f=bayes_opt_main,
pbounds=pbounds,
random_state=1,
allow_duplicate_points=True
)
logger = JSONLogger(path="./bayes_opt_run_logs.json")
optimizer.subscribe(Events.OPTIMIZATION_STEP, logger)
optimizer.maximize(
init_points=5,
n_iter=100
)
# partially based on : https://medium.com/analytics-vidhya/feature-importance-explained-bfc8d874bcf
import json
import matplotlib.pyplot as plt
import pandas as pd
import argparse
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
class FindParamImportance:
def __init__(
self,
logs_json_file,
plot_file_path='feature_importance.png',
verbose=False
) -> None:
self.logs_json_file = logs_json_file
self.plot_file_path = plot_file_path
self.verbose = verbose
def get_data(self):
runs = []
for line in open('logs.json', 'r'):
runs.append(json.loads(line))
# get header of the logs
header = ['target']
for key in runs[0]['params']:
header.append(key)
# create a dictionary from the header with empty lists
data = {}
data = data.fromkeys(header)
data = {key: [] for key in header}
# fill the dictionary with the data
for run in runs:
data['target'].append(run['target'])
for key in run['params']:
data[key].append(run['params'][key])
# create a dataframe from the dictionary
df = pd.DataFrame(data)
return df
def get_feature_importance(self, df):
# get the features
X = df.drop('target', axis=1)
# get the target
y = df['target']
ss = StandardScaler()
X_scaled = ss.fit_transform(X)
model=LinearRegression()
model.fit(X_scaled,y)
importance=model.coef_
# combine importance with the feature names
feature_importance = pd.DataFrame({'feature': X.columns, 'importance': abs(importance)})
# sort the values
feature_importance.sort_values(by='importance', ascending=True, inplace=True)
if self.verbose:
print(feature_importance)
return feature_importance
def gen_plot_of_feature_importance(self, feature_importance):
plt.figure(figsize=(10, 6))
plt.barh(feature_importance['feature'], feature_importance['importance'])
plt.title('Feature Importance')
# save the plot
plt.savefig(self.plot_file_path)
def main(self):
df = self.get_data()
feature_importance = self.get_feature_importance(df)
self.gen_plot_of_feature_importance(feature_importance)
if self.verbose:
print('Done')
print('Plot saved to: ', self.plot_file_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--logs_json_file', type=str, default='logs.json')
parser.add_argument('--plot_file_path', type=str, default='feature_importance.png')
parser.add_argument('--verbose', help="Print more information.", action="store_true")
args = parser.parse_args()
logs_json_file = args.logs_json_file
plot_file_path = args.plot_file_path
verbose = args.verbose
find_param_importance = FindParamImportance(
logs_json_file=logs_json_file,
plot_file_path=plot_file_path,
verbose=verbose
)
find_param_importance.main()
\ No newline at end of file
......@@ -5,9 +5,6 @@ import wandb
from metrics.instance_segmentation_metrics_in_folder import \
InstanceSegmentationMetricsInFolder
# wandb.login()
# wandb.init(project="instance_segmentation_classic", entity="smart_forest")
# define a class to run the command with arguments
class RunCommand:
......@@ -30,9 +27,12 @@ def main(
graph_edge_length,
add_leaves_edge_length
):
USE_WANDB = False
# initialize the sweep
run = wandb.init(project="paper-sweep-nibio-model-just-tls", entity="smart_forest")
if USE_WANDB:
run = wandb.init(project="paper-sweep-nibio-model-just-tls", entity="smart_forest")
# get files for the sweep
print("Getting files for the sweep")
......@@ -92,7 +92,10 @@ def main(
# log the metric
print("Logging the metric")
wandb.log({"f1_score": f1_score})
if USE_WANDB:
wandb.log({"f1_score": f1_score})
return f1_score
if __name__ == "__main__":
# use argparse to get the arguments
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment