Source code for bacpipe.model_pipelines.feature_extractors.aves_especies

from torchaudio.models import wav2vec2_model
import json
import torch
import torch.nn as nn

# extract_feature in the torchaudio version will output all 12 layers' output, -1 to select the final one
from ..model_utils import ModelBaseClass

BATCH_SIZE = 1  # necessary due to padding problem, experiment with this

SAMPLE_RATE = 16000
LENGTH_IN_SAMPLES = 16000

# paper: https://arxiv.org/abs/2210.14493


[docs] class Model(ModelBaseClass, nn.Module): def __init__(self, birdaves=False, nonbioaves=False, **kwargs): super().__init__(sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs) nn.Module.__init__(self) # reference: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2vec2/utils/import_fairseq.html base_path = self.model_base_path if birdaves: model_config_path = f"{base_path}/birdaves_especies/birdaves-biox-large.torchaudio.model_config.json" model_path = ( f"{base_path}/birdaves_especies/birdaves-biox-large.torchaudio.pt" ) elif nonbioaves: model_config_path = f"{base_path}/nonbioaves_especies/aves-base-nonbio.torchaudio.model_config.json" model_path = ( f"{base_path}/nonbioaves_especies/aves-base-nonbio.torchaudio.pt" ) else: model_config_path = ( f"{base_path}/aves_especies/aves-base-bio.torchaudio.model_config.json" ) model_path = f"{base_path}/aves_especies/aves-base-bio.torchaudio.pt" model_config = json.load(open(model_config_path, "r")) self.model = wav2vec2_model(**model_config, aux_num_out=None) self.model.load_state_dict(torch.load(model_path, weights_only=True)) self.model.feature_extractor.requires_grad_(False) self.model.eval()
[docs] def preprocess(self, audio): return audio
@torch.inference_mode() def __call__(self, input): embeds = [] for batch in input.split(BATCH_SIZE): out_raw = self.model.extract_features(batch)[0] # get final layer output out_raw = torch.stack(out_raw)[-1] # mean pooling out = out_raw.mean(axis=1) embeds.append(out) return torch.cat(embeds)
if __name__ == "__main__": torchaudio_model = Model("mean") torchaudio_model.eval() waveform = torch.rand((16_000)) x = waveform.unsqueeze(0) a = torchaudio_model(x)