CLAY v0 - Location Embeddings#
Imports#
import sys
sys.path.append("../")
import warnings
from pathlib import Path
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from src.datamodule import ClayDataModule, ClayDataset
from src.model_clay import CLAYModule
warnings.filterwarnings("ignore")
L.seed_everything(42)
# data directory for all chips
DATA_DIR = "../data/02"
# path of best model checkpoint for Clay v0
CKPT_PATH = "../checkpoints/v0/mae_epoch-24_val-loss-0.46.ckpt"
Load Model & DataModule#
# Load the model & set in eval mode
model = CLAYModule.load_from_checkpoint(CKPT_PATH, mask_ratio=0.7)
model.eval();
data_dir = Path(DATA_DIR)
# Load the Clay DataModule
ds = ClayDataset(chips_path=list(data_dir.glob("**/*.tif")))
dm = ClayDataModule(data_dir=str(data_dir), batch_size=100)
dm.setup(stage="fit")
# Load the train DataLoader
trn_dl = iter(dm.train_dataloader())
# Load the first batch of chips
batch = next(trn_dl)
batch.keys()
# Save a copy of batch to visualize later
_batch = batch["pixels"].detach().clone().cpu().numpy()
Pass model through the CLAY model#
# Pass the pixels through the encoder & decoder of CLAY
with torch.no_grad():
# Move data from to the device of model
batch["pixels"] = batch["pixels"].to(model.device)
batch["timestep"] = batch["timestep"].to(model.device)
batch["latlon"] = batch["latlon"].to(model.device)
# Pass pixels, latlon, timestep through the encoder to create encoded patches
(
unmasked_patches,
unmasked_indices,
masked_indices,
masked_matrix,
) = model.model.encoder(batch)
# Pass the unmasked_patches through the decoder to reconstruct the pixel space
pixels = model.model.decoder(unmasked_patches, unmasked_indices, masked_indices)
Extract Location & Timestep Embeddings#
In CLAY, the encoder receives unmasked patches, latitude-longitude data, and timestep information. Notably, the last 2 embeddings from the encoder specifically represent the latitude-longitude and timestep embeddings.
latlon_embeddings = unmasked_patches[:, -2, :].detach().cpu().numpy()
time_embeddings = unmasked_patches[:, -1, :].detach().cpu().numpy()
# Get normalized latlon that were input to the model
latlon = batch["latlon"].detach().cpu().numpy()
We will just focus on location embeddings in this notebook
latlon.shape, latlon_embeddings.shape
Latitude & Longitude map to 768 dimentional vector
Preform PCA over the location embeddings to visualize them in 2 dimension#
pca = PCA(n_components=2)
latlon_embeddings = pca.fit_transform(latlon_embeddings)
latlon_embeddings.shape
Create clusters of normalized latlon & latlon embeddings to check if there are any learned patterns in them after training#
Latlon Cluster
kmeans = KMeans(n_clusters=5)
kmeans.fit_transform(latlon)
latlon = np.column_stack((latlon, kmeans.labels_))
Latlon Embeddings Cluster
kmeans = KMeans(n_clusters=5)
kmeans.fit_transform(latlon_embeddings)
latlon_embeddings = np.column_stack((latlon_embeddings, kmeans.labels_))
latlon.shape, latlon_embeddings.shape
We are a third dimension to latlon & latlon embeddings with cluster labels
Plot latlon clusters#
plt.figure(figsize=(15, 15), dpi=80)
plt.scatter(latlon[:, 0], latlon[:, 1], c=latlon[:, 2], label="Actual", alpha=0.3)
for i in range(100):
txt = f"{latlon[:,0][i]:.2f},{latlon[:, 1][i]:.2f}"
plt.annotate(txt, (latlon[:, 0][i] + 1e-5, latlon[:, 1][i] + 1e-5))
As we see in the scatter plot above, there is nothing unique about latlon that go into the model, they are cluster based on their change in longitude values above
Plot latlon embeddings cluster#
plt.figure(figsize=(15, 15), dpi=80)
plt.scatter(
latlon_embeddings[:, 0],
latlon_embeddings[:, 1],
c=latlon_embeddings[:, 2],
label="Predicted",
alpha=0.3,
)
for i in range(100):
txt = i
plt.annotate(txt, (latlon_embeddings[:, 0][i], latlon_embeddings[:, 1][i]))
def show_cluster(ids):
fig, axes = plt.subplots(1, len(ids), figsize=(10, 5))
for i, ax in zip(ids, axes.flatten()):
img_path = batch["source_url"][i]
img = rio.open(img_path).read([3, 2, 1]).transpose(1, 2, 0)
img = (img - img.min()) / (img.max() - img.min())
ax.imshow(img)
ax.set_axis_off()
show_cluster((87, 37, 40))
show_cluster((23, 11, 41))
show_cluster((68, 71, 7))
We can see location embedding capturing semantic information as well