Source code for bacpipe.model_pipelines.feature_extractors.bat

import torch
import numpy as np
import librosa

from ..model_specific_utils.bat.module import BAT
from ..model_specific_utils.bat.prepare_data import prepareData, getSequences, slideWindow, germanBats
from ..model_utils import ModelBaseClass

IS_EXPANDED = False
if IS_EXPANDED:
    SAMPLE_RATE = 22050
    LENGTH_IN_SAMPLES = int(0.78 * SAMPLE_RATE * 10)
else:
    SAMPLE_RATE = 22050 * 10    # time expand
    LENGTH_IN_SAMPLES = int(0.78 * SAMPLE_RATE)


[docs] class Model(ModelBaseClass): def __init__(self, threshold=0.5, **kwargs): super().__init__(sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs) self.threshold = threshold self.classes = list(germanBats) self.model = BAT( max_len=60, patch_dim=44 * 257, d_model=64, num_classes=len(self.classes), nhead=2, dim_feedforward=32, num_layers=2, seq=False, ) state_dict = torch.load( self.model_base_path / "bat/bat_2_convnet_mixed.pth", map_location="cpu", weights_only=True, ) self.model.load_state_dict(state_dict) self.model.eval()
[docs] def preprocess(self, audio: torch.Tensor): audio = audio.to('cpu') b_y = audio.numpy() # b n input_seq = [] for y in b_y: # Spectrogram D = librosa.stft(y, n_fft=512) S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) # H, W # Custom filtering + denoising S_db = prepareData(y) # Sequence extraction sequence = np.asarray(slideWindow(S_db, size=44, step=22)[:-1]) n, w, h = sequence.shape input_seq.append(torch.tensor(sequence, dtype=torch.float32).reshape(n * w, h)) return torch.stack(input_seq, dim=0)
[docs] def classifier_predictions(self, cls_token): with torch.no_grad(): logits = self.model.classifier(cls_token) return torch.sigmoid(logits)
def __call__(self, x, return_class_results=False): with torch.no_grad(): cls_token = self.model(x, return_token=True) if not return_class_results: return cls_token return cls_token, self.classifier_predictions(cls_token)