Source code for mldas.trainers.base
__copyright__ = """
Machine Learning for Distributed Acoustic Sensing data (MLDAS)
Copyright (c) 2020, The Regents of the University of California,
through Lawrence Berkeley National Laboratory (subject to receipt of
any required approvals from the U.S. Dept. of Energy). All rights reserved.
If you have questions about your rights to use or distribute this software,
please contact Berkeley Lab's Intellectual Property Office at
IPO@lbl.gov.
NOTICE. This Software was developed under funding from the U.S. Department
of Energy and the U.S. Government consequently retains certain rights. As
such, the U.S. Government has been granted for itself and others acting on
its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the
Software to reproduce, distribute copies to the public, prepare derivative
works, and perform publicly and display publicly, and to permit others to do so.
"""
__license__ = "Modified BSD license (see LICENSE.txt)"
__maintainer__ = "Vincent Dumont"
__email__ = "vincentdumont11@gmail.com"
# System
import os
import time
import logging
# Externals
import numpy as np
import torch
[docs]class BaseTrainer(object):
"""
Base class for PyTorch trainers.
This implements the common training logic,
logging of summaries, and checkpoints.
"""
[docs] def __init__(self, output_dir=None, gpu=None,
distributed=False, rank=0, **kwargs):
self.logger = logging.getLogger(self.__class__.__name__)
self.output_dir = (os.path.expandvars(output_dir)
if output_dir is not None else None)
self.gpu = gpu
if gpu is not None:
self.device = torch.device('cuda', gpu)
torch.cuda.set_device(gpu)
else:
self.device = torch.device('cpu')
self.distributed = distributed
self.rank = rank
self.summaries = {}
[docs] def print_model_summary(self):
"""Override as needed"""
self.logger.info(
'Model: \n%s\nParameters: %i' %
(self.model, sum(p.numel()
for p in self.model.parameters()))
)
[docs] def save_summary(self, summaries):
"""Save summary information"""
for (key, val) in summaries.items():
summary_vals = self.summaries.get(key, [])
self.summaries[key] = summary_vals + [val]
[docs] def write_summaries(self):
assert self.output_dir is not None
summary_file = os.path.join(self.output_dir,
'summaries_%i.npz' % self.rank)
self.logger.info('Saving summaries to %s' % summary_file)
np.savez(summary_file, **self.summaries)
[docs] def write_checkpoint(self, checkpoint_id):
"""Write a checkpoint for the model"""
assert self.output_dir is not None
checkpoint_dir = os.path.join(self.output_dir, 'checkpoints')
checkpoint_file = 'model_checkpoint_%03i.pth.tar' % checkpoint_id
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(dict(model=self.model.state_dict()),
os.path.join(checkpoint_dir, checkpoint_file))
def build_model(self):
"""Virtual method to construct the model"""
raise NotImplementedError
def train_epoch(self, data_loader):
"""Virtual method to train a model"""
raise NotImplementedError
def evaluate(self, data_loader):
"""Virtual method to evaluate a model"""
raise NotImplementedError
[docs] def train(self, train_data_loader, epochs, valid_data_loader=None, test_data_loader=None, **kwargs):
"""Run the model training"""
# Loop over epochs
for i in range(epochs):
if i+1 in self.lr_decay_epoch:
self.optimizer = self.exp_lr_scheduler(self.optimizer)
self.logger.info(' EPOCH {:>3}/{:<3} | Model initial sumw: {:.5f} |'.format(i+1,epochs,sum(p.sum() for p in self.model.parameters())))
summary = dict(epoch=i)
# Train on this epoch
start_time = time.time()
summary.update(self.train_epoch(train_data_loader,**kwargs))
summary['train_time'] = time.time() - start_time
summary['train_samples'] = len(train_data_loader.sampler)
summary['train_rate'] = summary['train_samples'] / summary['train_time']
# Evaluate on this epoch
if valid_data_loader is not None:
start_time = time.time()
summary.update(self.evaluate(valid_data_loader,'Validation',**kwargs))
summary['valid_time'] = time.time() - start_time
summary['valid_samples'] = len(valid_data_loader.sampler)
summary['valid_rate'] = summary['valid_samples'] / summary['valid_time']
# Save summary, checkpoint
self.save_summary(summary)
if self.output_dir is not None and self.rank==0:
self.write_checkpoint(checkpoint_id=i)
# Evaluate on this epoch
if test_data_loader is not None:
self.evaluate(test_data_loader,'Testing',**kwargs)
return self.summaries