Source code for bacpipe.model_pipelines.feature_extractors.google_whale

import numpy as np

from .perch_v2 import Model

SAMPLE_RATE = 24_000
LENGTH_IN_SAMPLES = 50_000

[docs] class Model(Model): def __init__(self, **kwargs): super().__init__( sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, model_choice="multispecies_whale", **kwargs ) self.abbrev2label = { "Mn": "Humpback", "Oo": "Orca", "Be": "Bryde's", "Ba": "Minke", "Bm": "Blue", "Bp": "Fin", "Eg": "Right (Atlantic)", "Upcall": "Right (Pacific, upcall)", "Gunshot": "Right (Pacific, gunshot)", "Echolocation": "Orca echolocation", "Whistle": "Orca whistle", "Call": "Orca call", } self.class_label_key = "multispecies_whale" self.classes = [ self.abbrev2label[v] for v in self.class_list.classes ] def __call__(self, input, return_class_results=False): # if return_class_results: # embeds, class_preds = [], [] embeds = [] self.logits = [] for frame in input: results = self.model(frame) self.logits.append(list(results.logits.values())) embeds.append(results.embeddings.squeeze()) return np.array(embeds)
[docs] def classifier_predictions(self, embeddings): return np.array(self.logits).squeeze()