diff --git a/pipeline.py b/pipeline.py index cd5543e..2b7fef8 100644 --- a/pipeline.py +++ b/pipeline.py @@ -9,7 +9,7 @@ class PreTrainedPipeline(): def __init__(self, path=""): - model_dir = os.path.join(path, "ckpt_epoch_3_step_6900") + model_dir = path self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir) self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)