Source code for flowws_keras_experimental.Save

import contextlib
import hashlib
import json

import flowws
from flowws import Argument as Arg
import gtar
import keras_gtar
import numpy as np

[docs]@flowws.add_stage_arguments class Save(flowws.Stage): """Save the architecture and weights of a model using keras_gtar.""" ARGS = [ Arg('save_model', '-s', bool, False), Arg('hash_size', '-c', int, 0, help='If given, use a hash of the workflow description for the dump filename'), Arg('group', '-g', str, help='If given, set the GTAR group ID to save with'), Arg('file_modifiers', '-f', [str], [], help='List of additional filename modifiers to use'), ] def run(self, scope, storage): varying = [] if 'log_quantities' not in scope: history = scope['model'].history.history frame = scope.get('last_epoch', len(history['loss'])) for (key, val) in history.items(): val = np.asarray(val) dtype = str(val.dtype).replace('f', 'F').replace('u', 'U').replace('i', 'I') fmt = getattr(gtar.Format, dtype, gtar.Format.Float32) rec = gtar.Record('', key, str(frame), gtar.Behavior.Continuous, fmt, gtar.Resolution.Uniform) varying.append((rec, val)) for (frame, quantities) in scope.get('log_quantities', []): for (key, val) in quantities.items(): val = np.asarray(val) dtype = str(val.dtype).replace('f', 'F').replace('u', 'U').replace('i', 'I') fmt = getattr(gtar.Format, dtype, gtar.Format.Float32) rec = gtar.Record('', key, str(frame), gtar.Behavior.Continuous, fmt, gtar.Resolution.Uniform) varying.append((rec, val)) with contextlib.ExitStack() as st: modifiers = list(self.arguments['file_modifiers']) if self.arguments['hash_size']: N = self.arguments['hash_size'] mod = hashlib.sha1(json.dumps( scope['workflow'].to_JSON()).encode()).hexdigest()[:N] modifiers.append(mod) storage_handle = st.enter_context(storage.open( scope.get('dump_filename', 'dump.sqlite'), 'a', modifiers, on_filesystem=True)) traj = st.enter_context(keras_gtar.Trajectory( storage_handle.name, 'a', group=self.arguments.get('group', None))) if self.arguments['save_model']: traj.save(scope['model'], str(frame)) for (rec, val) in varying: traj.handle.writeRecord(rec, val) traj.handle.writeStr('workflow.json', json.dumps(scope['workflow'].to_JSON())) traj.handle.writeStr('metadata.json', json.dumps(scope.get('metadata', {})))