Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature-selection] Replace matplotlib with plotly #815

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 12 additions & 40 deletions feature_selection/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@
# limitations under the License.
#
import json
import os

import matplotlib.pyplot as plt
import mlrun
import mlrun.datastore
import mlrun.utils
import mlrun.feature_store as fs
import mlrun.utils
import numpy as np
import pandas as pd
import seaborn as sns
from mlrun.artifacts import PlotArtifact
import plotly.express as px
from mlrun.artifacts import PlotlyArtifact
from mlrun.datastore.targets import ParquetTarget
# MLRun utils
from mlrun.utils.helpers import create_class
Expand All @@ -42,15 +40,6 @@
}


def _clear_current_figure():
"""
Clear matplotlib current figure.
"""
plt.cla()
plt.clf()
plt.close()


def show_values_on_bars(axs, h_v="v", space=0.4):
def _show_on_single_plot(ax_):
if h_v == "v":
Expand All @@ -74,33 +63,18 @@ def _show_on_single_plot(ax_):


def plot_stat(context, stat_name, stat_df):
_clear_current_figure()

# Add chart
ax = plt.axes()
stat_chart = sns.barplot(
sorted_df = stat_df.sort_values(stat_name)
fig = px.bar(
data_frame=sorted_df,
x=stat_name,
y="index",
data=stat_df.sort_values(stat_name, ascending=False).reset_index(),
ax=ax,
y=sorted_df.index,
title=f"{stat_name} feature scores",
color=stat_name,
)
plt.tight_layout()

for p in stat_chart.patches:
width = p.get_width()
plt.text(
5 + p.get_width(),
p.get_y() + 0.55 * p.get_height(),
"{:1.2f}".format(width),
ha="center",
va="center",
)

context.log_artifact(
PlotArtifact(f"{stat_name}", body=plt.gcf()),
local_path=os.path.join("plots", "feature_selection", f"{stat_name}.html"),
item=PlotlyArtifact(key=stat_name, figure=fig),
local_path=f"{stat_name}.html",
)
_clear_current_figure()


def feature_selection(
Expand All @@ -115,7 +89,6 @@ def feature_selection(
sample_ratio: float = None,
output_vector_name: float = None,
ignore_type_errors: bool = False,
is_feature_vector: bool = False,
):
"""
Applies selected feature selection statistical functions or models on our 'df_artifact'.
Expand All @@ -138,10 +111,9 @@ def feature_selection(
model name (ex. LinearSVC), formalized json (contains 'CLASS',
'FIT', 'META') or a path to such json file.
:param max_scaled_scores: produce feature scores table scaled with max_scaler.
:param sample_ratio: percentage of the dataset the user whishes to compute the feature selection process on.
:param sample_ratio: percentage of the dataset the user wishes to compute the feature selection process on.
:param output_vector_name: creates a new feature vector containing only the identifies features.
:param ignore_type_errors: skips datatypes that are neither float nor int within the feature vector.
:param is_feature_vector: bool stating if the data is passed as a feature vector.
"""
stat_filters = stat_filters or DEFAULT_STAT_FILTERS
model_filters = model_filters or DEFAULT_MODEL_FILTERS
Expand Down
Loading
Loading