Deep Learning trainers

Base trainer class

class mldas.trainers.base. BaseTrainer ( output_dir = None , gpu = None , distributed = False , rank = 0 , ** kwargs ) [source]

Base class for PyTorch trainers. This implements the common training logic, logging of summaries, and checkpoints.

Methods Summary

__init__ ([output_dir, gpu, distributed, rank])

print_model_summary ()

Override as needed

save_summary (summaries)

Save summary information

write_summaries ()

write_checkpoint (checkpoint_id)

Write a checkpoint for the model

train (train_data_loader, epochs[, ...])

Run the model training

Methods Documentation

__init__ ( output_dir = None , gpu = None , distributed = False , rank = 0 , ** kwargs ) [source]
print_model_summary ( ) [source]

Override as needed

save_summary ( summaries ) [source]

Save summary information

write_summaries ( ) [source]
write_checkpoint ( checkpoint_id ) [source]

Write a checkpoint for the model

train ( train_data_loader , epochs , valid_data_loader = None , test_data_loader = None , ** kwargs ) [source]

Run the model training

Classifier training

class mldas.trainers.generic. GenericTrainer ( rank , device , output_dir = None , ** kwargs ) [source]

Trainer code for basic classification problems.

Methods Summary

__init__ (rank, device[, output_dir])

build_model (device_ids[, distributed, loss, ...])

Instantiate our model

train_epoch (data_loader[, rounded])

Train for one epoch

evaluate (data_loader, mode[, rounded])

"Evaluate the model

accuracy (batch_output, batch_target[, acc_tol])

Methods Documentation

__init__ ( rank , device , output_dir = None , ** kwargs ) [source]
build_model ( device_ids , distributed = False , loss = 'CE' , optimizer = 'SGD' , lr = 0.01 , lr_decay_epoch = [] , lr_decay_ratio = 0.5 , momentum = 0.9 , ** model_args ) [source]

Instantiate our model

train_epoch ( data_loader , rounded = False , ** kwargs ) [source]

Train for one epoch

evaluate ( data_loader , mode , rounded = False , ** kwargs ) [source]

“Evaluate the model

accuracy ( batch_output , batch_target , acc_tol = 20 , ** kwargs ) [source]