bacpipe.embedding_evaluation.probing.inference_probe
Functions
|
Load a linear probe that was previously trained and saved. |
|
Apply a previously trained linear probe to data. |
Classes
|
|
|
PurePath subclass that can make system calls. |
- bacpipe.embedding_evaluation.probing.inference_probe.prepare_probe_inference(model, probe_path='')[source]
Load a linear probe that was previously trained and saved. The probe is loaded and the state_dict of the model is loaded so that the probe is ready and in the exact same state as after training.
- Parameters:
model (str) – model name of backbone
probe_path (str, optional) – path to probe, will default to the standard bacpipe path, by default ‘’
- Returns:
torch model object – linear probe model
dict – dictionary to associate the columns of the generated predictions array with the corresponding class label
- bacpipe.embedding_evaluation.probing.inference_probe.run_probe_inference(model, linear_probe, threshold, embeds=None, return_binary_presence=True, callbacks=None, device='cpu')[source]
Apply a previously trained linear probe to data. This requires either that the embeddings were already created using the backbone and saved using the bacpipe folder structure, or that the embeddings are directly passed to this function. See the examples notebooks for an example use case. This function then loads the embeddings and applies the linear probe to classify the data.
- Parameters:
model (str) – model name
linear_probe (torch model) – linear probe torch model object
threshold (float) – float value to process the predictions
embeds (torch.Tensor, optional) – embeddings array, by default None
return_binary_presence (bool, optional) – if true a binary presence array is returned, by default True
callbacks (function, optional) – use to have custom progress bars increment, by default None
device (str, optional) – select device to process the probe, by default ‘cpu’
- Returns:
generated probe predictions
- Return type:
np.ndarray