Skip to content
Snippets Groups Projects
Commit a7bc9d77 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update of feature importance script

parent 68033c66
Branches
No related tags found
No related merge requests found
......@@ -195,6 +195,34 @@ class FindParamImportance:
if correlated_features[0][i] != correlated_features[1][i]:
print(feature_names[correlated_features[0][i]], feature_names[correlated_features[1][i]])
def find_correlation_with_the_output(self):
data = self.get_data()
# get the features
X = data.drop('target', axis=1)
# get the target
y = data['target']
# find how correlated each feature is with the target
correlations = []
for i in range(len(X.columns)):
correlations.append(np.corrcoef(X.iloc[:, i], y)[0, 1])
# plot the correlations and make text to fit in the plot
plt.figure(figsize=(10, 6))
plt.bar(X.columns, correlations)
plt.xticks(rotation='vertical')
plt.tight_layout()
# create a legend
plt.axhline(y=0, color='black', linestyle='--')
plt.axhline(y=0.15, color='red', linestyle='--')
plt.axhline(y=-0.15, color='red', linestyle='--')
plt.legend(['0', '0.15', '-0.15'])
plt.title('Correlation with the output')
plt.savefig('correlation_with_the_output.png')
def find_highest_params_values(self):
data = self.get_data()
......@@ -244,6 +272,7 @@ class FindParamImportance:
# cluster_labels_dbscan = self.cluster_dbscan()
self.find_correlation()
self.find_highest_params_values()
self.find_correlation_with_the_output()
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment