From 435eb54671cdd490345155bf08f6326a1678eed5 Mon Sep 17 00:00:00 2001 From: dcfidalgo Date: Mon, 8 Mar 2021 17:31:21 +0100 Subject: [PATCH] add training script --- zeroshot_training_script.py | 247 ++++++++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 zeroshot_training_script.py diff --git a/zeroshot_training_script.py b/zeroshot_training_script.py new file mode 100644 index 0000000..ef8fd11 --- /dev/null +++ b/zeroshot_training_script.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # Creating a Zero-Shot classifier based on BETO +# +# This notebook/script fine-tunes a BETO (spanish bert, 'dccuchile/bert-base-spanish-wwm-cased') model on the spanish XNLI dataset. +# The fine-tuned model can then be fed to a Huggingface ZeroShot pipeline to obtain a ZeroShot classifier. + +# In[ ]: + + +from datasets import load_dataset, Dataset, load_metric, load_from_disk +from transformers import AutoTokenizer, AutoModelForSequenceClassification +from transformers import Trainer, TrainingArguments +import torch +from pathlib import Path +# from ray import tune +# from ray.tune.suggest.hyperopt import HyperOptSearch +# from ray.tune.schedulers import ASHAScheduler + + +# # Prepare the datasets + +# In[ ]: + + +xnli_es = load_dataset("xnli", "es") + + +# In[ ]: + + +xnli_es + + +# >joeddav +# >Aug '20 +# > +# >@rsk97 In addition, just make sure the model used is trained on an NLI task and that the **last output label corresponds to entailment** while the **first output label corresponds to contradiction**. +# +# => We change the original `label` and use the `labels` column, which is required by a `AutoModelForSequenceClassification` + +# In[ ]: + + +# see markdown above +def switch_label_id(row): + if row["label"] == 0: + return {"labels": 2} + elif row["label"] == 2: + return {"labels": 0} + else: + return {"labels": 1} + +for split in xnli_es: + xnli_es[split] = xnli_es[split].map(switch_label_id) + + +# ## Tokenize data + +# In[ ]: + + +tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased") + + +# In a first attempt i padded all data to the maximum length of the dataset (379). However, the traninig takes substanially longer with all the paddings, it's better to pass in the tokenizer to the `Trainer` and let the `Trainer` do the padding on a batch level. + +# In[ ]: + + +# Figured out max length of the dataset manually +# max_length = 379 +def tokenize(row): + return tokenizer(row["premise"], row["hypothesis"], truncation=True, max_length=512) #, padding="max_length", max_length=max_length) + + +# In[ ]: + + +data = {} +for split in xnli_es: + data[split] = xnli_es[split].map( + tokenize, + remove_columns=["hypothesis", "premise", "label"], + batched=True, + batch_size=128 + ) + + +# In[ ]: + + +train_path = str(Path("./train_ds").absolute()) +valid_path = str(Path("./valid_ds").absolute()) + +data["train"].save_to_disk(train_path) +data["validation"].save_to_disk(valid_path) + + +# In[ ]: + + +# We can use `datasets.Dataset`s directly + +# class XnliDataset(torch.utils.data.Dataset): +# def __init__(self, data): +# self.data = data + +# def __getitem__(self, idx): +# item = {key: torch.tensor(val) for key, val in self.data[idx].items()} +# return item + +# def __len__(self): +# return len(self.data) + + +# In[ ]: + + +def trainable(config): + metric = load_metric("xnli", "es") + + def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = predictions.argmax(axis=-1) + return metric.compute(predictions=predictions, references=labels) + + model = AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3) + + training_args = TrainingArguments( + output_dir='./results', # output directory + do_train=True, + do_eval=True, + evaluation_strategy="steps", + eval_steps=500, + load_best_model_at_end=True, + metric_for_best_model="eval_accuracy", + num_train_epochs=config["epochs"], # total number of training epochs + per_device_train_batch_size=config["batch_size"], # batch size per device during training + per_device_eval_batch_size=config["batch_size_eval"], # batch size for evaluation + warmup_steps=config["warmup_steps"], # 500 + weight_decay=config["weight_decay"], # 0.001 # strength of weight decay + learning_rate=config["learning_rate"], # 5e-05 + logging_dir='./logs', # directory for storing logs + logging_steps=250, + #save_steps=500, # ignored when using load_best_model_at_end + save_total_limit=10, + no_cuda=False, + disable_tqdm=True, + ) + +# train_dataset = XnliDataset(load_from_disk(config["train_path"])) +# valid_dataset = XnliDataset(load_from_disk(config["valid_path"])) + train_dataset = load_from_disk(config["train_path"]) + valid_dataset = load_from_disk(config["valid_path"]) + + + trainer = Trainer( + model, + tokenizer=tokenizer, + args=training_args, # training arguments, defined above + train_dataset=train_dataset, # training dataset + eval_dataset=valid_dataset, # evaluation dataset + compute_metrics=compute_metrics, + ) + + trainer.train() + + +# In[ ]: + + +trainable( + { + "train_path": train_path, + "valid_path": valid_path, + "batch_size": 16, + "batch_size_eval": 64, + "warmup_steps": 500, + "weight_decay": 0.001, + "learning_rate": 5e-5, + "epochs": 3, + } +) + + +# # HPO + +# In[ ]: + + +# config = { +# "train_path": train_path, +# "valid_path": valid_path, +# "warmup_steps": tune.randint(0, 500), +# "weight_decay": tune.loguniform(0.00001, 0.1), +# "learning_rate": tune.loguniform(5e-6, 5e-4), +# "epochs": tune.choice([2, 3, 4]) +# } + + +# # In[ ]: + + +# analysis = tune.run( +# trainable, +# config=config, +# metric="eval_acc", +# mode="max", +# #search_alg=HyperOptSearch(), +# #scheduler=ASHAScheduler(), +# num_samples=1, +# ) + + +# # In[ ]: + + +# def model_init(): +# return AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3) + +# trainer = Trainer( +# args=training_args, # training arguments, defined above +# train_dataset=train_dataset, # training dataset +# eval_dataset=valid_dataset, # evaluation dataset +# model_init=model_init, +# compute_metrics=compute_metrics, +# ) + + +# # In[ ]: + + +# best_trial = trainer.hyperparameter_search( +# direction="maximize", +# backend="ray", +# n_trials=2, +# # Choose among many libraries: +# # https://docs.ray.io/en/latest/tune/api_docs/suggestion.html +# search_alg=HyperOptSearch(mode="max", metric="accuracy"), +# # Choose among schedulers: +# # https://docs.ray.io/en/latest/tune/api_docs/schedulers.html +# scheduler=ASHAScheduler(mode="max", metric="accuracy"), +# local_dir="tune_runs", +# ) +