# Systems
import os
import glob
import random
# Externals
import h5py
import numpy
import torch
from collections import OrderedDict
from torchvision import transforms
from PIL import Image
[docs]def hdf5read(fname,key=['dsi30','data','variable']):
Read input raw DAS file. The file is supposed to have an HDF5 format and the
actual data stored in a group that have one of the following names: ``dsi30``,
``data`` or ``variable``.
fname : :py:class:`str`
Full path to the HDF5 file
key : :py:class:`str` or :py:class:`list`
Group's key name where the data are stored. If not specify, a list of keys
will be looped over.
data : :py:class:`numpy.ndarray`
Raw data
f = h5py.File(fname,'r')
if type(key)==str:
data = f[f[key][0,0]]
assert any(keyname in f.keys() for keyname in key), 'No acceptable key found in input HDF5 file.'
for keyname in f.keys():
if keyname in key:
data = f[f['%s/dat'%keyname][0,0]]
return data
[docs]def load_model(fname,model):
Load saved model's parameter dictionary to initialized model.
The function will remove any ``.module`` string from parameter's name.
fname : :py:class:`str`
Path to saved model
model : :py:class:`torch.nn.Module`
Initialized network network architecture
model : :py:class:`torch.nn.Module`
Up-to-date neural network model
checkpoint = torch.load(fname,map_location=lambda storage, loc: storage)
state_dict = checkpoint['model']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
checkpoint['model'] = new_state_dict
return model
[docs]def save_data(data, dir_dst, fname):
Save raw data region as JPG image.
data : :py:class:`numpy.ndarray`
Input raw data
dir_dst : :py:class:`str`
Path to save image
fname : :py:class:`str`
Output file name
img = (data-data.min())/(data.max()-data.min())
img = Image.fromarray(numpy.uint8(img*255))
[docs]def load_image(data,rgb=False,to_numpy=False,squeeze=False):
Load input data as tensor image. Input data can be either in the form
of JPG image or raw data region and will be converted into numpy.uint8
format, then either RGB or grayscale image.
data : :py:class:`str` or :py:class:`numpy.ndarray`
Input data, either path saved image, or raw data array.
rgb : :py:class:`bool`
Convert data to RGB data image
to_numpy : :py:class:`bool`
Convert data to numpy array
squeeze : :py:class:`bool`
Squeeze data to remove any dimensions equal to 1.
if type(data)==str:
data = os.path.expandvars(data)
image = Image.open(data)
image = (data-data.min())/(data.max()-data.min())
image = Image.fromarray(numpy.uint8(image*255))
if rgb:
image = transforms.ToTensor()(image.convert("RGB")).view(1,3,*image.size)
image = transforms.ToTensor()(image.convert("L")).view(1,1,*image.size)
if squeeze:
image = torch.squeeze(image)
if to_numpy:
image = image.numpy()
return image
[docs]def load_bulk(dname,size,rgb=False,to_numpy=False,labeled=True):
Load multiple images from directory.
dname : :py:class:`str`
Path to directory where images are saved.
size : :py:class:`int`
Number of images to be loaded
rgb : :py:class:`bool`
Convert data to RGB data image
to_numpy : :py:class:`bool`
Convert data to numpy array
labeled : :py:class:`bool`
Whether images are saved in a labeled structure (same structure than for
:py:class:`torchvision.datasets.ImageFolder` class) or directly in target repository.
tensors : :py:class:`torch.Tensor` or :py:class:`numpy.ndarray`
Output list of loaded images either in tensor or numpy array formats.
labels : :py:class:`torch.Tensor` or :py:class:`numpy.ndarray`
Output list of labels either in tensor or numpy array formats.
fname = os.path.expandvars(dname)
if labeled:
all_labels = [label for label in os.listdir(fname+'/train/') if label!='.DS_Store']
assert size%len(all_labels)==0, 'Requested bulk size not multiple of number of labels.'
labels, tensors = [], []
for i,label in enumerate(all_labels):
file_list = glob.glob('%s/*/%s/*.jpg'%(dname,label))
for fname in random.sample(file_list,size//len(all_labels)):
labels += ([i]*(size//len(all_labels)))
tensors = torch.stack(tensors).squeeze(dim=1)
if to_numpy:
tensors = torch.squeeze(tensors).numpy()
idxs = random.sample(range(size), size)
return tensors[idxs], numpy.array(labels)[idxs]
tensors = []
file_list = glob.glob('%s/*.jpg'%dname)
for fname in random.sample(file_list,size):
tensors = torch.stack(tensors).squeeze(dim=1)
if to_numpy:
tensors = torch.squeeze(tensors).numpy()
return tensors,numpy.array([0]*(size))