From a7bc9d77713925b94576bd43915ce44eee7c1d35 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Fri, 30 Dec 2022 12:49:28 +0100 Subject: [PATCH] update of feature importance script --- helpers/find_param_importance.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/helpers/find_param_importance.py b/helpers/find_param_importance.py index a23ec79..27f6cc4 100644 --- a/helpers/find_param_importance.py +++ b/helpers/find_param_importance.py @@ -195,6 +195,34 @@ class FindParamImportance: if correlated_features[0][i] != correlated_features[1][i]: print(feature_names[correlated_features[0][i]], feature_names[correlated_features[1][i]]) + def find_correlation_with_the_output(self): + data = self.get_data() + # get the features + X = data.drop('target', axis=1) + # get the target + y = data['target'] + + # find how correlated each feature is with the target + correlations = [] + for i in range(len(X.columns)): + correlations.append(np.corrcoef(X.iloc[:, i], y)[0, 1]) + + # plot the correlations and make text to fit in the plot + plt.figure(figsize=(10, 6)) + plt.bar(X.columns, correlations) + plt.xticks(rotation='vertical') + plt.tight_layout() + # create a legend + plt.axhline(y=0, color='black', linestyle='--') + plt.axhline(y=0.15, color='red', linestyle='--') + plt.axhline(y=-0.15, color='red', linestyle='--') + plt.legend(['0', '0.15', '-0.15']) + plt.title('Correlation with the output') + plt.savefig('correlation_with_the_output.png') + + + + def find_highest_params_values(self): data = self.get_data() @@ -244,6 +272,7 @@ class FindParamImportance: # cluster_labels_dbscan = self.cluster_dbscan() self.find_correlation() self.find_highest_params_values() + self.find_correlation_with_the_output() if __name__ == '__main__': -- GitLab