package ai.djl.modality.nlp.bert;

import ai.djl.modality.nlp.preprocess.SimpleTokenizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/modality/nlp/bert/BertTokenizer.class */
public class BertTokenizer extends SimpleTokenizer {
    private static final Pattern PATTERN = Pattern.compile("(\\S+?)([.,?!])?(\\s+|$)");

    @Override // ai.djl.modality.nlp.preprocess.SimpleTokenizer, ai.djl.modality.nlp.preprocess.Tokenizer
    public List<String> tokenize(String str) {
        LinkedList linkedList = new LinkedList();
        Matcher matcher = PATTERN.matcher(str);
        while (matcher.find()) {
            linkedList.add(matcher.group(1));
            String group = matcher.group(2);
            if (group != null) {
                linkedList.add(group);
            }
        }
        return linkedList;
    }

    public <E> List<E> pad(List<E> list, E e, int i) {
        if (list.size() >= i) {
            return list;
        }
        ArrayList arrayList = new ArrayList(i);
        arrayList.addAll(list);
        for (int size = list.size(); size < i; size++) {
            arrayList.add(e);
        }
        return arrayList;
    }

    public BertToken encode(String str, String str2) {
        List<String> list = tokenize(str);
        List<String> list2 = tokenize(str2);
        int size = list.size() + list2.size();
        list.add(0, "[CLS]");
        list.add("[SEP]");
        list2.add("[SEP]");
        ArrayList arrayList = new ArrayList(list);
        arrayList.addAll(list2);
        int size2 = list.size();
        long[] jArr = new long[arrayList.size()];
        Arrays.fill(jArr, size2, jArr.length, 1L);
        long[] jArr2 = new long[arrayList.size()];
        Arrays.fill(jArr2, 1L);
        return new BertToken(arrayList, (List) Arrays.stream(jArr).boxed().collect(Collectors.toList()), (List) Arrays.stream(jArr2).boxed().collect(Collectors.toList()), size);
    }

    public BertToken encode(String str, String str2, int i) {
        BertToken encode = encode(str, str2);
        return new BertToken(pad(encode.getTokens(), "[PAD]", i), pad(encode.getTokenTypes(), 0L, i), pad(encode.getAttentionMask(), 0L, i), encode.getValidLength());
    }
}
