Tutorial BYOL: aprendizaje autosupervisado sobre imágenes CIFAR con código en Pytorch

Después de presentar SimCLR, un contrastivo marco de aprendizaje autosupervisado, decidí demostrar otro método infame, llamado BYOL. Bootstrap Your Own Latent (BYOL), es un nuevo algoritmo para el aprendizaje autosupervisado de representaciones de imágenes. BYOL tiene dos ventajas principales:

  • No utiliza explícitamente muestras negativas. En cambio, minimiza directamente la similitud de las representaciones de la misma imagen bajo una vista aumentada diferente (par positivo). Las muestras negativas son imágenes del lote distintas del par positivo.

  • Como resultado, BYOL Se afirma que requiere tamaños de lote más pequeños, lo que lo convierte en una opción atractiva.

A continuación, puede examinar el método. A diferencia del documento original, llamo al estudiante de la red en línea y al maestro de la red de destino.


descripción general de byol


Descripción general del método BYOL. Fuente: documento BYOL

Red en línea, también conocida como estudiante: en comparación con SimCLR, hay un segundo MLP, llamado vaticinador, lo que hace que todo el método sea asimétrico. ¿Asimétrica en comparación con qué? Bueno, al modelo docente (red objetivo).

¿Por qué es eso importante?

Porque el modelo de profesor está actualizado. solamente a través de la media móvil exponencial (EMA) de los parámetros del alumno. En última instancia, en cada iteración, se pasa al profesor un pequeño porcentaje (menos del 1%) de los parámetros del estudiante. De este modo, los gradientes fluyen solo a través de la red de estudiantes. Esto se puede implementar como:

class EMA():

def __init__(self, alpha):

super().__init__()

self.alpha = alpha

def update_average(self, old, new):

if old is None:

return new

return old * self.alpha + (1 - self.alpha) * new

ema = EMA(0.99)

for student_params, teacher_params in zip(student_model.parameters(),teacher_model.parameters()):

old_weight, up_weight = teacher_params.data, student_params.data

teacher_params.data = ema.update_average(old_weight, up_weight)

Otra diferencia clave entre Simclr y BYOL es la función de pérdida.

Función de pérdida

los predictor MLP es solamente aplicado al estudiante, haciendo que la arquitectura asimétrico. Esta es una elección de diseño clave para evitar el colapso del modo. El colapso del modo aquí sería generar la misma proyección para todas las entradas.


byol-paper-resumen-con-tensores


Descripción general del método BYOL. Fuente: documento BYOL

Finalmente, los autores definieron el siguiente error cuadrático medio entre las predicciones normalizadas L2 y las proyecciones objetivo:

Lθ,ξqˉθ(zθ)zˉξ22=22qθ(zθ),zξqθ(zθ)2zξ2.mathcal_ triangleqleft|bar_left(z_right)-bar_^right|_^=2-2 cdot fracz_^right .

La pérdida L2 se puede implementar de la siguiente manera. La normalización L2 se aplica de antemano.

import torch

import torch.nn.functional as F

def loss_fn(x, y):

x = F.normalize(x, dim=-1, p=2)

y = F.normalize(y, dim=-1, p=2)

return 2 - 2 * (x * y).sum(dim=-1)

El código está disponible en GitHub

Seguimiento de lo que sucede en el preentrenamiento autosupervisado: precisión de KNN

No obstante, la pérdida en el aprendizaje autosupervisado no es una métrica confiable para rastrear. Lo que descubrí que es la mejor manera de rastrear lo que sucede durante el entrenamiento es medir la precisión de ΚΝΝ.

La ventaja crítica de usar KNN es que no tenemos que entrenar un clasificador lineal en la parte superior cada vez, por lo que es más rápido y completamente sin supervisión.

Nota: la medición de KNN solo se aplica a la clasificación de imágenes, pero se hace una idea. Para este propósito, hice una clase para encapsular la lógica de KNN en nuestro contexto:

import numpy as np

import torch

from sklearn.model_selection import cross_val_score

from sklearn.neighbors import KNeighborsClassifier

from torch import nn

class KNN():

def __init__(self, model, k, device):

super(KNN, self).__init__()

self.k = k

self.device = device

self.model = model.to(device)

self.model.eval()

def extract_features(self, loader):

"""

Infer/Extract features from a trained model

Args:

loader: train or test loader

Returns: 3 tensors of all: input_images, features, labels

"""

x_lst = []

features = []

label_lst = []

with torch.no_grad():

for input_tensor, label in loader:

h = self.model(input_tensor.to(self.device))

features.append(h)

x_lst.append(input_tensor)

label_lst.append(label)

x_total = torch.stack(x_lst)

h_total = torch.stack(features)

label_total = torch.stack(label_lst)

return x_total, h_total, label_total

def knn(self, features, labels, k=1):

"""

Evaluating knn accuracy in feature space.

Calculates only top-1 accuracy (returns 0 for top-5)

Args:

features: [... , dataset_size, feat_dim]

labels: [... , dataset_size]

k: nearest neighbours

Returns: train accuracy, or train and test acc

"""

feature_dim = features.shape[-1]

with torch.no_grad():

