mldas.explore.single.suplearn_simple

mldas.explore.single. suplearn_simple ( model , criterion , optimizer , train_loader , test_loader , epochs = 1 , print_every = 1 , save_model = False , verbose = True ) [source]

Simple, non-optimized, supervised training with validation step performed at regular intervals during batch iteration for single node, single processor execution.

Parameters
model torch.nn.Module

Trained model

criterion e.g. torch.nn.CrossEntropyLoss

Loss function

optimizer torch.optim.Optimizer

Optimizer to perform gradient descent

train_loader torch.utils.data.DataLoader

Input dataset for training part

test_loader torch.utils.data.DataLoader

Input dataset for validation step

epochs int

Number of epochs to execute the training

print_every int

Batch interval at which both training/validation loss and accuracy are evaluated

save_model bool

Save updated model in dictionary

Returns
loss_hist numpy.ndarray

History of loss values

model torch.nn.Module or dict

Final trained model or dictionary of models.