Py Part 6 V1

In this video and notebook, I’ll be showing you how to save and load models. Like I said previously, you typically don’t want to have to train a new model every time you want to use it. So instead, you’ll train it once and then save it, and then if you need to use it for inference later, you can load it up and do that. So, we can start in import our normal things. So on menu, import fc_model. So, this is basically the file that I took the network that I built in part five and I just moved it into this file, so that I can easily load it in and train it, and all of that. So, let’s going to load that in. We’re using unloading the data, and again, we’re using Fashion MNIST. Then you can see, what looks like nice little images. So, it goes talking about, I create this file called fc_model, and the network is in there. So, I just created our model, create our network, so I’m using again, 784 inputs,10 outputs, 512, 256 and 128 units for three different hidden layers, and then using the NLLLoss and our Adam optimizer. Then we can just generally, with the network trained it is time to save it. The parameters of our model are actually stored in an attribute called state dict. So, let’s go and look at our model really quick, print model, sort of the model looks like, and then we can also look at the model.state_dict. So, this is a dictionary, so I’m just going to go ahead and print out the keys. So, then we can see for each of our hidden layers, we can go in there and get our weights and our biases and the output. So basically, it stores the weights and the biases and everything in the state_dict. So, this is the thing that we’re actually going to want to save. So to do that, is to torch.save and then we want to save our state_dict. Save it to called checkpoint.pth. So, now that is saved to disk. We can load this in, so called state_dict again, torch.load, checkpoint, and just to see that it’s the same thing, yeah. So, print out the keys again and it’s the same. So, we can save these state_dict, to a checkpoint and then we can load it back in with torch.load. Now, we have the state_dict, but it’s not attached to a model yet. So, we actually need to go for a model and say load state_dict and so, this actually loads the state dictionary. So, that’s our model, cool. So, it seems really straightforward, you just save the state_dict, load it back in, load it into your model with load state_dict but, check out what happens if my model has a different architecture. So, I’m going to create a new model, 784,10 but this time I’m going to use 400, 200, and 100 for my hidden layers. So now, we’ve already loaded the state_dict so now, let’s see how this new model, so see what happens when we do load the state dictionary. So, we get this error. So, what it says while copying the parameter, use dimensions or model or this, used dimensions in the checkpoint are this. So basically, what this is telling us is that, the dimensions of our network are different than the dimensions in the checkpoint. What this means is that, if you create a checkpoint from the state dictionary, it has to match up with exactly the same architecture when you load it back in. That means, we need to save the network architecture along side of the state_dict, so that when we load in the state dictionary, then we load in the architecture of the network, the parameters of that network. So, we can create a dictionary and just call this checkpoint, and my input size 784, I’ll put size 10, in layers. So here, I’m using a list comprehension to get the actual sizes of each of our hidden layers. So, for each in model.hidden layers. Then finally, we can save the state dictionary. So, we have this checkpoint that is basically just a dictionary that has our architecture parameters, our hyperparameters for like how we actually build the network, as well as the state dictionary that contains all of our parameters or weights and biases and things like that. So, we can save our checkpoint, it will save it to again, our checkpoint file. Nice. So now, using this checkpoint, we can rebuild our network. So, I usually like to build a small function, just called like load checkpoint. Which takes a file path. So now, we can load our checkpoint, and build our model. So, this will build a model for us and now, we can load in state dictionary into this model, and return our model. There we go. So, we see that we can now save our model itself like the number of units and the state dictionary and then load all the backup, recreate our model and load in the appropriate state_dict. All right, so in the next video, I’ll be showing you how to load in image data to use an image classifier or pretty much any application that requires images. Cheers!

댓글 남기기

이 사이트는 스팸을 줄이는 아키스밋을 사용합니다. 댓글이 어떻게 처리되는지 알아보십시오.

%d 블로거가 이것을 좋아합니다: