Source code for bacpipe.embedding_evaluation.probing.train_probe

import torch
import torch.nn as nn

from sklearn.neighbors import KNeighborsClassifier
from .dataset_probe import probe_dataset_loader
import bacpipe

import logging

logger = logging.getLogger("bacpipe")


[docs] class LinearProbe(nn.Module):
[docs] def __init__(self, in_dim, out_dim, device='cpu', **kwargs): """ Linear classification layer. Parameters ---------- in_dim : int number of input dimensions (dictated by embeddings) out_dim : int number of output dimensions (dictated by classes in ground truth) """ super(LinearProbe, self).__init__() self.probe = nn.Linear(in_dim, out_dim) self.probe.to(device)
[docs] def forward(self, x): return self.probe(x)
[docs] def train_linear_probe( linear_classifier, train_dataloader, learning_rate, num_epochs, device="cuda:0", **kwargs, ): """ Linear classification training pipeline. Hyperparameters are specified in settings.yaml file and passed to this function. Parameters ---------- linear_classifier : object classification object train_dataloader : DataLoader object dataset loader to iterate over learning_rate : float learning rate num_epochs : int number of epochs for training device : str, optional 'cpu' or 'cuda', by default "cuda:0" Returns ------- object trained linear classification object """ device = torch.device(device) try: linear_classifier = linear_classifier.to(device) except RuntimeError: logger.error('Traceback', exc_info=True) logger.info( "This problem is likely caused by tensorflow hogging all the gpu vram. " "The best fix for this is to simply restart bacpipe with the same settings, " "that way the GPU should be available for pytorch. Alternatively select " "`cpu` for device in the settings.yaml file." ) import sys sys.exit(0) # Define optimizer and loss function optimizer = torch.optim.Adam(linear_classifier.parameters(), lr=learning_rate) criterion = nn.CrossEntropyLoss() # Training loop for epoch in range(num_epochs): linear_classifier.train() logger.info(f"Epoch {epoch+1}/{num_epochs}") running_loss = 0.0 correct_train = 0 total_train = 0 for embeddings, y in train_dataloader: embeddings, y = embeddings.to(device), y.to(device) # Forward pass through linear classifier outputs = linear_classifier(embeddings) # Compute loss loss = criterion(outputs, y) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() # Track training loss and accuracy running_loss += loss.item() * embeddings.size(0) _, predicted = torch.max(outputs, 1) total_train += embeddings.size(0) correct_train += (predicted == y).sum().item() train_loss = running_loss / len(train_dataloader.dataset) train_accuracy = 100 * correct_train / total_train # logger.info(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss}, Accuracy: {train_accuracy}") logger.info( f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%" ) return linear_classifier
[docs] class KNNProbe(nn.Module):
[docs] def __init__(self, n_neighbors=15, testing=False, **kwargs): """ K-nearest neighbor classifier. Parameters ---------- n_neighbors : int, optional hyperparameter specified in settings.yaml file, by default 15 """ super(KNNProbe, self).__init__() self.knn = KNeighborsClassifier(n_neighbors=n_neighbors) self.is_trained = False # Flag to track if KNN is trained
[docs] def fit(self, x, y): """Train KNN classifier with numpy data""" x_np = x.cpu().detach().numpy() # Convert tensor to NumPy y_np = y.cpu().detach().numpy() self.knn.fit(x_np, y_np) self.is_trained = True
[docs] def forward(self, x): """Predict using KNN (only after it's trained)""" if not self.is_trained: error = ("\nKNN model is not trained. Call `fit()` first.") logger.exception(error) raise ValueError(error) x_np = x.cpu().detach().numpy() preds = self.knn.predict(x_np) # Predict labels probs = self.knn.predict_proba(x_np) # Predict probabilities preds_tensor = torch.tensor(preds, dtype=torch.long, device=x.device) probs_tensor = torch.tensor(probs, dtype=torch.float32, device=x.device) return preds_tensor, probs_tensor
[docs] def train_knn_probe(knn_classifier, train_dataloader, device="cpu", **kwargs): """ Pipeline for knn classifier training. Parameters ---------- knn_classifier : object classifier object train_dataloader : DataLoader object iterator for dataset device : str, optional 'cpu' or 'cuda', by default "cpu" Returns ------- object classifier object """ device = torch.device(device) knn_classifier.to(device) all_embeddings = [] all_labels = [] # Collect all embeddings and labels to train KNN for embeddings, y in train_dataloader: embeddings, y = embeddings.to(device), y.to(device) all_embeddings.append(embeddings) all_labels.append(y) all_embeddings = torch.cat(all_embeddings, dim=0) all_labels = torch.cat(all_labels, dim=0) # Train KNN knn_classifier.fit(all_embeddings, all_labels) logger.info("KNN Training Complete!") return knn_classifier
[docs] def train_probe( embeds, df, label2index, config="linear", learning_rate=None, num_epochs=None, n_neighbors=None, **kwargs): """ Classification pipeline. First the classification dataframe is loaded, then a dict is created to link labels to ints, then the dataset loaders are created to iterate over. Next depending of the specified config a linear or KNN classification is performed. Finally the classifiers are used for inference and based on that performance metrics are created. Parameters ---------- paths : SimpleNamespace dict dictionary object containing paths for loading and saving dataset_csv_path : string name of classification dataframe as secified in the settings.yaml file embeds : np.array the embeddings config : str, optional type of classification, by default 'linear' Returns ------- dict performance dictionary """ # generate the loaders train_gen = probe_dataset_loader("train", df, embeds, label2index, **kwargs) embed_size = embeds[0].shape[-1] if config == "linear": if learning_rate is None: learning_rate = bacpipe.settings.probe_configs['config_1']['learning_rate'] if num_epochs is None: num_epochs = bacpipe.settings.probe_configs['config_1']['num_epochs'] probe = LinearProbe( in_dim=embed_size, out_dim=len(df.label.unique()), **kwargs ) probe = train_linear_probe( probe, train_gen, learning_rate=learning_rate, num_epochs=num_epochs, **kwargs ) elif config == "knn": if n_neighbors is None: n_neighbors = bacpipe.settings.probe_configs['config_2']['n_neighbors'] if len(df[df.predefined_set =='test']) < n_neighbors: kwargs['n_neighbors'] = len(df[df.predefined_set =='test']) - 1 probe = KNNProbe(**kwargs) probe = train_knn_probe(probe, train_gen, **kwargs) return probe