Source code for optibeam.training

from .utils import *
from .evaluation import *
from .visualization import *

from tensorflow.keras.callbacks import Callback
from sklearn.model_selection import train_test_split
from IPython.display import clear_output
import matplotlib.pyplot as plt
import tensorflow as tf
import json
from datetime import datetime

# ------------------- callback functions for tensorflow fit -------------------
[docs] class PlotPredictionParamsCallback(Callback): """ Callback to plot the true and predicted beam parameters on a random validation image input: val_images (np.array, (m, n)), val_labels (np.array, (l, 1), normalized), val_beam_image (np.array, (p, q)) """ def __init__(self, val_images, val_labels, val_beam_images): super().__init__() self.val_images = val_images self.val_labels = val_labels self.val_beam_images = val_beam_images
[docs] def on_epoch_begin(self, epoch, logs=None): clear_output(wait=True) # Randomly select an image from the validation set idx = np.random.randint(0, len(self.val_images)) image = self.val_beam_images[idx] true_label = self.val_labels[idx] # Predict the output using the current model state pred_label = self.model.predict(self.val_images[idx][np.newaxis, :])[0] img_2_params_evaluation(image, true_label, pred_label)
[docs] class PlotPredictionImageCallback(Callback): """ Callback to plot the true and predicted images on a random validation image input: x_data (np.array, (n x m), speckle pattern), y_data (np.array, (p x q), beam image) """ def __init__(self, x_data, y_data): super(PlotPredictionImageCallback, self).__init__() self.x_data = x_data self.y_data = y_data
[docs] def on_epoch_end(self, epoch, logs=None, title = ['MMF Speckle Pattern (Input)', 'Original Beam Distribution (Ground Truth)', 'Reconstructed Image (Output)']): clear_output(wait=True) predictions = self.model.predict(self.x_data[tf.newaxis, ...], verbose=0) plt.figure(figsize=(15, 15)) display_list = [self.x_data.reshape(64, 64), self.y_data.reshape(32, 32), predictions.reshape(32, 32)] for i in range(3): # present the result in a nice visual way plt.subplot(1, 3, i+1) plt.title(title[i]) # Getting the pixel values in the [0, 1] range to plot. e.g. plt.imshow(display_list[i] * 0.5 + 0.5) plt.imshow(display_list[i], cmap='Greys_r') plt.axis('off') plt.show()
# ------------------- dataset preparation -------------------
[docs] def clean_tensor(narray): """ Discard some problematic images based on beam parameters calculation. In future, need to develop a better evaluation function (beam_params) to handle this properly? """ labels = list(beam_params(narray, normalize=True).values()) for i in labels: if i >= 1 or i <=0: return None, None return narray, labels
[docs] def split_dataset(data, labels, proportion=(8, 1, 1)): """ split dataset, Tensorflow only dimension: (n, 2, width, hight, channel), (n, beam parameters NO.) """ total = sum(proportion) prop_test = proportion[2] / total prop_val = proportion[1] / (total - proportion[2]) train_val, test, labels_train_val, labels_test = train_test_split( data, labels, test_size=prop_test, random_state=42) train, val, labels_train, labels_val = train_test_split( train_val, labels_train_val, test_size=prop_val, random_state=42) print("-"*50) print(f'train set shape: {train.shape}') print(f'train label shape: {labels_train.shape}') print(f'validation set shape: {val.shape}') print(f'validation label shape: {labels_val.shape}') print(f'test set shape: {test.shape}') print(f'test label shape: {labels_test.shape}') print("-"*50) return {'x_train' : train, 'label_train' : labels_train, 'x_val' : val, 'label_val' : labels_val, 'x_test' : test, 'label_test' : labels_test}
[docs] def seperate_img(data): """ temp functions for split orignal beam image and speckle pattern for later callback function use assume data consists of both beam image and speckle pattern """ new_data = np.transpose(data, (1, 0, 2, 3, 4)) return new_data[0], new_data[1] # beam image, speckle pattern
# ------------------- experiment logs -------------------
[docs] class Logger: """ Create folder and a log file in the specified directory, containing the experiment details (snapshot). After training, save the log content in the log file under the log directory. """ def __init__(self, log_dir, model=None, dataset=None, history=None, info=''): self.log_dir = os.path.join(log_dir, datetime.now().strftime("%Y-%m-%d_" + info)) self.model = model self.dataset = dataset self.history = history self.log_file = os.path.join(self.log_dir, 'log.json') self.log_content = {'info' : info, 'experiment_date' : datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'dataset_info': None, 'model_info': None, 'training_info': None} self.update()
[docs] def update(self): if self.dataset is not None: self.register_dataset() if self.model is not None: self.register_model() if self.history is not None: self.register_training()
[docs] def register_extra(self, extra_info): self.log_content['extra_info'] = extra_info
[docs] def register_dataset(self): if isinstance(self.dataset, np.ndarray): self.log_content['dataset_info'] = {'dataset_shape': str(self.dataset.shape), 'dataset_dtype': str(self.dataset.dtype), 'dataset_mean': str(np.mean(self.dataset)), 'dataset_std': str(np.std(self.dataset)), 'dataset_min': str(np.min(self.dataset)), 'dataset_max': str(np.max(self.dataset))}
[docs] def register_model(self): if isinstance(self.model, tf.keras.models.Model): self.log_content['model_info'] = self.tf_model_summary()
[docs] def register_training(self): os_info = get_system_info() if isinstance(self.model, tf.keras.models.Model): compiled_info = { 'loss': self.model.loss, 'optimizer': type(self.model.optimizer).__name__, 'optimizer_config': {k:str(v) for k,v in self.model.optimizer.get_config().items()}, 'metrics': [m.name for m in self.model.metrics] } self.log_content['training_info'] = {'os_info': os_info, 'compiled_info': compiled_info, 'epoch': len(self.history.epoch), 'training_history': self.history.history } compiled_info['tensorflow_version'] = tf.__version__
[docs] def tf_model_summary(self): summary = [] self.model.summary(print_fn=lambda x: summary.append(x)) return summary
[docs] def log_parse(self): pass
[docs] def save(self): if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) with open(self.log_file, 'w') as f: json.dump(self.log_content, f, indent=4) return self.log_file