Source code for bacpipe.model_pipelines.feature_extractors.audioprotopnet

import torch
import os
# Force Hugging Face to use PyTorch and ignore TensorFlow
os.environ["USE_TF"] = "0"
os.environ["TRANSFORMERS_NO_TF"] = "1"


from transformers import (
    AutoFeatureExtractor,
    AutoModel,
    AutoModelForSequenceClassification,
)
import pandas as pd


SAMPLE_RATE = 32_000
LENGTH_IN_SAMPLES = 160_000

from ..model_utils import ModelBaseClass


[docs] class Model(ModelBaseClass): def __init__(self, **kwargs): super().__init__(sr=SAMPLE_RATE, segment_length=LENGTH_IN_SAMPLES, **kwargs) self.batch_size = 4 model = AutoModelForSequenceClassification.from_pretrained( "DBD-research-group/AudioProtoPNet-5-BirdSet-XCL", trust_remote_code=True, ) # optional: patch missing attribute if other code expects it if not hasattr(model, "incorrect_class_connection"): model.incorrect_class_connection = None self.preprocessor = AutoFeatureExtractor.from_pretrained( "DBD-research-group/AudioProtoPNet-5-BirdSet-XCL", trust_remote_code=True, ) self.model = model.model.backbone.to(self.device) self.classifier = model.head.to(self.device) self.model.eval() id2label = model.config.id2label ebird2name = pd.read_csv( self.model_utils_base_path / "perch_v2/perch_hoplite/eBird2name.csv" ) self.classes = [ ( ebird2name["English name"][ebird2name.species_code == cls].iloc[0] if cls in ebird2name.species_code.values else cls ) for cls in id2label.values() ]
[docs] def preprocess(self, audio): return self.preprocessor(audio)
def __call__(self, x): self.results = self.model(x) return self.results.pooler_output
[docs] def classifier_predictions(self, embeddings): logits, _ = self.classifier(self.results.last_hidden_state) return torch.sigmoid(logits).detach()