diff --git a/helpers/find_param_importance.py b/helpers/find_param_importance.py index 33e166cee7ed8ef9c133d67ca410d0265a4149ef..a23ec79082e4f8f3bc6b09ea088ba05cd21adce4 100644 --- a/helpers/find_param_importance.py +++ b/helpers/find_param_importance.py @@ -162,8 +162,10 @@ class FindParamImportance: ss = StandardScaler() X_scaled = ss.fit_transform(X) - # remove from X_scaled the columns that have zeros or nan - X_scaled = X_scaled[:, ~np.all(X_scaled == 0, axis=0)] + # remove from X_scaled the columns that have nan values + X_scaled = X_scaled[:, ~np.isnan(X_scaled).any(axis=0)] + # replace nan values with 0 + X_scaled = np.nan_to_num(X_scaled) # compute the correlation matrix and put values in the figur corr = np.corrcoef(X_scaled.T) @@ -184,10 +186,11 @@ class FindParamImportance: upper = np.triu(corr) # find the indices of the upper triangle that are not zero # these are the indices of the correlated features - correlated_features = np.where(upper > 0.1) + correlated_features = np.where(upper > 0.15) # get the feature names feature_names = X.columns # print the correlated features + print('Correlated features:') for i in range(len(correlated_features[0])): if correlated_features[0][i] != correlated_features[1][i]: print(feature_names[correlated_features[0][i]], feature_names[correlated_features[1][i]]) @@ -211,12 +214,13 @@ class FindParamImportance: X_highest_mean = X_highest.mean() # print the features with the highest mean values + print(' ') + print('Features with the highest mean values:') + print('Feature name', 'Mean value') for i in range(len(X_highest_mean)): print(X_highest_mean.index[i], X_highest_mean[i]) - - def gen_plot_of_feature_importance(self, feature_importance): plt.figure(figsize=(10, 6)) plt.barh(feature_importance['feature'], feature_importance['importance']) @@ -236,8 +240,8 @@ class FindParamImportance: print('Done') print('Plot saved to: ', self.plot_file_path) - cluster_labels = self.cluster_results_kmeans() - cluster_labels_dbscan = self.cluster_dbscan() + # cluster_labels = self.cluster_results_kmeans() + # cluster_labels_dbscan = self.cluster_dbscan() self.find_correlation() self.find_highest_params_values()