|
790 | 790 | }, |
791 | 791 | { |
792 | 792 | "cell_type": "code", |
793 | | - "execution_count": 20, |
| 793 | + "execution_count": null, |
| 794 | + "metadata": {}, |
| 795 | + "outputs": [], |
| 796 | + "source": [ |
| 797 | + "from torch.utils.data import DataLoader\n", |
| 798 | + "from spotPython.data.diabetes import Diabetes\n", |
| 799 | + "from spotPython.light.netlightregression import NetLightRegression\n", |
| 800 | + "from torch import nn\n", |
| 801 | + "import lightning as L\n", |
| 802 | + "PATH_DATASETS = './data'\n", |
| 803 | + "BATCH_SIZE = 8\n", |
| 804 | + "\n", |
| 805 | + "dataset = Diabetes()\n", |
| 806 | + "train_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
| 807 | + "test_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
| 808 | + "val_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
| 809 | + "batch_x, batch_y = next(iter(train_loader)) \n", |
| 810 | + "print(batch_x.shape)\n", |
| 811 | + "print(batch_y.shape)\n", |
| 812 | + "\n", |
| 813 | + "net_light_base = NetLightRegression(l1=128, epochs=10, batch_size=BATCH_SIZE,\n", |
| 814 | + " initialization='xavier', act_fn=nn.ReLU(),\n", |
| 815 | + " optimizer='Adam', dropout_prob=0.1, lr_mult=0.1,\n", |
| 816 | + " patience=5, _L_in=10, _L_out=1)\n", |
| 817 | + "trainer = L.Trainer(max_epochs=2, enable_progress_bar=False)\n", |
| 818 | + "trainer.fit(net_light_base, train_loader)\n", |
| 819 | + "trainer.validate(net_light_base, val_loader)\n", |
| 820 | + "trainer.test(net_light_base, test_loader)\n" |
| 821 | + ] |
| 822 | + }, |
| 823 | + { |
| 824 | + "cell_type": "markdown", |
| 825 | + "metadata": {}, |
| 826 | + "source": [ |
| 827 | + "# tests optimizer_handler" |
| 828 | + ] |
| 829 | + }, |
| 830 | + { |
| 831 | + "cell_type": "code", |
| 832 | + "execution_count": 10, |
794 | 833 | "metadata": {}, |
795 | 834 | "outputs": [ |
796 | 835 | { |
|
811 | 850 | "15.9 K Trainable params\n", |
812 | 851 | "0 Non-trainable params\n", |
813 | 852 | "15.9 K Total params\n", |
814 | | - "0.064 Total estimated model params size (MB)\n" |
815 | | - ] |
816 | | - }, |
817 | | - { |
818 | | - "name": "stdout", |
819 | | - "output_type": "stream", |
820 | | - "text": [ |
821 | | - "torch.Size([8, 10])\n", |
822 | | - "torch.Size([8])\n" |
823 | | - ] |
824 | | - }, |
825 | | - { |
826 | | - "name": "stderr", |
827 | | - "output_type": "stream", |
828 | | - "text": [ |
| 853 | + "0.064 Total estimated model params size (MB)\n", |
829 | 854 | "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", |
830 | 855 | "`Trainer.fit` stopped: `max_epochs=2` reached.\n", |
831 | | - "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", |
832 | | - "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n" |
833 | | - ] |
834 | | - }, |
835 | | - { |
836 | | - "name": "stdout", |
837 | | - "output_type": "stream", |
838 | | - "text": [ |
839 | | - "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
840 | | - " Validate metric DataLoader 0\n", |
841 | | - "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
842 | | - " hp_metric 28981.529296875\n", |
843 | | - " val_loss 28981.529296875\n", |
844 | | - "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
845 | | - "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
846 | | - " Test metric DataLoader 0\n", |
847 | | - "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
848 | | - " hp_metric 28981.529296875\n", |
849 | | - " val_loss 28981.529296875\n", |
850 | | - "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n" |
| 856 | + "GPU available: True (mps), used: True\n", |
| 857 | + "TPU available: False, using: 0 TPU cores\n", |
| 858 | + "IPU available: False, using: 0 IPUs\n", |
| 859 | + "HPU available: False, using: 0 HPUs\n", |
| 860 | + "\n", |
| 861 | + " | Name | Type | Params | In sizes | Out sizes\n", |
| 862 | + "-------------------------------------------------------------\n", |
| 863 | + "0 | layers | Sequential | 15.9 K | [8, 10] | [8, 1] \n", |
| 864 | + "-------------------------------------------------------------\n", |
| 865 | + "15.9 K Trainable params\n", |
| 866 | + "0 Non-trainable params\n", |
| 867 | + "15.9 K Total params\n", |
| 868 | + "0.064 Total estimated model params size (MB)\n", |
| 869 | + "`Trainer.fit` stopped: `max_epochs=2` reached.\n" |
851 | 870 | ] |
852 | 871 | }, |
853 | 872 | { |
854 | 873 | "data": { |
855 | 874 | "text/plain": [ |
856 | | - "[{'val_loss': 28981.529296875, 'hp_metric': 28981.529296875}]" |
| 875 | + "True" |
857 | 876 | ] |
858 | 877 | }, |
859 | | - "execution_count": 20, |
| 878 | + "execution_count": 10, |
860 | 879 | "metadata": {}, |
861 | 880 | "output_type": "execute_result" |
862 | 881 | } |
|
867 | 886 | "from spotPython.light.netlightregression import NetLightRegression\n", |
868 | 887 | "from torch import nn\n", |
869 | 888 | "import lightning as L\n", |
870 | | - "PATH_DATASETS = './data'\n", |
| 889 | + "\n", |
871 | 890 | "BATCH_SIZE = 8\n", |
| 891 | + "lr_mult=0.1\n", |
872 | 892 | "\n", |
873 | 893 | "dataset = Diabetes()\n", |
874 | 894 | "train_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
875 | 895 | "test_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
876 | 896 | "val_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
877 | | - "batch_x, batch_y = next(iter(train_loader)) \n", |
878 | | - "print(batch_x.shape)\n", |
879 | | - "print(batch_y.shape)\n", |
880 | 897 | "\n", |
881 | 898 | "net_light_base = NetLightRegression(l1=128, epochs=10, batch_size=BATCH_SIZE,\n", |
882 | 899 | " initialization='xavier', act_fn=nn.ReLU(),\n", |
883 | | - " optimizer='Adam', dropout_prob=0.1, lr_mult=0.1,\n", |
| 900 | + " optimizer='Adam', dropout_prob=0.1, lr_mult=lr_mult,\n", |
884 | 901 | " patience=5, _L_in=10, _L_out=1)\n", |
885 | 902 | "trainer = L.Trainer(max_epochs=2, enable_progress_bar=False)\n", |
886 | 903 | "trainer.fit(net_light_base, train_loader)\n", |
887 | | - "trainer.validate(net_light_base, val_loader)\n", |
888 | | - "trainer.test(net_light_base, test_loader)\n" |
| 904 | + "# Adam uses a lr which is calculated as lr=lr_mult * 0.001, so this value\n", |
| 905 | + "# should be 0.1 * 0.001 = 0.0001 \n", |
| 906 | + "trainer.optimizers[0].param_groups[0][\"lr\"] == lr_mult*0.001\n", |
| 907 | + "\n", |
| 908 | + "\n", |
| 909 | + "net_light_base = NetLightRegression(l1=128, epochs=10, batch_size=BATCH_SIZE,\n", |
| 910 | + " initialization='xavier', act_fn=nn.ReLU(),\n", |
| 911 | + " optimizer='Adadelta', dropout_prob=0.1, lr_mult=lr_mult,\n", |
| 912 | + " patience=5, _L_in=10, _L_out=1)\n", |
| 913 | + "trainer = L.Trainer(max_epochs=2, enable_progress_bar=False)\n", |
| 914 | + "trainer.fit(net_light_base, train_loader)\n", |
| 915 | + "# Adadelta uses a lr which is calculated as lr=lr_mult * 1.0, so this value\n", |
| 916 | + "# should be 1.0 * 0.1 = 0.1 \n", |
| 917 | + "trainer.optimizers[0].param_groups[0][\"lr\"] == lr_mult*1.0\n" |
889 | 918 | ] |
890 | 919 | }, |
891 | 920 | { |
|
0 commit comments