Let us visualize how Clay reconstructs each sensor from 70% masked inputs

Let us visualize how Clay reconstructs each sensor from 70% masked inputs#

import sys
import warnings

sys.path.append("..")
warnings.filterwarnings(action="ignore")
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange

from src.datamodule import ClayDataModule
from src.model import ClayMAEModule
DATA_DIR = "/home/ubuntu/data"
CHECKPOINT_PATH = "../checkpoints/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt"
METADATA_PATH = "../configs/metadata.yaml"
CHIP_SIZE = 224
MASK_RATIO = 0.7  # 70% masking of inputs
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

MODEL#

Load the model with best checkpoint path and set it in eval mode. Set mask_ratio to 70% & shuffle to True.

module = ClayMAEModule.load_from_checkpoint(
    checkpoint_path=CHECKPOINT_PATH,
    metadata_path=METADATA_PATH,
    mask_ratio=0.7,
    shuffle=True,
)

module.eval();

DATAMODULE#

Load the ClayDataModule

# For model training, we stack chips from one sensor into batches of size 128.
# This reduces the num_workers we need to load the batches and speeds up the
# training process. Here, although the batch size is 1, the data module reads
# batch of size 128.
dm = ClayDataModule(
    data_dir=DATA_DIR,
    metadata_path=METADATA_PATH,
    size=CHIP_SIZE,
    batch_size=1,
    num_workers=1,
)
dm.setup(stage="fit")
Total number of chips: 193

Let us look at the data directory.

We have a folder for each sensor, i.e:

  • Landsat l1

  • Landsat l2

  • Sentinel 1 rtc

  • Sentinel 2 l2a

  • Naip

  • Linz

And, under each folder, we have stacks of chips as .npz files.

!tree -L 1 {DATA_DIR}
/home/ubuntu/data
├── landsat-c2l1
├── landsat-c2l2-sr
├── linz
├── naip
├── sentinel-1-rtc
└── sentinel-2-l2a

6 directories, 0 files
!tree -L 2 {DATA_DIR}/naip | head -5
/home/ubuntu/data/naip
├── cube_10.npz
├── cube_100045.npz
├── cube_100046.npz
├── cube_100072.npz

Now, lets look at what we have in each of the .npz files.

sample = np.load("/home/ubuntu/data/naip/cube_10.npz")
sample.keys()
KeysView(NpzFile '/home/ubuntu/data/naip/cube_10.npz' with keys: pixels, lon_norm, lat_norm, week_norm, hour_norm)
sample["pixels"].shape
(128, 4, 256, 256)
sample["lat_norm"].shape, sample["lon_norm"].shape
((128, 2), (128, 2))
sample["week_norm"].shape, sample["hour_norm"].shape
((128, 2), (128, 2))

As we see above, chips are stacked in batches of size 128.
The sample we are looking at is from NAIP so it has 4 bands & of size 256 x 256.
We also get normalized lat/lon & timestep (hour/week) information that is *(optionally required) by the model. If you don’t have this handy, feel free to pass zero tensors in their place.

Load a batch of data from ClayDataModule

ClayDataModule is designed to fetch random batches of data from different sensors sequentially, i.e batches are in ascending order of their directory - Landsat 1, Landsat 2, LINZ, NAIP, Sentinel 1 rtc, Sentinel 2 L2A and it repeats after that.

# We have a random sample subset of the data, so it's
# okay to use either the train or val dataloader.
dl = iter(dm.train_dataloader())
l1 = next(dl)
l2 = next(dl)
linz = next(dl)
naip = next(dl)
s1 = next(dl)
s2 = next(dl)
for sensor, chips in zip(
    ("l1", "l2", "linz", "naip", "s1", "s2"), (l1, l2, linz, naip, s1, s2)
):
    print(
        f"{chips['platform'][0]:<15}",
        chips["pixels"].shape,
        chips["time"].shape,
        chips["latlon"].shape,
    )
