Explore embedding space of CLAY Encoder for different sensors#

import sys
import warnings

sys.path.append("../..")
warnings.filterwarnings(action="ignore")
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange

from src.datamodule import ClayDataModule
from src.module import ClayMAEModule
DATA_DIR = "/home/ubuntu/data"
CHECKPOINT_PATH = "../../checkpoints/clay-v1.5.ckpt"
METADATA_PATH = "../../configs/metadata.yaml"
CHIP_SIZE = 256
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

MODEL#

Load the model with best checkpoint path and set it in eval mode.

# As we want to visualize the embeddings from the model,
# we neither mask the input image or shuffle the patches
module = ClayMAEModule.load_from_checkpoint(
    checkpoint_path=CHECKPOINT_PATH,
    model_size="large",
    metadata_path=METADATA_PATH,
    dolls=[16, 32, 64, 128, 256, 768, 1024],
    doll_weights=[1, 1, 1, 1, 1, 1, 1],
    mask_ratio=0.0,
    shuffle=False,
)

module.eval();

Distributed setup#

import os

import torch.distributed as dist

if not dist.is_initialized():
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"  # Can be any free port
    dist.init_process_group("nccl", rank=0, world_size=1)

DATAMODULE#

Load the ClayDataModule

# For model training, we stack chips from one sensor into batches of size 128.
# This reduces the num_workers we need to load the batches and speeds up the
# training process. Here, although the batch size is 1, the data module reads
# batch of size 128.
dm = ClayDataModule(
    data_dir=DATA_DIR,
    metadata_path=METADATA_PATH,
    size=CHIP_SIZE,
    batch_size=1,
    num_workers=1,
)
dm.setup(stage="fit")
Total number of chips: 258

Let us look at the data directory.

We have a folder for each sensor, i.e:

  • Landsat l1

  • Landsat l2

  • Sentinel 1 rtc

  • Sentinel 2 l2a

  • Naip

  • Linz

  • modis

And, under each folder, we have stacks of chips as .npz files.

!tree -L 1 {DATA_DIR}
/home/ubuntu/data
├── landsat-c2l1
├── landsat-c2l2-sr
├── linz
├── modis
├── naip
├── sentinel-1-rtc
└── sentinel-2-l2a

7 directories, 0 files
!tree -L 2 {DATA_DIR}/naip | head -5
/home/ubuntu/data/naip
├── cube_10.npz
├── cube_100045.npz
├── cube_100046.npz
├── cube_100072.npz

Now, lets look at what we have in each of the .npz files.

sample = np.load("/home/ubuntu/data/naip/cube_10.npz")
sample.keys()
KeysView(NpzFile '/home/ubuntu/data/naip/cube_10.npz' with keys: pixels, lon_norm, lat_norm, week_norm, hour_norm)
sample["pixels"].shape
(128, 4, 256, 256)
sample["lat_norm"].shape, sample["lon_norm"].shape
((128, 2), (128, 2))
sample["week_norm"].shape, sample["hour_norm"].shape
((128, 2), (128, 2))

As we see above, chips are stacked in batches of size 128.
The sample we are looking at is from NAIP so it has 4 bands & of size 256 x 256.
We also get normalized lat/lon & timestep (hour/week) information that is *(optionally required) by the model. If you don’t have this handy, feel free to pass zero tensors in their place.

Load a batch of data from ClayDataModule

ClayDataModule is designed to fetch random batches of data from different sensors sequentially, i.e batches are in ascending order of their directory - Landsat 1, Landsat 2, LINZ, MODIS, NAIP, Sentinel 1 rtc, Sentinel 2 L2A and it repeats after that.

# We have a random sample subset of the data, so it's
# okay to use either the train or val dataloader
dl = iter(dm.train_dataloader())
l1 = next(dl)
l2 = next(dl)
linz = next(dl)
modis = next(dl)
naip = next(dl)
s1 = next(dl)
s2 = next(dl)
for sensor, chips in zip(
    ("l1", "l2", "linz", "modis", "naip", "s1", "s2"),
    (l1, l2, linz, modis, naip, s1, s2),
):
    print(
        f"{chips['platform'][0]:<15}",
        chips["pixels"].shape,
        chips["time"].shape,
        chips["latlon"].shape,
    )
