CLAY v0 - Interpolation between images#
import sys
sys.path.append("../")
import os
from pathlib import Path
import imageio
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from src.datamodule import ClayDataModule, ClayDataset
from src.model_clay import CLAYModule
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[2], line 4
1 import os
2 from pathlib import Path
----> 4 import imageio
5 import matplotlib.pyplot as plt
6 import numpy as np
ModuleNotFoundError: No module named 'imageio'
# data directory for all chips
DATA_DIR = "../data/02"
# path of best model checkpoint for Clay v0
CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
Load Model & DataModule#
# Load the model & set in eval mode
model = CLAYModule.load_from_checkpoint(
CKPT_PATH, mask_ratio=0.0, shuffle=False
) # No masking or shuffling of patches
model.eval();
data_dir = Path(DATA_DIR)
# Load the Clay DataModule
ds = ClayDataset(chips_path=list(data_dir.glob("**/*.tif")))
dm = ClayDataModule(data_dir=str(data_dir), batch_size=2)
dm.setup(stage="fit")
# Load the train DataLoader
trn_dl = iter(dm.train_dataloader())
# Load the first batch of chips
batch = next(trn_dl)
batch.keys()
batch["pixels"].shape, batch["latlon"].shape, batch["timestep"].shape
def show(sample, idx=None, save=False):
Path("animate").mkdir(exist_ok=True)
sample = rearrange(sample, "c h w -> h w c")
denorm_sample = sample * torch.as_tensor(dm.STD) + torch.as_tensor(dm.MEAN)
rgb = denorm_sample[..., [2, 1, 0]]
plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min()))
plt.axis("off")
if save:
plt.savefig(f"animate/chip_{idx}.png")
sample1, sample2 = batch["pixels"]
show(sample1)
show(sample2)
Each batch has chips of shape 13 x 512 x 512
, normalized lat
& lon
coords & normalized timestep information as year
, month
& day
.
# Save a copy of batch to visualize later
_batch = batch["pixels"].detach().clone().cpu().numpy()
Pass data through the CLAY model#
# Pass the pixels through the encoder & decoder of CLAY
with torch.no_grad():
# Move data from to the device of model
batch["pixels"] = batch["pixels"].to(model.device)
batch["timestep"] = batch["timestep"].to(model.device)
batch["latlon"] = batch["latlon"].to(model.device)
# Pass pixels, latlon, timestep through the encoder to create encoded patches
(
unmasked_patches,
unmasked_indices,
masked_indices,
masked_matrix,
) = model.model.encoder(batch)
Create an image based on interpolation of the embedding values between 2 images#
Images are saved inside ./animate
for idx, alpha in enumerate(np.linspace(0, 1, 20)):
patch_break = 128
l1, l2 = unmasked_patches
l3 = alpha * l1 + (1 - alpha) * l2
l4 = torch.vstack((l1[:patch_break, :], l2[patch_break:, :]))
# Pass the unmasked_patches through the decoder to reconstruct the pixel space
with torch.no_grad():
pixels = model.model.decoder(
rearrange(l3, "gl d -> 1 gl d"), unmasked_indices[[0]], masked_indices[[0]]
)
image = rearrange(pixels, "b c (h w) (p1 p2) -> b c (h p1) (w p2)", h=16, p1=32)
_image = image[0].detach().cpu()
show(_image, idx, save=True)
fig, axs = plt.subplots(2, 10, figsize=(20, 4))
for ax, idx in zip(axs.flatten(), range(20)):
ax.imshow(Image.open(f"./animate/chip_{idx}.png"))
ax.set_title(f"Seq {idx}")
ax.set_axis_off()
plt.tight_layout()
Create a GIF of the interpolation of images#
img_paths = [f"./animate/chip_{idx}.png" for idx in range(20)]
with imageio.get_writer("animate/sample.gif", mode="I", duration=100) as writer:
for img_path in img_paths:
img = imageio.imread(img_path)
writer.append_data(img)
# Delete the images
for img_path in img_paths:
os.remove(img_path)
from IPython.display import Image, display
display(Image(filename="./animate/sample.gif"))