Skip to content
Snippets Groups Projects
Commit 68033c66 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update of feature importance script

parent fcfc03fb
No related branches found
No related tags found
No related merge requests found
...@@ -162,8 +162,10 @@ class FindParamImportance: ...@@ -162,8 +162,10 @@ class FindParamImportance:
ss = StandardScaler() ss = StandardScaler()
X_scaled = ss.fit_transform(X) X_scaled = ss.fit_transform(X)
# remove from X_scaled the columns that have zeros or nan # remove from X_scaled the columns that have nan values
X_scaled = X_scaled[:, ~np.all(X_scaled == 0, axis=0)] X_scaled = X_scaled[:, ~np.isnan(X_scaled).any(axis=0)]
# replace nan values with 0
X_scaled = np.nan_to_num(X_scaled)
# compute the correlation matrix and put values in the figur # compute the correlation matrix and put values in the figur
corr = np.corrcoef(X_scaled.T) corr = np.corrcoef(X_scaled.T)
...@@ -184,10 +186,11 @@ class FindParamImportance: ...@@ -184,10 +186,11 @@ class FindParamImportance:
upper = np.triu(corr) upper = np.triu(corr)
# find the indices of the upper triangle that are not zero # find the indices of the upper triangle that are not zero
# these are the indices of the correlated features # these are the indices of the correlated features
correlated_features = np.where(upper > 0.1) correlated_features = np.where(upper > 0.15)
# get the feature names # get the feature names
feature_names = X.columns feature_names = X.columns
# print the correlated features # print the correlated features
print('Correlated features:')
for i in range(len(correlated_features[0])): for i in range(len(correlated_features[0])):
if correlated_features[0][i] != correlated_features[1][i]: if correlated_features[0][i] != correlated_features[1][i]:
print(feature_names[correlated_features[0][i]], feature_names[correlated_features[1][i]]) print(feature_names[correlated_features[0][i]], feature_names[correlated_features[1][i]])
...@@ -211,12 +214,13 @@ class FindParamImportance: ...@@ -211,12 +214,13 @@ class FindParamImportance:
X_highest_mean = X_highest.mean() X_highest_mean = X_highest.mean()
# print the features with the highest mean values # print the features with the highest mean values
print(' ')
print('Features with the highest mean values:')
print('Feature name', 'Mean value')
for i in range(len(X_highest_mean)): for i in range(len(X_highest_mean)):
print(X_highest_mean.index[i], X_highest_mean[i]) print(X_highest_mean.index[i], X_highest_mean[i])
def gen_plot_of_feature_importance(self, feature_importance): def gen_plot_of_feature_importance(self, feature_importance):
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.barh(feature_importance['feature'], feature_importance['importance']) plt.barh(feature_importance['feature'], feature_importance['importance'])
...@@ -236,8 +240,8 @@ class FindParamImportance: ...@@ -236,8 +240,8 @@ class FindParamImportance:
print('Done') print('Done')
print('Plot saved to: ', self.plot_file_path) print('Plot saved to: ', self.plot_file_path)
cluster_labels = self.cluster_results_kmeans() # cluster_labels = self.cluster_results_kmeans()
cluster_labels_dbscan = self.cluster_dbscan() # cluster_labels_dbscan = self.cluster_dbscan()
self.find_correlation() self.find_correlation()
self.find_highest_params_values() self.find_highest_params_values()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment