Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Infer] Inference Distributed RPC Framework Optimization #5756

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,68 @@ def fd_inter_tensor(self) -> None:

def __repr__(self) -> str:
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"


class RPCBatchBucket(BatchBucket):
def __init__(self, *args, **argv):
self.is_rpc = True
self.device = "cpu"
super().__init__(*args, **argv)

# For compatibility
def get_1D_inputs(self) -> List[int]:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
if first_seq.output_len == 0:
# Assume prefill stage
assert all(
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.input_token_id)
return out_li
else:
# Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
assert self.is_compact, "BatchBucket is not compact"
out = [0] * self.current_batch_size
for seq_id, index_in_b in self._sequences_indexes.items():
seq: Sequence = self._sequences_dict[seq_id]
out[index_in_b] = seq.output_token_id[-1]
return out

# For compatibility
def get_sequence_lengths(self) -> List[int]:
assert self.is_compact # Debug usage
sequence_lengths = self.seq_lengths[: self.current_batch_size]
return sequence_lengths

def get_1D_inputs_spec_dec(self, n: int) -> List[int]:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return out_li

# For compatibility
def get_block_table_tensor(self) -> torch.Tensor:
assert self.is_compact # Debug usage
block_table = self.block_tables[: self.current_batch_size]
return block_table
23 changes: 17 additions & 6 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,14 @@ class InputMetaData(RPC_PARAM):

def to_rpc_param(self) -> Dict[str, any]:
return {
"block_tables": self.block_tables.tolist(),
"sequence_lengths": self.sequence_lengths.tolist(),
"block_tables": self.block_tables,
# "block_tables": self.block_tables.tolist()
# if isinstance(self.block_tables, torch.Tensor)
# else self.block_tables,
"sequence_lengths": self.sequence_lengths,
# "sequence_lengths": self.sequence_lengths.tolist()
# if isinstance(self.block_tables, torch.Tensor)
# else self.sequence_lengths,
"batch_size": self.batch_size,
"is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel,
Expand All @@ -112,12 +118,17 @@ def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
from colossalai.accelerator import get_accelerator

dtype = getattr(torch, rpc_dict["dtype"])
device = get_accelerator().get_current_device()
return InputMetaData(
block_tables=torch.tensor(
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
block_tables=(
torch.tensor(rpc_dict["block_tables"], dtype=torch.int, device=device)
if isinstance(rpc_dict["block_tables"], list)
else rpc_dict["block_tables"].to(device)
),
sequence_lengths=torch.tensor(
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
sequence_lengths=(
torch.tensor(rpc_dict["sequence_lengths"], dtype=torch.int, device=device)
if isinstance(rpc_dict["sequence_lengths"], list)
else rpc_dict["sequence_lengths"].to(device)
),
batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"],
Expand Down
1 change: 0 additions & 1 deletion colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def generate(

Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
"""

assert self.engine is not None, "Please init Engine first"
Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig

from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.batch_bucket import BatchBucket, RPCBatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
Expand Down Expand Up @@ -427,7 +427,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo

# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
self.running_bb = RPCBatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
Expand All @@ -437,7 +437,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
fd_interm_tensor=None,
dtype=self.dtype,
)
self.prefill_bb = BatchBucket(
self.prefill_bb = RPCBatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
Expand Down
73 changes: 53 additions & 20 deletions colossalai/inference/core/rpc_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import pickle
from itertools import count
from time import sleep
from typing import List, Tuple, Union
Expand All @@ -11,7 +12,7 @@
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.configuration_utils import PretrainedConfig

from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.batch_bucket import RPCBatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import find_available_ports
Expand Down Expand Up @@ -120,6 +121,9 @@ def __init__(
self.counter = count()
self._verify_args()

self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

self.logger.info("engine init over ")

def _verify_args(self) -> None:
Expand Down Expand Up @@ -162,8 +166,16 @@ def init_workers(self):
raise Exception("conn error!")
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
asyncio.run(self.init_worker_env())
self._init_worker_forward()
self.logger.info(f"init dist env over")

def _init_worker_forward(self):
"""
Async wrappers for forward, because it will be invoked many times.
"""
assert len(self.workers) == self.tp_size, "init workers first"
self.worker_forwards = [rpyc.async_(worker.execute_model_forward) for worker in self.workers]

async def async_parallel_wrapper(self, f, *args, **kwargs):
async_res = rpyc.async_(f)(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
Expand Down Expand Up @@ -210,7 +222,8 @@ async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]):
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
asyncio.run(self._init_device_cache(alloc_shape))

def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
def prepare_input(self, batch: RPCBatchBucket) -> Tuple[List[int], InputMetaData]:
assert batch.is_rpc, "the batch must be RPCBatchBucket"
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()

Expand All @@ -220,7 +233,7 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
assert n_tokens == len(input_ids)
n_tokens = n_tokens * batch.current_batch_size

batch_token_ids = None
Expand Down Expand Up @@ -252,40 +265,60 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
batch_token_ids=batch_token_ids,
)

return input_ids.tolist(), input_meta_data
return input_ids, input_meta_data

async def async_parallel_forward(self, async_f, *args, **kwargs):
async_res = async_f(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
assert async_res.ready
return async_res.value

async def step_(self, input_token_ids, input_meta_data: InputMetaData):
async def step_async(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"

init_tasks = [
self.async_parallel_wrapper(
worker.execute_model_forward,
input_token_ids,
input_meta_data.to_rpc_param(),
self.generation_config_dict,
)
for worker in self.workers
]
init_tasks = []
for rank, async_forward in enumerate(self.worker_forwards):
if rank == 0:
init_tasks.append(
self.async_parallel_forward(
async_forward,
pickle.dumps(input_token_ids),
pickle.dumps(input_meta_data.to_rpc_param()),
pickle.dumps(self.generation_config_dict),
)
)
else:
init_tasks.append(
self.async_parallel_forward(
async_forward,
None,
None,
None,
)
)

ret = await asyncio.gather(*init_tasks)

return ret[0]

def step(self) -> List[str]:
batch = self.request_handler.schedule()
with self.t_prepare:
batch = self.request_handler.schedule()

input_token_ids, input_meta_data = self.prepare_input(batch)

input_token_ids, input_meta_data = self.prepare_input(batch)
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
with self.t_exe:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data))

# update the request_handler
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences

def kill_workers(self):
"""
I don't find a good way to implicit invoke self.kill_workers
NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers
"""
assert len(self.workers) != 0
for proc in self.worker_processes:
Expand Down
Loading
Loading