223 lines
9.8 KiB
Python
223 lines
9.8 KiB
Python
import os
|
|
import random
|
|
import time
|
|
import datetime
|
|
|
|
import tensorflow as tf
|
|
|
|
import deepdanbooru as dd
|
|
|
|
|
|
def train_project(project_path, source_model):
|
|
project_context_path = os.path.join(project_path, 'project.json')
|
|
project_context = dd.io.deserialize_from_json(project_context_path)
|
|
|
|
width = project_context['image_width']
|
|
height = project_context['image_height']
|
|
database_path = project_context['database_path']
|
|
minimum_tag_count = project_context['minimum_tag_count']
|
|
model_type = project_context['model']
|
|
optimizer_type = project_context['optimizer']
|
|
learning_rate = project_context['learning_rate'] if 'learning_rate' in project_context else 0.001
|
|
learning_rates = project_context['learning_rates'] if 'learning_rates' in project_context else None
|
|
minibatch_size = project_context['minibatch_size']
|
|
epoch_count = project_context['epoch_count']
|
|
export_model_per_epoch = project_context[
|
|
'export_model_per_epoch'] if 'export_model_per_epoch' in project_context else 10
|
|
checkpoint_frequency_mb = project_context['checkpoint_frequency_mb']
|
|
console_logging_frequency_mb = project_context['console_logging_frequency_mb']
|
|
rotation_range = project_context['rotation_range']
|
|
scale_range = project_context['scale_range']
|
|
shift_range = project_context['shift_range']
|
|
|
|
# disable PNG warning
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
# tf.logging.set_verbosity(tf.logging.ERROR)
|
|
|
|
# tf.keras.backend.set_epsilon(1e-6)
|
|
# tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars')
|
|
# tf.config.gpu.set_per_process_memory_growth(True)
|
|
|
|
if optimizer_type == 'adam':
|
|
optimizer = tf.optimizers.Adam(learning_rate)
|
|
print('Using Adam optimizer ... ')
|
|
elif optimizer_type == 'sgd':
|
|
optimizer = tf.optimizers.SGD(
|
|
learning_rate, momentum=0.9, nesterov=True)
|
|
print('Using SGD optimizer ... ')
|
|
elif optimizer_type == 'rmsprop':
|
|
optimizer = tf.optimizers.RMSprop(learning_rate)
|
|
print('Using RMSprop optimizer ... ')
|
|
else:
|
|
raise Exception(
|
|
f"Not supported optimizer : {optimizer_type}")
|
|
|
|
if model_type == 'resnet_152':
|
|
model_delegate = dd.model.resnet.create_resnet_152
|
|
elif model_type == 'resnet_custom_v1':
|
|
model_delegate = dd.model.resnet.create_resnet_custom_v1
|
|
elif model_type == 'resnet_custom_v2':
|
|
model_delegate = dd.model.resnet.create_resnet_custom_v2
|
|
elif model_type == 'resnet_custom_v3':
|
|
model_delegate = dd.model.resnet.create_resnet_custom_v3
|
|
elif model_type == 'resnet_custom_v4':
|
|
model_delegate = dd.model.resnet.create_resnet_custom_v4
|
|
else:
|
|
raise Exception(f'Not supported model : {model_type}')
|
|
|
|
print('Loading tags ... ')
|
|
tags = dd.project.load_tags_from_project(project_path)
|
|
output_dim = len(tags)
|
|
|
|
print(f'Creating model ({model_type}) ... ')
|
|
# tf.keras.backend.set_learning_phase(1)
|
|
|
|
if source_model:
|
|
model = tf.keras.models.load_model(source_model)
|
|
print(f'Model : {model.input_shape} -> {model.output_shape} (loaded from {source_model})')
|
|
else:
|
|
inputs = tf.keras.Input(shape=(height, width, 3),
|
|
dtype=tf.float32) # HWC
|
|
ouputs = model_delegate(inputs, output_dim)
|
|
model = tf.keras.Model(inputs=inputs, outputs=ouputs, name=model_type)
|
|
print(f'Model : {model.input_shape} -> {model.output_shape}')
|
|
|
|
model.compile(optimizer=optimizer, loss=tf.keras.losses.BinaryCrossentropy(),
|
|
metrics=[tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
|
|
|
|
print(f'Loading database ... ')
|
|
image_records = dd.data.load_image_records(
|
|
database_path, minimum_tag_count)
|
|
|
|
# Checkpoint variables
|
|
used_epoch = tf.Variable(0, dtype=tf.int64)
|
|
used_minibatch = tf.Variable(0, dtype=tf.int64)
|
|
used_sample = tf.Variable(0, dtype=tf.int64)
|
|
offset = tf.Variable(0, dtype=tf.int64)
|
|
random_seed = tf.Variable(0, dtype=tf.int64)
|
|
|
|
checkpoint = tf.train.Checkpoint(
|
|
optimizer=optimizer,
|
|
model=model,
|
|
used_epoch=used_epoch,
|
|
used_minibatch=used_minibatch,
|
|
used_sample=used_sample,
|
|
offset=offset,
|
|
random_seed=random_seed)
|
|
|
|
manager = tf.train.CheckpointManager(
|
|
checkpoint=checkpoint,
|
|
directory=os.path.join(project_path, 'checkpoints'),
|
|
max_to_keep=3)
|
|
|
|
if manager.latest_checkpoint:
|
|
print(f"Checkpoint exists. Continuing training ... ({datetime.datetime.now()})")
|
|
checkpoint.restore(manager.latest_checkpoint)
|
|
print(f'used_epoch={int(used_epoch)}, used_minibatch={int(used_minibatch)}, used_sample={int(used_sample)}, offset={int(offset)}, random_seed={int(random_seed)}')
|
|
else:
|
|
print(f'No checkpoint. Starting new training ... ({datetime.datetime.now()})')
|
|
|
|
epoch_size = len(image_records)
|
|
slice_size = minibatch_size * checkpoint_frequency_mb
|
|
loss_sum = 0.0
|
|
loss_count = 0
|
|
used_sample_sum = 0
|
|
last_time = time.time()
|
|
|
|
while int(used_epoch) < epoch_count:
|
|
print(f'Shuffling samples (epoch {int(used_epoch)}) ... ')
|
|
epoch_random = random.Random(int(random_seed))
|
|
epoch_random.shuffle(image_records)
|
|
|
|
# Udpate learning rate
|
|
if learning_rates:
|
|
for learning_rate_per_epoch in learning_rates:
|
|
if learning_rate_per_epoch['used_epoch'] <= int(used_epoch):
|
|
learning_rate = learning_rate_per_epoch['learning_rate']
|
|
print(f'Trying to change learning rate to {learning_rate} ...')
|
|
optimizer.learning_rate.assign(learning_rate)
|
|
print(f'Learning rate is changed to {optimizer.learning_rate} ...')
|
|
|
|
while int(offset) < epoch_size:
|
|
image_records_slice = image_records[int(offset):min(
|
|
int(offset) + slice_size, epoch_size)]
|
|
|
|
image_paths = [image_record[0]
|
|
for image_record in image_records_slice]
|
|
tag_strings = [image_record[1]
|
|
for image_record in image_records_slice]
|
|
|
|
dataset_wrapper = dd.data.DatasetWrapper(
|
|
(image_paths, tag_strings), tags, width, height, scale_range=scale_range, rotation_range=rotation_range, shift_range=shift_range)
|
|
dataset = dataset_wrapper.get_dataset(minibatch_size)
|
|
|
|
for (x_train, y_train) in dataset:
|
|
sample_count = x_train.shape[0]
|
|
|
|
step_result = model.train_on_batch(
|
|
x_train, y_train, reset_metrics=False)
|
|
|
|
used_minibatch.assign_add(1)
|
|
used_sample.assign_add(sample_count)
|
|
used_sample_sum += sample_count
|
|
loss_sum += step_result[0]
|
|
loss_count += 1
|
|
|
|
if int(used_minibatch) % console_logging_frequency_mb == 0:
|
|
# calculate logging informations
|
|
current_time = time.time()
|
|
delta_time = current_time - last_time
|
|
step_metric_precision = step_result[1]
|
|
step_metric_recall = step_result[2]
|
|
if step_metric_precision + step_metric_recall > 0.0:
|
|
step_metric_f1_score = 2.0 * \
|
|
(step_metric_precision * step_metric_recall) / \
|
|
(step_metric_precision + step_metric_recall)
|
|
else:
|
|
step_metric_f1_score = 0.0
|
|
average_loss = loss_sum / float(loss_count)
|
|
samples_per_seconds = float(
|
|
used_sample_sum) / max(delta_time, 0.001)
|
|
progress = float(int(used_sample)) / \
|
|
float(epoch_size * epoch_count) * 100.0
|
|
remain_seconds = float(
|
|
epoch_size * epoch_count - int(used_sample)) / max(samples_per_seconds, 0.001)
|
|
eta_datetime = datetime.datetime.now() + datetime.timedelta(seconds=remain_seconds)
|
|
eta_datetime_string = eta_datetime.strftime(
|
|
'%Y-%m-%d %H:%M:%S')
|
|
print(
|
|
f'Epoch[{int(used_epoch)}] Loss={average_loss:.6f}, P={step_metric_precision:.6f}, R={step_metric_recall:.6f}, F1={step_metric_f1_score:.6f}, Speed = {samples_per_seconds:.1f} samples/s, {progress:.2f} %, ETA = {eta_datetime_string}')
|
|
|
|
# reset for next logging
|
|
model.reset_metrics()
|
|
loss_sum = 0.0
|
|
loss_count = 0
|
|
used_sample_sum = 0
|
|
last_time = current_time
|
|
|
|
offset.assign_add(slice_size)
|
|
print(f'Saving checkpoint ... ({datetime.datetime.now()})')
|
|
manager.save()
|
|
|
|
used_epoch.assign_add(1)
|
|
random_seed.assign_add(1)
|
|
offset.assign(0)
|
|
|
|
if int(used_epoch) % export_model_per_epoch == 0:
|
|
print('Saving model ... (per epoch {export_model_per_epoch})')
|
|
export_path = os.path.join(
|
|
project_path, f'model-{model_type}.h5.e{int(used_epoch)}')
|
|
model.save(export_path, include_optimizer=False, save_format='h5')
|
|
|
|
print('Saving model ...')
|
|
model_path = os.path.join(
|
|
project_path, f'model-{model_type}.h5')
|
|
|
|
# tf.keras.experimental.export_saved_model throw exception now
|
|
# see https://github.com/tensorflow/tensorflow/issues/27112
|
|
model.save(model_path, include_optimizer=False)
|
|
|
|
print('Training is complete.')
|
|
print(
|
|
f'used_epoch={int(used_epoch)}, used_minibatch={int(used_minibatch)}, used_sample={int(used_sample)}')
|