Source code for bacpipe.model_pipelines.feature_extractors.insect66

import torch
from types import SimpleNamespace
import yaml
import torch
import timm
import torch.nn as nn
import torchaudio as ta
from ..model_utils import ModelBaseClass

SAMPLE_RATE = 44100
LENGTH_IN_SAMPLES = int(5.5 * SAMPLE_RATE)


[docs] class SpectrogramCNN(nn.Module):
[docs] def __init__(self, cfg, init_backbone=True): """ Pytorch network class containing the transformation from waveform to mel spectrogram, as well as the forward pass through a CNN backbone. Data augmentation like mixup or masked frequency or time can also be applied here. Parameters ---------- cfg: SimpleNameSpace containing all configurations init_backbone: bool (Default=True). Whether to download and initialize the backbone. Not always necessary when debugging. """ super(SpectrogramCNN, self).__init__() self.cfg = cfg # for k, v in self.cfg.items(): # setattr(self.cfg, k, v) self.n_classes = cfg.n_classes # Initializes the transformation from waveform to mel spectrogram self.mel_spec = ta.transforms.MelSpectrogram( sample_rate=cfg.sample_rate, n_fft=cfg.n_fft, win_length=cfg.win_length, hop_length=cfg.hop_length, f_min=cfg.fmin, f_max=cfg.fmax, n_mels=cfg.n_mels, power=cfg.power, normalized=cfg.mel_normalized, ) self.amplitude_to_db = ta.transforms.AmplitudeToDB(top_db=cfg.top_db) self.wav2timefreq = torch.nn.Sequential(self.mel_spec, self.amplitude_to_db) if init_backbone: # Initialize pre-trained CNN # Input and output layers are automatically adjusted self.backbone = timm.create_model( cfg.backbone, pretrained=cfg.pretrained, num_classes=cfg.n_classes, in_chans=cfg.in_chans, )
[docs] class Model(ModelBaseClass): def __init__(self, **kwargs): super().__init__(sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs) with open(f"{self.model_base_path}/insect66/config_insecteffnet.yaml", "rt") as infp: cfg = SimpleNamespace(**yaml.safe_load(infp)) checkpoint = torch.jit.load(f"{self.model_base_path}/insect66/model_traced.pt") state_dict = checkpoint.state_dict() for k in ["wav2img.0.spectrogram.window", "wav2img.0.mel_scale.fb"]: state_dict[k.replace("wav2img", "wav2timefreq")] = state_dict.pop(k) self.model = SpectrogramCNN(cfg) self.model.load_state_dict(state_dict)
[docs] def preprocess(self, audio): audio = audio[:, None, :] # (bs, channel, mel, time) return self.model.wav2timefreq(audio)
@torch.inference_mode() def __call__(self, input): self.model.block_features = self.model.backbone.blocks( self.model.backbone.bn1(self.model.backbone.conv_stem(input)) ) self.model.embeddings = self.model.backbone.global_pool( self.model.backbone.bn2( self.model.backbone.conv_head(self.model.block_features) ) ) return self.model.embeddings