bacpipe.model_pipelines.model_utils

Functions

check_if_cudnn_tensorflow_compatible()

Classes

ModelBaseClass(sr, segment_length, model_name)

Path(*args, **kwargs)

PurePath subclass that can make system calls.

class bacpipe.model_pipelines.model_utils.ModelBaseClass(sr, segment_length, model_name, device=None, model_base_path=None, global_batch_size=None, dim_reduction_model=False, **kwargs)[source]

Bases: object

__init__(sr, segment_length, model_name, device=None, model_base_path=None, global_batch_size=None, dim_reduction_model=False, **kwargs)[source]

This base class defines key methods and attributes for all feature extractors to ensure that we can use the same processing pipeline to generate embeddings. The idea is to

1. initialize the model with prepare_inference, thereby loading the model and loading it onto the selected device.

  1. load and resample audio to the sample rate required by the model

3. window the audio into segments corresponding to the required input segment length.

4. Calculating spectrograms (if the model architecture is accessible) to batch preprocess the audio and potentially be able to in retrospect build the spectrograms to investigate

5. Initialize a torch dataloader object based on the model specific audio loading characteristics to speed up the inference process and looping through the segments

  1. Perform batch inference

If ‘cuda’ has been selected as device, a threading approach is used to load data in parallel while performing inference. The return value are the embeddings.

Parameters:
  • sr (int) – sample rate

  • segment_length (int) – segment length in samples

  • device (str) – ‘cpu’ or ‘cuda’

  • model_base_path (pathlib.Path) – path to moin model checkpoint dir

  • global_batch_size (int) – global batch size that is then used in comjunction with the segment length to calculate a model-specific batch size that results in approximately equal batches for different models

prepare_inference()[source]
preprocessing(audio)[source]
bacpipe.model_pipelines.model_utils.check_if_cudnn_tensorflow_compatible()[source]