8 – 07 RNN Training V4

Let’s take a closer look at how the decoder trains on a given caption. The decoder will be made of LSTM cells, which are good at remembering lengthy sequence of words. Each LSTM cell is expecting to see the same shape of the input vector at each time-step. The very first cell is connected to the output feature vector of the CNN encoder. Previously, I mentioned an embedding layer that transformed each input word into a vector of a certain shape before being fed as input to the RNN. We need to apply the same transformation to the output feature vector of the CNN. Once this feature vector is embedded into expected input shape, we can begin that RNN training process with this as the first input. The input to the RNN for all future time steps, will be the individual words of the training caption. So at the start of training, we have some input from our CNN, and LSTM cell with the initial state. Now the R-net has two responsibilities. One, remember spatial information from the input feature vector. And two, predict the next word. We know that the very first word it produces should always be the start token, and the next word should be those in the training caption. For our caption, a man holding a slice of pizza. We know that after the start token comes a, and after a comes man, and so on. At every time step, we look at the current caption word as input, and combine it with the hidden state of the LSTM cell to produce an output. This output is then passed through a fully connected layer that produces a distribution that represents the most likely next word. This is like how we’ve seen softmax apply to classification task. But in this case, it produces a list of next words scores instead of a list of class scores. We feed the next word in the caption to the network, and so on, until we reach the end token. The hidden state of an LSTM is a function of the input token to the LSTM, and the previous state. I’ll refer to this function as the recurrence function. The recurrence function is defined by weights, and during the training process this model uses back propagation to update these weights until the LSTM cells learn to produce the correct next word in the caption given the current input word. As with most models, you can also take advantage of batching the training data. The model update its weights after each training batch with the batch size is the number of image caption pairs sent through that network during a single training step. Once the model is trained, it will have learned from many image caption pairs and should be able to generate captions for new image data.

%d 블로거가 이것을 좋아합니다: