Patch level cloud coverage#

This notebook obtains patch level (32x32 subset) cloud cover percentages from the Scene classification mask tied to a Sentinel-2 dataset.

We will demonstrate how to do the following:

  1. Leverage the AWS Sentinel-2 STAC catalog to obtain fine “patch” level cloud cover percentages from the Scene classification (SCL) mask. These percentages will be mapped and added to the GeoParquet files such that they can be added to database tables for similarity search filters and any other relevant downstream tasks.

  2. Generate fine level (pixel of size 10m x 10m) embeddings for an area (5.12km x 5.12km).

  3. Save the fine level (patch) embeddings and execute a similarity search that leverages the cloud cover percentages as reference.

import glob
from pathlib import Path

import geopandas as gpd
import lancedb
import matplotlib.pyplot as plt
import numpy
import pandas as pd
import pystac_client
import rasterio
import rioxarray  # noqa: F401
import shapely
import stackstac
import torch
from rasterio.enums import Resampling
from shapely.geometry import Polygon, box

from src.datamodule import ClayDataModule
from src.model_clay import CLAYModule

pd.set_option("display.max_colwidth", None)

BAND_GROUPS_L2A = {
    "rgb": ["red", "green", "blue"],
    "rededge": ["rededge1", "rededge2", "rededge3", "nir08"],
    "nir": [
        "nir",
    ],
    "swir": ["swir16", "swir22"],
    "sar": ["vv", "vh"],
    "scl": ["scl"],
}

STAC_API_L2A = "https://earth-search.aws.element84.com/v1"
COLLECTION_L2A = "sentinel-2-l2a"

SCL_CLOUD_LABELS = [7, 8, 9, 10]

Find Sentinel-2 scenes stored as Cloud-Optimized GeoTIFFs#

Define an area of interest#

This is a hotspot area where mining extraction occurs on the island of Fiji. We used this in another tutorial, albeit with a cloud free composite. This will help demonstrate how we can capture clouds for the same region and time frame in the absence of a cloud-free composite.

# sample cluster
bbox_bl = (177.4199, -17.8579)
bbox_tl = (177.4156, -17.6812)
bbox_br = (177.5657, -17.8572)
bbox_tr = (177.5657, -17.6812)

Define spatiotemporal query

# Define area of interest
area_of_interest = shapely.box(
    xmin=bbox_bl[0], ymin=bbox_bl[1], xmax=bbox_tr[0], ymax=bbox_tr[1]
)

# Define temporal range
daterange: dict = ["2021-01-01T00:00:00Z", "2021-12-31T23:59:59Z"]
catalog_L2A = pystac_client.Client.open(STAC_API_L2A)

