Finetuning

Contents

Finetuning#

Fine-tuning refers to a process in machine learning where a pre-trained model is further trained on a specific dataset to adapt its parameters to a downstream task characterized by a relevent domain. It’s distinct from training a model from scratch using the downstream task dataset exclusively.

Related to finetuning in the field of training Foundation models is linear probing, which refers to a technique used to analyze or explore the representations learned by a Foundation model as it trains. When a large-scale model (like a vision transformer model) is pre-trained on a vast corpus of data, it learns rich and complex representations of patterns within the data. Linear probing involves examining or probing these learned representations by periodically (e.g. every few epochs of the Foundation model’s training cycle) finetuning a small downstream task on top of the pre-trained model’s layers or embeddings.

We use full finetuning and linear probing in Clay to evaluate the usefulness of the Foundation model both during its pre-training and afterwards.

Let’s take a look at how we are finetuning on the benchmark datacube-adapted Cloud to Street - Microsoft Flood Dataset. As a reminder, that is a downstream segmentation task for identifiying water pixels in recorded flood events. It’s a binary segmentation problem, specifically.

We process the datacubes into batches formatted in the way the pretrained Clay model expects, with the addition of information for label images as well. Here’s an example subset of a batch dictionary:

{'labels': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]),
 'pixels': tensor([[[[-0.5994, -0.6108, -0.6034,  ..., -0.5610, -0.5590, -0.5614],
           [-0.5767, -0.5950, -0.6004,  ..., -0.5619, -0.5536, -0.5610],
           [-0.5841, -0.5762, -0.5930,  ..., -0.5491, -0.5304, -0.5373],
           ...,
           [-0.5087, -0.5447, -0.4351,  ..., -0.6162, -0.6083, -0.6044],
           [-0.4184, -0.5432, -0.5003,  ..., -0.6108, -0.6128, -0.6073],
           [-0.2496, -0.5348, -0.5225,  ..., -0.6137, -0.6167, -0.6128]],

          [[-0.6371, -0.6435, -0.6425,  ..., -0.5834, -0.5898, -0.5923],
           [-0.6296, -0.6410, -0.6385,  ..., -0.5794, -0.5983, -0.5958],
           [-0.6167, -0.6177, -0.6182,  ..., -0.5545, -0.5913, -0.5834],
           ...,
           [-0.4800, -0.5153, -0.4308,  ..., -0.6525, -0.6410, -0.6331],
           [-0.4104, -0.5034, -0.4318,  ..., -0.6331, -0.6226, -0.6087],
           [-0.2404, -0.5222, -0.4522,  ..., -0.6231, -0.6241, -0.6177]],

          [[-0.7068, -0.7217, -0.7101,  ..., -0.6118, -0.6178, -0.6290],
           [-0.7087, -0.7022, -0.6924,  ..., -0.6141, -0.6146, -0.6234],
           [-0.7017, -0.6998, -0.6831,  ..., -0.5927, -0.6085, -0.6104],
           ...,
           [-0.5563, -0.5480, -0.4571,  ..., -0.7106, -0.7045, -0.6933],
           [-0.4725, -0.5526, -0.4781,  ..., -0.6975, -0.6789, -0.6807],
           [-0.3117, -0.4995, -0.5000,  ..., -0.6952, -0.6835, -0.6845]],

          ...,
          ]),
 'bbox': tensor([[ 661415., 5369305.,  666535., 5374425.]], dtype=torch.float64),
 'epsg': tensor([32633], dtype=torch.int32),
 'date': ['2020-10-20'],
 'latlon': tensor([[-0.8192, -0.7854]]),
 'timestep': tensor([[-1.2217,  2.7132, -2.4086]]),
 'source_url': ['S2A_L2A_20201022T100051_N0209_R122_T33UXP_20201022T111023_06144-02560_S1B_IW_GRDH_1SDV_20201020T164222_20201020T164247_023899_02D6C4_rtc']}

Batches of dictionaries like this run through the Clay model’s encoder to generate embeddings, such as this:

embedding_ex

from batches with image bands such as:

band_red_ex

and labels:

labels_ex

These embeddings are reshaped from shape batch size * (band groups length * number of patches) * embedding size to batch size * (band groups length * embedding size) * patch height * patch width before being passed to a series of 2D convolutional transpose and ReLU layers in a downstream decoder network.

That decoder network is the core of the downstream task. In a forward pass, it ingests the embeddings, runs them through those layers and computes a loss value with respect to the labels. The loss is back-propagated and the decoder gradually finetunes itself to the downstream dataset. Here’s a peek at the decoder layers:

Model(
  (decoder): Sequential(
    (0): Conv2d(4608, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): Upsample(scale_factor=2.0, mode='nearest')
    (2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Upsample(scale_factor=2.0, mode='nearest')
    (5): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Upsample(scale_factor=2.0, mode='nearest')
    (8): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Upsample(scale_factor=2.0, mode='nearest')
    (11): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): Upsample(scale_factor=2.0, mode='nearest')
  )
)

Note the absence of an encoder. That is important as this is a finetuning architecture in which the encoder is replaced by the embeddings from the pre-trained Clay model.

In comparison, the network we are using to train the downstream task from scratch looks notably different:

Model(
  (encoder): Sequential(
    (0): Conv2d(13, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Upsample(scale_factor=2.0, mode='nearest')
    (5): Conv2d(512, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In this architecture, there is a defined encoder since the embeddings aren’t doing the purpose of encoding latent information.

For both the finetuning and “from scratch” architectures, we use a binary_cross_entropy_with_logits loss function as this is a binary segmentation problem, and on the predictions, we run sigmoid and max functions to obtain final segmentation results.

The way we measure relative performance between the finetuned and “from scratch” model variants happens through calculation of evalution metrics common for segmentation, such as Dice coefficient, Intersection over Union, F1 score, precision and recall.

Linear probing#

For linear probing, we implement the finetuned architecture in a PyTorch callback that will execute every n epochs during the Foundation model’s training.