16 – Saving & Loading Models V1

Hello. In this video, I’m going to be talking about saving and loading PyTorch models. So far, you’ve seen how to train models and use them to make predictions. But a lot of times, you’ll want to train a model and then later come back to it and make predictions or even continue training it on new data. Here I’m going to show you how to save your train models and load them, so that later you can come back and make predictions. First things first, load in our modules. Here’s pretty much everything we’ve seen before, loading in torch and torchvision. Something new here is this fc_model. I wrote a module myself called fc_model that just implements a model for building a fully connected classifier. So it’s pretty straightforward, but I did it just for convenience for this particular notebook and video. Next, we’re going to load in our dataset. So here FashionMNIST again using a training dataset and a test dataset. As a reminder, this is what these images look like. So this is 28 by 28 gray-scale. This is a purse. It looks like. Now, that we have the modules loaded and our data loaded, we can train a model. Here, this is just a fully connected network. Like I mentioned before, it has 784 input units, 10 output units. This is a list that contains the size of the hidden layers in between. So we have three hidden layers. The first one has 512 units. The second one has 256 units and the last one is a 128 units. Then after this, it goes to the output layer. The output layer here is log softmax. That means that for the criterion for the loss, we’re going to use the negative log likelihood loss. Finally, we’ll use the Adam optimizer to train our network and update the parameters. With this model, we can simply call the train method in our fc_module and it will train it for us. After about two weeks of training, then we get to around 84 percent accuracy. We can probably do better, maybe changing the network architecture and let it train longer. But this is just a demonstration. Now, that we have the network trained, we can save it to a file. Again typically we want to do this because later we went to load in our trained model and make predictions or train it more. The way we save it is we actually save what’s called the state_dict. So this is a dictionary that contains all of the parameters for your model. So all the weights and bias tensors. Here I’m printing out the model. We can see that we have these hidden layers and each of these are linear layers. Then we also have the linear layer for the output. Remember that these linear layers have a weight tensor and a biosensor as parameters. We can look at our state_dict and we see we have the hidden_layers.0 which is this layer, the weight for that layer, the bias of that layer, and then the weight and bias for the second layer, this hidden layer is one and so on. So, the state_dict contains all of the weights and biases for our network. We can actually save those to a file and load them back to rebuild our model. The simplest way to do this is to save the state_dict with torch.save. So, we can do torch.save, pass in our models state_dict and then just name the file where we want to keep it. So here checkpoint.pth is the checkpoint file. The pth is the typical extension for PyTorch checkpoints and with a file saved, we can load our state_dict back in. So if we load it from the checkpoint file and then print out the keys, we again see we have the biases and then waits for the hidden layers. With the state_dict loaded, now we can load that into the model itself. So here we’ve just loaded the state_dict itself. We don’t have it included in the model yet. But now, we can take our model, use the method load_state_dict, pass in our state_dict and that will load it into the model itself. Now, our model is ready to be used for making predictions or whatever. The seems pretty straightforward, but it’s actually more complicated. So if we take this state_dict that we have loaded here and try to load it into a model with a different architecture, we’re going to get an error. If we look what the error says, it’s telling us there is a size mismatch for this weight, hidden_layers.0.weight that’s trying to copy a perimeter was shaped 512 by 784 and the shape in the current model is 400 by 784. So when you’re loading a state_dict into a model, the model’s parameters itself has to have the same shapes as this state_dict. What this means is that if we’re actually loading a checkpoint, we have to rebuild the model exactly as it was when it was trained. Therefore, we actually need to include information about the architecture of the model within the checkpoint itself. Here we’re creating a checkpoint which is just a dictionary. In the dictionary, we can define our architects. So here we have the input size is 784. Output size is 10. Now, our hidden layers is this list where we’re going through each of the hidden layers in the model and getting the output features, so the size of that layer. Then finally we have a key for our state_dict. So, we can include the models state_dict in this checkpoint dictionary. Now, we just need to save this entire checkpoint into our checkpoint file. With that, we have information about our model architecture in the checkpoint file itself. What that allows us to do is write a function for loading these checkpoints. So if we give it a file path then we can load the checkpoint. This checkpoint remember is this dictionary up here. So we have the input_size, output_size, the hidden_layers, and the state_dict. Now, using this checkpoint, we can recreate our model so fc_model.Network. Give it the parameters from our checkpoint, the input_size, output_size, and hidden_layers. This will create a model for us. It’s going to create the model with the same architecture as the one that we trained with. Now, that we have our model recreated, we can use load_state_dict to load in our state_dict. This should give us our model back. There we go. So we successfully took our model, saved it, reloaded it, created a new model, and now we have a new model, and it’s the same as one we trained on. I should point out here that this method that you write, this function that you write, load_checkpoint is going to be based on the architecture of whatever model. Like how you’ve implemented your model, so you won’t be able to use the same load_checkpoint function for everything you use. You’re going to have to build it custom for every model architecture that you’ve implemented. Cheers. See you in the next video.