Deep Learning trainers ¶
Base trainer class ¶
- class mldas.trainers.base. BaseTrainer ( output_dir = None , gpu = None , distributed = False , rank = 0 , ** kwargs ) [source] ¶
-
Base class for PyTorch trainers. This implements the common training logic, logging of summaries, and checkpoints.
Methods Summary
__init__
([output_dir, gpu, distributed, rank])Override as needed
save_summary
(summaries)Save summary information
write_checkpoint
(checkpoint_id)Write a checkpoint for the model
train
(train_data_loader, epochs[, ...])Run the model training
Methods Documentation
Classifier training ¶
- class mldas.trainers.generic. GenericTrainer ( rank , device , output_dir = None , ** kwargs ) [source] ¶
-
Trainer code for basic classification problems.
Methods Summary
__init__
(rank, device[, output_dir])build_model
(device_ids[, distributed, loss, ...])Instantiate our model
train_epoch
(data_loader[, rounded])Train for one epoch
evaluate
(data_loader, mode[, rounded])"Evaluate the model
accuracy
(batch_output, batch_target[, acc_tol])Methods Documentation