This commit is contained in:
parent
ced69de5b6
commit
b2b5013007
|
@ -24,7 +24,7 @@ For evaluation results, see [SBERT.net - Pretrained Cross-Encoder](https://www.s
|
||||||
Pre-trained models can be used like this:
|
Pre-trained models can be used like this:
|
||||||
```python
|
```python
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
model = CrossEncoder('model_name')
|
model = CrossEncoder('cross-encoder/nli-roberta-base')
|
||||||
scores = model.predict([('A man is eating pizza', 'A man eats something'), ('A black race car starts up in front of a crowd of people.', 'A man is driving down a lonely road.')])
|
scores = model.predict([('A man is eating pizza', 'A man eats something'), ('A black race car starts up in front of a crowd of people.', 'A man is driving down a lonely road.')])
|
||||||
|
|
||||||
#Convert scores to labels
|
#Convert scores to labels
|
||||||
|
@ -38,8 +38,8 @@ You can use the model also directly with Transformers library (without SentenceT
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained('model_name')
|
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-roberta-base')
|
||||||
tokenizer = AutoTokenizer.from_pretrained('model_name')
|
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-roberta-base')
|
||||||
|
|
||||||
features = tokenizer(['A man is eating pizza', 'A black race car starts up in front of a crowd of people.'], ['A man eats something', 'A man is driving down a lonely road.'], padding=True, truncation=True, return_tensors="pt")
|
features = tokenizer(['A man is eating pizza', 'A black race car starts up in front of a crowd of people.'], ['A man eats something', 'A man is driving down a lonely road.'], padding=True, truncation=True, return_tensors="pt")
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ with torch.no_grad():
|
||||||
|
|
||||||
## Zero-Shot Classification
|
## Zero-Shot Classification
|
||||||
This model can also be used for zero-shot-classification:
|
This model can also be used for zero-shot-classification:
|
||||||
```
|
```python
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
|
||||||
classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-roberta-base')
|
classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-roberta-base')
|
||||||
|
|
Loading…
Reference in New Issue