Burn scar analysis using embeddings from partial inputs#

This notebook contains a complete example for how to run Clay. It combines the following three different aspects

  1. Create single-chip datacubes with time series data for a location and a date range

  2. Run the model with partial inputs, in this case RGB + NIR

  3. Study burn scares through the embeddings generated for that datacube

Let’s start with importing and creating constants#

# Ensure working directory is the repo home
import os

os.chdir("..")
import warnings
from pathlib import Path

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy
import pandas as pd
import pystac_client
import rasterio
import rioxarray  # noqa: F401
import stackstac
import torch
from rasterio.enums import Resampling
from shapely import Point
from sklearn import decomposition

from src.datamodule import ClayDataModule
from src.model_clay import CLAYModule

warnings.filterwarnings("ignore")

BAND_GROUPS = {
    "rgb": ["red", "green", "blue"],
    "rededge": ["rededge1", "rededge2", "rededge3", "nir08"],
    "nir": [
        "nir",
    ],
    "swir": ["swir16", "swir22"],
    "sar": ["vv", "vh"],
}

STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

Search for imagery over an area of interest#

In this example we use a location and date range to visualize a forest fire that happened in Monchique in 2018

# Point over Monchique Portugal
poi = 37.30939, -8.57207

# Dates of a large forest fire
start = "2018-07-01"
end = "2018-09-01"

catalog = pystac_client.Client.open(STAC_API)

search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(poi[1] - 1e-5, poi[0] - 1e-5, poi[1] + 1e-5, poi[0] + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

items = search.get_all_items()

print(f"Found {len(items)} items")
Found 12 items

Download the data#

Get the data into a numpy array and visualize the imagery. The burn scar is visible in the last five images.

# Extract coordinate system from first item
epsg = items[0].properties["proj:epsg"]

# Convert point into the image projection
poidf = gpd.GeoDataFrame(
    pd.DataFrame(),
    crs="EPSG:4326",
    geometry=[Point(poi[1], poi[0])],
).to_crs(epsg)

coords = poidf.iloc[0].geometry.coords[0]

# Create bounds of the correct size, the model
# requires 512x512 pixels at 10m resolution.
bounds = (
    coords[0] - 2560,
    coords[1] - 2560,
    coords[0] + 2560,
    coords[1] + 2560,
)

# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB and NIR band groups.
stack = stackstac.stack(
    items,
    bounds=bounds,
    snap_bounds=False,
    epsg=epsg,
    resolution=10,
    dtype="float32",
    rescale=False,
    fill_value=0,
    assets=BAND_GROUPS["rgb"] + BAND_GROUPS["nir"],
    resampling=Resampling.nearest,
)

stack = stack.compute()

stack.sel(band=["red", "green", "blue"]).plot.imshow(
    row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6
)

Minicube visualization

Write data to tif files#

To use the mini datacube in the Clay dataloader, we need to write the images to tif files on disk. These tif files are then used by the Clay data loader for creating embeddings below.

outdir = Path("data/minicubes")
outdir.mkdir(exist_ok=True, parents=True)

# Write tile to output dir
for tile in stack:
    # Grid code like MGRS-29SNB
    mgrs = str(tile.coords["grid:code"].values).split("-")[1]
    date = str(tile.time.values)[:10]

    name = "{dir}/claytile_{mgrs}_{date}.tif".format(
        dir=outdir,
        mgrs=mgrs,
        date=date.replace("-", ""),
    )
    tile.rio.to_raster(name, compress="deflate")

    with rasterio.open(name, "r+") as rst:
        rst.update_tags(date=date)

Create embeddings#

Now switch gears and load the tiles to create embeddings and analyze them.

The model checkpoint can be loaded directly from huggingface, and the data directory points to the directory we created in the steps above.

Note that the normalization parameters for the data module need to be adapted based on the band groups that were selected as partial input. The full set of normalization parameters can be found here.

Load the model and set up the data module#

DATA_DIR = "data/minicubes"
CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt"

# Load model
rgb_model = CLAYModule.load_from_checkpoint(
    CKPT_PATH,
    mask_ratio=0.0,
    band_groups={"rgb": (2, 1, 0), "nir": (3,)},
    bands=4,
    strict=False,  # ignore the extra parameters in the checkpoint
)
# Set the model to evaluation mode
rgb_model.eval()


# Load the datamodule, with the reduced set of
class ClayDataModuleRGB(ClayDataModule):
    MEAN = [
        1369.03,  # red
        1597.68,  # green
        1741.10,  # blue
        2858.43,  # nir
    ]
    STD = [
        2026.96,  # red
        2011.88,  # green
        2146.35,  # blue
        2016.38,  # nir
    ]


data_dir = Path(DATA_DIR)

dm = ClayDataModuleRGB(data_dir=str(data_dir.absolute()), batch_size=20)
dm.setup(stage="predict")
trn_dl = iter(dm.predict_dataloader())
Total number of chips: 12

Create the embeddings for the images over the forest fire#

This will loop through the images returned by the data loader and evaluate the model for each one of the images. The raw embeddings are reduced to mean values to simplify the data.

embeddings = []

for batch in trn_dl:
    with torch.inference_mode():
        # Move data from to the device of model
        batch["pixels"] = batch["pixels"].to(rgb_model.device)
        # Pass just the specific band through the model
        batch["timestep"] = batch["timestep"].to(rgb_model.device)
        batch["latlon"] = batch["latlon"].to(rgb_model.device)

        # Pass pixels, latlon, timestep through the encoder to create encoded patches
        (
            unmasked_patches,
            unmasked_indices,
            masked_indices,
            masked_matrix,
        ) = rgb_model.model.encoder(batch)

        embeddings.append(unmasked_patches.detach().cpu().numpy())

embeddings = numpy.vstack(embeddings)

embeddings_mean = embeddings[:, :-2, :].mean(axis=1)

print(f"Average embeddings have shape {embeddings_mean.shape}")
Average embeddings have shape (12, 768)

Analyze embeddings#

Now we can make a simple analysis of the embeddings. We reduce all the embeddings to a single number using Principle Component Analysis. Then we can plot the principal components. The effect of the fire on the embeddings is clearly visible. We use the following color code in the graph:

Color

Interpretation

Green

Cloudy Images

Blue

Before the fire

Red

After the fire

pca = decomposition.PCA(n_components=1)
pca_result = pca.fit_transform(embeddings_mean)

plt.xticks(rotation=-30)
# All points
plt.scatter(stack.time, pca_result, color="blue")

# Cloudy images
plt.scatter(stack.time[0], pca_result[0], color="green")
plt.scatter(stack.time[2], pca_result[2], color="green")

# After fire
plt.scatter(stack.time[-5:], pca_result[-5:], color="red")
<matplotlib.collections.PathCollection at 0x7f9948d29890>
_images/166710d547e419f2e8f8ae39f5a03403d7d3b05859c8943aea546758c9c42ba3.png

In the plot above, each image embedding is one point. One can clearly distinguish the two cloudy images and the values after the fire are consistently low.