diff --git a/README.md b/README.md index ca69c5c..3693ee7 100644 --- a/README.md +++ b/README.md @@ -17,13 +17,15 @@ import torch import warnings warnings.filterwarnings("ignore") - +''' +uncomment to get reproducable paraphrase generations def random_state(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) random_state(1234) +''' #Init models (make sure you init ONLY once if you integrate this to your code) parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=False)