Pierre-Louis Bescond
Pierre-Louis BescondFuture Factory Global Project Leader @ Roquette

Symulator machine learningu z Dash

Sprawdź, jak stworzyć interaktywny symulator machine learningu z bibliotekami Dash lub Plotly, aby lepiej rozumieć zachowanie danego modelu.
15.07.20207 min
Symulator machine learningu z Dash

Udało Ci się! Wszystko zaczęło się od problemu biznesowego przedstawionego przez twoich kolegów. Od tamtego momentu przemierzaliście ciemne doliny konsolidacji danych i czyszczenia ich oraz szukaliście odpowiednich cech i modeli. Sprawdzasz, czy Twój model działa i chcesz podzielić się swoimi wnioskami. Nie wszyscy jednak znają RMSE, tablicę pomyłek, czy wartości Shapleya.

Na przykład, zespoły operacyjne w przemyśle stają przed codziennymi wyzwaniami związanymi z optymalizacją produkcji lub łańcucha dostaw, a uczenie maszynowe oferuje nowy sposób rozumienia zachowań złożonych procesów… o ile można je przełożyć na zrozumiałą analizę. Kiedy jednak przychodzi czas na podzielenie się insightami wniesionymi przez A.I., to Python Notebook nie jest niestety najlepszym wyborem.

Bardziej niż do tego, jak dobrze model działa, zespoły dążą do zrozumienia, jakie są najważniejsze czynniki wśród wielu parametrów produkcji i jak każdy z nich może wpływać na zachowanie modelu.


Biblioteki Dash oferują sposób na stworzenie dynamicznego dashboardu, dostępnego z poziomu przeglądarki internetowej razem z interaktywnymi funkcjami

