SmartFlowAI
点击上方蓝字关注我们
作者:HelloWorld,本文转载自 https://zhuanlan.zhihu.com/p/694856151
全文约 3000 字,预计阅读时间 10 分钟
本系列将介绍 FastChat 的方方面面:
FastChat[1] 是一个用于训练、部署和评估大模型的开源库。本篇文章先介绍 FastChat 的部署服务架构,后用 200 行代码实现一个最小的 FastChat 以便理解 FastChat 的实现逻辑。
FastChat 部署服务架构
FastChat 部署服务的代码位于 fastchat/serve[2],核心的文件有 3 个:
controller.py[3]:实现了 Controller,它的功能包括注册新 Worker、删除 Worker、分配 Worker model_worker.py[4]:实现了 Worker,它的功能是调用模型处理请求并将结果返回给 Server。每个 Worker 都单独拥有一个完整的模型,可以多个 Worker 处理同样的模型,例如 Worker 1 和 Worker 2 都处理 Model A,这样可以提高Model A 处理请求的吞吐量。另外,Worker 和 GPU 是一对多的关系,即一个 Worker 可以对应多个 GPU,例如使用了张量并行(Tensor Parallelism)将一个模型切分到多个 GPU 上 openai_api_server.py[5]:实现了 OpenAI 兼容的 RESTful API
它们的关系如图 1 所示:
以处理一个请求为例介绍它的流程:
用户往 Server(例如 OpenAI API Server)发送请求,其中请求包含了模型名以及输入,例如:
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Llama-3-8B-Instruct",
"messages": [{"role": "user", "content": "Hello! What is your name?"}]
}'
Server 向 Controller 发送请求,目的是获取处理 model 的 Worker 地址
Controller 根据负载均衡策略分配 Worker
Server 向 Worker 发送请求
Worker 处理请求并将结果返回给 Server
Server 将结果返回给用户
以上就是 FastChat 处理一个请求的流程,接下来,我们将实现一个最小的 FastChat。
实现 Mini FastChat
Mini FastChat 支持的功能和实现方式和 FastChat 类似,但做了简化,代码修改自 FastChat。
Mini FastChat 的目录结构如下:
mini-fastchat
├── controller.py
├── worker.py
└── openai_api_server.py
Controller
新建一个 controller.py 文件,主要实现了 Controller 类,它的功能是注册 Worker 以及为请求随机分配 Worker。同时,controller.py 提供了两个接口register_worker
和get_worker_address
,前者会被 Worker 调用以将 Worker 注册到 Controller 中,后者会被 API Server 调用以获得 Worker 的地址。
import argparse
import uvicorn
import random
from fastapi import FastAPI, Request
from loguru import logger
class Controller:
def __init__(self):
self.worker_info = {}
def register_worker(
self,
worker_addr: str,
model_name: str,
):
logger.info(f'Register worker: {worker_addr} {model_name}')
self.worker_info[worker_addr] = model_name
def get_worker_address(self, model_name: str):
# 为请求分配 worker
worker_addr_list = []
for worker_addr, _model_name in self.worker_info.items():
if _model_name == model_name:
worker_addr_list.append(worker_addr)
assert len(worker_addr_list) > 0, f'No worker for model: {model_name}'
# 使用随机的方式分配 worker
worker_addr = random.choice(worker_addr_list)
return worker_addr
app = FastAPI()
@app.post('/register_worker')
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
worker_addr=data['worker_addr'],
model_name=data['model_name'],
)
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data['model'])
return {"address": addr}
def create_controller():
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default='localhost')
parser.add_argument('--port', type=int, default=21001)
args = parser.parse_args()
logger.info(f'args: {args}')
controller = Controller()
return args, controller
if __name__ == '__main__':
args, controller = create_controller()
uvicorn.run(app, host=args.host, port=args.port, log_level='info')
Worker
新建一个 worker.py 文件,主要实现了 Worker 类,同时提供了api_generate
接口将会被 API Server 调用以处理用户的请求。
import argparse
import asyncio
from typing import Optional
import requests
import uvicorn
import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, Request
def load_model(model_path: str) -> None:
logger.info(f'Load model from {model_path}')
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map='auto',
)
logger.info(f'model device: {model.device}')
return model, tokenizer
def generate(model, tokenizer, params: dict):
input_ids = tokenizer.apply_chat_template(
params['messages'],
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
class Worker:
def __init__(
self,
controller_addr: str,
worker_addr: str,
model_path: str,
model_name: Optional[str] = None,
) -> None:
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.model, self.tokenizer = load_model(model_path)
self.model_name = model_name
self.register_to_controller()
def register_to_controller(self) -> None:
logger.info('Register to controller')
url = self.controller_addr + '/register_worker'
data = {
'worker_addr': self.worker_addr,
'model_name': self.model_name,
}
response = requests.post(url, json=data)
assert response.status_code == 200
def generate_gate(self, params: dict):
return generate(self.model, self.tokenizer, params)
app = FastAPI()
@app.post("/worker_generate")
async def api_generate(request: Request):
params = await request.json()
output = await asyncio.to_thread(worker.generate_gate, params)
return {'output': output}
def create_worker():
parser = argparse.ArgumentParser()
parser.add_argument('model_path', type=str, help='Path to the model')
parser.add_argument('model_name', type=str)
parser.add_argument('--host', type=str, default='localhost')
parser.add_argument('--port', type=int, default=21002)
parser.add_argument('--controller-address', type=str, default='http://localhost:21001')
args = parser.parse_args()
logger.info(f'args: {args}')
args.worker_address = f'http://{args.host}:{args.port}'
worker = Worker(worker_addr=args.worker_address, controller_addr=args.controller_address, model_path=args.model_path, model_name=args.model_name)
return args, worker
if __name__ == '__main__':
args, worker = create_worker()
uvicorn.run(app, host=args.host, port=args.port, log_level='info')
Server
import argparse
import asyncio
import aiohttp
import uvicorn
from fastapi import FastAPI, Request
from loguru import logger
app = FastAPI()
app_settings = {}
async def fetch_remote(url, payload):
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
return await response.json()
async def generate_completion(payload, worker_addr: str):
return await fetch_remote(worker_addr + "/worker_generate", payload)
async def get_worker_address(model_name: str) -> str:
controller_address = app_settings['controller_address']
res = await fetch_remote(
controller_address + "/get_worker_address", {"model": model_name}
)
return res['address']
@app.post('/v1/chat/completions')
async def create_chat_completion(request: Request):
data = await request.json()
worker_addr = await get_worker_address(data['model'])
response = asyncio.create_task(generate_completion(data, worker_addr))
await response
return response.result()
def create_openai_api_server():
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default='localhost')
parser.add_argument('--port', type=int, default=8000)
parser.add_argument('--controller-address', type=str, default='http://localhost:21001')
args = parser.parse_args()
logger.info(f'args: {args}')
app_settings['controller_address'] = args.controller_address
return args
if __name__ == '__main__':
args = create_openai_api_server()
uvicorn.run(app, host=args.host, port=args.port, log_level='info')
运行 Mini FastChat
配置环境
创建 conda
conda create -n fastchat python=3.10 -y conda activate fastchat
安装 torch2.2.1
conda install pytorch==2.2.1 pytorch-cuda=12.1 -c pytorch -c nvidia
安装依赖
pip install requests aiohttp uvicorn fastapi loguru transformers
运行
启动 controller
python mini-fastchat/controller.py
启动 worker
python mini-fastchat/worker.py meta-llama/Meta-Llama-3-8B-Instruct Llama-3-8B-Instruct
# 如果环境中还有多余的 GPU,可以再起一个 worker
CUDA_VISIBLE_DEVICES=1 python mini-fastchat/worker.py meta-llama/Meta-Llama-3-8B-Instruct Llama-3-8B-Instruct --port 21003
启动 API server
python mini-fastchat/openai_api_server.py
测试
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Llama-3-8B-Instruct",
"messages": [{"role": "user", "content": "Hello! What is your name?"}]
}'
如果上面的命令可以看到输出,则说明成功运行了。
可以改进的点
Mini FastChat 简单实现了类 FastChat 部署服务,但相比于 FastChat,还有很多可以改进的点,例如:
负载均衡策略:Mini FastChat 的 Controller 只支持了随机分配 Worker,而 FastChat Controller[6] 支持 LOTTERY 和 SHORTEST_QUEUE 策略 代码不够鲁棒:为了简化实现,Mini FastChat 没有处理可能出现的异常情况,例如输入有误、网络异常
往期 · 推荐
🌠 番外:我们期待与读者共同探讨如何在 AI 的辅助下,更好地发挥人类的潜力,以及如何培养和维持那些 AI 难以取代的核心技能。通过深入分析和实践,我们可以更清晰地认识到 AI 的辅助作用,并在 AI 时代下找到人类的独特价值和发展空间。“机智流”公众号后台聊天框回复“cc”,加入机智流大模型交流群!
FastChat: https://github.com/lm-sys/FastChat
[2]fastchat/serve: https://github.com/lm-sys/FastChat/tree/main/fastchat/serve
[3]controller.py: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/controller.py
[4]model_worker.py: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/base_model_worker.py
[5]openai_api_server.py: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
[6]Controller: https://github.com/lm-sys/FastChat/blob/827aaba091a03ca4f9ed3fc2c74ddf8ab567cadf/fastchat/serve/controller.py#L156