Generating vector embeddings#

Once you have a pretrained model, it is possible to pass some input images into the encoder part of the Vision Transformer and produce vector embeddings which contain a semantic representation of the image.

Producing embeddings from the pretrained model#

Step-by-step instructions to create embeddings for a single MGRS tile location (e.g. 27WXN):

  1. Ensure that you can access the 13-band GeoTIFF data files.

    aws s3 ls s3://clay-tiles-02/02/27WXN/

    This should report a list of filepaths if you have the correct permissions. Otherwise, please set up authentication before continuing.

  2. Download the pretrained model weights and put them in the checkpoints/ folder:

    aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/


    For running model inference on a large scale (hundreds or thousands of MGRS tiles), it is recommended to have a cloud VM instance with:

    1. A high bandwidth network (>25Gbps) to speed up data transfer from the S3 bucket to the compute device.

    2. An NVIDIA Ampere generation GPU (e.g. A10G) or newer, which would allow for efficient bfloat16 dtype calculations.

    For example, an AWS g5.4xlarge instance would be a cost effective option.

  3. Run model inference to generate the embeddings:

    python predict --ckpt_path=checkpoints/clay-small-70MT-1100T-10E.ckpt \
                              --trainer.precision=bf16-mixed \
                              --data.data_dir=s3://clay-tiles-02/02/27WXN \
                              --data.batch_size=32 \
                              --data.num_workers=16 \

    This should output a GeoParquet file containing the embeddings for MGRS tile 27WXN (recall that each 10000x10000 pixel MGRS tile contains hundreds of smaller 512x512 chips), saved to the data/embeddings/ folder. See the next subsection for details about the embeddings file.

    The embeddings_level flag determines how the embeddings are calculated. The default is mean, resulting in one average embedding per MGRS tile of size 768. If set to patch, the embeddings will be kept at the patch level. The embedding array will be of size 16 * 16 * 768, representing one embedding per patch. The third option group will keep the full dimensionality of the encoder output, including the band group dimension. The array size of those embeddings is 6 * 16 * 16 * 768.

    The embeddings are flattened into one-dimensional arrays because pandas does not allow for multidimensional arrays. This makes it necessary to reshape the flattened arrays to access the patch-level embeddings.


    For those interested in how the embeddings were computed, the predict step above does the following:

    1. Pass the 13-band GeoTIFF input into the Vision Transformer’s encoder, to produce raw embeddings of shape (B, 1538, 768), where B is the batch_size, 1538 is the patch dimension and 768 is the embedding length. The patch dimension itself is a concatenation of 1536 (6 band groups x 16x16 spatial patches of size 32x32 pixels each in a 512x512 image) + 2 (latlon embedding and time embedding) = 1538.

    2. By default, the mean or average is taken across the 1536 patch dimension, yielding an output embedding of shape (B, 768). If patch embeddings are requested, the shape is (B, 16 * 16 * 768), one embedding per patch.

    More details of how this is implemented can be found by inspecting the predict_step method in the file.

Format of the embeddings file#

The vector embeddings are stored in a single column within a GeoParquet file (*.gpq), with other columns containing spatiotemporal metadata. This file format is built on top of the popular Apache Parquet columnar storage format designed for fast analytics, and it is highly interoperable across different tools like QGIS, GeoPandas (Python), sfarrow (R), and more.

Filename convention#

The embeddings file utilizes the following naming convention:


Example: 27WXN_20200101_20231231_v001.gpq




The spatial location of the file’s contents in the Military Grid Reference System (MGRS), given as a 5-character string


The minimum acquisition date of the Sentinel-2 images used to generate the embeddings, given in YYYYMMDD format


The maximum acquisition date of the Sentinel-2 images used to generate the embeddings, given in YYYYMMDD format


Version of the generated embeddings, given as a 3-digit number

Table schema#

Each row within the GeoParquet table is generated from a 512x512 pixel image and contains a record of the embeddings, spatiotemporal metadata, and a link to the GeoTIFF file used as the source image for the embedding. The table looks something like this:

Embedding size is 768 by default, 16 * 16 * 768 for patch level embeddings, and 6 * 16 * 16 * 768 for group level embeddings.







[0.1, 0.4, … ]




[0.2, 0.5, … ]




[0.3, 0.6, … ]


Details of each column are as follows:

  • source_url (string) - The full URL to the 13-band GeoTIFF image the embeddings were derived from.

  • date (date32) - Acquisition date of the Sentinel-2 image used to generate the embeddings, in YYYY-MM-DD format.

  • embeddings (FixedShapeTensorArray) - The vector embeddings given as a 1-D tensor or list with a length of 768.

  • geometry (binary) - The spatial bounding box of where the 13-band image, provided in a WKB Polygon representation.


Additional technical details of the GeoParquet file:

  • GeoParquet specification v1.0.0

  • Coordinate reference system of geometries are in OGC:CRS84.

Reading the embeddings#

Sample code to read the GeoParquet embeddings file using geopandas.read_parquet

import geopandas as gpd

gpq_file = "data/embeddings/27WXN_20200101_20231231_v001.gpq"
geodataframe = gpd.read_parquet(path=gpq_file)

Converting to patch-level embeddings#

In the case where patch-level embeddings are requested, the resulting array will have all patch embeddings ravelled in one row. Each row represents a 512x512 pixel image, and contains 16x16 patch embeddings.

To convert each row into patch level embeddings, the embedding array has to be unravelled into 256 patches like so

# This assumes embeddings levels set to "patch"
ravelled_patch_embeddings = geodataframe.embeddings[0]
patch_embeddings = ravelled_patch_embeddings.reshape(16, 16, 768)