Source code for bacpipe.model_pipelines.feature_extractors.mix2

from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from bacpipe.model_pipelines.model_specific_utils.mix2.mobile_net_v3 import mobilenetv3, MinMaxNorm
import torch

SAMPLE_RATE = 16000
LENGTH_IN_SAMPLES = int(3 * SAMPLE_RATE)

from ..model_utils import ModelBaseClass


[docs] class Model(ModelBaseClass): def __init__(self, **kwargs): super().__init__(sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs) self.model = mobilenetv3() dict = torch.load( self.model_base_path / "mix2/mix2.pth", map_location=self.device, weights_only=True, ) self.model.load_state_dict(dict["encoder"]) self.mel = MelSpectrogram(n_fft=512, hop_length=128, n_mels=128) self.ampl2db = AmplitudeToDB() self.min_max_norm = MinMaxNorm()
[docs] def preprocess(self, audio): audio = audio.cpu() audio = self.mel(audio) audio = self.ampl2db(audio) audio = self.min_max_norm(audio) return audio.unsqueeze(dim=1)
def __call__(self, x): return self.model(x)