Source code for optibeam.visualization

from .utils import *
from io import BytesIO
from sklearn.decomposition import PCA
from moviepy.editor import ImageSequenceClip
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import matplotlib.patches
import seaborn as sns


# ------------------- Plot evaluation -------------------

[docs] def plot_prediction_comparison(real : np.array, predicted : np.array, param_name=''): # Assuming 'real' and 'predicted' are the lists containing the actual and predicted ages # Scatter Plot of Predicted vs. Actual Ages plt.figure(figsize=(14, 5)) plt.subplot(1, 3, 1) plt.scatter(real, predicted, alpha=0.6, s=3) plt.plot([real.min(), real.max()], [real.min(), real.max()], 'k--', lw=2) # Diagonal line plt.title(f'Predicted vs. Actual {param_name}') plt.xlabel(f'Actual {param_name}') plt.ylabel(f'Predicted {param_name}') plt.grid(True) # Residual Plot plt.subplot(1, 3, 2) residuals = predicted - real plt.scatter(predicted, residuals, alpha=0.6, s=3) plt.title(f'{param_name} Residual Plot') plt.xlabel(f'Predicted {param_name}') plt.ylabel('Residuals') plt.axhline(y=0, color='k', linestyle='--') plt.grid(True) # Histogram or Density Plot of Prediction Errors plt.subplot(1, 3, 3) ax = sns.histplot([i*100 for i in residuals], kde=True) # ax.lines[0].set_color('orange') plt.title('Histogram of Percentage Prediction Errors') plt.xlabel('Prediction Error (%)') plt.ylabel('Frequency') plt.grid(True) # plt.xlim(-100, 100) # Limiting x-axis to -100% to 100% for clearer visualization plt.tight_layout() plt.show()
# ------------------- PCA -------------------
[docs] class visualPCA: def __init__(self, n_components=3): self.pca = PCA(n_components=n_components) self.pc = None
[docs] def fit(self, data : np.array): # narray of images (flattened) self.pc = self.pca.fit_transform(data)
[docs] def plot_2d(self): plt.scatter(self.pc[:, 0], self.pc[:, 1], s=2) plt.xlabel('PC 1') plt.ylabel('PC 2') plt.show()
[docs] def plot_3d(self): # Create a 3D scatter plot using Plotly fig = go.Figure(data=[go.Scatter3d( x=self.pc[:, 0], y=self.pc[:, 1], z=self.pc[:, 2], mode='markers', marker=dict( size=2, color=self.pc[:, 2], # Set color to the third principal component colorscale='Viridis', # Color scale opacity=0.8 ) )]) # Update the layout of the plot for better visualization fig.update_layout( margin=dict(l=0, r=0, b=0, t=0), scene=dict( xaxis_title='PC 1', yaxis_title='PC 2', zaxis_title='PC 3' ) ) fig.show()
[docs] def plot_to_memory(self, angle : int) -> BytesIO: fig = plt.figure(figsize=(10, 7)) # Increase figure size for larger output ax = fig.add_subplot(111, projection='3d') scatter = ax.scatter(self.pc[:, 0], self.pc[:, 1], self.pc[:, 2], c=self.pc[:, 2], cmap='viridis', s=2) ax.set_xlabel('PC 1') ax.set_ylabel('PC 2') ax.set_zlabel('PC 3') ax.view_init(elev=20., azim=angle) # Adjust camera angle buf = BytesIO() plt.savefig(buf, format='png', dpi=150) # Specify DPI for higher resolution buf.seek(0) # Seek to the start of the BytesIO buffer plt.close(fig) # Close the figure to free memory return buf
[docs] def create_gif(self, save_to : str, start_angle=0, end_angle=89, nums=60, fps=30, reverse=True): image_buffers = [self.plot_to_memory(a) for a in np.linspace(start_angle, end_angle, nums)] images = [Image.open(image_buffer) for image_buffer in image_buffers] if reverse: images = images + images[::-1] clips = [np.array(image) for image in images] clip = ImageSequenceClip(clips, fps=fps) clip.write_gif(save_to + '/sample.gif')
# ------------------- plot image -------------------
[docs] def plot_narray(narray_img, channel=1): """ Plot a 2D NumPy array as an image. Parameters: narray_img (np.ndarray): A 2D NumPy array to plot as an image. """ if np.max(narray_img) <= 1: narray_img = (narray_img * 255).astype(np.uint8) if len(narray_img.shape) == 2: if channel == 1: plt.imshow(narray_img, cmap='gray') # cmap='gray' sets the colormap to grayscale else: plt.imshow(narray_img) plt.colorbar() # Add a color bar to show intensity scale plt.title('2D Array Image') plt.xlabel('X-axis') plt.ylabel('Y-axis') plt.show() else: plt.imshow(narray_img) plt.axis('off') plt.show()
[docs] def img_2_params_evaluation(image, true_label, pred_label): fig, ax = plt.subplots() ax.imshow(image.squeeze(), cmap='gray') # Display the image # Calculate normalized coordinates based on image dimensions # These are used for plotting the centroids and ellipses true_x = true_label[0] * image.shape[1] true_y = true_label[1] * image.shape[0] pred_x = pred_label[0] * image.shape[1] pred_y = pred_label[1] * image.shape[0] # Plot centroids with more professional styling ax.plot(true_x, true_y, 'o', markersize=3, markeredgecolor='blue', markerfacecolor='none', label='True Centroid') ax.plot(pred_x, pred_y, '^', markersize=3, markeredgecolor='darkred', markerfacecolor='none', label='Predicted Centroid') # Plot ellipses with professional style true_ellipse = matplotlib.patches.Ellipse((true_x, true_y), width=true_label[2] * image.shape[1] * 2, height=true_label[3] * image.shape[0] * 2, edgecolor='blue', facecolor='none', linewidth=1, linestyle='--', label='True Widths') ax.add_patch(true_ellipse) pred_ellipse = matplotlib.patches.Ellipse((pred_x, pred_y), width=pred_label[2] * image.shape[1] * 2, height=pred_label[3] * image.shape[0] * 2, edgecolor='darkred', facecolor='none', linewidth=1, linestyle='--', label='Predicted Widths') ax.add_patch(pred_ellipse) # Set labels and title with normalized axis labels ax.set_xlabel('Normalized Horizontal Position') ax.set_ylabel('Normalized Vertical Position') #ax.set_title('img2params model\'s prediction on a random testset sample', pad=20) # Improve the granularity of axis labels num_ticks = 10 # More ticks for better granularity tick_values = np.linspace(0, 1, num_ticks) tick_labels = [f"{x:.1f}" for x in tick_values] ax.set_xticks(tick_values * image.shape[1]) ax.set_xticklabels(tick_labels) ax.set_yticks(tick_values * image.shape[0]) ax.set_yticklabels(tick_labels) plt.legend() plt.show()