Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

mlflow_doc.ipynb 9.1 KiB

pirms 2 dienas
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Utilisation du serveur de tracking MLflow du projet\n",
  8. "\n",
  9. "Le serveur de tracking MLFlow est accessible à l'adresse suivante : https://champi.heuzef.com\n",
  10. "\n",
  11. "Ce notebook explique comment utiliser ce dernier."
  12. ]
  13. },
  14. {
  15. "cell_type": "code",
  16. "execution_count": 1,
  17. "metadata": {},
  18. "outputs": [
  19. {
  20. "data": {
  21. "text/plain": [
  22. "<Experiment: artifact_location='mlflow-artifacts:/103379370584144202', creation_time=1721579566179, experiment_id='103379370584144202', last_update_time=1721579566179, lifecycle_stage='active', name='champi', tags={}>"
  23. ]
  24. },
  25. "execution_count": 1,
  26. "metadata": {},
  27. "output_type": "execute_result"
  28. }
  29. ],
  30. "source": [
  31. "# Initialisation de l'URL\n",
  32. "mlflow_server_uri = \"https://champi.heuzef.com\"\n",
  33. "\n",
  34. "# Imports et paramétrage de MLflow\n",
  35. "from mlflow import MlflowClient\n",
  36. "import mlflow\n",
  37. "import setuptools\n",
  38. "\n",
  39. "mlflow.set_tracking_uri(mlflow_server_uri)\n",
  40. "mlflow.set_experiment(\"champi\") # Le nom du projet"
  41. ]
  42. },
  43. {
  44. "cell_type": "markdown",
  45. "metadata": {},
  46. "source": [
  47. "## Vérifier la disponibilité\n",
  48. "\n",
  49. "Dans un premier temps, il faut s'assurer que le serveur est bien joignable."
  50. ]
  51. },
  52. {
  53. "cell_type": "code",
  54. "execution_count": 2,
  55. "metadata": {},
  56. "outputs": [
  57. {
  58. "name": "stdout",
  59. "output_type": "stream",
  60. "text": [
  61. "Le serveur de tracking MLflow est disponible : https://champi.heuzef.com\n"
  62. ]
  63. },
  64. {
  65. "data": {
  66. "text/plain": [
  67. "<Response [200]>"
  68. ]
  69. },
  70. "execution_count": 2,
  71. "metadata": {},
  72. "output_type": "execute_result"
  73. }
  74. ],
  75. "source": [
  76. "import requests\n",
  77. "\n",
  78. "def is_mlflow_tracking_server_available(mlflow_server_uri):\n",
  79. " try:\n",
  80. " response = requests.get(mlflow_server_uri)\n",
  81. " if response.status_code == 200:\n",
  82. " return True\n",
  83. " else:\n",
  84. " return False\n",
  85. " except requests.exceptions.RequestException:\n",
  86. " return False\n",
  87. "\n",
  88. "if is_mlflow_tracking_server_available(mlflow_server_uri):\n",
  89. " print(\"Le serveur de tracking MLflow est disponible :\", mlflow_server_uri)\n",
  90. "else:\n",
  91. " print(\"Le serveur de tracking MLflow n'est pas disponible.\")\n",
  92. "\n",
  93. "requests.get(mlflow_server_uri)"
  94. ]
  95. },
  96. {
  97. "cell_type": "markdown",
  98. "metadata": {},
  99. "source": [
  100. "# Entrainement d'un modèle pour l'exemple\n",
  101. "\n",
  102. "Nous allons entrainer un petit modèle basique, avec Scikit-learn, pour obtenir quelques métriques qui seront enregistrés dans des variables."
  103. ]
  104. },
  105. {
  106. "cell_type": "code",
  107. "execution_count": 5,
  108. "metadata": {},
  109. "outputs": [
  110. {
  111. "name": "stdout",
  112. "output_type": "stream",
  113. "text": [
  114. "\n",
  115. "metrics :\n",
  116. "{'mae': 130.27056867217001, 'mse': 27735.25123415424, 'rmse': 166.53903816869555, 'r2': 0.14952412556941785}\n",
  117. "\n",
  118. "params :\n",
  119. "{'n_estimators': 10, 'max_depth': 10, 'random_state': 42}\n"
  120. ]
  121. }
  122. ],
  123. "source": [
  124. "# Imports librairies\n",
  125. "from sklearn.ensemble import RandomForestRegressor\n",
  126. "from sklearn.model_selection import train_test_split\n",
  127. "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
  128. "import pandas as pd\n",
  129. "import numpy as np\n",
  130. "\n",
  131. "# Import d'une database (au pif juste pour ce test)\n",
  132. "data = pd.read_csv(\"https://github.com/DataScientest-Studio/MLflow/raw/main/fake_data.csv\")\n",
  133. "X = data.drop(columns=[\"date\", \"demand\", \"weekend\", \"holiday\", \"promo\"])\n",
  134. "y = data[\"demand\"]\n",
  135. "X_train, X_val, y_train, y_val = train_test_split(\n",
  136. " X, y, test_size=0.2, random_state=42\n",
  137. ")\n",
  138. "\n",
  139. "# Train model\n",
  140. "params = {\n",
  141. " \"n_estimators\": 10,\n",
  142. " \"max_depth\": 10,\n",
  143. " \"random_state\": 42,\n",
  144. "}\n",
  145. "rf = RandomForestRegressor(**params)\n",
  146. "rf.fit(X_train, y_train)\n",
  147. "\n",
  148. "# Evaluate model\n",
  149. "y_pred = rf.predict(X_val)\n",
  150. "mae = mean_absolute_error(y_val, y_pred)\n",
  151. "mse = mean_squared_error(y_val, y_pred)\n",
  152. "rmse = np.sqrt(mse)\n",
  153. "r2 = r2_score(y_val, y_pred)\n",
  154. "metrics = {\"mae\": mae, \"mse\": mse, \"rmse\": rmse, \"r2\": r2}\n",
  155. "\n",
  156. "print(\"\\nmetrics :\")\n",
  157. "print(metrics)\n",
  158. "\n",
  159. "print(\"\\nparams :\")\n",
  160. "print(params)"
  161. ]
  162. },
  163. {
  164. "cell_type": "markdown",
  165. "metadata": {},
  166. "source": [
  167. "## Envoi des informations au serveur MLflow\n",
  168. "\n",
  169. "Maintenant que nous avons nos resultats, nous allons donc créer une \"run\" et la transférer sur le serveur. \n",
  170. "\n",
  171. "Pour cet exemple, c'est le module mlflow.sklearn qui est utilisé. Il vous faudra bien sur utiliser celui adapté à votre outil : https://mlflow.org/docs/latest/python_api/"
  172. ]
  173. },
  174. {
  175. "cell_type": "code",
  176. "execution_count": 20,
  177. "metadata": {},
  178. "outputs": [],
  179. "source": [
  180. "run_name = \"run_test_001\" # Le nom de la run, nous utiliserons notre propre nomenclature pour le projet\n",
  181. "\n",
  182. "with mlflow.start_run(run_name=run_name) as run:\n",
  183. " mlflow.log_params(params)\n",
  184. " mlflow.log_metrics(metrics)\n",
  185. " mlflow.sklearn.log_model(sk_model=rf, input_example=X_val, artifact_path=run_name+\"_artifacts\")"
  186. ]
  187. },
  188. {
  189. "cell_type": "markdown",
  190. "metadata": {},
  191. "source": [
  192. "Ce code executé en fin de script ou fin de notebook est finalement suffisant et assez flexible pour transférer les informations que nous souhaitons. Mais il est possible de faire encore plus simple en laissant MLflow se debrouiller avec `mlflow.autolog()`.\n",
  193. "\n",
  194. "> https://mlflow.org/docs/latest/tracking/autolog.html\n",
  195. "\n",
  196. "N'hésitez pas à tester des hyper-paramètres et envoyer quelques métriques à comparer sur l'interface."
  197. ]
  198. },
  199. {
  200. "cell_type": "markdown",
  201. "metadata": {},
  202. "source": [
  203. "# Charger le modèle le plus performant"
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": 7,
  209. "metadata": {},
  210. "outputs": [
  211. {
  212. "name": "stderr",
  213. "output_type": "stream",
  214. "text": [
  215. "/home/heuzef/GIT/jan24_cds_mushrooms/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
  216. " from .autonotebook import tqdm as notebook_tqdm\n",
  217. "Downloading artifacts: 100%|██████████| 6/6 [00:01<00:00, 4.64it/s] \n",
  218. "2024-10-01 15:18:25.979867: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
  219. "2024-10-01 15:18:26.089889: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
  220. "2024-10-01 15:18:26.138677: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
  221. "2024-10-01 15:18:26.233926: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
  222. "2024-10-01 15:18:26.255095: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
  223. "2024-10-01 15:18:26.371668: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
  224. "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
  225. "2024-10-01 15:18:27.797753: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
  226. ]
  227. },
  228. {
  229. "data": {
  230. "text/plain": [
  231. "mlflow.pyfunc.loaded_model:\n",
  232. " artifact_path: heuzef_efficientnetb1_010_artifacts\n",
  233. " flavor: mlflow.keras\n",
  234. " run_id: 93ce2df782da48108f127f3e6c4adb8b"
  235. ]
  236. },
  237. "execution_count": 7,
  238. "metadata": {},
  239. "output_type": "execute_result"
  240. }
  241. ],
  242. "source": [
  243. "import mlflow.pyfunc\n",
  244. "\n",
  245. "champi_cnn = mlflow.pyfunc.load_model(f\"models:/champi_cnn@champion\")\n",
  246. "\n",
  247. "champi_cnn"
  248. ]
  249. }
  250. ],
  251. "metadata": {
  252. "kernelspec": {
  253. "display_name": ".venv",
  254. "language": "python",
  255. "name": "python3"
  256. },
  257. "language_info": {
  258. "codemirror_mode": {
  259. "name": "ipython",
  260. "version": 3
  261. },
  262. "file_extension": ".py",
  263. "mimetype": "text/x-python",
  264. "name": "python",
  265. "nbconvert_exporter": "python",
  266. "pygments_lexer": "ipython3",
  267. "version": "3.11.9"
  268. }
  269. },
  270. "nbformat": 4,
  271. "nbformat_minor": 2
  272. }