landsat-c2l1    torch.Size([128, 6, 256, 256]) torch.Size([128, 4]) torch.Size([128, 4])
landsat-c2l2-sr torch.Size([128, 6, 256, 256]) torch.Size([128, 4]) torch.Size([128, 4])
linz            torch.Size([128, 3, 256, 256]) torch.Size([128, 4]) torch.Size([128, 4])
modis           torch.Size([128, 7, 256, 256]) torch.Size([128, 4]) torch.Size([128, 4])
naip            torch.Size([128, 4, 256, 256]) torch.Size([128, 4]) torch.Size([128, 4])
sentinel-1-rtc  torch.Size([128, 2, 256, 256]) torch.Size([128, 4]) torch.Size([128, 4])
sentinel-2-l2a  torch.Size([128, 10, 256, 256]) torch.Size([128, 4]) torch.Size([128, 4])

INPUT#

Model expects a dictionary with keys:

  • pixels: batch x band x height x width - normalized chips of a sensor

  • time: batch x 4 - horizontally stacked week_norm & hour_norm

  • latlon: batch x 4 - horizontally stacked lat_norm & lon_norm

  • waves: list[:band] - wavelengths of each band of the sensor from the metadata.yaml file

  • gsd: scalar - gsd of the sensor from metadata.yaml file

Normalization & stacking is taken care of by the ClayDataModule: Clay-foundation/model

When not using the ClayDataModule, make sure you normalize the chips & pass all items for the model.

def create_batch(chips, wavelengths, gsd, device):
    batch = {}

    batch["pixels"] = chips["pixels"].to(device)
    batch["time"] = chips["time"].to(device)
    batch["latlon"] = chips["latlon"].to(device)

    batch["waves"] = torch.tensor(wavelengths)
    batch["gsd"] = torch.tensor(gsd)

    return batch
# Let us see an example of what input looks like for NAIP & Sentinel 2
platform = "naip"
metadata = dm.metadata[platform]
wavelengths = list(metadata.bands.wavelength.values())
gsd = metadata.gsd
batch_naip = create_batch(naip, wavelengths, gsd, DEVICE)
platform = "sentinel-2-l2a"
metadata = dm.metadata[platform]
wavelengths = list(metadata.bands.wavelength.values())
gsd = metadata.gsd
batch_s2 = create_batch(s2, wavelengths, gsd, DEVICE)

FORWARD PASS - Clay Encoder#

with torch.no_grad():
    unmsk_patch_naip, *_ = module.model.encoder(batch_naip)
    unmsk_patch_s2, *_ = module.model.encoder(batch_s2)
unmsk_patch_naip.shape, unmsk_patch_s2.shape
(torch.Size([128, 1025, 1024]), torch.Size([128, 1025, 1024]))

ClayMAE model is trained using patch_size of 8. For chip_size of 256 x 256, we have

256 // 8 -> 32 rows
256 // 8 -> 32 cols

32 * 32 -> 1024 patches are passed through the forward pass of the model.

Here we see unmsk_patch shapes of size batch x (1 + 1024) x 1024, i.e
1 -> cls_token
1024 -> patches
1024 -> embedding dim

VISUALIZE EMBEDDINGS of NAIP#

def denormalize_images(normalized_images, means, stds):
    """Denormalizes an image using its mean & std"""
    means = np.array(means)
    stds = np.array(stds)
    means = means.reshape(1, -1, 1, 1)
    stds = stds.reshape(1, -1, 1, 1)
    denormalized_images = normalized_images * stds + means

    return denormalized_images
naip_mean = list(dm.metadata["naip"].bands.mean.values())
naip_std = list(dm.metadata["naip"].bands.std.values())

batch_naip_pixels = batch_naip["pixels"].detach().cpu().numpy()
batch_naip_pixels = denormalize_images(batch_naip_pixels, naip_mean, naip_std)
batch_naip_pixels = batch_naip_pixels.astype(np.uint8)

Plot first 24 chips that are fed to the model

fig, axs = plt.subplots(3, 8, figsize=(20, 8))

for idx, ax in enumerate(axs.flatten()):
    ax.imshow(batch_naip_pixels[idx, :3, ...].transpose(1, 2, 0))
    ax.set_axis_off()
    ax.set_title(idx)
../_images/7e4af3825b4a104a8ee40a88c1d4161b5a019ec2648d41c4332786b3e7dee28a.png

Rearrange the embeddings from the Clay Encoder back to images.

Embeddings are of shape: batch x (1:cls_token + 1024:patches) x 1024:embedding dimension
1024:patches can be transformed into images of shape 32 x 32
1024:embedding dimension can be moved as channel dimension
Here, each embedding dims represents a particular unique feature of the chip

unmsk_embed = rearrange(
    unmsk_patch_naip[:, 1:, :].detach().cpu().numpy(), "b (h w) d-> b d h w", h=32, w=32
)