search = catalog_L2A.search(
    collections=[COLLECTION_L2A],
    datetime=daterange,
    intersects=area_of_interest,
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

items_L2A = search.get_all_items()

print(f"Found {len(items_L2A)} items")

Download the data#

Get the data into a numpy array and visualize the imagery.

# Extract coordinate system from first item
epsg = items_L2A[0].properties["proj:epsg"]

# Convert point from lon/lat to UTM projection
poidf = gpd.GeoDataFrame(crs="OGC:CRS84", geometry=[area_of_interest.centroid]).to_crs(
    epsg
)
geom = poidf.iloc[0].geometry

# Create bounds of the correct size, the model
# requires 512x512 pixels at 10m resolution.
bounds = (geom.x - 2560, geom.y - 2560, geom.x + 2560, geom.y + 2560)

# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB and SCL band groups.
stack_L2A = stackstac.stack(
    items_L2A[0],
    bounds=bounds,
    snap_bounds=False,
    epsg=epsg,
    resolution=10,
    dtype="float32",
    rescale=False,
    fill_value=0,
    assets=BAND_GROUPS_L2A["rgb"] + BAND_GROUPS_L2A["scl"],
    resampling=Resampling.nearest,
    xy_coords="center",
)

stack_L2A = stack_L2A.compute()
print(stack_L2A.shape)
assert stack_L2A.shape == (1, 4, 512, 512)

stack_L2A.sel(band=["red", "green", "blue"]).plot.imshow(
    row="time", rgb="band", vmin=0, vmax=2000
)

Write the stack to file#

outdir = Path("data/minicubes_cloud")
outdir.mkdir(exist_ok=True, parents=True)

write = True
if write:
    # Write tile to output dir, whilst dropping the SCL band in the process
    for tile in stack_L2A.sel(band=["red", "green", "blue"]):
        date = str(tile.time.values)[:10]

        name = "{dir}/claytile_{date}.tif".format(
            dir=outdir,
            date=date.replace("-", ""),
        )
        tile.rio.to_raster(name, compress="deflate")

        with rasterio.open(name, "r+") as rst:
            rst.update_tags(date=date)

Get the geospatial bounds and cloud cover percentages for the 32x32 windows#

We will use the geospatial bounds of the 32x32 windowed subsets (“chunks”) to store the patch level embeddings.

# Function to count cloud pixels in a subset


def count_cloud_pixels(subset_scl, cloud_labels):
    cloud_pixels = 0
    for label in cloud_labels:
        cloud_pixels += numpy.count_nonzero(subset_scl == label)
    return cloud_pixels
# Define the chunk size for tiling
chunk_size = {"x": 32, "y": 32}  # Adjust the chunk size as needed

# Tile the data
ds_chunked_L2A = stack_L2A.chunk(chunk_size)

# Get the dimensions of the data array
dims = ds_chunked_L2A.dims

# Get the geospatial information from the original dataset
geo_info = ds_chunked_L2A.attrs

# Iterate over the chunks and compute the geospatial bounds for each chunk
chunk_bounds = {}

# Iterate over the chunks and compute the cloud count for each chunk
cloud_pcts = {}

# Get the geospatial transform and CRS
transform = ds_chunked_L2A.attrs["transform"]
crs = ds_chunked_L2A.attrs["crs"]

for x in range(ds_chunked_L2A.sizes["x"] // chunk_size["x"]):  # + 1):
    for y in range(ds_chunked_L2A.sizes["y"] // chunk_size["y"]):  # + 1):
        # Compute chunk coordinates
        x_start = x * chunk_size["x"]
        y_start = y * chunk_size["y"]
        x_end = min(x_start + chunk_size["x"], ds_chunked_L2A.sizes["x"])
        y_end = min(y_start + chunk_size["y"], ds_chunked_L2A.sizes["y"])

        # Compute chunk geospatial bounds
        lon_start, lat_start = transform * (x_start, y_start)
        lon_end, lat_end = transform * (x_end, y_end)
        # print(lon_start, lat_start, lon_end, lat_end, x, y)

        # Store chunk bounds
        chunk_bounds[(x, y)] = {
            "lon_start": lon_start,
            "lat_start": lat_start,
            "lon_end": lon_end,
            "lat_end": lat_end,
        }

        # Extract the subset of the SCL band
        subset_scl = ds_chunked_L2A.sel(band="scl")[:, y_start:y_end, x_start:x_end]

        # Count the cloud pixels in the subset
        cloud_pct = count_cloud_pixels(subset_scl, SCL_CLOUD_LABELS)

        # Store the cloud percent for this chunk
        cloud_pcts[(x, y)] = int(100 * (cloud_pct / 1024))


# Print chunk bounds
# for key, value in chunk_bounds.items():
# print(f"Chunk {key}: {value}")

# Print indices where cloud percentages exceed some interesting threshold
cloud_threshold = 50
for key, value in cloud_pcts.items():
    if value > cloud_threshold:
        print(f"Chunk {key}: Cloud percentage = {value}")
DATA_DIR = "data/minicubes_cloud"
CKPT_PATH = (
    "https://huggingface.co/made-with-clay/Clay/resolve/main/"
    "Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
)
# Load model
multi_model = CLAYModule.load_from_checkpoint(
    CKPT_PATH,
    mask_ratio=0.0,
    band_groups={"rgb": (2, 1, 0)},
    bands=3,
    strict=False,  # ignore the extra parameters in the checkpoint
    embeddings_level="group",
)
# Set the model to evaluation mode
multi_model.eval()


# Load the datamodule, with the reduced set of
class ClayDataModuleMulti(ClayDataModule):
    MEAN = [
        1369.03,  # red
        1597.68,  # green
        1741.10,  # blue
    ]
    STD = [
        2026.96,  # red
        2011.88,  # green
        2146.35,  # blue
    ]


data_dir = Path(DATA_DIR)

dm = ClayDataModuleMulti(data_dir=str(data_dir.absolute()), batch_size=1)
dm.setup(stage="predict")
trn_dl = iter(dm.predict_dataloader())
embeddings = []
for batch in trn_dl:
    with torch.no_grad():
        # Move data from to the device of model
        batch["pixels"] = batch["pixels"].to(multi_model.device)
        # Pass just the specific band through the model
        batch["timestep"] = batch["timestep"].to(multi_model.device)
        batch["latlon"] = batch["latlon"].to(multi_model.device)

        # Pass pixels, latlon, timestep through the encoder to create encoded patches
        (
            unmasked_patches,
            unmasked_indices,
            masked_indices,
            masked_matrix,
        ) = multi_model.model.encoder(batch)
        print(unmasked_patches.detach().cpu().numpy())

        embeddings.append(unmasked_patches.detach().cpu().numpy())
print(len(embeddings[0]))  # embeddings is a list
print(embeddings[0].shape)  # with date and lat/lon
print(embeddings[0][:, :-2, :].shape)  # remove date and lat/lon
# remove date and lat/lon and reshape to disaggregated patches
embeddings_patch = embeddings[0][:, :-2, :].reshape([1, 16, 16, 768])
embeddings_patch.shape
# average over the band groups
embeddings_patch_avg_group = embeddings_patch.mean(axis=0)
embeddings_patch_avg_group.shape

Save the patch level embeddings to independent GeoParquet files#

Save the patch level embeddings with the matching geospatial bounds and cloud cover percentages from the chunks we computed earlier. We are correlating patch to chunk bounds based on matching index. This assumes the patches and chunks both define 32x32 subsets with zero overlap.

outdir_embeddings = Path("data/embeddings_cloud")
outdir_embeddings.mkdir(exist_ok=True, parents=True)
# Iterate through each patch
for i in range(embeddings_patch_avg_group.shape[0]):
    for j in range(embeddings_patch_avg_group.shape[1]):
        embeddings_output_patch = embeddings_patch_avg_group[i, j]

        item_ = [
            element for element in list(chunk_bounds.items()) if element[0] == (i, j)
        ]
        box_ = [
            item_[0][1]["lon_start"],
            item_[0][1]["lat_start"],
            item_[0][1]["lon_end"],
            item_[0][1]["lat_end"],
        ]
        cloud_pct_ = [
            element for element in list(cloud_pcts.items()) if element[0] == (i, j)
        ]
        source_url = batch["source_url"]
        date = batch["date"]
        data = {
            "source_url": batch["source_url"][0],
            "date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
                dtype="date32[day][pyarrow]"
            ),
            "embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
            "cloud_cover": cloud_pct_[0][1],
        }

        # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
        # The box_ list is encoded as
        # [bottom left x, bottom left y, top right x, top right y]
        box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])

        # Create the GeoDataFrame
        gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f"EPSG:{epsg}")

        # Reproject to WGS84 (lon/lat coordinates)
        gdf = gdf.to_crs(epsg=4326)

        outpath = (
            f"{outdir_embeddings}/"
            f"{batch['source_url'][0].split('/')[-1][:-4]}_{i}_{j}.gpq"
        )
        gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
        print(
            f"Saved {len(gdf)} rows of embeddings of "
            f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
        )

