You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mlflow_manuel.ipynb 17 KiB

2 days ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 3,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import mlflow\n",
  10. "\n",
  11. "mlflow_server_uri = \"https://champi.heuzef.com\"\n",
  12. "\n",
  13. "mlflow.set_tracking_uri(mlflow_server_uri)\n",
  14. "mlflow.set_experiment(\"champi\") # Le nom du projet\n",
  15. "\n",
  16. "with mlflow.start_run():\n",
  17. " for epoch, metrics in epochs_data.items():\n",
  18. " mlflow.log_metric(\"accuracy\", metrics[\"accuracy\"], step=epoch)\n",
  19. " mlflow.log_metric(\"loss\", metrics[\"loss\"], step=epoch)\n",
  20. " mlflow.log_metric(\"val_accuracy\", metrics[\"val_accuracy\"], step=epoch)\n",
  21. " mlflow.log_metric(\"val_loss\", metrics[\"val_loss\"], step=epoch)\n",
  22. " mlflow.log_metric(\"learning_rate\", metrics[\"learning_rate\"], step=epoch)\n",
  23. " mlflow.log_param(\"model\", \"resnet50_callbacks_5layers_unfreeze\")"
  24. ]
  25. },
  26. {
  27. "cell_type": "code",
  28. "execution_count": 1,
  29. "metadata": {},
  30. "outputs": [],
  31. "source": [
  32. "epochs_data = {\n",
  33. " 1: {\n",
  34. " \"accuracy\": 0.0847,\n",
  35. " \"loss\": 5.8520,\n",
  36. " \"val_accuracy\": 0.4417,\n",
  37. " \"val_loss\": 3.9198,\n",
  38. " \"learning_rate\": 1.0000e-05\n",
  39. " },\n",
  40. " 2: {\n",
  41. " \"accuracy\": 0.1360,\n",
  42. " \"loss\": 4.2380,\n",
  43. " \"val_accuracy\": 0.5399,\n",
  44. " \"val_loss\": 3.6218,\n",
  45. " \"learning_rate\": 1.0000e-05\n",
  46. " },\n",
  47. " 3: {\n",
  48. " \"accuracy\": 0.2402,\n",
  49. " \"loss\": 3.9099,\n",
  50. " \"val_accuracy\": 0.6810,\n",
  51. " \"val_loss\": 3.2485,\n",
  52. " \"learning_rate\": 1.0000e-05\n",
  53. " },\n",
  54. " 4: {\n",
  55. " \"accuracy\": 0.3660,\n",
  56. " \"loss\": 3.5690,\n",
  57. " \"val_accuracy\": 0.7362,\n",
  58. " \"val_loss\": 2.7518,\n",
  59. " \"learning_rate\": 1.0000e-05\n",
  60. " },\n",
  61. " 5: {\n",
  62. " \"accuracy\": 0.4996,\n",
  63. " \"loss\": 3.1509,\n",
  64. " \"val_accuracy\": 0.7730,\n",
  65. " \"val_loss\": 2.3880,\n",
  66. " \"learning_rate\": 1.0000e-05\n",
  67. " },\n",
  68. " 6: {\n",
  69. " \"accuracy\": 0.6404,\n",
  70. " \"loss\": 2.7446,\n",
  71. " \"val_accuracy\": 0.7914,\n",
  72. " \"val_loss\": 2.2776,\n",
  73. " \"learning_rate\": 1.0000e-05\n",
  74. " },\n",
  75. " 7: {\n",
  76. " \"accuracy\": 0.7238,\n",
  77. " \"loss\": 2.4671,\n",
  78. " \"val_accuracy\": 0.8160,\n",
  79. " \"val_loss\": 2.2215,\n",
  80. " \"learning_rate\": 1.0000e-05\n",
  81. " },\n",
  82. " 8: {\n",
  83. " \"accuracy\": 0.7835,\n",
  84. " \"loss\": 2.2709,\n",
  85. " \"val_accuracy\": 0.8221,\n",
  86. " \"val_loss\": 2.1896,\n",
  87. " \"learning_rate\": 1.0000e-05\n",
  88. " },\n",
  89. " 9: {\n",
  90. " \"accuracy\": 0.8369,\n",
  91. " \"loss\": 2.0898,\n",
  92. " \"val_accuracy\": 0.8344,\n",
  93. " \"val_loss\": 2.1191,\n",
  94. " \"learning_rate\": 1.0000e-05\n",
  95. " },\n",
  96. " 10: {\n",
  97. " \"accuracy\": 0.8759,\n",
  98. " \"loss\": 1.9587,\n",
  99. " \"val_accuracy\": 0.8589,\n",
  100. " \"val_loss\": 2.1327,\n",
  101. " \"learning_rate\": 1.0000e-05\n",
  102. " },\n",
  103. " 11: {\n",
  104. " \"accuracy\": 0.9061,\n",
  105. " \"loss\": 1.8641,\n",
  106. " \"val_accuracy\": 0.8528,\n",
  107. " \"val_loss\": 2.0924,\n",
  108. " \"learning_rate\": 1.0000e-05\n",
  109. " },\n",
  110. " 12: {\n",
  111. " \"accuracy\": 0.9285,\n",
  112. " \"loss\": 1.7836,\n",
  113. " \"val_accuracy\": 0.8405,\n",
  114. " \"val_loss\": 2.1150,\n",
  115. " \"learning_rate\": 1.0000e-05\n",
  116. " },\n",
  117. " 13: {\n",
  118. " \"accuracy\": 0.9438,\n",
  119. " \"loss\": 1.7205,\n",
  120. " \"val_accuracy\": 0.8466,\n",
  121. " \"val_loss\": 2.0380,\n",
  122. " \"learning_rate\": 1.0000e-05\n",
  123. " },\n",
  124. " 14: {\n",
  125. " \"accuracy\": 0.9596,\n",
  126. " \"loss\": 1.6534,\n",
  127. " \"val_accuracy\": 0.8405,\n",
  128. " \"val_loss\": 2.0600,\n",
  129. " \"learning_rate\": 1.0000e-05\n",
  130. " },\n",
  131. " 15: {\n",
  132. " \"accuracy\": 0.9639,\n",
  133. " \"loss\": 1.6170,\n",
  134. " \"val_accuracy\": 0.8528,\n",
  135. " \"val_loss\": 2.0319,\n",
  136. " \"learning_rate\": 1.0000e-05\n",
  137. " },\n",
  138. " 16: {\n",
  139. " \"accuracy\": 0.9740,\n",
  140. " \"loss\": 1.5687,\n",
  141. " \"val_accuracy\": 0.8528,\n",
  142. " \"val_loss\": 2.0340,\n",
  143. " \"learning_rate\": 1.0000e-05\n",
  144. " },\n",
  145. " 17: {\n",
  146. " \"accuracy\": 0.9745,\n",
  147. " \"loss\": 1.5440,\n",
  148. " \"val_accuracy\": 0.8405,\n",
  149. " \"val_loss\": 2.0015,\n",
  150. " \"learning_rate\": 1.0000e-05\n",
  151. " },\n",
  152. " 18: {\n",
  153. " \"accuracy\": 0.9821,\n",
  154. " \"loss\": 1.5040,\n",
  155. " \"val_accuracy\": 0.8650,\n",
  156. " \"val_loss\": 1.9666,\n",
  157. " \"learning_rate\": 1.0000e-05\n",
  158. " },\n",
  159. " 19: {\n",
  160. " \"accuracy\": 0.9877,\n",
  161. " \"loss\": 1.4689,\n",
  162. " \"val_accuracy\": 0.8650,\n",
  163. " \"val_loss\": 1.9311,\n",
  164. " \"learning_rate\": 1.0000e-05\n",
  165. " },\n",
  166. " 20: {\n",
  167. " \"accuracy\": 0.9911,\n",
  168. " \"loss\": 1.4414,\n",
  169. " \"val_accuracy\": 0.8773,\n",
  170. " \"val_loss\": 1.9213,\n",
  171. " \"learning_rate\": 1.0000e-05\n",
  172. " },\n",
  173. " 21: {\n",
  174. " \"accuracy\": 0.9876,\n",
  175. " \"loss\": 1.4215,\n",
  176. " \"val_accuracy\": 0.8528,\n",
  177. " \"val_loss\": 1.9028,\n",
  178. " \"learning_rate\": 1.0000e-05\n",
  179. " },\n",
  180. " 22: {\n",
  181. " \"accuracy\": 0.9938,\n",
  182. " \"loss\": 1.3837,\n",
  183. " \"val_accuracy\": 0.8773,\n",
  184. " \"val_loss\": 1.8655,\n",
  185. " \"learning_rate\": 1.0000e-05\n",
  186. " },\n",
  187. " 23: {\n",
  188. " \"accuracy\": 0.9908,\n",
  189. " \"loss\": 1.3609,\n",
  190. " \"val_accuracy\": 0.8528,\n",
  191. " \"val_loss\": 1.9050,\n",
  192. " \"learning_rate\": 1.0000e-05\n",
  193. " },\n",
  194. " 24: {\n",
  195. " \"accuracy\": 0.9949,\n",
  196. " \"loss\": 1.3233,\n",
  197. " \"val_accuracy\": 0.8589,\n",
  198. " \"val_loss\": 1.8978,\n",
  199. " \"learning_rate\": 1.0000e-05\n",
  200. " },\n",
  201. " 25: {\n",
  202. " \"accuracy\": 0.9954,\n",
  203. " \"loss\": 1.2955,\n",
  204. " \"val_accuracy\": 0.8650,\n",
  205. " \"val_loss\": 1.8539,\n",
  206. " \"learning_rate\": 2.0000e-06\n",
  207. " },\n",
  208. " 26: {\n",
  209. " \"accuracy\": 0.9962,\n",
  210. " \"loss\": 1.2885,\n",
  211. " \"val_accuracy\": 0.8712,\n",
  212. " \"val_loss\": 1.8314,\n",
  213. " \"learning_rate\": 2.0000e-06\n",
  214. " },\n",
  215. " 27: {\n",
  216. " \"accuracy\": 0.9958,\n",
  217. " \"loss\": 1.2817,\n",
  218. " \"val_accuracy\": 0.8650,\n",
  219. " \"val_loss\": 1.8324,\n",
  220. " \"learning_rate\": 2.0000e-06\n",
  221. " },\n",
  222. " 28: {\n",
  223. " \"accuracy\": 0.9976,\n",
  224. " \"loss\": 1.2710,\n",
  225. " \"val_accuracy\": 0.8589,\n",
  226. " \"val_loss\": 1.8316,\n",
  227. " \"learning_rate\": 2.0000e-06\n",
  228. " },\n",
  229. " 29: {\n",
  230. " \"accuracy\": 0.9963,\n",
  231. " \"loss\": 1.2652,\n",
  232. " \"val_accuracy\": 0.8650,\n",
  233. " \"val_loss\": 1.8263,\n",
  234. " \"learning_rate\": 1.0000e-06\n",
  235. " },\n",
  236. " 30: {\n",
  237. " \"accuracy\": 0.9975,\n",
  238. " \"loss\": 1.2591,\n",
  239. " \"val_accuracy\": 0.8712,\n",
  240. " \"val_loss\": 1.8105,\n",
  241. " \"learning_rate\": 1.0000e-06\n",
  242. " },\n",
  243. " 31: {\n",
  244. " \"accuracy\": 0.9979,\n",
  245. " \"loss\": 1.2559,\n",
  246. " \"val_accuracy\": 0.8650,\n",
  247. " \"val_loss\": 1.8199,\n",
  248. " \"learning_rate\": 1.0000e-06\n",
  249. " },\n",
  250. " 32: {\n",
  251. " \"accuracy\": 0.9978,\n",
  252. " \"loss\": 1.2485,\n",
  253. " \"val_accuracy\": 0.8650,\n",
  254. " \"val_loss\": 1.8312,\n",
  255. " \"learning_rate\": 1.0000e-06\n",
  256. " },\n",
  257. " 33: {\n",
  258. " \"accuracy\": 0.9972,\n",
  259. " \"loss\": 1.2443,\n",
  260. " \"val_accuracy\": 0.8528,\n",
  261. " \"val_loss\": 1.8201,\n",
  262. " \"learning_rate\": 1.0000e-06\n",
  263. " },\n",
  264. " 34: {\n",
  265. " \"accuracy\": 0.9980,\n",
  266. " \"loss\": 1.2402,\n",
  267. " \"val_accuracy\": 0.8650,\n",
  268. " \"val_loss\": 1.8054,\n",
  269. " \"learning_rate\": 1.0000e-06\n",
  270. " },\n",
  271. " 35: {\n",
  272. " \"accuracy\": 0.9987,\n",
  273. " \"loss\": 1.2324,\n",
  274. " \"val_accuracy\": 0.8528,\n",
  275. " \"val_loss\": 1.7969,\n",
  276. " \"learning_rate\": 1.0000e-06\n",
  277. " },\n",
  278. " 36: {\n",
  279. " \"accuracy\": 0.9985,\n",
  280. " \"loss\": 1.2261,\n",
  281. " \"val_accuracy\": 0.8589,\n",
  282. " \"val_loss\": 1.7965,\n",
  283. " \"learning_rate\": 1.0000e-06\n",
  284. " },\n",
  285. " 37: {\n",
  286. " \"accuracy\": 0.9986,\n",
  287. " \"loss\": 1.2198,\n",
  288. " \"val_accuracy\": 0.8589,\n",
  289. " \"val_loss\": 1.7878,\n",
  290. " \"learning_rate\": 1.0000e-06\n",
  291. " },\n",
  292. " 38: {\n",
  293. " \"accuracy\": 0.9990,\n",
  294. " \"loss\": 1.2131,\n",
  295. " \"val_accuracy\": 0.8650,\n",
  296. " \"val_loss\": 1.7906,\n",
  297. " \"learning_rate\": 1.0000e-06\n",
  298. " },\n",
  299. " 39: {\n",
  300. " \"accuracy\": 0.9977,\n",
  301. " \"loss\": 1.2106,\n",
  302. " \"val_accuracy\": 0.8589,\n",
  303. " \"val_loss\": 1.7810,\n",
  304. " \"learning_rate\": 1.0000e-06\n",
  305. " }\n",
  306. "}\n"
  307. ]
  308. },
  309. {
  310. "cell_type": "code",
  311. "execution_count": null,
  312. "metadata": {},
  313. "outputs": [],
  314. "source": [
  315. "import json\n",
  316. "import tensorflow as tf\n",
  317. "from tensorflow.keras.applications import ResNet50\n",
  318. "from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout\n",
  319. "from tensorflow.keras.models import Model\n",
  320. "from tensorflow.keras.optimizers import Adam\n",
  321. "from tensorflow.keras.regularizers import l2\n",
  322. "from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau\n",
  323. "\n",
  324. "# Define paths to directories\n",
  325. "train_dir = 'C:/Users/vsavelev/GITHUB/DS_projet/LAYER2/MO/train'\n",
  326. "validation_dir = 'C:/Users/vsavelev/GITHUB/DS_projet/LAYER2/MO/validation'\n",
  327. "\n",
  328. "# Create Dataset from directories\n",
  329. "def create_dataset(directory, batch_size):\n",
  330. " dataset = tf.keras.preprocessing.image_dataset_from_directory(\n",
  331. " directory,\n",
  332. " image_size=(224, 224),\n",
  333. " batch_size=batch_size,\n",
  334. " label_mode='categorical', # Labels are one-hot encoded\n",
  335. " shuffle=True\n",
  336. " )\n",
  337. " return dataset\n",
  338. "\n",
  339. "# Load data from directories\n",
  340. "batch_size = 32\n",
  341. "train_dataset = create_dataset(train_dir, batch_size)\n",
  342. "validation_dataset = create_dataset(validation_dir, batch_size)\n",
  343. "\n",
  344. "# Build ResNet50 Model\n",
  345. "base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
  346. "x = base_model.output\n",
  347. "x = GlobalAveragePooling2D()(x)\n",
  348. "x = Dropout(0.7)(x)\n",
  349. "x = Dense(1024, activation='relu', kernel_regularizer=l2(0.001))(x)\n",
  350. "x = Dropout(0.7)(x)\n",
  351. "predictions = Dense(len(train_dataset.class_names), activation='softmax', kernel_regularizer=l2(0.01))(x)\n",
  352. "\n",
  353. "model = Model(inputs=base_model.input, outputs=predictions)\n",
  354. "\n",
  355. "# Unfreeze the last few layers\n",
  356. "for layer in base_model.layers[-5:]:\n",
  357. " layer.trainable = True\n",
  358. "\n",
  359. "# Compile the model\n",
  360. "model.compile(optimizer=Adam(learning_rate=1e-5),\n",
  361. " loss='categorical_crossentropy',\n",
  362. " metrics=['accuracy'])\n",
  363. "\n",
  364. "# Callbacks for early stopping and learning rate reduction\n",
  365. "early_stopping = EarlyStopping(\n",
  366. " monitor='val_loss', \n",
  367. " patience=3, \n",
  368. " restore_best_weights=True \n",
  369. ")\n",
  370. "\n",
  371. "reduce_lr = ReduceLROnPlateau(\n",
  372. " monitor='val_loss', \n",
  373. " factor=0.2, \n",
  374. " patience=2, \n",
  375. " min_lr=1e-6 \n",
  376. ")\n",
  377. "\n",
  378. "# Custom callback to save metrics per epoch to a JSON file\n",
  379. "class MetricsLogger(tf.keras.callbacks.Callback):\n",
  380. " def __init__(self, log_file='training_metrics.json'):\n",
  381. " super().__init__()\n",
  382. " self.log_file = log_file\n",
  383. " self.logs = []\n",
  384. "\n",
  385. " def on_epoch_end(self, epoch, logs=None):\n",
  386. " self.logs.append({\n",
  387. " 'epoch': epoch,\n",
  388. " 'loss': logs['loss'],\n",
  389. " 'accuracy': logs['accuracy'],\n",
  390. " 'val_loss': logs['val_loss'],\n",
  391. " 'val_accuracy': logs['val_accuracy']\n",
  392. " })\n",
  393. " with open(self.log_file, 'w') as f:\n",
  394. " json.dump(self.logs, f, indent=4)\n",
  395. "\n",
  396. "# Train the model and log metrics\n",
  397. "metrics_logger = MetricsLogger(log_file='training_metrics.json')\n",
  398. "\n",
  399. "history = model.fit(\n",
  400. " train_dataset,\n",
  401. " epochs=50,\n",
  402. " validation_data=validation_dataset,\n",
  403. " callbacks=[reduce_lr, early_stopping, metrics_logger]\n",
  404. ")\n",
  405. "\n",
  406. "# Save the model\n",
  407. "model.save('resnet50_model_callbacks.h5')\n"
  408. ]
  409. },
  410. {
  411. "cell_type": "code",
  412. "execution_count": null,
  413. "metadata": {},
  414. "outputs": [],
  415. "source": [
  416. "import mlflow\n",
  417. "import mlflow.keras\n",
  418. "import json\n",
  419. "import matplotlib.pyplot as plt\n",
  420. "\n",
  421. "# Function to load metrics from JSON\n",
  422. "def load_metrics_from_json(log_file='training_metrics.json'):\n",
  423. " with open(log_file, 'r') as f:\n",
  424. " return json.load(f)\n",
  425. "\n",
  426. "# Function to plot training history\n",
  427. "def plot_training_history(metrics):\n",
  428. " epochs = [m['epoch'] for m in metrics]\n",
  429. " train_accuracy = [m['accuracy'] for m in metrics]\n",
  430. " val_accuracy = [m['val_accuracy'] for m in metrics]\n",
  431. " train_loss = [m['loss'] for m in metrics]\n",
  432. " val_loss = [m['val_loss'] for m in metrics]\n",
  433. " \n",
  434. " plt.figure(figsize=(12, 8))\n",
  435. " \n",
  436. " # Plot training & validation accuracy values\n",
  437. " plt.subplot(1, 2, 1)\n",
  438. " plt.plot(epochs, train_accuracy)\n",
  439. " plt.plot(epochs, val_accuracy)\n",
  440. " plt.title('Model accuracy')\n",
  441. " plt.ylabel('Accuracy')\n",
  442. " plt.xlabel('Epoch')\n",
  443. " plt.legend(['Train', 'Validation'], loc='upper left')\n",
  444. "\n",
  445. " # Plot training & validation loss values\n",
  446. " plt.subplot(1, 2, 2)\n",
  447. " plt.plot(epochs, train_loss)\n",
  448. " plt.plot(epochs, val_loss)\n",
  449. " plt.title('Model loss')\n",
  450. " plt.ylabel('Loss')\n",
  451. " plt.xlabel('Epoch')\n",
  452. " plt.legend(['Train', 'Validation'], loc='upper left')\n",
  453. "\n",
  454. " plt.savefig(\"training_history.png\")\n",
  455. " plt.show()\n",
  456. "\n",
  457. "# Load metrics and log to MLflow\n",
  458. "with mlflow.start_run():\n",
  459. " # Load metrics from JSON\n",
  460. " metrics = load_metrics_from_json('training_metrics.json')\n",
  461. "\n",
  462. " # Log metrics for each epoch\n",
  463. " for m in metrics:\n",
  464. " mlflow.log_metric('train_loss', m['loss'], step=m['epoch'])\n",
  465. " mlflow.log_metric('train_accuracy', m['accuracy'], step=m['epoch'])\n",
  466. " mlflow.log_metric('val_loss', m['val_loss'], step=m['epoch'])\n",
  467. " mlflow.log_metric('val_accuracy', m['val_accuracy'], step=m['epoch'])\n",
  468. "\n",
  469. " # Log hyperparameters\n",
  470. " mlflow.log_param(\"learning_rate\", 1e-5)\n",
  471. " mlflow.log_param(\"batch_size\", 32)\n",
  472. " mlflow.log_param(\"dropout_rate\", 0.7)\n",
  473. " mlflow.log_param(\"l2_regularization\", 0.01)\n",
  474. " \n",
  475. " # Plot and log training history\n",
  476. " plot_training_history(metrics)\n",
  477. " mlflow.log_artifact(\"training_history.png\")\n",
  478. "\n",
  479. " # Log the trained model\n",
  480. " mlflow.keras.log_model(tf.keras.models.load_model('resnet50_model_callbacks.h5'), \"resnet50_model\")\n",
  481. "\n"
  482. ]
  483. }
  484. ],
  485. "metadata": {
  486. "kernelspec": {
  487. "display_name": "base",
  488. "language": "python",
  489. "name": "python3"
  490. },
  491. "language_info": {
  492. "codemirror_mode": {
  493. "name": "ipython",
  494. "version": 3
  495. },
  496. "file_extension": ".py",
  497. "mimetype": "text/x-python",
  498. "name": "python",
  499. "nbconvert_exporter": "python",
  500. "pygments_lexer": "ipython3",
  501. "version": "3.11.7"
  502. }
  503. },
  504. "nbformat": 4,
  505. "nbformat_minor": 2
  506. }