diff --git a/helpers/find_param_importance.py b/helpers/find_param_importance.py index a23ec79082e4f8f3bc6b09ea088ba05cd21adce4..27f6cc4116b2da1be20a03d1c711e3349db67378 100644 --- a/helpers/find_param_importance.py +++ b/helpers/find_param_importance.py @@ -195,6 +195,34 @@ class FindParamImportance: if correlated_features[0][i] != correlated_features[1][i]: print(feature_names[correlated_features[0][i]], feature_names[correlated_features[1][i]]) + def find_correlation_with_the_output(self): + data = self.get_data() + # get the features + X = data.drop('target', axis=1) + # get the target + y = data['target'] + + # find how correlated each feature is with the target + correlations = [] + for i in range(len(X.columns)): + correlations.append(np.corrcoef(X.iloc[:, i], y)[0, 1]) + + # plot the correlations and make text to fit in the plot + plt.figure(figsize=(10, 6)) + plt.bar(X.columns, correlations) + plt.xticks(rotation='vertical') + plt.tight_layout() + # create a legend + plt.axhline(y=0, color='black', linestyle='--') + plt.axhline(y=0.15, color='red', linestyle='--') + plt.axhline(y=-0.15, color='red', linestyle='--') + plt.legend(['0', '0.15', '-0.15']) + plt.title('Correlation with the output') + plt.savefig('correlation_with_the_output.png') + + + + def find_highest_params_values(self): data = self.get_data() @@ -244,6 +272,7 @@ class FindParamImportance: # cluster_labels_dbscan = self.cluster_dbscan() self.find_correlation() self.find_highest_params_values() + self.find_correlation_with_the_output() if __name__ == '__main__':