Source code for bacpipe.model_pipelines.feature_extractors.perch_v2

from bacpipe.model_pipelines.model_specific_utils.perch_v2.perch_hoplite.zoo.model_configs import load_model_by_name
import tensorflow as tf
import numpy as np
import pandas as pd
import logging
logger = logging.getLogger("bacpipe")

tf.keras.backend.clear_session()

from ..model_utils import ModelBaseClass

SAMPLE_RATE = 32000
LENGTH_IN_SAMPLES = 160000


[docs] class Model(ModelBaseClass): def __init__( self, model_choice="perch_v2_cpu", sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs ): super().__init__(sr=sr, segment_length=segment_length, **kwargs) if model_choice == 'vggish': self.bool_classifier = False if self.device == 'cuda' and model_choice.startswith('perch_v2'): if len(tf.config.list_physical_devices("GPU")) > 0: model_choice = 'perch_v2' mod = load_model_by_name(model_choice) self.model = mod.embed if not hasattr(self, 'class_label_key'): self.class_label_key = "labels" if model_choice in ['vggish']: return elif not model_choice in ["multispecies_whale"]: self.class_list = mod.class_list[self.class_label_key].classes self.ebird2name = pd.read_csv( self.model_utils_base_path / "perch_v2/perch_hoplite/eBird2name.csv" ) self.classes = self.class_list self.classes = [ ( self.ebird2name["English name"][ self.ebird2name.species_code == cls ].iloc[0] if cls in self.ebird2name.species_code.values else cls ) for cls in self.classes ] else: self.class_list = mod.class_list if model_choice.startswith('perch_v2'): self.class_label_key = 'label'
[docs] def preprocess(self, audio): audio = audio.cpu() return tf.convert_to_tensor(audio, dtype=tf.float32)
def __call__(self, input): try: self.results = self.model(input) except Exception as e: logger.exception( "You are on a operating system that does not currently support this model. " "Perch V2 is currently only supported on linux or other machines supporting " "XLA deserialization. ", e ) import sys sys.exit(0) return self.results.embeddings
[docs] def classifier_predictions(self, embeddings): inferece_results = self.results.logits[self.class_label_key] return tf.nn.sigmoid(inferece_results).numpy()