Update pipeline.py
This commit is contained in:
parent
c1c837b30e
commit
93bcb8d51d
|
@ -9,7 +9,7 @@ class PreTrainedPipeline():
|
|||
|
||||
def __init__(self, path=""):
|
||||
|
||||
model_dir = os.path.join(path, "ckpt_epoch_3_step_6900")
|
||||
model_dir = path
|
||||
|
||||
self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
|
||||
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
|
||||
|
|
Loading…
Reference in New Issue