-
Notifications
You must be signed in to change notification settings - Fork 622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] (Willing to PR) Proposal: Drop-in fast replacement of PreTrainedModel.generate
#2569
Comments
Thanks for pointing this out! For your information, we may not try to integrate SGLang into TRL ourselves since TRL is more or less out of date. But for OpenRLHF, yes. Definitely, we will do this. And I am working on that. For OpenRLHF For |
Also, I am not quite sure why we need to discuss this. In my experience, every time the policy model gets updated, we should also update the inference engine's weights. It's not related to the PPO batch size but only related to the training epochs? |
Also, in your title, “Drop-in fast replacement of PreTrainedModel.generate”. Do you mean that changing the inference engine in these post-training frameworks, from huggingface/vllm to sglang? I am working on this and making some progress. If you do really have time and are willing to contribute, we are really glad. Do you have time for a quick discussion? I send you the link on WeChat. Thanks so much! |
Interesting! I have searched how people are doing OpenRLHF nowadays, have found TRL and OpenRLHF (and maybe other frameworks?), and it seems TRL is more popular than OpenRLHF. May I know a bit about the "out of date" thing? If it is out of date then I will not spend much time looking at it.
Yes that would be super great. I have tried to use gloo backend and it can broadcast weight to same GPU (while nccl throws error), but I am not sure whether that's suboptimal. For example, I guess it at least have a memory copy? Instead, if we already know it is in the same GPU, maybe we can do zero-copy to save some time, and also avoid the need of broadcasting which is extra complexity and can introduce bugs.
My proposal was a drop-in replacement. In other words, users only need to do something like, say, For example, users can directly use TRL's PPOTrainer and OpenRLHF's train_ppo (non-ray) trainer, with the only change being At the same time, we surely can do the non-dropin things, i.e. directly modify the TRL/OpenRLHF/whatever framework's code to add it. Happy to see this is WIP and I am also happy to contribute some PRs (for this one and #2542)! |
Well. A lot of my friends use OpenRLHF since it's easier to hack than trl. I think TRL and OpenRLHF are both good and perfect for us to contribute.
Well. Yesterday I used OpenRLHF with vllm on ray across 8 * H100, but the broadcast failed. Maybe I should use gloo instead of nccl this time. I will try this out. And, I am not sure how big the model we can use on a single GPU? PPO a 7B model takes 3 * H100 with adam offloading and co-locate ref/actor, critic/reward. If we can do this on one GPU, that would be perfect.
Sorry. I do not fully understand what is drop-in 😂 I always do non-drop-in and we call this open to use. Like how we integrate SGLang with xgrammar. We can discuss this tomorrow. Thanks so much for help and merry Christmas. |
I see, thanks for the info.
Theoretically speaking, if we use bf16 model weight, then
=> at least 56B. If we can make SGLang use almost zero memory by temporarily removing model weight and kv cache, then looks like we can fill in one 80B card, though not sure whether there are enough memory for large enough batch size for forward/backward.
A super naive version would be:
usage:
then, even though PPOTrainer calls PreTrainedModel.generate, but our SGLangModelWrapper will make it call SGLang.generate and gets faster.
You are welcome, and also merry Christmas :) |
Checklist
Motivation
Hi thanks for the lib! Currently, a lot of code uses
model.generate()
, such as TRL's PPOTrainer, etc. If we can make a drop-in replacement of it using SGLang, then everyone can very easily speed up their code related to generation. For example, TRL's PPOTrainer, OpenRLHF's train_ppo.py (not the train_ppo_ray.py which is more for distributed training). IMHO there are many places this can be useful - many online RL algorithm can benefit from this.As for when to update SGLang weight from HF weight, most naive solution may be, we update weights every time the generate is called. This may not be a big problem, because we can configure the PPO batch size to be so huge that the model.generate is only called once.
Related: #2542 With that, we can reduce memory footprint outside generate.
Related resources
No response
The text was updated successfully, but these errors were encountered: