"""
This module defines a generic trainer for simple models and datasets.
"""
# Locals
import os
import time
import logging
# Externals
import numpy
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel
# Locals
from ..models import get_model
from ..explore import figures
[docs]class GenericTrainer():
"""Trainer code for basic classification problems."""
[docs] def __init__(self, rank, device, output_dir=None, **kwargs):
self.device = device
self.output_dir = (os.path.expandvars(output_dir) if output_dir is not None else None)
self.rank = rank
self.summaries = {}
[docs] def build_model(self, device_ids, distributed=False, loss='CE', optimizer='SGD',
lr=0.01, lr_decay_epoch=[], lr_decay_ratio=0.5, momentum=0.9, **model_args):
"""Instantiate our model"""
self.loss = loss
self.lr_decay_ratio = lr_decay_ratio
self.lr_decay_epoch = lr_decay_epoch
# Construct the model
self.model = get_model(**model_args).to(self.device)
# Distributed data parallelism
if distributed:
self.model = DistributedDataParallel(self.model, device_ids=device_ids)
# TODO: add support for more optimizers and loss functions here
opt_type = dict(SGD=torch.optim.SGD)[optimizer]
self.optimizer = opt_type(self.model.parameters(), lr=lr, momentum=momentum)
loss_type = dict(CE=torch.nn.CrossEntropyLoss,
BCE=torch.nn.BCEWithLogitsLoss,
MSE=torch.nn.MSELoss)[loss]
self.loss_func = loss_type()
def exp_lr_scheduler(self, optimizer):
"""
Decay learning rate by a factor of lr_decay
"""
for param_group in optimizer.param_groups:
param_group['lr'] *= self.lr_decay_ratio
return optimizer
def train(self, train_data_loader, epochs, valid_data_loader=None, test_data_loader=None, **kwargs):
"""Run the model training"""
# Loop over epochs
for i in range(epochs):
if i+1 in self.lr_decay_epoch:
self.optimizer = self.exp_lr_scheduler(self.optimizer)
logging.info(' EPOCH {:>3}/{:<3} | Model initial sumw: {:.5f}'.format(i+1,epochs,sum(p.sum() for p in self.model.parameters())))
summary = dict(epoch=i)
# Train on this epoch
start_time = time.time()
summary.update(self.train_epoch(train_data_loader,**kwargs))
summary['train_time'] = time.time() - start_time
summary['train_samples'] = len(train_data_loader.sampler)
summary['train_rate'] = summary['train_samples'] / summary['train_time']
# Evaluate on this epoch
if valid_data_loader is not None:
start_time = time.time()
summary.update(self.evaluate(valid_data_loader,'Validation',**kwargs))
summary['valid_time'] = time.time() - start_time
summary['valid_samples'] = len(valid_data_loader.sampler)
summary['valid_rate'] = summary['valid_samples'] / summary['valid_time']
# Save summary, checkpoint
self.save_summary(summary)
# if self.output_dir is not None and self.rank==0:
self.write_checkpoint(checkpoint_id=i)
# Evaluate on this epoch
if test_data_loader is not None:
summary.update(self.evaluate(test_data_loader,'Testing',**kwargs))
# Save summary, checkpoint
self.save_summary(summary)
return self.summaries
[docs] def train_epoch(self, data_loader, rounded=False, **kwargs):
"""Train for one epoch"""
self.model.train()
sum_loss = 0
sum_correct = 0
# Loop over training batches
for i, (batch_input, batch_target) in enumerate(data_loader):
batch_input = batch_input.to(self.device)
if self.loss=='BCE' and batch_target.dim()==1:
batch_target = batch_target.float().unsqueeze(1)
batch_target = batch_target.to(self.device)
self.model.zero_grad()
batch_output = self.model(batch_input)
if rounded: batch_output = batch_output.round()
batch_loss = self.loss_func(batch_output, batch_target)
batch_loss.backward()
self.optimizer.step()
loss = batch_loss.item()
sum_loss += loss
n_correct = self.accuracy(batch_output, batch_target, **kwargs)
sum_correct += n_correct
logging.debug(' batch {:>3}/{:<3} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'
.format(i+1, len(data_loader), len(batch_input), loss, 100*n_correct/len(batch_input)))
train_loss = sum_loss / (i + 1)
train_acc = sum_correct / len(data_loader.sampler)
logging.info('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'
.format('Training', len(data_loader.sampler), train_loss, 100*train_acc))
return dict(train_loss=train_loss)
[docs] @torch.no_grad()
def evaluate(self, data_loader, mode, rounded=False, **kwargs):
""""Evaluate the model"""
self.model.eval()
sum_loss = 0
sum_correct = 0
# Loop over batches
n = 0
for i, (batch_input, batch_target) in enumerate(data_loader):
batch_input = batch_input.to(self.device)
if self.loss=='BCE' and batch_target.dim()==1:
batch_target = batch_target.float().unsqueeze(1)
batch_target = batch_target.to(self.device)
batch_output = self.model(batch_input)
loss = self.loss_func(batch_output, batch_target).item()
sum_loss += loss
n_correct = self.accuracy(batch_output, batch_target, **kwargs)
sum_correct += n_correct
# if mode=='Testing':
# os.makedirs('results',exist_ok=True)
# for data_input,data_output in zip(batch_target,batch_output):
# figures.plot_test_2d(data_input,data_output,'results/%05i'%n)
# n+=1
valid_loss = sum_loss / (i + 1)
valid_acc = sum_correct / len(data_loader.sampler)
logging.info('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'
.format(mode, len(data_loader.sampler), valid_loss, 100*valid_acc))
mode = 'test' if mode=='Testing' else 'valid'
return {'%s_loss'%mode:valid_loss, '%s_acc'%mode:valid_acc}
[docs] def accuracy(self, batch_output, batch_target, acc_tol=20, **kwargs):
# Count number of correct predictions
if self.loss=='MSE':
#batch_preds = torch.round(batch_output)
batch_preds = batch_output
#n_correct = batch_preds.eq(batch_target).float().mean(dim=1).sum().item()
#n_correct = batch_preds.sub(batch_target).abs().lt(acc_tol).float().mean(dim=1).sum().item()
n_correct = batch_target.sub(batch_preds).square().div(batch_preds.square()).sqrt().mul(100).lt(acc_tol).float().mean(dim=1).sum().item()
elif self.loss=='BCE':
batch_preds = (torch.sigmoid(batch_output)>0.5).float()
if batch_preds.dim()==1:
n_correct = batch_preds.eq(batch_target).float().sum()
else:
n_correct = batch_preds.eq(batch_target).all(dim=1).float().sum()
else:
_, batch_preds = torch.max(batch_output, 1)
n_correct = batch_preds.eq(batch_target).sum().item()
return n_correct
def print_model_summary(self):
"""Override as needed"""
logging.info(
'Model: \n%s\nParameters: %i' %
(self.model, sum(p.numel()
for p in self.model.parameters()))
)
def save_summary(self, summaries):
"""Save summary information"""
for (key, val) in summaries.items():
summary_vals = self.summaries.get(key, [])
self.summaries[key] = summary_vals + [val]
def write_summaries(self):
assert self.output_dir is not None
summary_file = os.path.join(self.output_dir, 'summaries_%i.npz' % self.rank)
logging.info('Saving summaries to %s' % summary_file)
numpy.savez(summary_file, **self.summaries)
def write_checkpoint(self, checkpoint_id):
"""Write a checkpoint for the model"""
assert self.output_dir is not None
checkpoint_dir = os.path.join(self.output_dir, 'checkpoints')
checkpoint_file = 'model_checkpoint_%03i_%i.pth.tar' % (checkpoint_id,self.rank)
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(dict(model=self.model.state_dict()),os.path.join(checkpoint_dir, checkpoint_file))
def get_trainer(**kwargs):
"""
Test
"""
return GenericTrainer(**kwargs)