Source code for bacpipe.model_pipelines.feature_extractors.audiomae

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------


import torch
import torch.nn as nn
from tqdm import tqdm
import pathlib

from timm.models.layers import trunc_normal_
from timm.models.layers import to_2tuple


import bacpipe.model_pipelines.model_specific_utils.audiomae.models_vit as models_vit
from bacpipe.model_pipelines.model_specific_utils.audiomae.dataset import AudiosetDataset
from ..model_utils import ModelBaseClass

BATCH_SIZE = 8  # important to lower this if run on laptop cpu

SAMPLE_RATE = 16000
LENGTH_IN_SAMPLES = int(10 * SAMPLE_RATE)



[docs] class PatchEmbed_new(nn.Module): """Flexible Image to Patch Embedding""" def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) self.img_size = img_size self.patch_size = patch_size self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=stride ) # with overlapped patches # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w self.patch_hw = (h, w) self.num_patches = h * w
[docs] def get_output_shape(self, img_size): # todo: don't be lazy.. return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
[docs] def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) x = x.flatten(2).transpose(1, 2) return x
[docs] class Model(ModelBaseClass): def __init__(self, **kwargs): super().__init__(sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs) self.nb_classes = 527 self.model = "vit_base_patch16" self.model_path = self.model_base_path / "audiomae/finetuned.pth" self.global_pool = True self.drop_path = 0.1 self.dataset = "audioset" self.mask_2d = True self.norm_stats = [-4.2677393, 4.5689974] target_length = 1024 self.img_size = (target_length, 128) # 1024, 128 self.in_chans = 1 self.emb_dim = 768 self.audio_conf_val = { "num_mel_bins": 128, "target_length": target_length, "freqm": 0, "timem": 0, "mixup": 0, "dataset": self.dataset, "mode": "val", "mean": self.norm_stats[0], "std": self.norm_stats[1], "noise": False, } self.model = models_vit.__dict__[self.model]( num_classes=self.nb_classes, drop_path_rate=self.drop_path, global_pool=self.global_pool, mask_2d=self.mask_2d, use_custom_patch=False, ) self.model.patch_embed = PatchEmbed_new( img_size=self.img_size, patch_size=(16, 16), in_chans=1, embed_dim=self.emb_dim, stride=16, ) # no overlap. stride=img_size=16 num_patches = self.model.patch_embed.num_patches # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 self.model.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, self.emb_dim), requires_grad=False ) # fixed sin-cos embedding if isinstance(self.model_path, pathlib.WindowsPath): try: # Save original PosixPath original_posix_path = pathlib.PosixPath # patch PosixPath to return str or WindowsPath pathlib.PosixPath = pathlib.WindowsPath checkpoint = torch.load( self.model_path, map_location=self.device, weights_only=False ) finally: # Restore original PosixPath to avoid side effects pathlib.PosixPath = original_posix_path else: checkpoint = torch.load( self.model_path, map_location=self.device, weights_only=False ) checkpoint_model = checkpoint["model"] # load pre-trained model self.model.load_state_dict(checkpoint_model) # manually initialize fc layer trunc_normal_(self.model.head.weight, std=2e-5) self.audio_obj = AudiosetDataset(sr=SAMPLE_RATE, audio_conf=self.audio_conf_val)
[docs] def preprocess(self, audio): processed_frame = [] for frame in audio: processed_frame.append(self.audio_obj.process(frame.view(1, -1))) processed_frame = torch.stack(processed_frame) return processed_frame.unsqueeze(dim=1)
@torch.inference_mode() def __call__(self, input): return self.model(input)