This commit is contained in:
parent
23111f3256
commit
ced69de5b6
13
README.md
13
README.md
|
@ -50,3 +50,16 @@ with torch.no_grad():
|
|||
labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
|
||||
print(labels)
|
||||
```
|
||||
|
||||
## Zero-Shot Classification
|
||||
This model can also be used for zero-shot-classification:
|
||||
```
|
||||
from transformers import pipeline
|
||||
|
||||
classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-roberta-base')
|
||||
|
||||
sent = "Apple just announced the newest iPhone X"
|
||||
candidate_labels = ["technology", "sports", "politics"]
|
||||
res = classifier(sent, candidate_labels)
|
||||
print(res)
|
||||
```
|
Loading…
Reference in New Issue