From 23ccddc0972402bfcdd3ca9f3a15ef67a59e1100 Mon Sep 17 00:00:00 2001 From: Will Dirschka Date: Mon, 1 Mar 2021 22:40:20 -0500 Subject: [PATCH] removed older file --- digits_example.ipynb | 479 ------------------------------------------- 1 file changed, 479 deletions(-) delete mode 100644 digits_example.ipynb diff --git a/digits_example.ipynb b/digits_example.ipynb deleted file mode 100644 index 8a2700a..0000000 --- a/digits_example.ipynb +++ /dev/null @@ -1,479 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 16, - "id": "senior-message", - "metadata": {}, - "outputs": [], - "source": [ - "import sklearn\n", - "import torch\n", - "import pandas\n", - "\n", - "from sklearn import datasets\n", - "import numpy as np\n", - "\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "decimal-reviewer", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'images', 'DESCR'])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data = datasets.load_digits()\n", - "print(data.keys())\n", - "print(data[\"data\"][0])\n", - "print(np.max(data[\"data\"]))" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "sporting-dealer", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "5\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAKmElEQVR4nO3d34tc9RnH8c+nG6W1plloYpFsyORCAlJoIktAUqKJWGIVNxe9SEChUvCmSpYWRHvXf8DYiyJI1AqmShsVRKxW0LUVWmsSd1uTNSUNG7JBm4Qa/HHREH16sROIsnbPzJxf8/h+weL+GPb7jO7bM3N29nwdEQKQx9eaHgBAuYgaSIaogWSIGkiGqIFkllXxTVeuXBmdTqeKb92oo0eP1rre8uXLa12vLsuWVfJjt6irrrqqtrXqNDc3p7Nnz3qxr1Xyb7fT6ejAgQNVfOtGbd26tdb1brzxxlrXq8vo6Ghta+3evbu2teo0Pj7+pV/j4TeQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEyhqG1vt33U9jHb91c9FID+LRm17RFJv5Z0i6RrJe2yfW3VgwHoT5Ej9SZJxyLieEScl/S0pIlqxwLQryJRr5Z08pKP57uf+xzbd9s+YPvAmTNnypoPQI9KO1EWEY9ExHhEjK9ataqsbwugR0WiPiVpzSUfj3U/B6CFikT9lqRrbK+zfbmknZKer3YsAP1a8iIJEXHB9j2SXpY0IumxiDhc+WQA+lLoyicR8aKkFyueBUAJeEUZkAxRA8kQNZAMUQPJEDWQDFEDyRA1kEx9+59UZGZmpra1pqamalur7vUmJur7w7u6dzr5quFIDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkV26HjM9mnb79QxEIDBFDlS/0bS9ornAFCSJaOOiD9J+k8NswAoQWnPqdl2B2gHtt0BkuHsN5AMUQPJFPmV1lOS/iJpve152z+pfiwA/Sqyl9auOgYBUA4efgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJDP22O6Ojo7Wt1el0altLkqanp2tba8WKFbWthWpxpAaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIJki1yhbY/s120dsH7a9u47BAPSnyGu/L0j6eUQcsr1c0kHbr0TEkYpnA9CHItvuvBcRh7rvfyRpVtLqqgcD0J+enlPb7kjaKOnNRb7GtjtACxSO2vaVkp6RNBkRH37x62y7A7RDoahtX6aFoPdFxLPVjgRgEEXOflvSo5JmI+LB6kcCMIgiR+rNku6UtM32dPfthxXPBaBPRbbdeUOSa5gFQAl4RRmQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyQz9Xlp17jdVt6mpqdrWmpiYqG0tVIsjNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQTJELD37d9t9sz3S33fllHYMB6E+Rl4n+V9K2iPi4e6ngN2z/ISL+WvFsAPpQ5MKDIenj7oeXdd+iyqEA9K/oxfxHbE9LOi3plYhg2x2gpQpFHRGfRsQGSWOSNtn+7iK3YdsdoAV6OvsdEeckvSZpeyXTABhYkbPfq2yPdt//hqSbJb1b8VwA+lTk7PfVkp6wPaKF/wn8LiJeqHYsAP0qcvb771rYkxrAEOAVZUAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kM/Tb7mzYsKG2tc6dO1fbWpK0Y8eO2taanJysba09e/bUttZXEUdqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSKRx194L+b9vmooNAi/VypN4tabaqQQCUo+i2O2OSbpW0t9pxAAyq6JH6IUn3Sfrsy27AXlpAOxTZoeM2Sacj4uD/ux17aQHtUORIvVnS7bbnJD0taZvtJyudCkDflow6Ih6IiLGI6EjaKenViLij8skA9IXfUwPJ9HQ5o4iYkjRVySQASsGRGkiGqIFkiBpIhqiBZIgaSIaogWSIGkhm6LfdWbt2bW1rffDBB7WtJUknTpyoba1Op1PbWnVu8VPnz0dbcKQGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiCZQi8T7V5J9CNJn0q6EBHjVQ4FoH+9vPZ7a0ScrWwSAKXg4TeQTNGoQ9IfbR+0ffdiN2DbHaAdikb9/Yi4TtItkn5qe8sXb8C2O0A7FIo6Ik51/3la0nOSNlU5FID+Fdkg75u2l198X9IPJL1T9WAA+lPk7Pd3JD1n++LtfxsRL1U6FYC+LRl1RByX9L0aZgFQAn6lBSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSQz9Nvu1On111+vdb2pqala16vL3NxcbWux7Q6AoUfUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyhaK2PWp7v+13bc/avr7qwQD0p+hrv38l6aWI+JHtyyVdUeFMAAawZNS2V0jaIunHkhQR5yWdr3YsAP0q8vB7naQzkh63/bbtvd3rf38O2+4A7VAk6mWSrpP0cERslPSJpPu/eCO23QHaoUjU85LmI+LN7sf7tRA5gBZaMuqIeF/SSdvru5+6SdKRSqcC0LeiZ7/vlbSve+b7uKS7qhsJwCAKRR0R05LGqx0FQBl4RRmQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDybCXVg/27NlT63ozMzO1rTU5OVnbWjfccENta30VcaQGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpJZMmrb621PX/L2oe3JGmYD0IclXyYaEUclbZAk2yOSTkl6rtqxAPSr14ffN0n6V0ScqGIYAIPrNeqdkp5a7AtsuwO0Q+Gou9f8vl3S7xf7OtvuAO3Qy5H6FkmHIuLfVQ0DYHC9RL1LX/LQG0B7FIq6u3XtzZKerXYcAIMquu3OJ5K+XfEsAErAK8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSMYRUf43tc9I6vXPM1dKOlv6MO2Q9b5xv5qzNiIW/cupSqLuh+0DETHe9BxVyHrfuF/txMNvIBmiBpJpU9SPND1AhbLeN+5XC7XmOTWAcrTpSA2gBEQNJNOKqG1vt33U9jHb9zc9Txlsr7H9mu0jtg/b3t30TGWyPWL7bdsvND1LmWyP2t5v+13bs7avb3qmXjX+nLq7QcA/tXC5pHlJb0naFRFHGh1sQLavlnR1RByyvVzSQUk7hv1+XWT7Z5LGJX0rIm5rep6y2H5C0p8jYm/3CrpXRMS5hsfqSRuO1JskHYuI4xFxXtLTkiYanmlgEfFeRBzqvv+RpFlJq5udqhy2xyTdKmlv07OUyfYKSVskPSpJEXF+2IKW2hH1akknL/l4Xkl++C+y3ZG0UdKbDY9Slock3Sfps4bnKNs6SWckPd59arG3e9HNodKGqFOzfaWkZyRNRsSHTc8zKNu3STodEQebnqUCyyRdJ+nhiNgo6RNJQ3eOpw1Rn5K05pKPx7qfG3q2L9NC0PsiIsvllTdLut32nBaeKm2z/WSzI5VmXtJ8RFx8RLVfC5EPlTZE/Zaka2yv656Y2Cnp+YZnGphta+G52WxEPNj0PGWJiAciYiwiOlr4b/VqRNzR8FiliIj3JZ20vb77qZskDd2JzULX/a5SRFywfY+klyWNSHosIg43PFYZNku6U9I/bE93P/eLiHixuZFQwL2S9nUPMMcl3dXwPD1r/FdaAMrVhoffAEpE1EAyRA0kQ9RAMkQNJEPUQDJEDSTzP9kTn3l/eNyAAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "index = 25\n", - "print(data[\"target\"][index])\n", - "plt.imshow(16 - data[\"data\"][index].reshape(8,8), cmap=\"gray\", vmin=0, vmax=16)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "inclusive-registration", - "metadata": {}, - "outputs": [], - "source": [ - "from skorch.regressor import NeuralNetRegressor\n", - "from skorch.utils import to_tensor, to_numpy\n", - "from skorch.net import NeuralNet\n", - "\n", - "import torch\n", - "import torch.nn as nn" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "sitting-truth", - "metadata": {}, - "outputs": [], - "source": [ - "# This is a Neural Network where it'll be a CNN followed by a fully connected layer\n", - "class Example(nn.Module):\n", - " def __init__(self):\n", - " super(Example, self).__init__()\n", - " \"\"\"\n", - " CNN = nn.Sequential(\n", - " nn.Conv2d(1,16, kernel_size=3, padding=1, stride=1),\n", - " nn.ReLU()\n", - " )\n", - " \"\"\"\n", - " self.FINAL = nn.Sequential(\n", - " nn.Linear(64,64),\n", - " nn.Linear(64,64),\n", - " nn.Linear(64,10),\n", - " nn.ReLU()\n", - " )\n", - " \n", - " def forward(self, x):\n", - " x = self.FINAL(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "dense-directory", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.0809, 1.5086, 0.0000, ..., 0.0000, 0.3085, 0.0000],\n", - " [0.0878, 1.7563, 0.6057, ..., 0.0000, 0.3286, 0.0000],\n", - " [0.0000, 1.2282, 0.5646, ..., 0.2210, 0.6609, 0.0000],\n", - " ...,\n", - " [0.7648, 2.1294, 0.0000, ..., 0.0000, 0.5024, 0.0000],\n", - " [0.5641, 2.7074, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [0.3867, 1.7069, 0.0000, ..., 0.0000, 0.6122, 0.0755]],\n", - " grad_fn=)" - ] - }, - "execution_count": 99, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Example().forward(to_tensor(data[\"data\"].astype(np.single), device=\"cpu\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "wired-helmet", - "metadata": {}, - "outputs": [], - "source": [ - "true_target = np.zeros((len(data[\"target\"]), 10)).astype(np.single)\n", - "for idx, thing in enumerate(data[\"target\"]):\n", - " true_target[idx][thing] = 1\n", - " \n", - "true_feature = to_tensor(data[\"data\"].astype(np.single), device=\"cpu\")\n", - "true_target = to_tensor(true_target, device=\"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "id": "checked-highlight", - "metadata": {}, - "outputs": [], - "source": [ - "simple_nn = NeuralNetRegressor(\n", - " Example,\n", - " max_epochs=1,\n", - " batch_size=100,\n", - " warm_start=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "id": "inner-mission", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " epoch train_loss valid_loss dur\n", - "------- ------------ ------------ ------\n", - " 1 \u001b[36m0.1681\u001b[0m \u001b[32m0.1033\u001b[0m 0.0689\n", - " 2 \u001b[36m0.0932\u001b[0m \u001b[32m0.0961\u001b[0m 0.0689\n", - " 3 \u001b[36m0.0863\u001b[0m \u001b[32m0.0912\u001b[0m 0.0605\n", - " 4 \u001b[36m0.0823\u001b[0m \u001b[32m0.0879\u001b[0m 0.0690\n", - " 5 \u001b[36m0.0793\u001b[0m \u001b[32m0.0855\u001b[0m 0.0657\n", - " 6 \u001b[36m0.0767\u001b[0m \u001b[32m0.0837\u001b[0m 0.0724\n", - " 7 \u001b[36m0.0741\u001b[0m \u001b[32m0.0818\u001b[0m 0.0692\n", - " 8 \u001b[36m0.0721\u001b[0m \u001b[32m0.0801\u001b[0m 0.0693\n", - " 9 \u001b[36m0.0706\u001b[0m \u001b[32m0.0786\u001b[0m 0.0705\n", - " 10 \u001b[36m0.0695\u001b[0m \u001b[32m0.0773\u001b[0m 0.0718\n", - " 11 \u001b[36m0.0684\u001b[0m \u001b[32m0.0762\u001b[0m 0.0713\n", - " 12 \u001b[36m0.0676\u001b[0m \u001b[32m0.0753\u001b[0m 0.0638\n", - " 13 \u001b[36m0.0669\u001b[0m \u001b[32m0.0744\u001b[0m 0.0698\n", - " 14 \u001b[36m0.0662\u001b[0m \u001b[32m0.0735\u001b[0m 0.1182\n", - " 15 \u001b[36m0.0657\u001b[0m \u001b[32m0.0729\u001b[0m 0.0730\n", - " 16 \u001b[36m0.0653\u001b[0m \u001b[32m0.0723\u001b[0m 0.0708\n", - " 17 \u001b[36m0.0648\u001b[0m \u001b[32m0.0718\u001b[0m 0.0712\n", - " 18 \u001b[36m0.0644\u001b[0m \u001b[32m0.0713\u001b[0m 0.0778\n", - " 19 \u001b[36m0.0641\u001b[0m \u001b[32m0.0708\u001b[0m 0.0741\n", - " 20 \u001b[36m0.0638\u001b[0m \u001b[32m0.0704\u001b[0m 0.0763\n", - " 21 \u001b[36m0.0635\u001b[0m \u001b[32m0.0701\u001b[0m 0.0723\n", - " 22 \u001b[36m0.0633\u001b[0m \u001b[32m0.0697\u001b[0m 0.0735\n", - " 23 \u001b[36m0.0631\u001b[0m \u001b[32m0.0694\u001b[0m 0.0733\n", - " 24 \u001b[36m0.0628\u001b[0m \u001b[32m0.0690\u001b[0m 0.0694\n", - " 25 \u001b[36m0.0626\u001b[0m \u001b[32m0.0687\u001b[0m 0.0668\n", - " 26 \u001b[36m0.0625\u001b[0m \u001b[32m0.0684\u001b[0m 0.0716\n", - " 27 \u001b[36m0.0623\u001b[0m \u001b[32m0.0681\u001b[0m 0.0782\n", - " 28 \u001b[36m0.0621\u001b[0m \u001b[32m0.0678\u001b[0m 0.0778\n", - " 29 \u001b[36m0.0619\u001b[0m \u001b[32m0.0675\u001b[0m 0.0745\n", - " 30 \u001b[36m0.0617\u001b[0m \u001b[32m0.0675\u001b[0m 0.0782\n", - " 31 \u001b[36m0.0615\u001b[0m \u001b[32m0.0672\u001b[0m 0.0736\n", - " 32 \u001b[36m0.0613\u001b[0m \u001b[32m0.0671\u001b[0m 0.0755\n", - " 33 \u001b[36m0.0611\u001b[0m \u001b[32m0.0668\u001b[0m 0.0786\n", - " 34 \u001b[36m0.0609\u001b[0m \u001b[32m0.0666\u001b[0m 0.0715\n", - " 35 \u001b[36m0.0607\u001b[0m \u001b[32m0.0663\u001b[0m 0.0645\n", - " 36 \u001b[36m0.0606\u001b[0m \u001b[32m0.0661\u001b[0m 0.0723\n", - " 37 \u001b[36m0.0604\u001b[0m \u001b[32m0.0659\u001b[0m 0.0773\n", - " 38 \u001b[36m0.0603\u001b[0m \u001b[32m0.0657\u001b[0m 0.0786\n", - " 39 \u001b[36m0.0601\u001b[0m \u001b[32m0.0655\u001b[0m 0.0737\n", - " 40 \u001b[36m0.0600\u001b[0m \u001b[32m0.0653\u001b[0m 0.0753\n", - " 41 \u001b[36m0.0598\u001b[0m \u001b[32m0.0651\u001b[0m 0.0741\n", - " 42 \u001b[36m0.0597\u001b[0m \u001b[32m0.0649\u001b[0m 0.0754\n", - " 43 \u001b[36m0.0595\u001b[0m \u001b[32m0.0647\u001b[0m 0.0702\n", - " 44 \u001b[36m0.0594\u001b[0m \u001b[32m0.0645\u001b[0m 0.0788\n", - " 45 \u001b[36m0.0593\u001b[0m \u001b[32m0.0643\u001b[0m 0.0777\n", - " 46 \u001b[36m0.0591\u001b[0m \u001b[32m0.0641\u001b[0m 0.0811\n", - " 47 \u001b[36m0.0589\u001b[0m \u001b[32m0.0639\u001b[0m 0.0734\n", - " 48 \u001b[36m0.0588\u001b[0m \u001b[32m0.0637\u001b[0m 0.0825\n", - " 49 \u001b[36m0.0587\u001b[0m \u001b[32m0.0635\u001b[0m 0.0735\n", - " 50 \u001b[36m0.0586\u001b[0m 0.0636 0.0745\n", - " 51 \u001b[36m0.0583\u001b[0m \u001b[32m0.0632\u001b[0m 0.0740\n", - " 52 \u001b[36m0.0580\u001b[0m \u001b[32m0.0629\u001b[0m 0.0748\n", - " 53 \u001b[36m0.0578\u001b[0m \u001b[32m0.0624\u001b[0m 0.0785\n", - " 54 \u001b[36m0.0574\u001b[0m \u001b[32m0.0614\u001b[0m 0.0786\n", - " 55 \u001b[36m0.0570\u001b[0m \u001b[32m0.0605\u001b[0m 0.0752\n", - " 56 \u001b[36m0.0564\u001b[0m \u001b[32m0.0597\u001b[0m 0.0744\n", - " 57 \u001b[36m0.0554\u001b[0m \u001b[32m0.0585\u001b[0m 0.0735\n", - " 58 \u001b[36m0.0545\u001b[0m \u001b[32m0.0572\u001b[0m 0.0718\n", - " 59 \u001b[36m0.0536\u001b[0m \u001b[32m0.0564\u001b[0m 0.0748\n", - " 60 \u001b[36m0.0529\u001b[0m \u001b[32m0.0558\u001b[0m 0.0720\n", - " 61 \u001b[36m0.0522\u001b[0m \u001b[32m0.0553\u001b[0m 0.0671\n", - " 62 \u001b[36m0.0518\u001b[0m \u001b[32m0.0548\u001b[0m 0.0661\n", - " 63 \u001b[36m0.0513\u001b[0m \u001b[32m0.0545\u001b[0m 0.0801\n", - " 64 \u001b[36m0.0509\u001b[0m \u001b[32m0.0543\u001b[0m 0.0746\n", - " 65 \u001b[36m0.0506\u001b[0m \u001b[32m0.0541\u001b[0m 0.0782\n", - " 66 \u001b[36m0.0502\u001b[0m \u001b[32m0.0540\u001b[0m 0.0787\n", - " 67 \u001b[36m0.0499\u001b[0m \u001b[32m0.0539\u001b[0m 0.0754\n", - " 68 \u001b[36m0.0495\u001b[0m \u001b[32m0.0538\u001b[0m 0.0692\n", - " 69 \u001b[36m0.0493\u001b[0m \u001b[32m0.0538\u001b[0m 0.0691\n", - " 70 \u001b[36m0.0490\u001b[0m \u001b[32m0.0538\u001b[0m 0.0697\n", - " 71 \u001b[36m0.0488\u001b[0m \u001b[32m0.0538\u001b[0m 0.0704\n", - " 72 \u001b[36m0.0486\u001b[0m \u001b[32m0.0537\u001b[0m 0.0686\n", - " 73 \u001b[36m0.0485\u001b[0m \u001b[32m0.0537\u001b[0m 0.0658\n", - " 74 \u001b[36m0.0483\u001b[0m \u001b[32m0.0536\u001b[0m 0.0679\n", - " 75 \u001b[36m0.0482\u001b[0m \u001b[32m0.0536\u001b[0m 0.0670\n", - " 76 \u001b[36m0.0481\u001b[0m \u001b[32m0.0535\u001b[0m 0.0703\n", - " 77 \u001b[36m0.0480\u001b[0m \u001b[32m0.0535\u001b[0m 0.0650\n", - " 78 \u001b[36m0.0479\u001b[0m \u001b[32m0.0534\u001b[0m 0.0640\n", - " 79 \u001b[36m0.0478\u001b[0m \u001b[32m0.0533\u001b[0m 0.0653\n", - " 80 \u001b[36m0.0477\u001b[0m \u001b[32m0.0533\u001b[0m 0.0650\n", - " 81 \u001b[36m0.0476\u001b[0m \u001b[32m0.0532\u001b[0m 0.0707\n", - " 82 \u001b[36m0.0475\u001b[0m \u001b[32m0.0532\u001b[0m 0.0633\n", - " 83 \u001b[36m0.0474\u001b[0m \u001b[32m0.0531\u001b[0m 0.0641\n", - " 84 \u001b[36m0.0473\u001b[0m \u001b[32m0.0531\u001b[0m 0.0691\n", - " 85 \u001b[36m0.0472\u001b[0m \u001b[32m0.0530\u001b[0m 0.0644\n", - " 86 \u001b[36m0.0471\u001b[0m \u001b[32m0.0530\u001b[0m 0.0658\n", - " 87 \u001b[36m0.0471\u001b[0m 0.0530 0.0696\n", - " 88 \u001b[36m0.0470\u001b[0m \u001b[32m0.0530\u001b[0m 0.0695\n", - " 89 \u001b[36m0.0469\u001b[0m \u001b[32m0.0529\u001b[0m 0.0780\n", - " 90 \u001b[36m0.0468\u001b[0m \u001b[32m0.0529\u001b[0m 0.0805\n", - " 91 \u001b[36m0.0468\u001b[0m \u001b[32m0.0528\u001b[0m 0.0723\n", - " 92 \u001b[36m0.0467\u001b[0m \u001b[32m0.0528\u001b[0m 0.0750\n", - " 93 \u001b[36m0.0466\u001b[0m \u001b[32m0.0528\u001b[0m 0.0852\n", - " 94 \u001b[36m0.0466\u001b[0m \u001b[32m0.0528\u001b[0m 0.0759\n", - " 95 \u001b[36m0.0465\u001b[0m \u001b[32m0.0527\u001b[0m 0.0782\n", - " 96 \u001b[36m0.0465\u001b[0m \u001b[32m0.0527\u001b[0m 0.0705\n", - " 97 \u001b[36m0.0464\u001b[0m \u001b[32m0.0527\u001b[0m 0.0687\n", - " 98 \u001b[36m0.0464\u001b[0m \u001b[32m0.0527\u001b[0m 0.0764\n", - " 99 \u001b[36m0.0463\u001b[0m \u001b[32m0.0526\u001b[0m 0.0765\n", - " 100 \u001b[36m0.0463\u001b[0m \u001b[32m0.0526\u001b[0m 0.0738\n" - ] - }, - { - "data": { - "text/plain": [ - "[initialized](\n", - " module_=Example(\n", - " (FINAL): Sequential(\n", - " (0): Linear(in_features=64, out_features=64, bias=True)\n", - " (1): Linear(in_features=64, out_features=64, bias=True)\n", - " (2): Linear(in_features=64, out_features=10, bias=True)\n", - " (3): ReLU()\n", - " )\n", - " ),\n", - ")" - ] - }, - "execution_count": 102, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "simple_nn.max_epochs=100\n", - "simple_nn.fit(true_feature, true_target)" - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "id": "variable-northern", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "tensor([[-0.0682, 0.0583, -0.0530, 0.0937, -0.1139, 0.0543, -0.0610, 0.0346,\n", - " -0.1029, -0.0529, 0.0331, -0.1198, 0.0675, -0.0422, 0.0507, 0.0855,\n", - " -0.0359, -0.0737, -0.1115, 0.1023, -0.0264, 0.0894, -0.1055, 0.1010,\n", - " 0.1048, 0.0235, -0.0558, 0.0308, -0.0337, -0.0468, -0.0640, 0.0014,\n", - " 0.0745, 0.0759, -0.0443, -0.0605, 0.0350, 0.0552, 0.1105, 0.1237,\n", - " -0.0117, -0.0714, -0.0716, 0.0023, -0.0093, 0.0493, 0.0133, -0.1223,\n", - " -0.0695, -0.0871, -0.1051, -0.0561, -0.0374, 0.0455, -0.1184, -0.0959,\n", - " 0.1205, -0.0795, -0.1223, -0.0029, -0.0978, -0.1356, -0.0249, 0.0818],\n", - " [ 0.0268, -0.0151, 0.0594, 0.0007, -0.1805, 0.0184, 0.0588, 0.0713,\n", - " -0.1118, -0.1249, -0.1760, 0.0042, -0.0618, -0.0826, 0.0063, -0.0493,\n", - " -0.0999, 0.0408, -0.0771, 0.0466, -0.0897, 0.0043, -0.1119, 0.0797,\n", - " 0.0040, -0.0281, -0.0161, -0.0697, -0.1113, 0.0390, -0.1071, -0.0502,\n", - " 0.0019, 0.0091, -0.0656, -0.0504, 0.0005, -0.0169, -0.0192, 0.1181,\n", - " -0.0622, 0.0554, -0.0725, 0.0197, -0.0681, -0.1217, -0.0663, -0.0783,\n", - " 0.0783, -0.1152, -0.0111, 0.0245, -0.0715, 0.0231, -0.0308, -0.1178,\n", - " 0.0520, 0.0406, 0.0662, 0.0062, -0.1208, -0.0780, -0.0425, 0.0616],\n", - " [ 0.1000, 0.0513, -0.0716, -0.0646, 0.0159, -0.0503, 0.0974, 0.0864,\n", - " 0.0261, 0.0199, -0.0792, -0.1401, -0.0302, 0.0259, -0.0120, 0.1097,\n", - " 0.0390, -0.1284, -0.0029, -0.0645, -0.0557, -0.0680, -0.1333, -0.0501,\n", - " 0.0055, -0.1356, -0.1075, -0.1406, -0.1544, -0.0535, -0.1107, -0.0782,\n", - " -0.0704, -0.1291, -0.1302, 0.0390, -0.0837, -0.0025, 0.0362, 0.0558,\n", - " -0.0584, -0.1022, 0.0020, -0.1117, -0.0480, 0.0321, 0.0818, -0.0209,\n", - " -0.0179, -0.0105, 0.0370, -0.1360, 0.0286, -0.1461, -0.0615, -0.0715,\n", - " -0.0027, 0.1225, 0.0733, -0.1936, -0.0960, 0.0399, -0.0506, 0.0918],\n", - " [-0.0936, -0.0409, 0.0725, -0.0452, -0.0925, -0.0254, 0.0173, -0.0147,\n", - " 0.0371, 0.0394, 0.0031, -0.0767, -0.0770, 0.0132, -0.0612, -0.0763,\n", - " 0.0687, -0.0840, -0.1908, 0.0080, -0.1263, 0.0380, -0.1317, -0.0976,\n", - " 0.1072, -0.0212, 0.0227, 0.0503, -0.1678, 0.0411, -0.1306, -0.1199,\n", - " -0.0212, 0.0697, -0.1497, -0.0159, -0.0951, 0.0272, -0.0254, 0.0950,\n", - " -0.0375, 0.0219, -0.0701, 0.0274, 0.0514, -0.1323, -0.0244, 0.0209,\n", - " 0.0384, 0.0959, -0.0526, -0.0959, 0.0087, -0.0671, 0.0740, -0.0541,\n", - " -0.1209, 0.1143, -0.1211, -0.1215, 0.0241, -0.0355, 0.0496, -0.0873],\n", - " [-0.0620, 0.0964, -0.0432, -0.1068, -0.0070, -0.0644, -0.0126, 0.0664,\n", - " -0.0151, -0.0771, 0.0025, -0.0242, -0.0711, -0.0941, -0.1007, 0.0095,\n", - " -0.1179, -0.1091, 0.0729, -0.0793, -0.0850, -0.0388, -0.0042, 0.0264,\n", - " 0.0937, -0.0199, -0.0066, -0.0429, -0.1129, 0.0139, 0.0438, -0.1031,\n", - " 0.0996, -0.0426, 0.0846, 0.0011, -0.0587, 0.0588, -0.1116, -0.0752,\n", - " 0.0777, -0.0379, 0.0350, 0.0959, -0.0049, 0.0242, -0.0661, 0.0270,\n", - " 0.1113, 0.0603, -0.1295, -0.0281, -0.0134, -0.1082, 0.1021, -0.1169,\n", - " -0.1114, 0.0351, -0.1150, 0.0309, 0.0431, -0.0171, -0.1119, 0.0017],\n", - " [ 0.0734, 0.0953, 0.0616, -0.0765, -0.0249, 0.0455, -0.1004, 0.0682,\n", - " -0.0680, -0.0642, -0.0412, 0.0677, -0.0558, 0.0843, 0.0286, -0.0780,\n", - " -0.0241, -0.0779, -0.0062, 0.0342, -0.0726, -0.1107, -0.0498, 0.0803,\n", - " 0.0718, -0.0912, 0.0912, -0.0413, -0.1076, -0.1350, -0.0281, -0.0306,\n", - " -0.0570, 0.0969, 0.0343, -0.0138, 0.0775, -0.0800, 0.0467, 0.0467,\n", - " 0.0679, -0.0963, -0.1284, -0.0203, 0.0300, -0.1200, -0.0924, -0.0614,\n", - " 0.1139, 0.0177, -0.0947, 0.0778, 0.1138, -0.0736, -0.0912, 0.0732,\n", - " 0.0650, -0.0060, -0.0317, -0.0114, -0.0456, -0.0276, -0.0152, -0.1120],\n", - " [-0.1107, 0.0063, -0.0494, -0.0140, 0.1035, 0.0454, 0.0423, 0.0689,\n", - " 0.0021, 0.0475, 0.0671, -0.0212, -0.1229, 0.1020, -0.0078, 0.0880,\n", - " 0.0239, 0.0618, -0.1219, -0.0547, -0.0027, 0.0735, -0.0259, -0.0719,\n", - " 0.0393, -0.0955, -0.0895, 0.0760, -0.1398, -0.0768, 0.0255, 0.0599,\n", - " 0.0922, -0.0008, -0.1118, 0.0995, -0.1291, -0.1104, 0.0115, -0.0022,\n", - " -0.0370, 0.0026, 0.0256, -0.0527, -0.0461, 0.0648, -0.0488, 0.1144,\n", - " -0.0648, 0.0695, -0.1226, 0.0125, -0.1156, -0.0063, 0.0620, 0.0854,\n", - " -0.0051, 0.1106, -0.0881, -0.0268, -0.0463, -0.0225, 0.0864, -0.0413],\n", - " [-0.0504, 0.0472, -0.0564, -0.1158, -0.0995, -0.0740, -0.0521, -0.0234,\n", - " 0.0025, -0.0136, -0.0891, -0.0797, -0.0865, -0.1338, -0.0999, 0.1095,\n", - " -0.0205, -0.1233, -0.0582, 0.0441, -0.1451, -0.1247, 0.0698, 0.0658,\n", - " -0.1129, 0.0953, -0.0850, 0.0039, -0.2036, 0.0316, -0.0664, -0.1012,\n", - " 0.0476, -0.1400, -0.0008, -0.0369, -0.0311, -0.0517, -0.0898, -0.0993,\n", - " -0.0752, -0.0025, 0.0052, 0.0132, -0.1387, 0.0159, -0.0578, -0.0838,\n", - " 0.0899, -0.0111, -0.0181, -0.1449, -0.0837, -0.1060, 0.0211, -0.0045,\n", - " 0.0114, -0.1102, -0.0891, -0.1203, 0.0010, -0.1243, 0.0960, 0.0809],\n", - " [ 0.0860, 0.0034, -0.0498, 0.0240, -0.0842, -0.1156, -0.0980, -0.0578,\n", - " -0.0208, -0.0575, 0.0558, 0.0039, -0.0282, -0.1278, 0.0243, 0.0026,\n", - " 0.0530, 0.0711, -0.0707, -0.0561, 0.0393, -0.0645, -0.1039, 0.0328,\n", - " -0.0154, -0.1046, -0.1203, -0.1073, -0.0313, 0.0916, -0.0694, 0.1208,\n", - " -0.0155, -0.0113, -0.0538, 0.0951, -0.1426, -0.0177, 0.0761, -0.0302,\n", - " 0.0170, 0.0187, 0.0551, 0.0536, 0.0819, 0.0485, 0.0715, 0.1027,\n", - " -0.0539, 0.0185, 0.0287, -0.0661, 0.0278, -0.1256, -0.0127, -0.0422,\n", - " -0.0439, 0.0243, -0.0195, 0.0962, -0.1265, 0.0251, 0.0398, 0.0980],\n", - " [-0.0930, 0.0090, -0.0900, 0.0151, -0.0738, -0.0641, -0.1250, -0.1027,\n", - " -0.1119, -0.1132, -0.0733, -0.0020, -0.0558, -0.0918, -0.0067, 0.0356,\n", - " 0.0680, 0.0746, 0.0284, -0.0184, 0.0550, -0.1243, -0.0446, 0.0365,\n", - " -0.0409, -0.0668, 0.0492, -0.0216, -0.0328, -0.0824, -0.0770, -0.0982,\n", - " -0.0960, 0.0909, -0.0230, -0.0147, -0.0285, -0.0052, -0.0853, 0.0024,\n", - " -0.0801, 0.0154, -0.0153, -0.1093, -0.0263, -0.0813, 0.0111, 0.1049,\n", - " -0.0621, 0.0714, -0.0707, -0.0453, -0.1037, -0.0350, -0.0442, 0.1021,\n", - " -0.0793, 0.0311, 0.0718, -0.1177, -0.1275, 0.0688, -0.0619, 0.0386]],\n", - " requires_grad=True)\n" - ] - }, - { - "ename": "AttributeError", - "evalue": "'ReLU' object has no attribute 'weight'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0mTraceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msimple_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFINAL\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 284\u001b[0m \"\"\"\n\u001b[1;32m 285\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 286\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 287\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(s)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msimple_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFINAL\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 575\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 576\u001b[0m type(self).__name__, name))\n\u001b[1;32m 577\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'ReLU' object has no attribute 'weight'" - ] - } - ], - "source": [ - "simple_nn.module_.FINAL.apply(lambda s: print(s.weight))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "amber-uruguay", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "living-austria", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}