landsat-c2l1    torch.Size([128, 6, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])
landsat-c2l2-sr torch.Size([128, 6, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])
linz            torch.Size([128, 3, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])
naip            torch.Size([128, 4, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])
sentinel-1-rtc  torch.Size([128, 2, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])
sentinel-2-l2a  torch.Size([128, 10, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])

INPUT#

Model expects a dictionary with keys:

  • pixels: batch x band x height x width - normalized chips of a sensor

  • time: batch x 4 - horizontally stacked week_norm & hour_norm

  • latlon: batch x 4 - horizontally stacked lat_norm & lon_norm

  • waves: list[:band] - wavelengths of each band of the sensor from the metadata.yaml file

  • gsd: scalar - gsd of the sensor from metadata.yaml file

Normalization & stacking is taken care of by the ClayDataModule: Clay-foundation/model

When not using the ClayDataModule, make sure you normalize the chips & pass all items for the model.

def create_batch(chips, wavelengths, gsd, device):
    batch = {}

    batch["pixels"] = chips["pixels"].to(device)
    batch["time"] = chips["time"].to(device)
    batch["latlon"] = chips["latlon"].to(device)

    batch["waves"] = torch.tensor(wavelengths)
    batch["gsd"] = torch.tensor(gsd)

    return batch

FORWARD PASS#

Pass a batch from single sensor through the Encoder & Decoder of Clay.

# Let us see an example of what input looks like for NAIP
platform = "naip"
metadata = dm.metadata[platform]
wavelengths = list(metadata.bands.wavelength.values())
gsd = metadata.gsd
batch_naip = create_batch(naip, wavelengths, gsd, DEVICE)
with torch.no_grad():
    unmsk_patch, unmsk_idx, msk_idx, msk_matrix = module.model.encoder(batch_naip)
    pixels_naip, _ = module.model.decoder(
        unmsk_patch,
        unmsk_idx,
        msk_idx,
        msk_matrix,
        batch_naip["time"],
        batch_naip["latlon"],
        batch_naip["gsd"],
        batch_naip["waves"],
    )
def denormalize_images(normalized_images, means, stds):
    means = np.array(means)
    stds = np.array(stds)
    means = means.reshape(1, -1, 1, 1)
    stds = stds.reshape(1, -1, 1, 1)
    denormalized_images = normalized_images * stds + means

    return denormalized_images
naip_mean = list(metadata.bands.mean.values())
naip_std = list(metadata.bands.std.values())

batch_naip_pixels = batch_naip["pixels"].detach().cpu().numpy()
batch_naip_pixels = denormalize_images(batch_naip_pixels, naip_mean, naip_std)
batch_naip_pixels = batch_naip_pixels.astype(np.uint8)

Rearrange the patches from the decoder back into pixel space.

Output from Clay decoder is of shape: batch x patches x pixel_values_per_patch, we will reshape that to batch x channel x height x width

pixels_naip = rearrange(
    pixels_naip,
    "b (h w) (c p1 p2) -> b c (h p1) (w p2)",
    c=batch_naip_pixels.shape[1],
    h=28,
    w=28,
    p1=8,
    p2=8,
)
denorm_pixels_naip = denormalize_images(
    pixels_naip.cpu().detach().numpy(), naip_mean, naip_std
)
denorm_pixels_naip = denorm_pixels_naip.astype(np.uint8)
fig, axs = plt.subplots(6, 8, figsize=(20, 14))

for i in range(0, 6, 2):
    for j in range(8):
        idx = (i // 2) * 8 + j
        axs[i][j].imshow(batch_naip_pixels[idx, [0, 1, 2], ...].transpose(1, 2, 0))
        axs[i][j].set_axis_off()
        axs[i][j].set_title(f"Actual {idx}")
        axs[i + 1][j].imshow(denorm_pixels_naip[idx, [0, 1, 2], ...].transpose(1, 2, 0))
        axs[i + 1][j].set_axis_off()
        axs[i + 1][j].set_title(f"Rec {idx}")

Minicube visualization

Next steps#

  • Look at reconstructions from other sensors

  • Interpolate between embedding spaces & reconstruct images from there (create a video animation)