mldas.explore.single.suplearn_simple ¶
- mldas.explore.single. suplearn_simple ( model , criterion , optimizer , train_loader , test_loader , epochs = 1 , print_every = 1 , save_model = False , verbose = True ) [source] ¶
 - 
          
Simple, non-optimized, supervised training with validation step performed at regular intervals during batch iteration for single node, single processor execution.
- Parameters
 - 
            
- 
              
               model
              
              
               
                
torch.nn.Module - 
              
Trained model
 - 
              
               criterion
              
              
               e.g.
               
                
torch.nn.CrossEntropyLoss - 
              
Loss function
 - 
              
               optimizer
              
              
               
                
torch.optim.Optimizer - 
              
Optimizer to perform gradient descent
 - 
              
               train_loader
              
              
               
                
torch.utils.data.DataLoader - 
              
Input dataset for training part
 - 
              
               test_loader
              
              
               
                
torch.utils.data.DataLoader - 
              
Input dataset for validation step
 - 
              
               epochs
              
              
               
                
int - 
              
Number of epochs to execute the training
 - 
              
               print_every
              
              
               
                
int - 
              
Batch interval at which both training/validation loss and accuracy are evaluated
 - 
              
               save_model
              
              
               
                
bool - 
              
Save updated model in dictionary
 
 - 
              
               model
              
              
               
                
 - Returns
 - 
            
- 
              
               loss_hist
              
              
               
                
numpy.ndarray - 
              
History of loss values
 - 
              
               model
              
              
               
                
torch.nn.Moduleordict - 
              
Final trained model or dictionary of models.
 
 - 
              
               loss_hist