features_np = features.cpu().view(-1, feature_dim).numpy()

labels_np = labels.cpu().view(-1).numpy()

self.cls = KNeighborsClassifier(k, metric="cosine").fit(features_np, labels_np)

acc = self.eval(features, labels)

return acc

def eval(self, features, labels):

feature_dim = features.shape[-1]

features = features.cpu().view(-1, feature_dim).numpy()

labels = labels.cpu().view(-1).numpy()

acc = 100 * np.mean(cross_val_score(self.cls, features, labels))

return acc

def _find_best_indices(self, h_query, h_ref):

h_query = h_query / h_query.norm(dim=1).view(-1, 1)

h_ref = h_ref / h_ref.norm(dim=1).view(-1, 1)

scores = torch.matmul(h_query, h_ref.t())

score, indices = scores.topk(1, dim=1)

return score, indices

def fit(self, train_loader, test_loader=None):

with torch.no_grad():

x_train, h_train, l_train = self.extract_features(train_loader)

train_acc = self.knn(h_train, l_train, k=self.k)

if test_loader is not None:

x_test, h_test, l_test = self.extract_features(test_loader)

test_acc = self.eval(h_test, l_test)

return train_acc, test_acc

Ahora podemos centrarnos en el método y el modelo BYOL.

Modificar resnet: añadir cabezales de proyección MLP

Comenzaremos con un modelo base (resnet18) y lo modificaremos para el aprendizaje autosupervisado. La última capa que normalmente hace la clasificación se reemplaza con una función de identidad. Las funciones de salida de resnet18 se enviarán al proyector MLP.

import copy

import torch

from torch import nn

import torch.nn.functional as F

class MLP(nn.Module):

def __init__(self, dim, embedding_size=256, hidden_size=2048, batch_norm_mlp=False):

super().__init__()

norm = nn.BatchNorm1d(hidden_size) if batch_norm_mlp else nn.Identity()

self.net = nn.Sequential(

nn.Linear(dim, hidden_size),

norm,

nn.ReLU(inplace=True),

nn.Linear(hidden_size, embedding_size)

)

def forward(self, x):

return self.net(x)

class AddProjHead(nn.Module):

def __init__(self, model, in_features, layer_name, hidden_size=4096,

embedding_size=256, batch_norm_mlp=True):

super(AddProjHead, self).__init__()

self.backbone = model

setattr(self.backbone, layer_name, nn.Identity())

self.backbone.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

self.backbone.maxpool = torch.nn.Identity()

self.projection = MLP(in_features, embedding_size, hidden_size=hidden_size, batch_norm_mlp=batch_norm_mlp)

def forward(self, x, return_embedding=False):

embedding = self.backbone(x)

if return_embedding:

return embedding

return self.projection(embedding)

También reemplacé la primera capa conv de resnet18 de convolución 7×7 a 3×3 ya que estamos jugando con imágenes 32×32 (CIFAR-10).

El código está disponible en GitHub. Si planea solidificar su conocimiento de Pytorch, hay dos libros increíbles que recomendamos encarecidamente: Aprendizaje profundo con PyTorch de Publicaciones Manning y Aprendizaje automático con PyTorch y Scikit-Learn de Sebastián Raschka. Siempre puedes usar el código de descuento del 35% blaissummer21 para todos los productos de Manning.

El método BYOL real

Hasta ahora presenté todos los componentes importantes para llegar a este punto. Ahora construiremos el BYOL módulo con nuestras queridas redes de estudiantes y docentes. Observe que el predictor de estudiantes MLP y el proyector son idénticos.

Mi implementación de BYOL se basó en lucidrains’ repositorio. Lo modifiqué para hacerlo más simple y jugar con él.

class BYOL(nn.Module):

def __init__(

self,

net,

batch_norm_mlp=True,

layer_name='fc',

in_features=512,

projection_size=256,

projection_hidden_size=2048,

moving_average_decay=0.99,

use_momentum=True):

"""

Args:

net: model to be trained

batch_norm_mlp: whether to use batchnorm1d in the mlp predictor and projector

in_features: the number features that are produced by the backbone net i.e. resnet

projection_size: the size of the output vector of the two identical MLPs

projection_hidden_size: the size of the hidden vector of the two identical MLPs

augment_fn2: apply different augmentation the second view

moving_average_decay: t hyperparameter to control the influence in the target network weight update

use_momentum: whether to update the target network

"""

super().__init__()

self.net = net

self.student_model = AddProjHead(model=net, in_features=in_features,

layer_name=layer_name,

embedding_size=projection_size,

hidden_size=projection_hidden_size,

batch_norm_mlp=batch_norm_mlp)

self.use_momentum = use_momentum

self.teacher_model = self._get_teacher()

self.target_ema_updater = EMA(moving_average_decay)

self.student_predictor = MLP(projection_size, projection_size, projection_hidden_size)

@torch.no_grad()

def _get_teacher(self):

return copy.deepcopy(self.student_model)

@torch.no_grad()

def update_moving_average(self):

assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum '

'for the target encoder '

assert self.teacher_model is not None, 'target encoder has not been created yet'

