Source code for bacpipe.model_pipelines.feature_extractors.batdetect2_clip_avg

from functools import partial

import numpy as np
import torch

from ..model_utils import ModelBaseClass

SAMPLE_RATE = 256_000
LENGTH_IN_SAMPLES = 256_000
DEFAULT_SEGMENT_DURATION = 1
NUM_FEATURES = 32
NUM_CLASSES = 17


[docs] class Model(ModelBaseClass): def __init__( self, segment_duration=DEFAULT_SEGMENT_DURATION, **kwargs, ): super().__init__( sr=SAMPLE_RATE, segment_length=int(segment_duration * SAMPLE_RATE), **kwargs, ) from batdetect2 import api self.config = api.get_config() self.model, _ = api.load_model(device=self.device) # type: ignore self.generate_spectrogram = partial( api.generate_spectrogram, config=self.config, samp_rate=SAMPLE_RATE, device=self.device, ) self.classes = self.config["class_names"]
[docs] def preprocess(self, audio): if audio.device.type == 'cuda': segments = audio.cpu().numpy() else: segments = audio.numpy() # NOTE: Need to pre-process each segment separately spectrograms = torch.stack( [self.generate_spectrogram(segment) for segment in segments] ) if len(spectrograms.shape) > 4: spectrograms = spectrograms.squeeze(1) return spectrograms
@torch.no_grad() def __call__(self, x): self.output = self.model(x) features = self.output.features.mean(dim=(-2, -1)) return features
[docs] def classifier_predictions(self, embeddings): # NOTE: Last element is the background class class_scores = self.output.pred_class.amax(dim=(-2, -1))[:, :-1] return class_scores