Source code for mldas.trainers.generic

"""
This module defines a generic trainer for simple models and datasets.
"""

# Locals
import os
import time
import logging

# Externals
import numpy
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel

# Locals
from ..models import get_model
from ..explore import figures

[docs]class GenericTrainer(): """Trainer code for basic classification problems."""
[docs] def __init__(self, rank, device, output_dir=None, **kwargs): self.device = device self.output_dir = (os.path.expandvars(output_dir) if output_dir is not None else None) self.rank = rank self.summaries = {}
[docs] def build_model(self, device_ids, distributed=False, loss='CE', optimizer='SGD', lr=0.01, lr_decay_epoch=[], lr_decay_ratio=0.5, momentum=0.9, **model_args): """Instantiate our model""" self.loss = loss self.lr_decay_ratio = lr_decay_ratio self.lr_decay_epoch = lr_decay_epoch # Construct the model self.model = get_model(**model_args).to(self.device) # Distributed data parallelism if distributed: self.model = DistributedDataParallel(self.model, device_ids=device_ids) # TODO: add support for more optimizers and loss functions here opt_type = dict(SGD=torch.optim.SGD)[optimizer] self.optimizer = opt_type(self.model.parameters(), lr=lr, momentum=momentum) loss_type = dict(CE=torch.nn.CrossEntropyLoss, BCE=torch.nn.BCEWithLogitsLoss, MSE=torch.nn.MSELoss)[loss] self.loss_func = loss_type()
def exp_lr_scheduler(self, optimizer): """ Decay learning rate by a factor of lr_decay """ for param_group in optimizer.param_groups: param_group['lr'] *= self.lr_decay_ratio return optimizer def train(self, train_data_loader, epochs, valid_data_loader=None, test_data_loader=None, **kwargs): """Run the model training""" # Loop over epochs for i in range(epochs): if i+1 in self.lr_decay_epoch: self.optimizer = self.exp_lr_scheduler(self.optimizer) logging.info(' EPOCH {:>3}/{:<3} | Model initial sumw: {:.5f}'.format(i+1,epochs,sum(p.sum() for p in self.model.parameters()))) summary = dict(epoch=i) # Train on this epoch start_time = time.time() summary.update(self.train_epoch(train_data_loader,**kwargs)) summary['train_time'] = time.time() - start_time summary['train_samples'] = len(train_data_loader.sampler) summary['train_rate'] = summary['train_samples'] / summary['train_time'] # Evaluate on this epoch if valid_data_loader is not None: start_time = time.time() summary.update(self.evaluate(valid_data_loader,'Validation',**kwargs)) summary['valid_time'] = time.time() - start_time summary['valid_samples'] = len(valid_data_loader.sampler) summary['valid_rate'] = summary['valid_samples'] / summary['valid_time'] # Save summary, checkpoint self.save_summary(summary) # if self.output_dir is not None and self.rank==0: self.write_checkpoint(checkpoint_id=i) # Evaluate on this epoch if test_data_loader is not None: summary.update(self.evaluate(test_data_loader,'Testing',**kwargs)) # Save summary, checkpoint self.save_summary(summary) return self.summaries
[docs] def train_epoch(self, data_loader, rounded=False, **kwargs): """Train for one epoch""" self.model.train() sum_loss = 0 sum_correct = 0 # Loop over training batches for i, (batch_input, batch_target) in enumerate(data_loader): batch_input = batch_input.to(self.device) if self.loss=='BCE' and batch_target.dim()==1: batch_target = batch_target.float().unsqueeze(1) batch_target = batch_target.to(self.device) self.model.zero_grad() batch_output = self.model(batch_input) if rounded: batch_output = batch_output.round() batch_loss = self.loss_func(batch_output, batch_target) batch_loss.backward() self.optimizer.step() loss = batch_loss.item() sum_loss += loss n_correct = self.accuracy(batch_output, batch_target, **kwargs) sum_correct += n_correct logging.debug(' batch {:>3}/{:<3} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}' .format(i+1, len(data_loader), len(batch_input), loss, 100*n_correct/len(batch_input))) train_loss = sum_loss / (i + 1) train_acc = sum_correct / len(data_loader.sampler) logging.info('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}' .format('Training', len(data_loader.sampler), train_loss, 100*train_acc)) return dict(train_loss=train_loss)
[docs] @torch.no_grad() def evaluate(self, data_loader, mode, rounded=False, **kwargs): """"Evaluate the model""" self.model.eval() sum_loss = 0 sum_correct = 0 # Loop over batches n = 0 for i, (batch_input, batch_target) in enumerate(data_loader): batch_input = batch_input.to(self.device) if self.loss=='BCE' and batch_target.dim()==1: batch_target = batch_target.float().unsqueeze(1) batch_target = batch_target.to(self.device) batch_output = self.model(batch_input) loss = self.loss_func(batch_output, batch_target).item() sum_loss += loss n_correct = self.accuracy(batch_output, batch_target, **kwargs) sum_correct += n_correct # if mode=='Testing': # os.makedirs('results',exist_ok=True) # for data_input,data_output in zip(batch_target,batch_output): # figures.plot_test_2d(data_input,data_output,'results/%05i'%n) # n+=1 valid_loss = sum_loss / (i + 1) valid_acc = sum_correct / len(data_loader.sampler) logging.info('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}' .format(mode, len(data_loader.sampler), valid_loss, 100*valid_acc)) mode = 'test' if mode=='Testing' else 'valid' return {'%s_loss'%mode:valid_loss, '%s_acc'%mode:valid_acc}
[docs] def accuracy(self, batch_output, batch_target, acc_tol=20, **kwargs): # Count number of correct predictions if self.loss=='MSE': #batch_preds = torch.round(batch_output) batch_preds = batch_output #n_correct = batch_preds.eq(batch_target).float().mean(dim=1).sum().item() #n_correct = batch_preds.sub(batch_target).abs().lt(acc_tol).float().mean(dim=1).sum().item() n_correct = batch_target.sub(batch_preds).square().div(batch_preds.square()).sqrt().mul(100).lt(acc_tol).float().mean(dim=1).sum().item() elif self.loss=='BCE': batch_preds = (torch.sigmoid(batch_output)>0.5).float() if batch_preds.dim()==1: n_correct = batch_preds.eq(batch_target).float().sum() else: n_correct = batch_preds.eq(batch_target).all(dim=1).float().sum() else: _, batch_preds = torch.max(batch_output, 1) n_correct = batch_preds.eq(batch_target).sum().item() return n_correct
def print_model_summary(self): """Override as needed""" logging.info( 'Model: \n%s\nParameters: %i' % (self.model, sum(p.numel() for p in self.model.parameters())) ) def save_summary(self, summaries): """Save summary information""" for (key, val) in summaries.items(): summary_vals = self.summaries.get(key, []) self.summaries[key] = summary_vals + [val] def write_summaries(self): assert self.output_dir is not None summary_file = os.path.join(self.output_dir, 'summaries_%i.npz' % self.rank) logging.info('Saving summaries to %s' % summary_file) numpy.savez(summary_file, **self.summaries) def write_checkpoint(self, checkpoint_id): """Write a checkpoint for the model""" assert self.output_dir is not None checkpoint_dir = os.path.join(self.output_dir, 'checkpoints') checkpoint_file = 'model_checkpoint_%03i_%i.pth.tar' % (checkpoint_id,self.rank) os.makedirs(checkpoint_dir, exist_ok=True) torch.save(dict(model=self.model.state_dict()),os.path.join(checkpoint_dir, checkpoint_file))
def get_trainer(**kwargs): """ Test """ return GenericTrainer(**kwargs)