diff --git a/helpers/find_param_importance.py b/helpers/find_param_importance.py
index 19f5e45142203af1697e630a5657f2d3af2e2201..33e166cee7ed8ef9c133d67ca410d0265a4149ef 100644
--- a/helpers/find_param_importance.py
+++ b/helpers/find_param_importance.py
@@ -1,10 +1,12 @@
 # partially based on : https://medium.com/analytics-vidhya/feature-importance-explained-bfc8d874bcf
 import json
 import matplotlib.pyplot as plt
+import numpy as np
 import pandas as pd
 import argparse
 from sklearn.linear_model import LinearRegression
 from sklearn.preprocessing import StandardScaler
+from sklearn.cluster import KMeans
 
 class FindParamImportance:
     def __init__(
@@ -67,6 +69,154 @@ class FindParamImportance:
 
         return feature_importance
 
+
+    def cluster_results_kmeans(self):
+        # use k-means to cluster the results parameters into 4 groups
+        K = 7
+
+        data = self.get_data()
+     
+        X = data.drop('target', axis=1)
+        # get the target
+        y = data['target']
+
+        ss = StandardScaler()
+        X_scaled = ss.fit_transform(X)
+
+        kmeans = KMeans(n_clusters=K, random_state=0).fit(X_scaled)
+        labels = kmeans.labels_
+
+        # find mean values of y for each cluster
+        y_means = []
+        for i in range(K):
+            y_means.append(y[labels == i].mean())
+
+        # sort the clusters by the mean values of y
+        print(y_means)
+
+
+        print(labels)
+        # plot the results
+        plt.figure(figsize=(10, 6))
+        plt.scatter(X_scaled[:, 0], X_scaled[:, 1], c=labels, s=50, cmap='viridis')
+        plt.title('K-means Clustering')
+        plt.xlabel('Feature 1')
+        plt.ylabel('Feature 2')
+        # add cluster labels to the plot
+        for i in range(K):
+            plt.text(X_scaled[labels == i, 0].mean(), X_scaled[labels == i, 1].mean(), str(i), ha='center', va='center', bbox=dict(facecolor='white', alpha=0.5, lw=0))
+        plt.savefig('kmeans.png')
+
+        # compute mean values of y for each cluster
+        y_means = []
+        for i in range(4):
+            y_means.append(y[labels == i].mean())
+
+
+        return labels
+
+    def cluster_dbscan(self):
+        data = self.get_data()
+        X = data.drop('target', axis=1)
+        # get the target
+        y = data['target']
+
+        ss = StandardScaler()
+        X_scaled = ss.fit_transform(X)
+
+        from sklearn.cluster import DBSCAN
+        db = DBSCAN(eps=2, min_samples=3).fit(X_scaled)
+        labels = db.labels_
+
+        print("db scan labels: ", labels)
+
+        # find mean values of y for each cluster
+        y_means = []
+        # find number of unique labels
+        n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
+        for i in range(n_clusters_):
+            y_means.append(y[labels == i].mean())
+
+        print("db scan: ", y_means)
+
+
+        # plot the results the most important features
+        plt.figure(figsize=(10, 6))
+        plt.scatter(X_scaled[:, 0], X_scaled[:, 1], c=labels, s=50, cmap='viridis')
+        plt.title('DBSCAN Clustering')
+        plt.xlabel('Feature 1')
+        plt.ylabel('Feature 2')
+        # add cluster labels to the plot
+        for i in range(n_clusters_):
+            plt.text(X_scaled[labels == i, 0].mean(), X_scaled[labels == i, 1].mean(), str(i), ha='center', va='center', bbox=dict(facecolor='white', alpha=0.5, lw=0))
+        plt.savefig('dbscan.png')
+
+    def find_correlation(self):
+        data = self.get_data()
+
+        # get the features
+        X = data.drop('target', axis=1)
+        # get the target
+        y = data['target']
+
+        ss = StandardScaler()
+        X_scaled = ss.fit_transform(X)
+
+        # remove from X_scaled the columns that have zeros or nan
+        X_scaled = X_scaled[:, ~np.all(X_scaled == 0, axis=0)]
+
+        # compute the correlation matrix and put values in the figur
+        corr = np.corrcoef(X_scaled.T)
+        # plot the correlation matrix
+        plt.figure(figsize=(10, 6))
+        plt.imshow(corr, cmap='viridis')
+        plt.colorbar()
+        plt.xticks(range(len(X.columns)), X.columns, rotation='vertical')
+        plt.yticks(range(len(X.columns)), X.columns)
+        # add correlation values to the plot
+        for i in range(len(X_scaled.T)):
+            for j in range(len(X_scaled.T)):
+                plt.text(i, j, round(corr[i, j], 2), ha='center', va='center', color='white')
+        plt.savefig('correlation.png')
+
+        # find the most correlated features
+        # get the upper triangle of the correlation matrix
+        upper = np.triu(corr)
+        # find the indices of the upper triangle that are not zero
+        # these are the indices of the correlated features
+        correlated_features = np.where(upper > 0.1)
+        # get the feature names
+        feature_names = X.columns
+        # print the correlated features
+        for i in range(len(correlated_features[0])):
+            if correlated_features[0][i] != correlated_features[1][i]:
+                print(feature_names[correlated_features[0][i]], feature_names[correlated_features[1][i]])
+
+    def find_highest_params_values(self):
+        data = self.get_data()
+
+        # get mean values of the features for 6 the highest values of the target
+        # get the features
+        X = data.drop('target', axis=1)
+        # get the target
+        y = data['target']
+
+        # sort the target values
+        y_sorted = y.sort_values(ascending=False)
+        # get the indices of the 6 highest values
+        indices = y_sorted.index[:4]
+        # get the features for the 6 highest values
+        X_highest = X.loc[indices]
+        # get the mean values of the features
+        X_highest_mean = X_highest.mean()
+
+        # print the features with the highest mean values
+        for i in range(len(X_highest_mean)):
+            print(X_highest_mean.index[i], X_highest_mean[i])
+
+
+
+
     def gen_plot_of_feature_importance(self, feature_importance):
         plt.figure(figsize=(10, 6))
         plt.barh(feature_importance['feature'], feature_importance['importance'])
@@ -86,6 +236,12 @@ class FindParamImportance:
             print('Done')
             print('Plot saved to: ', self.plot_file_path)
 
+        cluster_labels = self.cluster_results_kmeans()
+        cluster_labels_dbscan = self.cluster_dbscan()
+        self.find_correlation()
+        self.find_highest_params_values()
+
+
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--logs_json_file', type=str, default='logs.json')