Pick a NAIP chip from the first 24 plotted above and visualize what each of the embedding dims look like for it. To prevent overflowing the notebook, lets plot the first 256 embedding dimensions.

embed = unmsk_embed[0]  # 3 is randomly picked chip
fig, axs = plt.subplots(16, 16, figsize=(20, 20))

for idx, ax in enumerate(axs.flatten()):
    ax.imshow(embed[idx], cmap="bwr")
    ax.set_axis_off()
    ax.set_title(idx)
plt.tight_layout()
../_images/b7a8de0a61bd3613698367007534117d0f48e27db717b658bc9d1217c4afc3d3.png

As we see above, each embedding dimension represents a feature of the chip. Some are simple & easy to interpret for human eyes like edges, patterns, features like land & water - while some are more complex.
Now, lets pick one embedding from the 1024 plotted above & visualize the same for all the chips.

fig, axs = plt.subplots(6, 8, figsize=(20, 14))
embed_dim = 61  # pick any embedding dimension

for i in range(0, 6, 2):
    for j in range(8):
        idx = (i // 2) * 8 + j
        axs[i][j].imshow(batch_naip_pixels[idx, :3, ...].transpose(1, 2, 0))
        axs[i][j].set_axis_off()
        axs[i][j].set_title(f"Image {idx}")
        embed = unmsk_embed[idx]
        axs[i + 1][j].imshow(embed[embed_dim], cmap="gray")
        axs[i + 1][j].set_axis_off()
        axs[i + 1][j].set_title(f"Embed {idx}")
../_images/6c17609bfa50f9913494c428e74cdc2c4405b9a168135d14309575b3e5289e2c.png

VISUALIZE EMBEDDINGS of S2#

We will repeat the same set of steps for Sentinel 2 now.

s2_mean = list(dm.metadata["sentinel-2-l2a"].bands.mean.values())
s2_std = list(dm.metadata["sentinel-2-l2a"].bands.std.values())

batch_s2_pixels = batch_s2["pixels"].detach().cpu().numpy()
batch_s2_pixels = denormalize_images(batch_s2_pixels, s2_mean, s2_std)
fig, axs = plt.subplots(3, 8, figsize=(20, 8))

for idx, ax in enumerate(axs.flatten()):
    ax.imshow(
        np.clip(batch_s2_pixels[idx, [2, 1, 0], ...].transpose(1, 2, 0) / 2000, 0, 1)
    )
    ax.set_axis_off()
    ax.set_title(idx)
../_images/32a5fdc58c9880b9d7a804ff6e7f05f837b84316dd68f51c366fba9d97daf044.png
unmsk_embed_s2 = rearrange(
    unmsk_patch_s2[:, 1:, :].detach().cpu().numpy(), "b (h w) d-> b d h w", h=32, w=32
)
embed_s2 = unmsk_embed_s2[2]
fig, axs = plt.subplots(16, 16, figsize=(20, 20))

for idx, ax in enumerate(axs.flatten()):
    ax.imshow(embed_s2[idx], cmap="bwr")
    ax.set_axis_off()
    ax.set_title(idx)
plt.tight_layout()
../_images/1f52402ac203f2add3ad890cd03629f1ad72b304848ea6fcdabafb4ecf929c72.png
fig, axs = plt.subplots(6, 8, figsize=(20, 14))
embed_dim = 61

for i in range(0, 6, 2):
    for j in range(8):
        idx = (i // 2) * 8 + j
        axs[i][j].imshow(
            np.clip(
                batch_s2_pixels[idx, [2, 1, 0], ...].transpose(1, 2, 0) / 2000, 0, 1
            )
        )
        axs[i][j].set_axis_off()
        axs[i][j].set_title(f"Image {idx}")
        embed_s2 = unmsk_embed_s2[idx]
        axs[i + 1][j].imshow(embed_s2[embed_dim], cmap="gray")
        axs[i + 1][j].set_axis_off()
        axs[i + 1][j].set_title(f"Embed {idx}")
../_images/3c8f5da2a8e3008ee8b658a7546761a8fef880b3ed459fd17b60e94e42df8fb0.png

Next steps#

  • Visualize embeddings for other sensors that the model is trained on i.e Landsat, Sentinel-1, LINZ

  • Visualize embeddings for sensors that the model has not seen during training. As the model has seen imagery from 0.5cm to 30m resolution, feel free to pick a sensor that falls in or around this range. We will add support for other sensors in later release.

  • Pick embeddings that seem to solve your tasks & try doing segmentation or detection using classical computer vision (will be a fun exercise).