Similarity search on the patch embedding level#

We will use reference indices based on cloud cover percentage to define a filtered search.

db = lancedb.connect("embeddings")
# Data for DB table
data = []
# Dataframe to find overlaps within
gdfs = []
for emb in glob.glob(f"{outdir_embeddings}/*.gpq"):
    gdf = gpd.read_parquet(emb)
    gdf["year"] = gdf.date.dt.year
    gdf["tile"] = gdf["source_url"].apply(
        lambda x: Path(x).stem.rsplit("/")[-1].rsplit("_")[0]
    )
    gdf["idx"] = "_".join(emb.split("/")[-1].split("_")[2:]).replace(".gpq", "")
    gdf["box"] = [box(*geom.bounds) for geom in gdf.geometry]
    gdfs.append(gdf)

    for _, row in gdf.iterrows():
        data.append(
            {
                "vector": row["embeddings"],
                "path": row["source_url"],
                "tile": row["tile"],
                "date": row["date"],
                "year": int(row["year"]),
                "cloud_cover": row["cloud_cover"],
                "idx": row["idx"],
                "box": row["box"].bounds,
            }
        )
# Combine patch level geodataframes into one
embeddings_gdf = pd.concat(gdfs, ignore_index=True)
embeddings_gdf.columns

(Optional) check on what an embedding’s RGB subset looks like#

embeddings_gdf_shuffled = embeddings_gdf.sample(frac=1).reset_index(drop=True)

area_of_interest_embedding = embeddings_gdf_shuffled.box.iloc[0]

# Extract coordinate system from first item
epsg = items_L2A[0].properties["proj:epsg"]

# Convert point from lon/lat to UTM projection
box_embedding = gpd.GeoDataFrame(
    crs="OGC:CRS84", geometry=[area_of_interest_embedding]
).to_crs(epsg)
geom_embedding = box_embedding.iloc[0].geometry

# Create bounds of the correct size, the model
# requires 32x32 pixels at 10m resolution.

# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB group.
stack_embedding = stackstac.stack(
    items_L2A[0],
    bounds=geom_embedding.bounds,
    snap_bounds=False,
    epsg=epsg,
    resolution=10,
    dtype="float32",
    rescale=False,
    fill_value=0,
    assets=BAND_GROUPS_L2A["rgb"],
    resampling=Resampling.nearest,
    xy_coords="center",
)

stack_embedding = stack_embedding.compute()
assert stack_embedding.shape == (1, 3, 32, 32)

stack_embedding.sel(band=["red", "green", "blue"]).plot.imshow(
    row="time", rgb="band", vmin=0, vmax=2000
)

