bacpipe.embedding_evaluation.probing.inference_probe

Functions

prepare_probe_inference(model[, probe_path])

Load a linear probe that was previously trained and saved.

run_probe_inference(model, linear_probe, ...)

Apply a previously trained linear probe to data.

Classes

LinearProbe(in_dim, out_dim[, device])

Path(*args, **kwargs)

PurePath subclass that can make system calls.

bacpipe.embedding_evaluation.probing.inference_probe.prepare_probe_inference(model, probe_path='')[source]

Load a linear probe that was previously trained and saved. The probe is loaded and the state_dict of the model is loaded so that the probe is ready and in the exact same state as after training.

Parameters:
  • model (str) – model name of backbone

  • probe_path (str, optional) – path to probe, will default to the standard bacpipe path, by default ‘’

Returns:

  • torch model object – linear probe model

  • dict – dictionary to associate the columns of the generated predictions array with the corresponding class label

bacpipe.embedding_evaluation.probing.inference_probe.run_probe_inference(model, linear_probe, threshold, embeds=None, return_binary_presence=True, callbacks=None, device='cpu')[source]

Apply a previously trained linear probe to data. This requires either that the embeddings were already created using the backbone and saved using the bacpipe folder structure, or that the embeddings are directly passed to this function. See the examples notebooks for an example use case. This function then loads the embeddings and applies the linear probe to classify the data.

Parameters:
  • model (str) – model name

  • linear_probe (torch model) – linear probe torch model object

  • threshold (float) – float value to process the predictions

  • embeds (torch.Tensor, optional) – embeddings array, by default None

  • return_binary_presence (bool, optional) – if true a binary presence array is returned, by default True

  • callbacks (function, optional) – use to have custom progress bars increment, by default None

  • device (str, optional) – select device to process the probe, by default ‘cpu’

Returns:

generated probe predictions

Return type:

np.ndarray