-
Notifications
You must be signed in to change notification settings - Fork 36
/
tokenization_gptpangu.py
153 lines (123 loc) · 4.92 KB
/
tokenization_gptpangu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import torch
import sentencepiece
import jieba
import numpy as np
from transformers.tokenization_utils import PreTrainedTokenizer
jieba.add_word('<s>')
jieba.add_word('</s>')
jieba.add_word('<eot>')
jieba.add_word('<unk>')
jieba.add_word('<sep>')
jieba.add_word('<pad>')
class GPTPanguTokenizer(PreTrainedTokenizer):
# Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
vocab_files_names = {
"model_file": "vocab.model"
}
def __init__(
self,
model_file,
**kwargs
):
super().__init__(**kwargs)
self.sp = sentencepiece.SentencePieceProcessor()
self.sp.Load(model_file=model_file)
self.translator = str.maketrans(" \n", "\u2582\u2583")
self.vocab_file = model_file
# special token ids
# self.eos_token_id = self.sp.piece_to_id("<eot>")
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if self.bos_token_id is not None:
if token_ids_1 is None:
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
bos = [self.bos_token_id]
sep = [self.sep_token_id]
eos = [self.eos_token_id]
return bos + token_ids_0 + sep + token_ids_1 + eos
else:
if token_ids_1 is None:
return token_ids_0 + [self.eos_token_id]
sep = [self.sep_token_id]
eos = [self.eos_token_id]
return token_ids_0 + sep + token_ids_1 + eos
def tokenize(self, text, **kwargs):
""" Tokenize a string. """
seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
return seg_list
def convert_tokens_to_ids(self, tokens):
if tokens is None:
return None
if isinstance(tokens, str):
return self._convert_token_to_id_with_added_voc(tokens)
special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]
ids = []
i = 0
for j in special_tokens_index:
new_seg = " ".join(tokens[i:j])
ids.extend(self.sp.encode(new_seg))
ids.append(self._convert_token_to_id(tokens[j]))
i = j + 1
new_seg = " ".join(tokens[i:])
ids.extend(self.sp.encode(new_seg))
return ids
# new_seg = " ".join(tokens)
# return self.sp.encode(new_seg)
# # return tokens
def _convert_token_to_id(self, token):
return self.sp.piece_to_id(token)
def _convert_id_to_token(self, index):
return self.sp.id_to_piece(index)
def convert_ids_to_tokens(self, ids):
return self.decode(ids)
def decode(self, ids, **kwargs):
if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):
ids = ids.tolist()
if kwargs.get('skip_special_tokens', None) is True:
ids = [token_id for token_id in ids if token_id not in self.all_special_ids]
text = self.sp.decode(ids)
if isinstance(text, list):
text = text[0]
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')#.replace('⁇', self.unk_token)
return text
@property
def vocab_size(self) -> int:
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
return len(self.sp)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["model_file"]
)
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)