This commit is contained in:
parent
23111f3256
commit
ced69de5b6
13
README.md
13
README.md
|
@ -50,3 +50,16 @@ with torch.no_grad():
|
||||||
labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
|
labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
|
||||||
print(labels)
|
print(labels)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Zero-Shot Classification
|
||||||
|
This model can also be used for zero-shot-classification:
|
||||||
|
```
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-roberta-base')
|
||||||
|
|
||||||
|
sent = "Apple just announced the newest iPhone X"
|
||||||
|
candidate_labels = ["technology", "sports", "politics"]
|
||||||
|
res = classifier(sent, candidate_labels)
|
||||||
|
print(res)
|
||||||
|
```
|
Loading…
Reference in New Issue