Compare commits
10 Commits
a9aaa74e9e
...
9fc9c4e180
Author | SHA1 | Date |
---|---|---|
|
9fc9c4e180 | |
|
c626438eec | |
|
df7df4d210 | |
|
6a35c499ad | |
|
231de06fb7 | |
|
a4deec410f | |
|
210c8358e0 | |
|
bffd62b384 | |
|
e96ffdb447 | |
|
048916aef0 |
|
@ -6,3 +6,5 @@
|
|||
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
---
|
||||
license: mit
|
||||
thumbnail: https://huggingface.co/front/thumbnails/facebook.png
|
||||
pipeline_tag: zero-shot-classification
|
||||
datasets:
|
||||
- multi_nli
|
||||
---
|
||||
|
||||
# bart-large-mnli
|
||||
|
||||
This is the checkpoint for [bart-large](https://huggingface.co/facebook/bart-large) after being trained on the [MultiNLI (MNLI)](https://huggingface.co/datasets/multi_nli) dataset.
|
||||
|
||||
Additional information about this model:
|
||||
- The [bart-large](https://huggingface.co/facebook/bart-large) model page
|
||||
- [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
|
||||
](https://arxiv.org/abs/1910.13461)
|
||||
- [BART fairseq implementation](https://github.com/pytorch/fairseq/tree/master/fairseq/models/bart)
|
||||
|
||||
## NLI-based Zero Shot Text Classification
|
||||
|
||||
[Yin et al.](https://arxiv.org/abs/1909.00161) proposed a method for using pre-trained NLI models as a ready-made zero-shot sequence classifiers. The method works by posing the sequence to be classified as the NLI premise and to construct a hypothesis from each candidate label. For example, if we want to evaluate whether a sequence belongs to the class "politics", we could construct a hypothesis of `This text is about politics.`. The probabilities for entailment and contradiction are then converted to label probabilities.
|
||||
|
||||
This method is surprisingly effective in many cases, particularly when used with larger pre-trained models like BART and Roberta. See [this blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html) for a more expansive introduction to this and other zero shot methods, and see the code snippets below for examples of using this model for zero-shot classification both with Hugging Face's built-in pipeline and with native Transformers/PyTorch code.
|
||||
|
||||
#### With the zero-shot classification pipeline
|
||||
|
||||
The model can be loaded with the `zero-shot-classification` pipeline like so:
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
classifier = pipeline("zero-shot-classification",
|
||||
model="facebook/bart-large-mnli")
|
||||
```
|
||||
|
||||
You can then use this pipeline to classify sequences into any of the class names you specify.
|
||||
|
||||
```python
|
||||
sequence_to_classify = "one day I will see the world"
|
||||
candidate_labels = ['travel', 'cooking', 'dancing']
|
||||
classifier(sequence_to_classify, candidate_labels)
|
||||
#{'labels': ['travel', 'dancing', 'cooking'],
|
||||
# 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],
|
||||
# 'sequence': 'one day I will see the world'}
|
||||
```
|
||||
|
||||
If more than one candidate label can be correct, pass `multi_class=True` to calculate each class independently:
|
||||
|
||||
```python
|
||||
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
|
||||
classifier(sequence_to_classify, candidate_labels, multi_class=True)
|
||||
#{'labels': ['travel', 'exploration', 'dancing', 'cooking'],
|
||||
# 'scores': [0.9945111274719238,
|
||||
# 0.9383890628814697,
|
||||
# 0.0057061901316046715,
|
||||
# 0.0018193122232332826],
|
||||
# 'sequence': 'one day I will see the world'}
|
||||
```
|
||||
|
||||
|
||||
#### With manual PyTorch
|
||||
|
||||
```python
|
||||
# pose sequence as a NLI premise and label as a hypothesis
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
|
||||
|
||||
premise = sequence
|
||||
hypothesis = f'This example is {label}.'
|
||||
|
||||
# run through model pre-trained on MNLI
|
||||
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
|
||||
truncation_strategy='only_first')
|
||||
logits = nli_model(x.to(device))[0]
|
||||
|
||||
# we throw away "neutral" (dim 1) and take the probability of
|
||||
# "entailment" (2) as the probability of the label being true
|
||||
entail_contradiction_logits = logits[:,[0,2]]
|
||||
probs = entail_contradiction_logits.softmax(dim=1)
|
||||
prob_label_is_true = probs[:,1]
|
||||
```
|
36
config.json
36
config.json
|
@ -1,49 +1,49 @@
|
|||
{
|
||||
"_num_labels": 3,
|
||||
"activation_dropout": 0.0,
|
||||
"architectures": null,
|
||||
"activation_function": "gelu",
|
||||
"add_final_layer_norm": false,
|
||||
"architectures": [
|
||||
"BartForSequenceClassification"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"classif_dropout": 0.0,
|
||||
"classifier_dropout": 0.0,
|
||||
"d_model": 1024,
|
||||
"decoder_attention_heads": 16,
|
||||
"decoder_ffn_dim": 4096,
|
||||
"decoder_layerdrop": 0.0,
|
||||
"decoder_layers": 12,
|
||||
"do_sample": false,
|
||||
"decoder_start_token_id": 2,
|
||||
"dropout": 0.1,
|
||||
"encoder_attention_heads": 16,
|
||||
"encoder_ffn_dim": 4096,
|
||||
"encoder_layerdrop": 0.0,
|
||||
"encoder_layers": 12,
|
||||
"finetuning_task": null,
|
||||
"eos_token_id": 2,
|
||||
"forced_eos_token_id": 2,
|
||||
"gradient_checkpointing": false,
|
||||
"id2label": {
|
||||
"0": "contradiction",
|
||||
"1": "neutral",
|
||||
"2": "entailment"
|
||||
},
|
||||
"init_std": 0.02,
|
||||
"is_decoder": false,
|
||||
"is_encoder_decoder": true,
|
||||
"label2id": {
|
||||
"contradiction": 0,
|
||||
"entailment": 2,
|
||||
"neutral": 1
|
||||
},
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"max_position_embeddings": 1024,
|
||||
"model_type": "bart",
|
||||
"num_beams": 1,
|
||||
"normalize_before": false,
|
||||
"num_hidden_layers": 12,
|
||||
"num_labels": 3,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_past": false,
|
||||
"pruned_heads": {},
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torchscript": false,
|
||||
"use_bfloat16": false,
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": false,
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 50265
|
||||
}
|
||||
|
|
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
|||
{"model_max_length": 1024}
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue