diff --git a/README.md b/README.md index 7df34e5..72776c6 100755 --- a/README.md +++ b/README.md @@ -49,4 +49,17 @@ with torch.no_grad(): label_mapping = ['contradiction', 'entailment', 'neutral'] labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)] print(labels) -``` \ No newline at end of file +``` + +## Zero-Shot Classification +This model can also be used for zero-shot-classification: +``` +from transformers import pipeline + +classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-roberta-base') + +sent = "Apple just announced the newest iPhone X" +candidate_labels = ["technology", "sports", "politics"] +res = classifier(sent, candidate_labels) +print(res) +``` \ No newline at end of file