"""This module defines a generic trainer for simple models and datasets."""# Localsimportosimporttimeimportlogging# Externalsimportnumpyimporttorchfromtorchimportnnfromtorch.nn.parallelimportDistributedDataParallel# Localsfrom..modelsimportget_modelfrom..exploreimportfigures
[docs]classGenericTrainer():"""Trainer code for basic classification problems."""
[docs]defbuild_model(self,device_ids,distributed=False,loss='CE',optimizer='SGD',lr=0.01,lr_decay_epoch=[],lr_decay_ratio=0.5,momentum=0.9,**model_args):"""Instantiate our model"""self.loss=lossself.lr_decay_ratio=lr_decay_ratioself.lr_decay_epoch=lr_decay_epoch# Construct the modelself.model=get_model(**model_args).to(self.device)# Distributed data parallelismifdistributed:self.model=DistributedDataParallel(self.model,device_ids=device_ids)# TODO: add support for more optimizers and loss functions hereopt_type=dict(SGD=torch.optim.SGD)[optimizer]self.optimizer=opt_type(self.model.parameters(),lr=lr,momentum=momentum)loss_type=dict(CE=torch.nn.CrossEntropyLoss,BCE=torch.nn.BCEWithLogitsLoss,MSE=torch.nn.MSELoss)[loss]self.loss_func=loss_type()
defexp_lr_scheduler(self,optimizer):""" Decay learning rate by a factor of lr_decay """forparam_groupinoptimizer.param_groups:param_group['lr']*=self.lr_decay_ratioreturnoptimizerdeftrain(self,train_data_loader,epochs,valid_data_loader=None,test_data_loader=None,**kwargs):"""Run the model training"""# Loop over epochsforiinrange(epochs):ifi+1inself.lr_decay_epoch:self.optimizer=self.exp_lr_scheduler(self.optimizer)logging.info(' EPOCH {:>3}/{:<3} | Model initial sumw: {:.5f}'.format(i+1,epochs,sum(p.sum()forpinself.model.parameters())))summary=dict(epoch=i)# Train on this epochstart_time=time.time()summary.update(self.train_epoch(train_data_loader,**kwargs))summary['train_time']=time.time()-start_timesummary['train_samples']=len(train_data_loader.sampler)summary['train_rate']=summary['train_samples']/summary['train_time']# Evaluate on this epochifvalid_data_loaderisnotNone:start_time=time.time()summary.update(self.evaluate(valid_data_loader,'Validation',**kwargs))summary['valid_time']=time.time()-start_timesummary['valid_samples']=len(valid_data_loader.sampler)summary['valid_rate']=summary['valid_samples']/summary['valid_time']# Save summary, checkpointself.save_summary(summary)# if self.output_dir is not None and self.rank==0:self.write_checkpoint(checkpoint_id=i)# Evaluate on this epochiftest_data_loaderisnotNone:summary.update(self.evaluate(test_data_loader,'Testing',**kwargs))# Save summary, checkpointself.save_summary(summary)returnself.summaries
[docs]deftrain_epoch(self,data_loader,rounded=False,**kwargs):"""Train for one epoch"""self.model.train()sum_loss=0sum_correct=0# Loop over training batchesfori,(batch_input,batch_target)inenumerate(data_loader):batch_input=batch_input.to(self.device)ifself.loss=='BCE'andbatch_target.dim()==1:batch_target=batch_target.float().unsqueeze(1)batch_target=batch_target.to(self.device)self.model.zero_grad()batch_output=self.model(batch_input)ifrounded:batch_output=batch_output.round()batch_loss=self.loss_func(batch_output,batch_target)batch_loss.backward()self.optimizer.step()loss=batch_loss.item()sum_loss+=lossn_correct=self.accuracy(batch_output,batch_target,**kwargs)sum_correct+=n_correctlogging.debug(' batch {:>3}/{:<3} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'.format(i+1,len(data_loader),len(batch_input),loss,100*n_correct/len(batch_input)))train_loss=sum_loss/(i+1)train_acc=sum_correct/len(data_loader.sampler)logging.info('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'.format('Training',len(data_loader.sampler),train_loss,100*train_acc))returndict(train_loss=train_loss)
[docs]@torch.no_grad()defevaluate(self,data_loader,mode,rounded=False,**kwargs):""""Evaluate the model"""self.model.eval()sum_loss=0sum_correct=0# Loop over batchesn=0fori,(batch_input,batch_target)inenumerate(data_loader):batch_input=batch_input.to(self.device)ifself.loss=='BCE'andbatch_target.dim()==1:batch_target=batch_target.float().unsqueeze(1)batch_target=batch_target.to(self.device)batch_output=self.model(batch_input)loss=self.loss_func(batch_output,batch_target).item()sum_loss+=lossn_correct=self.accuracy(batch_output,batch_target,**kwargs)sum_correct+=n_correct# if mode=='Testing':# os.makedirs('results',exist_ok=True)# for data_input,data_output in zip(batch_target,batch_output):# figures.plot_test_2d(data_input,data_output,'results/%05i'%n)# n+=1valid_loss=sum_loss/(i+1)valid_acc=sum_correct/len(data_loader.sampler)logging.info('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'.format(mode,len(data_loader.sampler),valid_loss,100*valid_acc))mode='test'ifmode=='Testing'else'valid'return{'%s_loss'%mode:valid_loss,'%s_acc'%mode:valid_acc}
[docs]defaccuracy(self,batch_output,batch_target,acc_tol=20,**kwargs):# Count number of correct predictionsifself.loss=='MSE':#batch_preds = torch.round(batch_output)batch_preds=batch_output#n_correct = batch_preds.eq(batch_target).float().mean(dim=1).sum().item()#n_correct = batch_preds.sub(batch_target).abs().lt(acc_tol).float().mean(dim=1).sum().item()n_correct=batch_target.sub(batch_preds).square().div(batch_preds.square()).sqrt().mul(100).lt(acc_tol).float().mean(dim=1).sum().item()elifself.loss=='BCE':batch_preds=(torch.sigmoid(batch_output)>0.5).float()ifbatch_preds.dim()==1:n_correct=batch_preds.eq(batch_target).float().sum()else:n_correct=batch_preds.eq(batch_target).all(dim=1).float().sum()else:_,batch_preds=torch.max(batch_output,1)n_correct=batch_preds.eq(batch_target).sum().item()returnn_correct
defprint_model_summary(self):"""Override as needed"""logging.info('Model: \n%s\nParameters: %i'%(self.model,sum(p.numel()forpinself.model.parameters())))defsave_summary(self,summaries):"""Save summary information"""for(key,val)insummaries.items():summary_vals=self.summaries.get(key,[])self.summaries[key]=summary_vals+[val]defwrite_summaries(self):assertself.output_dirisnotNonesummary_file=os.path.join(self.output_dir,'summaries_%i.npz'%self.rank)logging.info('Saving summaries to %s'%summary_file)numpy.savez(summary_file,**self.summaries)defwrite_checkpoint(self,checkpoint_id):"""Write a checkpoint for the model"""assertself.output_dirisnotNonecheckpoint_dir=os.path.join(self.output_dir,'checkpoints')checkpoint_file='model_checkpoint_%03i_%i.pth.tar'%(checkpoint_id,self.rank)os.makedirs(checkpoint_dir,exist_ok=True)torch.save(dict(model=self.model.state_dict()),os.path.join(checkpoint_dir,checkpoint_file))
defget_trainer(**kwargs):""" Test """returnGenericTrainer(**kwargs)