Persistencia de modelos en Python: cómo guardar tu modelo entrenado de Machine Learning

Entrenar un modelo de Machine Learning es un proceso muchas veces lento, por lo que no tiene sentido volver a entrenar el modelo cada vez que lo necesitemos en el futuro.

Por suerte, una herramienta de SciKit Learn nos permite guardar nuestro modelo ya entrenado para utilizarlo cuando lo necesitemos. Vamos allá:

1. Guardar el modelo entrenado

Utilizaremos un directorio de SciKit Learn llamado “externals”, que contiene dependencias externas agrupadas que se actualizan de vez en cuando. Dicho de otra forma, el código allí no es realmente parte de scikit, son solo otras bibliotecas que usa y almacena scikit para evitar problemas de dependencia si el usuario tiene instaladas diferentes versiones.

En concreto utilizaremos joblib, una librería usada para realizar funciones de pipelining en Python. Importamos la librería de la siguiente forma:

from sklearn.externals import joblib 

Ahora supongamos que tenemos el siguiente modelo, clf_rf, ya entrenado:

clf_rf.fit(x_train, y_train) # Entrenamiento del modelo

Podemos guardarlo realizando un joblib.dump():

joblib.dump(clf_rf, 'modelo_entrenado.pkl') # Guardo el modelo.

El modelo se guardará en el fichero “modelo_entrenado.pkl” dentro del directorio que hayamos establecido por defecto en nuestro intérprete de Python.

2. Carga del modelo entrenado

Cuando necesitemos cargar el modelo ya entrenado, simplemente hacemos un joblib.load():

clf_rf = joblib.load('modelo_entrenado.pkl') # Carga del modelo.

Si queremos asegurarnos que el modelo se ha guardado correctamente, podemos calcular el rendimiento del modelo antes de guardarlo y al cargarlo de nuevo de la siguiente forma:

clf_rf.score(x_train, y_train)

Ambos rendimientos deben ser exactamente iguales.

Para saber más:
SciKit Learn Model Persistence.
Save and Load Machine Learning Models in Python with scikit-learn.
scikit-learn: Save and Restore Models.

Deja un comentario

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *

Solve : *
29 − 13 =