-
Notifications
You must be signed in to change notification settings - Fork 12
/
SFT_stage1.py
293 lines (229 loc) · 12.6 KB
/
SFT_stage1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import os
import json
import torch
import logging
import argparse
from tqdm import tqdm
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
import wandb
from accelerate import Accelerator
from transformers import set_seed, get_cosine_schedule_with_warmup
import shutil
import json
import traceback
from jinja2 import Template
from transformers import AutoModelForCausalLM, AutoTokenizer
os.umask(0)
logger = logging.getLogger(__name__)
logging.basicConfig(level='INFO')
class Train_dataset(torch.utils.data.Dataset):
def __init__(self, config, tokenizer):
self.config = config
self.tokenizer = tokenizer
with open(config.data_path) as f:
self.data = json.load(f)
newdata = []
for da in self.data:
newdata.append(da)
print('过滤掉',len(self.data),len(newdata))
self.data = newdata
self.max_seq_len = self.config.max_seq_len
self.debug = 0
# 如果从Base LLMs训练,选择 llama3-instruct作为模版
chat_template_llama3 = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
if not tokenizer.chat_template:
tokenizer.chat_template = chat_template_llama3
self.template = Template(tokenizer.chat_template)
def __getitem__(self, index):
return self.data[index]
def get_response(self,da):
temp = '## Thinking\n\n{}\n\n## Final Response\n\n{}'
return temp.format(da['Complex_CoT'],da['Response'])
def get_prompt(self,da):
q = da['Question']
a = self.get_response(da)
assert q is not None and a is not None, f'q:{q} a:{a}'
input = self.template.render(messages=[{"role": "user", "content": q},{"role": "assistant", "content": a}],bos_token=self.tokenizer.bos_token,add_generation_prompt=False)
input_ids = self.tokenizer.encode(input,add_special_tokens= False)
query = self.template.render(messages=[{"role": "user", "content": q}],bos_token=self.tokenizer.bos_token,add_generation_prompt=True)
query_ids = self.tokenizer.encode(query,add_special_tokens= False)
labels = [-100]*len(query_ids) + input_ids[len(query_ids):]
assert len(labels) == len(input_ids)
return {"input_ids": input_ids[-self.max_seq_len:], "labels": labels[-self.max_seq_len:]}
def collate_fn(self, batch):
data = [ self.get_prompt(da) for da in batch]
input_ids = [item["input_ids"] for item in data]
labels = [item["labels"] for item in data]
max_len = max(len(x) for x in input_ids)
max_len = min(max_len,self.max_seq_len)
input_ids = [ item[:max_len] + [self.tokenizer.eos_token_id]*(max_len-len(item)) for item in input_ids]
labels = [ item[:max_len] + [-100]*(max_len-len(item)) for item in labels]
if self.debug < 3:
print('input_ids',self.tokenizer.decode(input_ids[-1]))
print('labels',self.tokenizer.decode([0 if x == -100 else x for x in labels[-1]]))
self.debug += 1
return {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(labels),
}
def __len__(self):
return len(self.data)
class SFTMetric:
def __init__(self, device):
self.n_step = 0
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_loss = torch.Tensor([0]).to(device=device)
self.world_size = dist.get_world_size()
def __call__(self, logits, labels, loss):
return self.update(logits, labels, loss)
def update(self, logits, labels, loss):
self.n_step += 1
with torch.no_grad():
shift_preds = logits[..., :-1, :].argmax(dim=-1)
shift_labels = labels[..., 1:]
self.right += (shift_preds == shift_labels).masked_fill(shift_labels.eq(-100), 0).sum().item()
self.total += (shift_labels != -100).sum().item()
self.total_loss += loss.item()
def get_metric(self, reset=True):
dist.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM)
dist.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM)
dist.all_reduce(self.total_loss, op=torch.distributed.ReduceOp.SUM)
acc = (self.right / self.total).item()
loss = self.total_loss.item() / (self.world_size * self.n_step)
if reset:
self.n_step = 0
self.right.fill_(0)
self.total.fill_(0)
self.total_loss.fill_(0)
return acc, loss
def train(args):
accelerator = Accelerator(mixed_precision='bf16', gradient_accumulation_steps=args.gradient_accumulation_steps)
if accelerator.is_main_process:
wandb.init(project = args.experiment_name, config=args, dir=args.log_dir, mode="offline")
accelerator.print(f'args:\n{args}')
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_bsz_per_gpu
accelerator.state.deepspeed_plugin.deepspeed_config['train_batch_size'] = args.train_bsz_per_gpu*dist.get_world_size()*accelerator.gradient_accumulation_steps
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
# open gradient checkpointing
model.gradient_checkpointing_enable()
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
train_dataset = Train_dataset(args, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_bsz_per_gpu, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn)
num_training_steps = int(len(train_dataloader) * (args.n_epochs)) // accelerator.gradient_accumulation_steps // dist.get_world_size()
lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_rates * num_training_steps), num_training_steps=num_training_steps)
accelerator.print(f'gradient_accumulation_steps:{accelerator.gradient_accumulation_steps} data_path:{args.data_path} lr:{args.learning_rate} num_training_steps:{num_training_steps}')
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
start_epoch = 0
start_step = 0
global_step = 0
metric = SFTMetric(device=torch.cuda.current_device())
def save_checkpoint(epoch, step, global_step):
save_dir = os.path.join(args.output_dir, f"checkpoint-{epoch}-{global_step}")
if accelerator.is_main_process:
checkpoint_files = os.listdir(args.output_dir)
checkpoint_files = [file for file in checkpoint_files if file.startswith("checkpoint-")]
num_checkpoints = len(checkpoint_files)
if args.max_ckpts>0:
if num_checkpoints >= args.max_ckpts:
checkpoint_files.sort(key=lambda x: os.path.getctime(os.path.join(args.output_dir, x)))
oldest_checkpoint = checkpoint_files[0]
shutil.rmtree(os.path.join(args.output_dir, oldest_checkpoint))
os.makedirs(save_dir, exist_ok=True)
output_dir = os.path.join(save_dir, 'tfmr')
if accelerator.state.deepspeed_plugin.zero_stage!=3:
model.save_pretrained(output_dir,state_dict=accelerator.get_state_dict(model))
tokenizer.save_pretrained(output_dir)
copy_files = []
for item in os.listdir(args.model_path):
if os.path.exists(os.path.join(output_dir,item)):
continue
if item.startswith("pytorch_model") and item.endswith(".bin"):
continue
if item.endswith(".index.json") or item.endswith(".safetensors"):
continue
s = os.path.join(args.model_path, item)
if os.path.isfile(s):
shutil.copy(s, os.path.join(output_dir,item))
copy_files.append(item)
print(f'huggingface model save in {output_dir}, copy file:{copy_files}')
if accelerator.state.deepspeed_plugin.zero_stage==3:
unwrap_model = accelerator.unwrap_model(model)
unwrap_model.save_pretrained(os.path.join(save_dir, f'tfmr'),is_main_process=accelerator.is_main_process,save_function=accelerator.save,state_dict=accelerator.get_state_dict(model))
accelerator.wait_for_everyone()
accelerator.save({"epoch": epoch, "step": step, "global_step": global_step}, os.path.join(save_dir, "training_state.pt"))
accelerator.print(f'checkpoint checkpoint-{epoch}-{global_step} is saved...')
accelerator.print(accelerator.deepspeed_config)
model.train()
for epoch in range(start_epoch, args.n_epochs):
train_dataloader_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) if accelerator.is_main_process else enumerate(train_dataloader)
for batch_cnt, batch in train_dataloader_iterator:
if epoch==start_epoch and batch_cnt<start_step:
continue
if batch_cnt == 1 and epoch == 0:
torch.cuda.empty_cache()
input_ids=batch['input_ids']
labels=batch['labels']
output = model(input_ids=input_ids, labels=labels, return_dict=True,use_cache=False)
loss = output.loss
metric(output.logits, labels, loss)
acc, train_loss = metric.get_metric()
accelerator.backward(loss)
if (global_step+1) % accelerator.gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
if accelerator.is_main_process:
train_dataloader_iterator.set_postfix(epoch=epoch, current_step=batch_cnt, total_step=len(train_dataloader), skip=accelerator.optimizer_step_was_skipped, loss=round(train_loss, 3), acc=round(acc, 3), length=len(input_ids[0]), lr=lr_scheduler.get_last_lr()[0])
if global_step % 3 == 0 and accelerator.is_main_process:
wandb.log({
'skip': int(accelerator.optimizer_step_was_skipped),
'loss': train_loss,
'acc': acc,
'lr': lr_scheduler.get_last_lr()[0]
}, step=global_step)
accelerator.wait_for_everyone()
save_checkpoint(epoch, batch_cnt, global_step)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Args of sft')
# Experiment Args
parser.add_argument('--experiment_name', type=str,default='sft_stage1')
# Model Args
parser.add_argument('--model_path', required=True, type=str)
# Data Args
parser.add_argument('--data_path', required=True, type=str)
# Training Args
parser.add_argument('--output_dir', default='./ckpts', type=str)
parser.add_argument('--max_ckpts', default=2, type=int)
parser.add_argument('--log_dir', default='./train_logs', type=str)
parser.add_argument('--max_seq_len', default=8192, type=int)
parser.add_argument('--gradient_checkpointing', action='store_true')
parser.add_argument('--gradient_accumulation_steps', default=8, type=int)
parser.add_argument('--train_bsz_per_gpu', default=2, type=int)
parser.add_argument('--weight_decay', default=0.1, type=float)
parser.add_argument('--learning_rate', default=5e-6, type=float)
parser.add_argument('--warmup_rates', default=0.05, type=float)
parser.add_argument('--n_epochs', default=3, type=int)
# Other Args
parser.add_argument('--seed', default=42, type=int)
args = parser.parse_args()
args.log_dir = os.path.join(args.log_dir,args.experiment_name)
args.output_dir = os.path.join(args.output_dir,args.experiment_name)
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.output_dir, exist_ok=True)
set_seed(args.seed)
train(args)