-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoding_layer_train.py
More file actions
25 lines (21 loc) · 1.12 KB
/
decoding_layer_train.py
File metadata and controls
25 lines (21 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def decoding_layer_train(encoder_state, dec_cell, dec_embed_input,
target_sequence_length, max_summary_length,
output_layer, keep_prob):
"""
Create a training process in decoding layer
:return: BasicDecoderOutput containing training logits and sample_id
"""
dec_cell = tf.contrib.rnn.DropoutWrapper(dec_cell,
output_keep_prob=keep_prob)
# for only input layer
helper = tf.contrib.seq2seq.TrainingHelper(dec_embed_input,
target_sequence_length)
decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell,
helper,
encoder_state,
output_layer)
# unrolling the decoder layer
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
impute_finished=True,
maximum_iterations=max_summary_length)
return outputs