Source code for bacpipe.model_pipelines.feature_extractors.beats

import torch

from bacpipe.model_pipelines.model_specific_utils.naturebeats.BEATs import BEATs, BEATsConfig
from ..model_utils import ModelBaseClass


SAMPLE_RATE = 16_000
LENGTH_IN_SAMPLES = int(5 * SAMPLE_RATE)


BEATS_PRETRAINED_PATH_FT = "beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt"


[docs] class BeatsModel: def __init__(self, checkpoint_path): # load the fine-tuned checkpoints checkpoint = torch.load(checkpoint_path) cfg = BEATsConfig(checkpoint["cfg"]) self.model = BEATs(cfg) self.model.load_state_dict(checkpoint["model"]) self.model.eval() self.avg_pooling = True # disable classifier self.model.predictor = None self.process_audio_beats = self.model.preprocess
[docs] def get_embeddings(self, spectrogram_input): """ Taken from the BEATS forward call. Adapted to work based on the spectrogram input to enable visualization of spectrograms for model result interpretation. Parameters ---------- spectrogram_input : torch.Tensor batched spectrograms from self.model.preprocess Returns ------- torch.Tensor batched embeddings """ spectrogram_input = spectrogram_input.unsqueeze(1) features = self.model.patch_embedding(spectrogram_input) features = features.reshape(features.shape[0], features.shape[1], -1) features = features.transpose(1, 2) features = self.model.layer_norm(features) if self.model.post_extract_proj is not None: features = self.model.post_extract_proj(features) x = self.model.dropout_input(features) x, _ = self.model.encoder( x, padding_mask=None, ) if self.avg_pooling: x = x.mean(dim=1) return x
[docs] class Model(ModelBaseClass): def __init__(self, **kwargs): super().__init__(sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs) self.model = BeatsModel( checkpoint_path=self.model_base_path / BEATS_PRETRAINED_PATH_FT ) self.model.model.eval() self.model.model.to(self.device)
[docs] def preprocess(self, audio): return self.model.process_audio_beats(audio)
def __call__(self, x): return self.model.get_embeddings(x)