Pomysł jest prosty:

  • Tworzysz skrypt .py, w którym „klasyczne” zadania uczenia maszynowego są połączone z zakodowanym układem strony internetowej, blisko kodu HTML.
  • Skrypt Pythona generuje po jego uruchomieniu dynamiczną stronę internetową pod lokalnym adresem (http://127.0.0.1:8050/), która zawiera współdziałające ze sobą komponenty.


Ikony zaprojektowane przez Pixel Buddha i Pixelmeetup (źródło: flaticon.com)

Tworzymy demo

Po pierwsze, musimy zaprojektować w miarę prawdopodobny przypadek użycia. „make_regression” od Scikit-learn wygeneruje nasze źródło danych. Będziemy dostosowywać niektóre cechy, aby uzyskać realistyczne liczby w przykładzie:

# -*- coding: utf-8 -*-

# We start with the import of standard ML librairies
import pandas as pd
import numpy as np
import math

from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor

# We add all Plotly and Dash necessary librairies
import plotly.graph_objects as go

import dash
import dash_core_components as dcc
import dash_html_components as html
import dash_daq as daq
from dash.dependencies import Input, Output


# We start by creating a virtual regression use-case
X, y = make_regression(n_samples=1000, n_features=8, n_informative=5, random_state=22)

# We rename columns as industrial parameters
col_names = ["Temperature","Viscosity","Pressure", "pH","Inlet_flow", "Rotating_Speed","Particles_size","Color_density"]

df = pd.DataFrame(X, columns=col_names)

# We change the most important features ranges to make them look like actual figures
df["pH"]=6.5+df["pH"]/4
df["Pressure"]=10+df["Pressure"]
df["Temperature"]=20+df["Temperature"]
df["Y"] = 90+y/20

# We train a simple RF model
model = RandomForestRegressor()
model.fit(df.drop("Y", axis=1), df["Y"])

# We create a DataFrame to store the features' importance and their corresponding label
df_feature_importances = pd.DataFrame(model.feature_importances_*100,columns=["Importance"],index=col_names)
df_feature_importances = df_feature_importances.sort_values("Importance", ascending=False)

Fazy generowania i modelowania danych


Po wykonaniu zadań związanych z uczeniem maszynowym musimy jeszcze przygotować informacje dynamiczne, które będą wyświetlane w naszej przeglądarce. W tym przykładzie utworzymy:

  • Wykres słupkowy pokazujący znaczenie cech modelu
  • Trzy suwaki umożliwiające zmianę wartości trzech najważniejszych cech i zrozumienie ich wpływu na prognozę modelu.

# We create a Features Importance Bar Chart
fig_features_importance = go.Figure()
fig_features_importance.add_trace(go.Bar(x=df_feature_importances.index,
                                         y=df_feature_importances["Importance"],
                                         marker_color='rgb(171, 226, 251)')
                                 )
fig_features_importance.update_layout(title_text='<b>Features Importance of the model<b>', title_x=0.5)
# The command below can be activated in a standard notebook to display the chart
#fig_features_importance.show()

# We record the name, min, mean and max of the three most important features
slider_1_label = df_feature_importances.index[0]
slider_1_min = math.floor(df[slider_1_label].min())
slider_1_mean = round(df[slider_1_label].mean())
slider_1_max = round(df[slider_1_label].max())

slider_2_label = df_feature_importances.index[1]
slider_2_min = math.floor(df[slider_2_label].min())
slider_2_mean = round(df[slider_2_label].mean())
slider_2_max = round(df[slider_2_label].max())

slider_3_label = df_feature_importances.index[2]
slider_3_min = math.floor(df[slider_3_label].min())
slider_3_mean = round(df[slider_3_label].mean())
slider_3_max = round(df[slider_3_label].max())


Być może zauważyliście, że tworzenie wykresu słupkowego jest podobne do tworzenia tego używanego w Plotly, z takim tylko wyjątkiem, że nie wywołujemy tutaj funkcji .show(). A to dlatego, że wykres będzie renderowany w sekcji „Layout”. Sekcja „Layout” (patrz poniżej) wymaga pewnej wprawy dla użytkowników, którzy nie znają zarówno Pythona, jak i HTML.

Pomysł polega tutaj na połączeniu elementów HTML (wywoływanych przez htm.xxx) z elementami Dynamic Dash Components (dcc.xxx).

Uprośćmy to:

<H1>     Title
<DCC> Features Importance Chart
<H4> #1 Importance Feature Name
<DCC> #1 Feature slider
<H4> #2 Importance Feature Name
<DCC> #2 Feature slider
<H4> #2 Importance Feature Name
<DCC> #2 Feature slider
<H2> Updated predictions
###############################################################################

app = dash.Dash()

# The page structure will be:
#    Features Importance Chart
#    <H4> Feature #1 name
#    Slider to update Feature #1 value
#    <H4> Feature #2 name
#    Slider to update Feature #2 value
#    <H4> Feature #3 name
#    Slider to update Feature #3 value
#    <H2> Updated Prediction
#    Callback fuction with Sliders values as inputs and Prediction as Output

# We apply basic HTML formatting to the layout
app.layout = html.Div(style={'textAlign': 'center', 'width': '800px', 'font-family': 'Verdana'},
                      
                    children=[

                        # Title display
                        html.H1(children="Simulation Tool"),
                        
                        # Dash Graph Component calls the fig_features_importance parameters
                        dcc.Graph(figure=fig_features_importance),
                        
                        # We display the most important feature's name
                        html.H4(children=slider_1_label),

                        # The Dash Slider is built according to Feature #1 ranges
                        dcc.Slider(
                            id='X1_slider',
                            min=slider_1_min,
                            max=slider_1_max,
                            step=0.5,
                            value=slider_1_mean,
                            marks={i: '{} bars'.format(i) for i in range(slider_1_min, slider_1_max+1)}
                            ),

                        # The same logic is applied to the following names / sliders
                        html.H4(children=slider_2_label),

                        dcc.Slider(
                            id='X2_slider',
                            min=slider_2_min,
                            max=slider_2_max,
                            step=0.5,
                            value=slider_2_mean,
                            marks={i: '{}°'.format(i) for i in range(slider_2_min, slider_2_max+1)}
                        ),

                        html.H4(children=slider_3_label),

                        dcc.Slider(
                            id='X3_slider',
                            min=slider_3_min,
                            max=slider_3_max,
                            step=0.1,
                            value=slider_3_mean,
                            marks={i: '{}'.format(i) for i in np.linspace(slider_3_min, slider_3_max,1+(slider_3_max-slider_3_min)*5)},
                        ),
                        
                        # The predictin result will be displayed and updated here
                        html.H2(id="prediction_result"),

                    ])

Oto „najbardziej wrażliwa” część skryptu: funkcja app.callback umożliwia enkapsulację standardowej funkcji python, dzięki czemu będzie ona oddziaływać z komponentami strony internetowej zaprojektowanymi powyżej. Jej mechanizm można uprościć w następujący sposób:

@app.callback(Output_on_webpage,
              Input_1_from_webpage,
              Input_2_from_webpage,
              Input_3_from_webpage)
python_function(Input_1, Input_2, Input_3):
       Output = model_evaluation([Input_1, Input_2, Input_3])
       return Output
# The callback function will provide one "Ouput" in the form of a string (=children)
@app.callback(Output(component_id="prediction_result",component_property="children"),
# The values correspnding to the three sliders are obtained by calling their id and value property
              [Input("X1_slider","value"), Input("X2_slider","value"), Input("X3_slider","value")])

# The input variable are set in the same order as the callback Inputs
def update_prediction(X1, X2, X3):

    # We create a NumPy array in the form of the original features
    # ["Pressure","Viscosity","Particles_size", "Temperature","Inlet_flow", "Rotating_Speed","pH","Color_density"]
    # Except for the X1, X2 and X3, all other non-influencing parameters are set to their mean
    input_X = np.array([X1,
                       df["Viscosity"].mean(),
                       df["Particles_size"].mean(),
                       X2,
                       df["Inlet_flow"].mean(),
                       df["Rotating_Speed"].mean(),
                       X3,
                       df["Color_density"].mean()]).reshape(1,-1)        
    
    # Prediction is calculated based on the input_X array
    prediction = model.predict(input_X)[0]
    
    # And retuned to the Output of the callback function
    return "Prediction: {}".format(round(prediction,1))

if __name__ == "__main__":
    app.run_server()


Założyliśmy tu, że tylko trzy najważniejsze funkcje były warte aktualizacji. Dlatego rozważaliśmy średnią wartość każdej innej cechy, gdy przetwarzał dane z tablicy. Jest to oczywiście wybór dokonany dla tego przykładu, dzięki czemu będzie prościej.

I gotowe!

Otwieramy wiersz polecenia i wpisujemy python dashboard.py.

(base) PS C:\Users\...\> jupyter dashboard.py
 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
 * Debug mode: off
 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)


Teraz wystarczy otworzyć naszą przeglądarkę pod adresem http://127.0.0.1:8050/ i sprawdzić wynik:


Za każdym razem, gdy suwak się porusza, funkcja app.callback uruchamia skrypt python, aby ponownie wyliczyć prognozę.

Skrypt dostępny tutaj

Jak możesz sobie wyobrazić, jest to bardzo potężne narzędzie, umożliwiające interakcje z modelem i zrozumienie jego zachowania dla tych, którzy nie zajmują się data science. Jest to również bardzo dobra platforma do testowania dokładności modelu: zespoły operacyjne będą mogły sprawdzić i potwierdzić, czy zachowanie modelu odpowiada ich doświadczeniu w terenie.

Polecę jeszcze dwie rzeczy, zanim skończymy:

  • Jeśli nie znasz Plotly oraz Dash, zacznij od treningu na standardowych wykresach Plotly (kroki są bardzo dobrze wyjaśnione).
  • Ten kurs online jest naprawdę dobrze zaprojektowany i przeprowadzi Cię przez podstawy Pythona (Pandas / NumPy), składnię wykresów oraz projektowanie z Dash. 


Oryginał tekstu w języku angielskim możesz przeczytać tutaj.

<p>Loading...</p>