Skip to content

Conversation

@ByrdOfAFeather
Copy link

@ByrdOfAFeather ByrdOfAFeather commented Aug 7, 2022

Rather than removing text which can create oddities, we may want to consider ways to replace tokens that would otherwise be removed. I added a support for a custom replacement_fn, which is similar to the classifier_fn. My particular use case was using T5, as such, I modified the generation of perturbed data to be in batch style rather than going one at a time.

This solves partially #648

Example replacement_fn:

def t5_wrapper(text_as_list: List[str], masks: list[list[bool]]):
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    out_refs = []
    masker_idxs = []
    outs = []
    for mask in masks:
        local_out = ""
        local_out_ref = ""
        local_masker_idx = 0
        for idx in range(len(mask)):
            if mask[idx]:
                local_out += text_as_list[idx]
                local_out_ref += text_as_list[idx]
            else:
                try:
                    local_out += tokenizer.additional_special_tokens[local_masker_idx]
                    local_masker_idx += 1
                except IndexError:
                    continue
        masker_idxs.append(local_masker_idx)
        outs.append(local_out)
        out_refs.append(local_out_ref)

    model.cuda()
    batch_size = 50
    if len(outs) > batch_size:
        input_ids = tokenizer(outs, return_tensors="pt", padding=True, max_length=512, truncation=True)
        model_suggestions = []
        for idx in range(0, len(input_ids.input_ids), batch_size):
            local_inputs = {}
            for key, value in input_ids.items():
                local_inputs[key] = value[idx: idx+batch_size]
            for key, value in local_inputs.items():
                local_inputs[key] = value.cuda()
            outputs = model.generate(**local_inputs)
            model_suggestions.extend(tokenizer.batch_decode(outputs, skip_special_tokens=False))
    else:
        input_ids = tokenizer(outs, return_tensors="pt", padding=True)
        for key, value in input_ids.items():
            input_ids[key] = value.cuda()
        outputs = model.generate(**input_ids)
        model_suggestions = tokenizer.batch_decode(outputs, skip_special_tokens=False)

    inversed_data = []
    for idx, suggestion in enumerate(model_suggestions):
        local_out = outs[idx]
        local_masker_idx = masker_idxs[idx]
        present_tokens = [tokenizer.additional_special_tokens[idx] for idx in range(local_masker_idx) if
                          tokenizer.additional_special_tokens[idx] in suggestion]
        for idx, present in enumerate(present_tokens):
            if idx == len(present_tokens) - 1:
                index = suggestion.find(present)
                start_idx = index + len(present)
                local_out = local_out.replace(present, suggestion[start_idx:])
            else:
                base_index = suggestion.find(present)
                start_idx = base_index + len(present)
                upper_index = suggestion.find(present_tokens[idx + 1])
                local_out = local_out.replace(present, suggestion[start_idx:upper_index])
        for item in tokenizer.additional_special_tokens:
            local_out = local_out.replace(item, "")
        inversed_data.append(local_out)
    return inversed_data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant