from .utils import *
from .evaluation import *
from .visualization import *
from tensorflow.keras.callbacks import Callback
from sklearn.model_selection import train_test_split
from IPython.display import clear_output
import matplotlib.pyplot as plt
import tensorflow as tf
import json
from datetime import datetime
# ------------------- callback functions for tensorflow fit -------------------
[docs]
class PlotPredictionParamsCallback(Callback):
"""
Callback to plot the true and predicted beam parameters on a random validation image
input: val_images (np.array, (m, n)), val_labels (np.array, (l, 1), normalized), val_beam_image (np.array, (p, q))
"""
def __init__(self, val_images, val_labels, val_beam_images):
super().__init__()
self.val_images = val_images
self.val_labels = val_labels
self.val_beam_images = val_beam_images
[docs]
def on_epoch_begin(self, epoch, logs=None):
clear_output(wait=True)
# Randomly select an image from the validation set
idx = np.random.randint(0, len(self.val_images))
image = self.val_beam_images[idx]
true_label = self.val_labels[idx]
# Predict the output using the current model state
pred_label = self.model.predict(self.val_images[idx][np.newaxis, :])[0]
img_2_params_evaluation(image, true_label, pred_label)
[docs]
class PlotPredictionImageCallback(Callback):
"""
Callback to plot the true and predicted images on a random validation image
input: x_data (np.array, (n x m), speckle pattern), y_data (np.array, (p x q), beam image)
"""
def __init__(self, x_data, y_data):
super(PlotPredictionImageCallback, self).__init__()
self.x_data = x_data
self.y_data = y_data
[docs]
def on_epoch_end(self, epoch, logs=None, title = ['MMF Speckle Pattern (Input)',
'Original Beam Distribution (Ground Truth)',
'Reconstructed Image (Output)']):
clear_output(wait=True)
predictions = self.model.predict(self.x_data[tf.newaxis, ...], verbose=0)
plt.figure(figsize=(15, 15))
display_list = [self.x_data.reshape(64, 64), self.y_data.reshape(32, 32),
predictions.reshape(32, 32)]
for i in range(3): # present the result in a nice visual way
plt.subplot(1, 3, i+1)
plt.title(title[i])
# Getting the pixel values in the [0, 1] range to plot. e.g. plt.imshow(display_list[i] * 0.5 + 0.5)
plt.imshow(display_list[i], cmap='Greys_r')
plt.axis('off')
plt.show()
# ------------------- dataset preparation -------------------
[docs]
def clean_tensor(narray):
"""
Discard some problematic images based on beam parameters calculation.
In future, need to develop a better evaluation function (beam_params) to handle this properly?
"""
labels = list(beam_params(narray, normalize=True).values())
for i in labels:
if i >= 1 or i <=0:
return None, None
return narray, labels
[docs]
def split_dataset(data, labels, proportion=(8, 1, 1)):
"""
split dataset, Tensorflow only
dimension: (n, 2, width, hight, channel), (n, beam parameters NO.)
"""
total = sum(proportion)
prop_test = proportion[2] / total
prop_val = proportion[1] / (total - proportion[2])
train_val, test, labels_train_val, labels_test = train_test_split(
data, labels, test_size=prop_test, random_state=42)
train, val, labels_train, labels_val = train_test_split(
train_val, labels_train_val, test_size=prop_val, random_state=42)
print("-"*50)
print(f'train set shape: {train.shape}')
print(f'train label shape: {labels_train.shape}')
print(f'validation set shape: {val.shape}')
print(f'validation label shape: {labels_val.shape}')
print(f'test set shape: {test.shape}')
print(f'test label shape: {labels_test.shape}')
print("-"*50)
return {'x_train' : train, 'label_train' : labels_train,
'x_val' : val, 'label_val' : labels_val,
'x_test' : test, 'label_test' : labels_test}
[docs]
def seperate_img(data):
"""
temp functions for split orignal beam image and speckle pattern for later callback function use
assume data consists of both beam image and speckle pattern
"""
new_data = np.transpose(data, (1, 0, 2, 3, 4))
return new_data[0], new_data[1] # beam image, speckle pattern
# ------------------- experiment logs -------------------
[docs]
class Logger:
"""
Create folder and a log file in the specified directory, containing the experiment details (snapshot).
After training, save the log content in the log file under the log directory.
"""
def __init__(self, log_dir, model=None, dataset=None, history=None, info=''):
self.log_dir = os.path.join(log_dir, datetime.now().strftime("%Y-%m-%d_" + info))
self.model = model
self.dataset = dataset
self.history = history
self.log_file = os.path.join(self.log_dir, 'log.json')
self.log_content = {'info' : info,
'experiment_date' : datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'dataset_info': None,
'model_info': None,
'training_info': None}
self.update()
[docs]
def update(self):
if self.dataset is not None:
self.register_dataset()
if self.model is not None:
self.register_model()
if self.history is not None:
self.register_training()
[docs]
def register_dataset(self):
if isinstance(self.dataset, np.ndarray):
self.log_content['dataset_info'] = {'dataset_shape': str(self.dataset.shape),
'dataset_dtype': str(self.dataset.dtype),
'dataset_mean': str(np.mean(self.dataset)),
'dataset_std': str(np.std(self.dataset)),
'dataset_min': str(np.min(self.dataset)),
'dataset_max': str(np.max(self.dataset))}
[docs]
def register_model(self):
if isinstance(self.model, tf.keras.models.Model):
self.log_content['model_info'] = self.tf_model_summary()
[docs]
def register_training(self):
os_info = get_system_info()
if isinstance(self.model, tf.keras.models.Model):
compiled_info = {
'loss': self.model.loss,
'optimizer': type(self.model.optimizer).__name__,
'optimizer_config': {k:str(v) for k,v in self.model.optimizer.get_config().items()},
'metrics': [m.name for m in self.model.metrics]
}
self.log_content['training_info'] = {'os_info': os_info,
'compiled_info': compiled_info,
'epoch': len(self.history.epoch),
'training_history': self.history.history
}
compiled_info['tensorflow_version'] = tf.__version__
[docs]
def tf_model_summary(self):
summary = []
self.model.summary(print_fn=lambda x: summary.append(x))
return summary
[docs]
def log_parse(self):
pass
[docs]
def save(self):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
with open(self.log_file, 'w') as f:
json.dump(self.log_content, f, indent=4)
return self.log_file