Adding New Defenses
This guide will walk you through the process of adding a defense method similar to STRIP, which detects poisoned data by analyzing the entropy of perturbed samples. To create a similar defense class, you will define a new defense method or enhance an existing one while following the same structure.
Here we show the STRIP defense as an example to illustrate how to add a new defense method. The steps are as follows:
Define the Defense Class
To create a new defense, you need to inherit from InputFilteringBase, which provides basic filtering functionality. In the constructor (__init__), set the relevant attributes, such as the type of data the defense handles (e.g., images, text, audio) and other defense-specific parameters.
Example:
class STRIP(InputFilteringBase):
def __init__(self, args) -> None:
super().__init__(args)
self.args = args
self.set_pertub_func() # Set the perturbation function based on data type
Define the Perturbation Functions
The STRIP defense works by perturbing inputs (e.g., adding noise to images, shuffling words in text) and then measuring the entropy of the model’s output on these perturbed samples. The key is to define how to perturb different types of data.
Example code for image perturbation:
def perturb_img(self, img):
perturbation = torch.randn(img.shape)
return img + perturbation # Add noise to the image
Example code for text perturbation:
def perturb_txt(self, text):
words = text.split()
swap_pos = np.random.randint(0, len(words), int(len(words) * 0.1))
for idx in swap_pos:
words[idx] = self.replace_words[np.random.randint(len(self.replace_words))]
return " ".join(words) # Randomly replace some words
Calculate Entropy for Defense
The core of the defense is calculating the entropy of the model’s predictions on perturbed data. Based on this entropy, the defense decides whether the input is poisoned or clean.
Example:
def cal_entropy(self, model, data_lst, sample=False):
perturbed_samples = [self.perturb(data) for data in data_lst]
probs = model(perturbed_samples)
entropy = -torch.sum(probs * torch.log2(probs), dim=-1) # Compute entropy
return entropy
Implement the Sample Filtering Method
Write the sample_filter() method to compute entropy and determine whether the sample is malicious based on a pre-defined threshold.
If the entropy is below the threshold, mark the sample as malicious.
Example:
def sample_filter(self, data):
poison_entropy = self.cal_entropy(self.model.model, data, sample=True)
return (1 if poison_entropy < self.threshold else 0, poison_entropy) # 1 for malicious
Test the Defense
After defining your defense method, it is important to test it under various scenarios and datasets. Make sure to verify that the filtering logic correctly identifies malicious samples across different data types (e.g., images, text, audio).
Run your defense with different thresholds and configurations to optimize performance.