Instantiate a dedicated DB table#

db.drop_table("clay-v001")
db.table_names()
tbl = db.create_table("clay-v001", data=data, mode="overwrite")

Set up filtered searchs#

# Function to get the average of some list of reference vectors
def get_average_vector(idxs):
    reformatted_idxs = ["_".join(map(str, idx)) for idx in idxs]
    matching_rows = [
        tbl.to_pandas().query(f"idx == '{idx}'") for idx in reformatted_idxs
    ]
    matching_vectors = [row.iloc[0]["vector"] for row in matching_rows]
    vector_mean = numpy.mean(matching_vectors, axis=0)
    return vector_mean

Let’s remind ourselves which patches have clouds in them.

cloud_threshold = 80
# Get indices for patches where cloud percentages
# exceed some interesting threshold
cloudy_indices = []
for key, value in cloud_pcts.items():
    if value > cloud_threshold:
        print(f"Chunk {key}: Cloud percentage = {value}")
        cloudy_indices.append(key)
v_cloudy = get_average_vector(cloudy_indices)
result_cloudy = tbl.search(query=v_cloudy).limit(10).to_pandas()

Now let’s set up a filtered search for patches that have very little cloud coverage.

cloud_threshold = 10
# Get indices for patches where cloud percentages
# do not exceed some interesting threshold
non_cloudy_indices = []
for key, value in cloud_pcts.items():
    if value < cloud_threshold:
        # print(f"Chunk {key}: Cloud percentage = {value}")
        non_cloudy_indices.append(key)
v_non_cloudy = get_average_vector(non_cloudy_indices)
result_non_cloudy = tbl.search(query=v_non_cloudy).limit(10).to_pandas()

Plot similar patches#

def plot(df, cols=10):
    fig, axs = plt.subplots(1, cols, figsize=(20, 10))

    row_0 = df.iloc[0]
    path = row_0["path"]
    chip = rasterio.open(path)
    tile = row_0["tile"]
    width = chip.width
    height = chip.height
    # Define the window size
    window_size = (32, 32)

    idxs_windows = {"idx": [], "window": []}

    # Iterate over the image in 32x32 windows
    for col in range(0, width, window_size[0]):
        for row in range(0, height, window_size[1]):
            # Define the window
            window = ((row, row + window_size[1]), (col, col + window_size[0]))

            # Read the data within the window
            data = chip.read(window=window)

            # Get the index of the window
            index = (col // window_size[0], row // window_size[1])

            # Process the window data here
            # For example, print the index and the shape of the window data
            # print("Index:", index)
            # print("Window Shape:", data.shape)

            idxs_windows["idx"].append("_".join(map(str, index)))
            idxs_windows["window"].append(data)

    # print(idxs_windows)

    for ax, (_, row) in zip(axs.flatten(), df.iterrows()):
        idx = row["idx"]
        # Find the corresponding window based on the idx
        window_index = idxs_windows["idx"].index(idx)
        window_data = idxs_windows["window"][window_index]
        # print(window_data.shape)
        subset_img = numpy.clip(
            (window_data.transpose(1, 2, 0)[:, :, :3] / 10_000) * 3, 0, 1
        )
        ax.imshow(subset_img)
        ax.set_title(f"{tile}/{idx}/{row.cloud_cover}")
        ax.set_axis_off()
    plt.tight_layout()
    fig.savefig("similar.png")

Result from searching for cloudy samples#

plot(result_cloudy)

Result from searching for non-cloudy samples#

plot(result_non_cloudy)

Visualize the area of interest with the cloudy and non-cloudy patch results#

# Make geodataframe of the search results
# cloudy
result_cloudy_boxes = [
    Polygon(
        [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]
    )
    for bbox in result_cloudy["box"]
]
result_cloudy_gdf = gpd.GeoDataFrame(result_cloudy, geometry=result_cloudy_boxes)
result_cloudy_gdf.crs = "EPSG:4326"
# non-cloudy
result_non_cloudy_boxes = [
    Polygon(
        [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]
    )
    for bbox in result_non_cloudy["box"]
]
result_non_cloudy_gdf = gpd.GeoDataFrame(
    result_non_cloudy, geometry=result_non_cloudy_boxes
)
result_non_cloudy_gdf.crs = "EPSG:4326"

# Plot the AOI in RGB
plot = stack_L2A.sel(band=["B04", "B03", "B02"]).plot
plot.imshow(row="time", rgb="band", vmin=0, vmax=2000)

# Overlay the bounding boxes of the patches identified from the similarity search
result_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color="red", alpha=0.5)
result_non_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color="blue", alpha=0.5)


# Set plot title and labels
plt.title("Sentinel-2 with cloudy and non-cloudy embeddings")
plt.xlabel("Longitude")
plt.ylabel("Latitude")

# Show the plot
plt.show()