for student_params, teacher_params in zip(self.student_model.parameters(), self.teacher_model.parameters()):

old_weight, up_weight = teacher_params.data, student_params.data

teacher_params.data = self.target_ema_updater.update_average(old_weight, up_weight)

def forward(

self,

image_one, image_two=None,

return_embedding=False):

if return_embedding or (image_two is None):

return self.student_model(image_one, return_embedding=True)

student_proj_one = self.student_model(image_one)

student_proj_two = self.student_model(image_two)

student_pred_one = self.student_predictor(student_proj_one)

student_pred_two = self.student_predictor(student_proj_two)

with torch.no_grad():

teacher_proj_one = self.teacher_model(image_one).detach_()

teacher_proj_two = self.teacher_model(image_two).detach_()

loss_one = loss_fn(student_pred_one, teacher_proj_one)

loss_two = loss_fn(student_pred_two, teacher_proj_two)

return (loss_one + loss_two).mean()

Para CIFAR-10 es suficiente usar 2048 como dimensión oculta y 256 como dimensión incrustada. Entrenaremos un resnet18 que genere 512 funciones para 100 épocas. Las partes del código que se refieren a la carga de datos y los aumentos se omiten para aumentar la legibilidad. Puedes buscarlos en el código.

Puedes usar el optimizador de Adam (

yor=3104lr=3 * 10^

por supuesto) o LARS con

yor=0.1lr=0.1

. Los resultados informados son con Adam, pero también validé que KNN aumenta en las primeras épocas con LARS.

Lo único que se cambiará en el código del tren es la actualización de EMA.

def training_step(model, data):

(view1, view2), _ = data

loss = model(view1.cuda(), view2.cuda())

return loss

def train_one_epoch(model, train_dataloader, optimizer):

model.train()

total_loss = 0.

num_batches = len(train_dataloader)

for data in train_dataloader:

optimizer.zero_grad()

loss = training_step(model, data)

loss.backward()

optimizer.step()

model.update_moving_average()

total_loss += loss.item()

return total_loss/num_batches

¡Vamos a saltar a los resultados!

Resultados: Precisión KNN VS épocas previas al entrenamiento


knn-byol-entrenamiento


Precisión KNN cada 4 épocas. Imagen por autor

¿No es sorprendente que sin ninguna etiqueta podamos alcanzar una precisión de validación del 70 %? Encontré esto sorprendente, especialmente para este método que parece ser menos sensible al tamaño del lote.

Pero, ¿por qué el tamaño del lote tiene un efecto aquí? ¿No se supone que no debe usar parís negativo? ¿De dónde viene la dependencia del tamaño del lote?

Respuesta corta: ¡Bueno, es la normalización por lotes en las capas MLP!

Aquí están los experimentos que hice para verificarlo.

Una nota sobre la norma de lotes en las redes MLP y el impulso de EMA

Tenía curiosidad por observar el colapso del modo sin normalización por lotes. Puede intentarlo usted mismo configurando:

model = BYOL(model, in_features=512, batch_norm_mlp=False)

Observé que la distancia L2 llega a casi cero desde las primeras épocas:

Epoch 0: loss:0.06423207696957084

Epoch 8: loss:0.005584242034894534

Epoch 20: loss:0.005460431350347323

La pérdida llega aproximadamente a cero y KNN deja de aumentar (35 % VS 60 % en la configuración normal). Es por eso que se afirma que BYOL implícitamente utiliza una forma de aprendizaje contrastivo al aprovechar las estadísticas de lotes en los MLP. Aquí está la precisión de KNN:


mode-collapse-byol-no-batch-norma


Colapso de modo en BYOL al eliminar la norma por lotes en MLP. Imagen por autor

Soy muy consciente de los documentos que muestran que las estadísticas de lotes no son la única condición para que BYOL funcione. Esta es una publicación experimental, así que no voy a jugar ese juego. Tenía curiosidad por observar el colapso del modo aquí.

Conclusión

Para obtener una explicación más detallada del método, consulte el video de Yannic sobre BYOL:

En este tutorial, implementamos BYOL paso a paso y nos entrenamos previamente en CIFAR10. Observamos el aumento masivo en la precisión de KNN al hacer coincidir las representaciones de la misma imagen. Un clasificador aleatorio tendría un 10 % y con 100 épocas alcanzamos una precisión de validación KNN del 70 % sin ninguna etiqueta. ¿Cuan genial es eso?

Para obtener más información sobre el aprendizaje autosupervisado, ¡estén atentos! Apóyenos compartiendo en las redes sociales, haciendo una donación o comprando nuestro libro Aprendizaje profundo en producción. Sería muy apreciado.

Libro Aprendizaje Profundo en Producción 📖

Aprenda a crear, entrenar, implementar, escalar y mantener modelos de aprendizaje profundo. Comprenda la infraestructura de ML y MLOps con ejemplos prácticos.

Aprende más

* Divulgación: tenga en cuenta que algunos de los enlaces anteriores pueden ser enlaces de afiliados y, sin costo adicional para usted, ganaremos una comisión si decide realizar una compra después de hacer clic.

Fuente del artículo

¿Que te ha parecido?

Deja un comentario