Imagine that you are taken with a sudden desire to understand how the fruit of a tropical tree gets transformed into chocolate bars, or want to understand the role of fever in the human body's immune response: how would you go about finding that information?
If your specific question has already been asked and answered clearly and succintly on one of the many question answering platforms available on the Internet (such as Quora, Reddit, or Yahoo Answers), you're in luck: modern search engines will probably take you to that pre-existing answer pretty reliably in a matter of a few clicks.
If no one else has asked the exact question you are interested in, however, the process will be a little more involved. You will likely have to collect relevant information from a variety of sources, figure out how these pieces of knowledge fit together in relation to your query, and synthetize a narrative that answers your initial question.
Now, wouldn't it be great if your computer could do all of that for you: gather the right sources (e.g. paragraphs from relevant Wikipedia pages), synthetize the information, and write up an easy-to-read, original summary of the relevant points? Such a system isn't quite available yet, at least not one that can provide reliable information in its summary. Even though current systems excel at finding an extractive span that answers a factoid question in a given document, they still find open-domain settings where a model needs to find its own sources of information and long answer generation challenging.
Thankfully, a number of recent advances in natural language understanding and generation have made working toward solving this problem much easier! These advances include progress in the pre-training (e.g. BART, T5) and evaluation (e.g. for factuality) of sequence-to-sequence models for conditional text generation, new ways to use language understanding models to find information in Wikipedia (e.g. REALM, DPR), and a new training dataset introduced in the paper ELI5: Long Form Question Answering.
The ELI5 dataset was built by gathering questions that were asked by community members of the r/explainlikeimfive subreddit, along with the answers that were provided by other users. The rules of the subreddit make this data particularly well suited to training a model for abstractive question answering: the questions need to seek an objective explanation about well established facts, and the answers provided need to be understandable to a layperson without any particular knowledge domain.
In this notebook, we show how we can take advantage of these recent advances to train a long form question answering system which takes in a question, fetches 10 relevant passages from a Wikipedia snapshot, and writes a multi-sentence answer based on the question and retrieved passages. In particular, training embedding-based retrieval models to gather supporting evidence for open-domain questions is relatively new research area: the last few months have seen some significant progress in cases where direct supervision is available, or with extensive task-specific pretraining. Here, we show how the ELI5 dataset allows us to train a dense retrieval system without access to either, making dense retrieval models more accessible. See this presentation from the Hugging Face reading group for a non-exhaustive overview of recent work in the field.
Follow along to learn about the steps involved and read some background on the state of the art for some related tasks, or go straight to the:
(And don't forget to scroll down on the left sidebar to show all of the generation options!)
The implementation presented here relies on the Hugging Face 🤗transformers and 🤗nlp libraries. Wikipedia indexing relies on ElasticSearch with its python bindings for the sparse version, and faiss for the dense version. You can get all of these by running:
pip install elasticsearch
pip install faiss_gpu
pip install nlp
pip install transformers
wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.7.1-linux-x86_64.tar.gz
tar -xzvf elasticsearch-7.7.1-linux-x86_64.tar.gz
The training relies on two datasets: ELI5, a processed version of the r/explainlikeimfive subreddit, and the Wiki40b Wikipedia image. You can download both using the 🤗nlp linrary with:
import nlp
eli5 = nlp.load_dataset('eli5')
wiki40b_snippets = nlp.load_dataset('wiki_snippets', name='wiki40b_en_100_0')['train']
Additionally, all of the useful methods used in this notebook are compiled in the lfqa_utils.py script:
from lfqa_utils import *
Before we go any further, let us take a moment to talk about the provenance of our training data. While Reddit hosts a number of thriving communities with high quality discussions, it is also widely known to have corners where sexism, hate, and harassment are significant issues. See for example the recent post from Reddit founder u/spez outlining some of the ways he thinks the website's historical policies have been responsible for this problem, Adrienne Massanari's 2015 article on GamerGate and follow-up works, or a 2019 Wired article on misogyny on Reddit.
While there has been some recent work in the NLP community on de-biasing models (e.g. Black is to Criminal as Caucasian is to Police: Detecting and Removing Multiclass Bias in Word Embeddings for word embeddings trained specifically on Reddit data), this problem is far from solved, and the likelihood that a trained model might learn the biases present in the data remains a significant concern.
As mentioned above, the magnitude of the problem depends on the specific communities/subreddits. This work uses data from r/explainlikeimfive, and the nlp
library also gives access to examples from r/askscience, and r/AskHistorians. There are some encouraging signs for all of these communities: r/explainlikeimfive and r/askscience have similar structures and purposes, and r/askscience was found in 2015 to show medium supportiveness and very low toxicity when compared to other subreddits (see a hackerfall post, thecut.com write-up and supporting data). Meanwhile, the r/AskHistorians rules mention that the admins will not tolerate "racism, sexism, or any other forms of bigotry".
This is obviously not enough to exonerate the model (the pre-training step, for example, raises its own questions on that topic), and there is still a lot of interesting work to do to be able to quantify the biases in a conditional text generation model. One thing you can do to help: if you find any particularly egregious answers provided by the model when using the demo, or want to collaborate on this research question please send a DM to @YJernite on Twitter!
Let's recap: we are interested in the task of Long Form Question Answering. As in other Question Answering tasks, the model is presented with a question, and is required to generate a natural language answer. Whereas a majority of QA datasets contain mostly factoid questions, where the answer, such as a date or the name of a single entity, can be expressed in a few words or single sentence, Long Form QA focuses on questions which call for an explanation consisting of a few sentences or a few paragraphs.
In order to teach a model to answer such questions, we use questions and answers written by Reddit users. Note that the nlp.load_dataset
command above actually downloaded questions and their associated answers from the r/explainlikeimfive, r/askscience, and r/AskHistorians subreddits. We focus here on the ELI5/explainlikeimfive part to train the system, as these examples tend to be a little simpler.
Let's look at one item from the test set:
eli5['test_eli5'][12345]
So here we have the question:
Why does water heated to room temperature feel colder than the air around it?
This definitely requires a multi-step explanation: no single phrase can sum up all of the information we are looking for. Here are the answers that were given on ELI5, and were given scores of +5 and +2 respectively by Reddit users:
Water transfers heat more efficiently than air. When something feels cold it's because heat is being transferred from your skin to whatever you're touching. Since water absorbs the heat more readily than air, it feels colder.
Air isn't as good at transferring heat compared to something like water or steel (sit on a room temperature steel bench vs. a room temperature wooden bench, and the steel one will feel more cold). When you feel cold, what you're feeling is heat being transferred out of you. If there is no breeze, you feel a certain way. If there's a breeze, you will get colder faster (because the moving air is pulling the heat away from you), and if you get into water, its quite good at pulling heat from you. Get out of the water and have a breeze blow on you while you're wet, all of the water starts evaporating, pulling even more heat from you.
First, note that in this case we have two answers which broadly describe the same phenomenon: the first one is scored higher because it is more succint and to the point. This example already illustrates one important feature of the LFQA task: there are usually several valid ways to answer a given question. Of the 272K examples in the ELI5 training set, nearly two thirds (167K) have at least two answers. We'll need to keep this in mind when training and evaluation of the model.
Secondly, we need to give our model access to the information that is expressed in both these answers. Recently released models have been shown to include a significant amount of world knowledge in their parameters without the need of any external knowledge at all (see e.g. the Closed-book QA performance of the T5 model). There are several advantages to giving the model explicit access to information in text form however. First, a larger number of parameters in a model implies a larger computational cost. Secondly, getting information from a text database allows us to easily update the model's knowledge without having to re-train its parameters.
Here, we choose to give the model access to Wikipedia text. Full Wikipedia articles are typically too long for most current models to handle, and notable exceptions like the Reformer or Longformer architectures unfortunately do not yet have pre-trained sequence-to-sequence variants. Thus, we follow previous work in splitting Wikipedia articles into disjoint snippets of 100 words, and keep track of the title of the article and sections a snippet came from. Here's how you can get a pre-processed Wiki40b version split into 100-word passages with the nlp
library, and an example snippet which has some of the information we're looking for ("little conduction would occur since air is a poor conductor of heat"):
wiki40b_snippets[8991855]
In the next two sections, we show how we can use either a sparse retriever or a trained dense retriever to automatically find relevant snippets for a question.
In this section, we show how to use either such a "classical" Information Retrieval (IR) system based on sparse word matching with ElasticSearch, an extremely popular and efficient search engine that can be used for finding documents that match a given query based on word overlap.
Specifically, ElasticSearch provides a convenient way to index documents so they can easily be queried for nearest neighbor search using the BM25 similarity function (which relies on TF-IDF weighting of words). While this word-matching based approach has obvious limitations, such as failing to take synonyms and sometimes grammatical variation into account, it does pretty well overall and has only recently been overtaken by embedding-based systems for Wikipedia-based Open-Domain QA tasks.
In order to use ElasticSearch, you will first need to launch a server. In a different window, run:
./elasticsearch-7.7.0/bin/elasticsearch
By default, your ElasticSearch server will be listening on localhost
port 9200
. To connect to it run:
es_client = Elasticsearch([{'host': 'localhost', 'port': '9200'}])
The eli5_utils.py
script provides utilities to create (make_es_index_snippets
) and query (query_es_index
) an ElasticSearch index from within Python.
The main implementation details are:
index_config
variable: index_config = {
"settings": {
"number_of_shards": 1,
},
"mappings": {
"properties": {
"article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
"section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
"passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"}
}
}
}
banned = ['how', 'why', 'what', 'where', 'which', 'do', 'does', 'is', '?', 'eli5', 'eli5:']
q = ' '.join([w for w in q.split() if w not in banned])
response = es_client.search(
index = index_name,
body = {
"query": {
"multi_match": {
"query": q,
"fields": ["article_title", "section_title", "passage_text^2"],
"type": "cross_fields",
}
},
"size": n_results,
}
)
Here's the command to create the index, it should take one to three hours depending on your system.
if not es_client.indices.exists('wiki40b_snippets_100w'):
make_es_index_snippets(es_client, wiki40b_snippets, index_name='wiki40b_snippets_100w')
Now let's test the ElasticSearch retriever with our running example ELI5 question about skin-to-water heat transfer by returning the 10 best candidate passages:
question = eli5['test_eli5'][12345]['title']
doc, res_list = query_es_index(question, es_client, index_name='wiki40b_snippets_100w')
df = pd.DataFrame({
'Article': ['---'] + [res['article_title'] for res in res_list],
'Sections': ['---'] + [res['section_title'] if res['section_title'].strip() != '' else res['article_title']
for res in res_list],
'Text': ['--- ' + question] + [res['passage_text'] for res in res_list],
})
df.style.set_properties(**{'text-align': 'left'})
We can immediately see both the strengths and limitations of this approach. The system manages to retrieve documents that are all broadly on topic, mentioning some combination of water, air, relative temperature, and temperature transfer. In spite of this, only example 8 ends up containing information that is actually relevant to the question:
Cold air with high relative humidity "feels" colder than dry air of the same temperature because high humidity in cold weather increases the conduction of heat from the body.
We got lucky this time, but this passage could as easily have been ranked 11th and not been included in the support document we provide to the answer generation system. As it is, the model will have to sort through mostly off-topic information to find this sentence when reading the resulting supporting document.
The sparse retriever works by finding passages which feature the words from the query. However, it has no way to know a priori which of these words are more important in context, and seems to struggle with understanding the central theme of the query (human-perceived temperature).
Thankfully, some recent works have taken advantage of advances in pre-trained contextual word representations to solve this problem. Models such as DPR or REALM for example learn to compute a vector representation of the query, as well as vector representations of Wikipedia passages in such a way that the passages that best answers a question maximize the dot product between the two representations. Retrieval is then reduced to a Maximum Inner Product Search, which can be executed efficiently using systems like FAISS.
These successes are very encouraging for our Open-Domain Long Form QA application. However, our task and setup do not quite meet the requirements of either of either of these approaches. On the one hand, the DPR system is trained using gold passage annotations: most major QA dataset tell the system which Wikipedia passage contains the answer. Unfortunately, we do not have such annotations for the ELI5 data. On the other hand, while REALM is trained without passage supervision, it requires a pretty expensive pre-training step with an Inverse Cloze Task (100,000 steps with batch size 4096), and the ability to re-compute the embeddings of all Wikipedia passages regularly during training.
In order to train a similar dense retrieval system at reduced cost without having access to gold passage annotation, we will have to take advantage of another unique feature of our dataset, namely the fact that the long form answers are quite similar in style to the Wikipedia passages we want to index. Our hypothesis then is that if we train a system to embed the questions and answers in our dataset in a way that allows us to easily match questions to answers, then using the answer embedder on Wikipedia passages should allow us to similarly match questions to supporting evidence from Wikipedia.
As mentioned above, we want to train a system to produce question and answer embeddings, such that the dot product between the representation of a question and any of its answers is greater than between it and answers of all of the other questions in the dataset.
Unfortunately, actually comparing all questions to all answers before taking every single gradient step is computationally prohibitive: instead, we follow previous work in simply processing medium to large batches of question-answer pairs, and making sure that the dot product of a question with its answer is larger than with all other answers in the batch, and vice versa.
We use a cross-entropy loss for the multinomial distribution over all of the answers (or questions) in a batch, and make use of PyTorch gradient checkpointing to be able to use large batches with limited GPU memory: you can find all implementation details in the RetrievalQAEmbedder
class in eli5_utils.py
.
We use a single BERT-style pre-trained model to embed the questions and answers, and learn different projection matrices to bring both representations down to dimension 128: the projection matrices are trained from scratch as the sentence embedding model is fine-tuned. We found that the 8-layer distilled version of BERT from the Well-Read Students Learn Better paper performed as well or better as full BERT for a notable gain in computation speed: if you want an even faster model, that work provides pre-trained models spanning the full range of computation/accuracy trade-offs.
The model can than be trained with the following code: with batch size 32/512 on a single 16GB GPU, you can run 10 training epochs in under 6 hours.
# training arguments
class ArgumentsQAR():
def __init__(self):
self.batch_size = 512
self.max_length = 128
self.checkpoint_batch_size = 32
self.print_freq = 100
self.pretrained_model_name = "google/bert_uncased_L-8_H-768_A-12"
self.model_save_name = "retriever_models/eli5_retriever_model_l-8_h-768_b-512-512"
self.learning_rate = 2e-4
self.num_epochs = 10
qar_args = ArgumentsQAR()
# prepare torch Dataset objects
qar_train_dset = ELI5DatasetQARetriver(eli5['train_eli5'], training=True)
qar_valid_dset = ELI5DatasetQARetriver(eli5['validation_eli5'], training=False)
# load pre-trained BERT and make model
qar_tokenizer, qar_model = make_qa_retriever_model(
model_name=qar_args.pretrained_model_name,
from_file=None,
device="cuda:0"
)
# train the model
train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args)
If you don't want to train the model yourself, you can also download trained weights from the Hugging Face model repository with:
qar_tokenizer = AutoTokenizer.from_pretrained('yjernite/retribert-base-uncased')
qar_model = AutoModel.from_pretrained('yjernite/retribert-base-uncased').to('cuda:1')
_ = qar_model.eval()
Once the model is trained, it can be used to compute passage embeddings for all Wikipedia snippets. The make_qa_dense_index
method takes advantage of numpy
memory-mapping, so embeddings are written directly to disk. Again with a single GPU, computing the full set of passage embeddings should take about 18 hours.
if not os.path.isfile('wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat'):
make_qa_dense_index(
qar_model, qar_tokenizer, wiki40b_snippets, device='cuda:0',
index_name='wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat'
)
Now that we have trained our model to compute query and answer embeddings and used it to compute passage embeddings for all our Wikipedia snippets, let's see whether it can actually find supporting evidence for a new question. Recall the the two steps to using the dense retriever: we first compute an embedding for a new question, then do Max Inner Product Search with the pre-computed passage representations.
The MIPS part can be executed efficiently with the faiss
library. Additionally, since we computed 128-dimensional passage embeddings, the whole of the representations fits on a GPU, making retrieval even faster. We can create the faiss_gpu
index with the following code:
faiss_res = faiss.StandardGpuResources()
wiki40b_passage_reps = np.memmap(
'wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat',
dtype='float32', mode='r',
shape=(wiki40b_snippets.num_rows, 128)
)
wiki40b_index_flat = faiss.IndexFlatIP(128)
wiki40b_gpu_index = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
wiki40b_gpu_index.add(wiki40b_passage_reps)
Now we can use the query_qa_dense_index
function to query the dense index for our running example question about perceived temperature:
question = eli5['test_eli5'][12345]['title']
doc, res_list = query_qa_dense_index(question, qar_model, qar_tokenizer, wiki40b_snippets, wiki40b_gpu_index, device='cuda:1')
df = pd.DataFrame({
'Article': ['---'] + [res['article_title'] for res in res_list],
'Sections': ['---'] + [res['section_title'] if res['section_title'].strip() != '' else res['article_title']
for res in res_list],
'Text': ['--- ' + question] + [res['passage_text'] for res in res_list],
})
df.style.set_properties(**{'text-align': 'left'})
The retrieved documents are quite different from the ones returned by the sparse retrieval, with a greater focus on how water helps draw heat from a body, either through evaporation or through better conduction, which is information the model needs to answer this question.
The retriever still misses out on one aspect of the query: the way the question is formulated implies that in the considered scenario the person is immersed in water rather than just wet, which makes the "latent heat" and evaporation arguments a little less relevant, but that's a really subtle distinction!
We have trained a retrieval model that seems to be working a little better than the traditional word-matching based approach, at least on our running example. Before we use it to actually answer questions, however, we would like to be able to get some quantitative evaluation of the performances of both approaches.
For the retriever, we want to favor recall over precision: our first priority is to make sure that all of the information needed to write the answers is present in the support document. If there is unrelated information, the generation model can learn to sort it out. We measure this by computing the proportion of words in the high-scoring answers which are present in the retrieved support document. To focus on important words, we also weigh answer words by their Inverse Document Frequency. This gives us the following IDF-recall scoring function:
# We first select high-scoring answers (answers beyond the first must have a score of at least 3)
test_qa_list = [(exple['title'],
' '.join([a
for i, (a, sc) in enumerate(zip(exple['answers']['text'], exple['answers']['score'])) \
if i == 0 or sc >= 3
]))
for exple in eli5['test_eli5']]
# We then compute word frequencies in answer text
answer_doc_freq = {}
for q, a in test_qa_list:
for w in a.lower().split():
answer_doc_freq[w] = answer_doc_freq.get(w, 0) + 1
# The IDF-recall function is then:
def da_idf_recall(doc, answer):
d_words = dict([(w, True) for w in doc.lower().split()])
a_words = answer.lower().split()
recall = sum([1. / math.log(1 + answer_doc_freq.get(w, 1)) for w in a_words if w in d_words]) / \
sum([1. / math.log(1 + answer_doc_freq.get(w, 1)) for w in a_words])
return recall
The evaluate_retriever
function in eli5_utils.py
takes a retrieval and scoring function and computes both the average retrieval time and score of the document relative the the provided answer. Let's write some short-hand functions for the dense and sparse retrievers with our currently loaded indexes, and evaluate them on the ELI5 test set (be advised that evaluating the retriever on the full test set takes up to two hours):
def dense_ret_for_eval(question, n_ret):
_, dense_res_list = query_qa_dense_index(
question, qar_model, qar_tokenizer, wiki40b_snippets, wiki40b_gpu_index, n_results=n_ret, device='cuda:1'
)
dense_doc = ' '.join([res['passage_text'] for res in dense_res_list])
return dense_doc
def sparse_ret_for_eval(question, n_ret):
_, sparse_res_list = query_es_index(
question, es_client, index_name='wiki40b_snippets_100w', n_results=n_ret
)
sparse_doc = ' '.join([res['passage_text'] for res in sparse_res_list])
return sparse_doc
dense_score = evaluate_retriever(test_qa_list, dense_ret_for_eval, da_idf_recall)
sparse_score = evaluate_retriever(test_qa_list, sparse_ret_for_eval, da_idf_recall)
df = pd.DataFrame({
'IDF-Recall': [sparse_score['idf_recall'], dense_score['idf_recall']],
'Time/Query': [sparse_score['retrieval_time'], dense_score['retrieval_time']],
}, index=[ 'Sparse', 'Dense'])
df.style.format({'IDF-Recall': "{:.4f}", 'Time/Query': "{:.4f}"})
This metric obviously has limitations. Since it only looks at individual word matches, it is oblivious to word order or paraphrases among others. However, we can be encouraged by the fact that the dense retriever not only yields higher IDF-recall, it also takes less than a third of the time of the ElasticSearch-based system! Considering these results, we can confidently use it for the next part: training the sequence-to-sequence answer generation system.
Now that we know how to create an evidence document with supporting information for a given question, let's look into training the second component of our system: the answer generation module. We will instantiate it as a sequence-to-sequence model which uses the BART architecture, and initialize it with the bart-large pretrained weights.
In short, the BART paper uses a denoising auto-encoder style objective to pre-train an encoder-decoder model (similarly to how masked language modeling is used to pre-trained BERT-style encoders). Among other applications, they show that large-scale pre-training with their objective followed by fine-tuning on ELI5 data yields the state-of-the-art ROUGE performance for the original version of the dataset (which uses pre-computed support documents made from CommonCrawl pages).
We provide the concatenation of the question and support document as input to the model, and train the decoder to minimize the perplexity of the gold answer. One notable choice is that we train the model using all high-scoring answers in the training set, so the model will see several instances of the same question-document input with different outputs. The supporting passages are separated by a special token <P>
, so the input for our running example will look like:
question: Why does water heated to room temperature feel colder than the air around it? context: \<P> when the skin is completely wet. The body continuously loses ... this heat comes from the liquid itself and the surrounding gas and surfaces. \<P> protected by a glass panel. Consequently, these types of collectors... Since heat loss due to convection cannot cross a vacuum, it forms an efficient isolation mechanism to keep heat inside the collector pipes. Since two flat \<P> ... \<P> changes. Conduction On... Fluids—especially gases—are less conductive. Thermal contact conductance is the study of heat conduction between solid bodies in contact. The process of heat transfer
The first thing we do is pre-compute the support documents for the training and validation sets so we can use all available GPUs to train the sequence-to-sequence model. The model is then trained with the train_qa_s2s
function in eli5_utils.py
. A 16GB GPU accomodates up to two examples at a time, so here is the code to train the model using 4 GPUs with torch.nn.DataPArallel
. One epoch should take about 18 hours:
# pre-computing support documents
eli5_train_docs = []
for example in eli5['train_eli5']:
support_doc, dense_res_list = query_qa_dense_index(
example['title'], qar_model, qar_tokenizer, wiki40b_snippets, wiki40b_gpu_index, n_results=n_ret
)
eli5_train_docs += [(example['q_id'], support_doc, dense_res_list)]
eli5_valid_docs = []
for example in eli5['validation_eli5']:
support_doc, dense_res_list = query_qa_dense_index(
example['title'], qar_model, qar_tokenizer, wiki40b_snippets, wiki40b_gpu_index, n_results=n_ret
)
eli5_valid_docs += [(example['q_id'], support_doc, dense_res_list)]
# training loop proper
class ArgumentsS2S():
def __init__(self):
self.batch_size = 8
self.backward_freq = 16
self.max_length = 1024
self.print_freq = 100
self.model_save_name = "seq2seq_models/eli5_bart_model"
self.learning_rate = 2e-4
self.num_epochs = 3
s2s_args = ArgumentsS2S()
eli5_train_docs = json.load(open('precomputed/eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('precomputed/eli5_valid_precomputed_dense_docs.json'))
s2s_train_dset = ELI5DatasetS2S(eli5['train_eli5'], document_cache=dict([(k, d) for k, d, src_ls in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5['validation_eli5'], document_cache=dict([(k, d) for k, d, src_ls in eli5_valid_docs]), training=False)
qa_s2s_tokenizer, pre_model = make_qa_s2s_model(
model_name="facebook/bart-large",
from_file=None,
device="cuda:0"
)
qa_s2s_model = torch.nn.DataParallel(pre_model)
train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args)
Again, if you don't want to train the model yourself, we made trained weights available on the Hugging Face model repository , which you can download with:
qa_s2s_tokenizer = AutoTokenizer.from_pretrained('yjernite/bart_eli5')
qa_s2s_model = AutoModelForSeq2SeqLM.from_pretrained('yjernite/bart_eli5').to('cuda:0')
_ = qa_s2s_model.eval()
We now have everything we need to answer any question! Now let's try the full system on our running example along with the first four questions of the test set:
questions = []
answers = []
for i in [12345] + [j for j in range(4)]:
# create support document with the dense index
question = eli5['test_eli5'][i]['title']
doc, res_list = query_qa_dense_index(
question, qar_model, qar_tokenizer,
wiki40b_snippets, wiki40b_gpu_index, device='cuda:1'
)
# concatenate question and support document into BART input
question_doc = "question: {} context: {}".format(question, doc)
# generate an answer with beam search
answer = qa_s2s_generate(
question_doc, qa_s2s_model, qa_s2s_tokenizer,
num_answers=1,
num_beams=8,
min_len=64,
max_len=256,
max_input_length=1024,
device="cuda:0"
)[0]
questions += [question]
answers += [answer]
df = pd.DataFrame({
'Question': questions,
'Answer': answers,
})
df.style.set_properties(**{'text-align': 'left'})
We made it, and a lot of these answers actually make sense! The model seems to sometimes struggle with coherence and with starting some of the answers, but we're getting some pretty good information overall.
The last thing we'll do is see how we can get a quantitative evaluation of the model performance. Here, we'll use the ROUGE implementation provided in the nlp
library.
Note that it is a different implementation than the one used in the BART and ELI5 papers: the rouge Python package they use normalises all numerical values, among other pre-processing choices, leading to higher numbers. We reproduce their evaluation in the Appendix section, but recommend using the more sensitive metric provided by the nlp
package, which can be computed with:
predicted = []
reference = []
# Generate answers for the full test set
for i in range(eli5['test_eli5'].num_rows):
# create support document with the dense index
question = eli5['test_eli5'][i]['title']
doc, res_list = query_qa_dense_index(
question, qar_model, qar_tokenizer,
wiki40b_snippets, wiki40b_gpu_index, device='cuda:1'
)
# concatenate question and support document into BART input
question_doc = "question: {} context: {}".format(question, doc)
# generate an answer with beam search
answer = qa_s2s_generate(
question_doc, qa_s2s_model, qa_s2s_tokenizer,
num_answers=1,
num_beams=8,
min_len=96,
max_len=256,
max_input_length=1024,
device="cuda:0"
)[0]
predicted += [answer]
reference += [eli5['test_eli5'][i]['answers']['text'][0]]
# Compare each generation to the fist answer from the dataset
nlp_rouge = nlp.load_metric('rouge')
scores = nlp_rouge.compute(
predicted, reference,
rouge_types=['rouge1', 'rouge2', 'rougeL', 'rougeLsum'],
use_agregator=True, use_stemmer=False
)
df = pd.DataFrame({
'rouge1': [scores['rouge1'].mid.precision, scores['rouge1'].mid.recall, scores['rouge1'].mid.fmeasure],
'rouge2': [scores['rouge2'].mid.precision, scores['rouge2'].mid.recall, scores['rouge2'].mid.fmeasure],
'rougeL': [scores['rougeL'].mid.precision, scores['rougeL'].mid.recall, scores['rougeL'].mid.fmeasure],
}, index=[ 'P', 'R', 'F'])
df.style.format({'rouge1': "{:.4f}", 'rouge2': "{:.4f}", 'rougeL': "{:.4f}"})
That's it for today! And once again, if you want to play with the model a bit more and ask it whatever question comes to mind, please feel free to head over to:
Thank you for reading!
Here we reproduce the ROUGE evaluation from the original ELI5 paper to be able to comparable our performance to theirs. Our generation setting leads to lower ROUGE-1 and ROUGE-2 than the state-of-the-art reported in BART (30.6 and 6.2 respectively), and higher ROUGE-L (24.3).
from nltk import PorterStemmer
from rouge import Rouge
from spacy.lang.en import English
from time import time
stemmer = PorterStemmer()
rouge = Rouge()
tokenizer = English().Defaults.create_tokenizer()
def compute_rouge_eli5(compare_list):
preds = [" ".join([stemmer.stem(str(w))
for w in tokenizer(pred)])
for gold, pred in compare_list]
golds = [" ".join([stemmer.stem(str(w))
for w in tokenizer(gold)])
for gold, pred in compare_list]
scores = rouge.get_scores(preds, golds, avg=True)
return scores
compare_list = [(g, p) for p, g in zip(predicted, reference)]
scores = compute_rouge_eli5(compare_list)
df = pd.DataFrame({
'rouge1': [scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']],
'rouge2': [scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']],
'rougeL': [scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']],
}, index=[ 'P', 'R', 'F'])
df.style.format({'rouge1': "{:.4f}", 'rouge2': "{:.4f}", 'rougeL': "{:.4f}"})