Data Science Asked on August 28, 2020
Looking for some explanation of understanding of the BERT implementation by huggingface. I would explain my understanding below and then ask question:
Below is code for question answering modeling_bert.py Line 1483:
class BertForQuestionAnswering(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) ############FOCUS ON THIS LINE##########
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) ############FOCUS ON THIS LINE##########
start_logits, end_logits = logits.split(1, dim=-1) ############FOCUS ON THIS LINE##########
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
I have added comment to focus as ############FOCUS ON THIS LINE########## . My understanding of this code is, first the question and passage is passed to bert. Bert generates output for the hidden layers which is then passed to self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
. The num_lables are 2. So the output of this layer contains 2 tensors which represent the probability (logit) for each word in the vocabulary.
Now the following code finds the average weight of loss for each word in vocabulary.
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 ############FOCUS ON THIS LINE##########
if return_tuple:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
Please correct me, if my understanding of above code is wrong in any way. If correct, then my question is, why the loss of start word and end word is minimum for the best answer? Why not, the total weight for best answer is somewhere in middle of answer?
For example, in the image below, the best answer (total loss) is from "Retail Banking + Investment Banking". Why not the total loss was more for "Retail Banking + Investment Solutions" ?
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP