import collections
import contextlib
import hashlib
import json
import random
import re
import time
import flowws
from flowws import Argument as Arg
import gtar
import keras_gtar
import numpy as np
import tensorflow as tf
from tensorflow import keras
try:
import tensorflow_addons as tfa
except ImportError:
tfa = None
OPTIMIZER_MAP = dict(
adadelta='Adadelta',
adam='Adam',
rmsprop='RMSprop',
sgd='SGD',
)
def generator_label_shuffler(seed, gen):
rng = np.random.default_rng(seed)
for batch in gen:
rng.shuffle(batch[-1])
yield batch
class SigtermException(Exception):
@classmethod
def handle(cls, signum, frame):
raise cls()
@classmethod
def register(cls):
import signal
signal.signal(signal.SIGTERM, cls.handle)
class SuppressExceptionScope(contextlib.AbstractContextManager):
def __init__(self, scope, exctype, label):
self._scope = scope
self._exctype = exctype
self._label = label
def __enter__(self):
pass
def __exit__(self, exctype, excinst, exctb):
result = exctype is not None and issubclass(exctype, self._exctype)
if result:
print('Caught {}, exiting for {}'.format(str(exctype), self._label))
self._scope.setdefault('exit_reason', []).append(self._label)
return result
class GTARLog(keras.callbacks.Callback):
def __init__(self, filename, buffer_size=8, group='gtar_log'):
self.buffers = collections.defaultdict(list)
self.handle = gtar.GTAR(filename, 'a')
self.buffer_size = buffer_size
self.group = group
self._last_frame = '0'
def flush(self, key, frame):
val = self.buffers.pop(key)
if not len(val):
return
val = np.asarray(val)
dtype = str(val.dtype).replace('f', 'F').replace('u', 'U').replace('i', 'I')
fmt = getattr(gtar.Format, dtype, gtar.Format.Float32)
rec = gtar.Record(self.group, key, frame, gtar.Behavior.Continuous, fmt,
gtar.Resolution.Uniform)
self.handle.writeRecord(rec, val)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
for (k, v) in logs.items():
self.buffers[k].append(v)
frame = self._last_frame = str(epoch)
if any(len(v) >= self.buffer_size for v in self.buffers.values()):
for k in list(self.buffers):
self.flush(k, frame)
def on_train_end(self, logs=None):
for k in list(self.buffers):
self.flush(k, self._last_frame)
self.handle.close()
class TimedBackupAndRestore(keras.callbacks.BackupAndRestore):
def __init__(self, time_limit, *args,
train_generator=None, train_generator_steps=None,
validation_generator=None, validation_generator_steps=None,
**kwargs):
self.time_limit = time_limit
self._last_runtime = 0
self._duration = self.parse_time(self.time_limit)
self._train_generator = train_generator
self._train_generator_steps = train_generator_steps
assert (train_generator is None) or (train_generator_steps is not None)
self._validation_generator = validation_generator
self._validation_generator_steps = validation_generator_steps
assert (validation_generator is None) or (validation_generator_steps is not None)
self._loaded_epoch = None
super().__init__(*args, **kwargs)
@staticmethod
def parse_time(t):
"""Parse a duration string (i.e. 18h3m) into a number of seconds"""
seconds = dict(h=60*60, m=60, s=1)
time_regex = re.compile(r'(?P<count>\d+)(?P<unit>[hms])')
components = re.findall(time_regex, t)
matched_pieces = ''.join([''.join(pair) for pair in components])
assert matched_pieces == t, 'Failed to fully parse "{}"'.format(t)
result = 0
for (count, unit) in components:
result += int(count)*seconds[unit]
return result
def on_epoch_begin(self, epoch, logs=None):
if self._loaded_epoch is None:
self._loaded_epoch = epoch
if self._train_generator is not None:
for _ in range(epoch*self._train_generator_steps):
next(self._train_generator)
self._train_generator = None
if self._validation_generator is not None:
for _ in range(epoch*self._validation_generator_steps):
next(self._validation_generator)
self._validation_generator = None
def on_epoch_end(self, *args, **kwargs):
current_time = time.time()
if current_time > self._last_runtime + self._duration:
super().on_epoch_end(*args, **kwargs)
self._last_runtime = current_time
def get_config(self):
result = super().get_config()
result['time_limit'] = self.time_limit
return result
[docs]@flowws.add_stage_arguments
class Train(flowws.Stage):
"""Build a model and perform some number of training steps.
Training will proceed for the model and dataset that have been
specified in previous stages.
"""
ARGS = [
Arg('optimizer', '-o', str, 'adam',
help='optimizer to use'),
Arg('optimizer_kwargs', None, [(str, eval)], [],
help='Keyword arguments to pass to optimizer'),
Arg('epochs', '-e', int, 2000,
help='Max number of epochs'),
Arg('batch_size', '-b', int, 256,
help='Batch size'),
Arg('validation_split', '-v', float, .3),
Arg('early_stopping', type=int),
Arg('early_stopping_best', None, type=bool,
help='If True, restore the best weights at the end of early stopping'),
Arg('reduce_lr', type=int),
Arg('reduce_lr_factor', None, float, .5,
help='Factor to scale learning rate with reduce_lr enabled'),
Arg('dump_period', '-d', int),
Arg('hash_size', '-c', int, 0,
help='If given, use a hash of the workflow description for the dump filename'),
Arg('seed', '-s', int),
Arg('summarize', None, bool, False,
help='If True, print the model summary before training'),
Arg('recompile', None, bool, False,
help='If True, always compile the model in this stage'),
Arg('verbose', None, bool, True,
help='If True, print the training progress'),
Arg('clean_batch_multiple', None, bool, False,
help='If True, make the training data a clean multiple of the batch size'),
Arg('rebuild_model', '-r', bool, False,
help='If True, always rebuild the model when one already exists'),
Arg('generator_train_steps', None, int, None,
help='Number of steps to use as an epoch for training from a generator'),
Arg('generator_val_steps', None, int, None,
help='Number of steps to use as an epoch for evaluation from a generator'),
Arg('disable_tqdm', None, bool, False,
help='If True, don\'t use tqdm to display a progress bar'),
Arg('use_multiprocessing', None, bool, True,
help='If True, use multiprocessing with generators'),
Arg('accumulate_gradients', None, int,
help='Number of batches over which to accumulate gradients before applying'),
Arg('catch_keyboard_interrupt', None, bool, False,
help='If True, catch keyboard interrupts and continue to the next stage'),
Arg('monitor_quantity', None, str, 'val_loss',
help='Quantity to monitor for reduce_lr and early_stopping'),
Arg('shuffle_labels', None, bool, False,
help='If True, shuffle labels for training'),
Arg('checkpoint_dir', None, str,
help='If given, save and restore model checkpoints at the given location'),
Arg('checkpoint_duration', None, str, '10m',
help='Time duration for model checkpointing, if enabled'),
Arg('catch_sigterm', None, bool, False,
help='If True, catch sigterm events and continue to the next stage'),
Arg('terminate_on_nan', None, bool, False,
help='If True, terminate training on nan loss'),
Arg('gtar_log_period', None, int,
help='Number of epochs to buffer for logging quantities via GTAR'),
Arg('gtar_log_modifiers', None, [str],
help='Filename modifiers for live logging of quantities via GTAR'),
]
def run(self, scope, storage):
if 'seed' in self.arguments:
s = self.arguments['seed']
random.seed(s)
random.seed(random.randrange(2**32))
np.random.seed(random.randrange(2**32))
tf.random.set_seed(random.randrange(2**32))
if self.arguments['clean_batch_multiple']:
bs = self.arguments['batch_size']
x_train = scope['x_train']
scope['x_train'] = x_train[:len(x_train)//bs*bs]
y_train = scope['y_train']
scope['y_train'] = y_train[:len(y_train)//bs*bs]
metrics = scope.get('metrics', [])
if self.arguments['optimizer_kwargs']:
optimizer_cls = getattr(
keras.optimizers, OPTIMIZER_MAP[self.arguments['optimizer']])
optimizer = optimizer_cls(**dict(self.arguments['optimizer_kwargs']))
else:
optimizer = self.arguments['optimizer']
should_compile = self.arguments['recompile']
should_compile |= 'accumulate_gradients' in self.arguments
if 'model' not in scope or self.arguments['rebuild_model']:
ModelCls = scope.get('custom_model_class', keras.models.Model)
model = ModelCls(scope['input_symbol'], scope['output'])
scope['model'] = model
for term in scope.get('extra_losses', []):
model.add_loss(term)
should_compile = True
else:
model = scope['model']
if self.arguments['summarize']:
model.summary()
if should_compile:
if isinstance(optimizer, str):
optimizer = keras.optimizers.get(optimizer)
if 'accumulate_gradients' in self.arguments:
from .accumulate_gradients import convert
convert(optimizer, self.arguments['accumulate_gradients'])
compile_kwargs = scope.get('compile_kwargs', {})
model.compile(optimizer, loss=scope['loss'], metrics=metrics, **compile_kwargs)
callbacks = list(scope.get('callbacks', []))
if 'early_stopping' in self.arguments:
callbacks.append(keras.callbacks.EarlyStopping(
patience=self.arguments['early_stopping'],
monitor=self.arguments['monitor_quantity'],
restore_best_weights=self.arguments.get('early_stopping_best', False)))
if 'reduce_lr' in self.arguments:
callbacks.append(keras.callbacks.ReduceLROnPlateau(
patience=self.arguments['reduce_lr'],
monitor=self.arguments['monitor_quantity'],
factor=self.arguments['reduce_lr_factor'],
verbose=True, min_delta=0))
restore_callback = None
if 'checkpoint_dir' in self.arguments:
kwargs = {}
if 'train_generator' in scope:
kwargs['train_generator'] = scope['train_generator']
kwargs['train_generator_steps'] = (
self.arguments.get('generator_train_steps', None) or
scope.get('generator_train_steps', None))
if 'validation_generator' in scope:
kwargs['validation_generator'] = scope['validation_generator']
kwargs['validation_generator_steps'] = (
self.arguments.get('generator_val_steps', None) or
scope.get('generator_val_steps', None))
restore_callback = TimedBackupAndRestore(
self.arguments['checkpoint_duration'],
self.arguments['checkpoint_dir'], **kwargs)
callbacks.append(restore_callback)
if self.arguments['terminate_on_nan']:
callbacks.append(keras.callbacks.TerminateOnNaN())
verbose = self.arguments['verbose']
if tfa is not None and verbose and not self.arguments['disable_tqdm']:
callbacks.append(tfa.callbacks.TQDMProgressBar(
show_epoch_progress=False, update_per_second=1))
verbose = False
with contextlib.ExitStack() as context_stack:
if self.arguments.get('dump_period', None):
modifiers = []
if self.arguments['hash_size']:
N = self.arguments['hash_size']
mod = hashlib.sha1(json.dumps(
scope['workflow'].to_JSON()).encode()).hexdigest()[:N]
modifiers.append(mod)
handle = context_stack.enter_context(storage.open(
scope.get('dump_filename', 'dump.tar'), 'a', modifiers, on_filesystem=True))
cbk = keras_gtar.GTARLogger(
handle.name, self.arguments['dump_period'], append=True, when='pre_epoch')
callbacks.append(cbk)
initial_epoch = scope.setdefault('last_epoch', 0)
total_epochs = initial_epoch + self.arguments['epochs']
args = []
kwargs = dict(
verbose=verbose,
epochs=total_epochs,
callbacks=callbacks,
initial_epoch=initial_epoch
)
if 'train_generator' in scope:
train_gen = scope['train_generator']
if self.arguments['shuffle_labels']:
train_gen = generator_label_shuffler(
self.arguments.get('seed', 13), train_gen)
args.append(train_gen)
kwargs['steps_per_epoch'] = (self.arguments.get('generator_train_steps', None) or
scope.get('generator_train_steps', None))
kwargs['use_multiprocessing'] = self.arguments['use_multiprocessing']
if 'validation_generator' in scope:
val_gen = scope['validation_generator']
if self.arguments['shuffle_labels']:
val_gen = generator_label_shuffler(
self.arguments.get('seed', 13), val_gen)
kwargs['validation_data'] = val_gen
kwargs['validation_steps'] = (self.arguments.get('generator_val_steps', None) or
scope.get('generator_val_steps', None))
else:
labels = scope['y_train']
if self.arguments['shuffle_labels']:
labels = labels.copy()
np.random.shuffle(labels)
args.extend([scope['x_train'], labels])
kwargs['batch_size'] = self.arguments['batch_size']
kwargs['validation_split'] = self.arguments['validation_split']
if 'validation_data' in scope:
kwargs['validation_data'] = scope['validation_data']
with contextlib.ExitStack() as st:
if self.arguments.get('gtar_log_period', None):
mods = self.arguments['gtar_log_modifiers']
storage_handle = st.enter_context(storage.open(
scope.get('dump_filename', 'dump.sqlite'), 'a',
mods, on_filesystem=True))
filename = storage_handle.name
callback = GTARLog(filename, self.arguments['gtar_log_period'])
callbacks.append(callback)
if self.arguments['catch_keyboard_interrupt']:
st.enter_context(SuppressExceptionScope(
scope, KeyboardInterrupt, 'catch_keyboard_interrupt'))
if self.arguments['catch_sigterm']:
SigtermException.register()
st.enter_context(SuppressExceptionScope(
scope, SigtermException, 'catch_sigterm'))
model.fit(*args, **kwargs)
if self.arguments['epochs']:
last_epoch = scope['last_epoch']
if restore_callback is not None:
last_epoch = restore_callback._loaded_epoch
current_epoch = last_epoch + len(model.history.history['loss'])
scope['last_epoch'] = current_epoch
log_quantities = scope.setdefault('log_quantities', [])
log_quantities.append((current_epoch, model.history.history))