本地大模型API服务搭建
在完成大语言模型的本地部署后,下一个重要步骤是将模型能力包装成标准的API服务,以便其他应用程序调用。本节我们将详细介绍如何使用FastAPI框架搭建API服务,实现负载均衡,以及部署监控告警系统。
FastAPI接口开发
FastAPI是一个现代化、高性能的Python Web框架,特别适合构建AI模型的API服务。我们来实现一个完整的API服务:
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import List, Optional
import torch
import logging
import time
from contextlib import asynccontextmanager
import uvicorn
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('api_service.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 定义请求模型
class GenerationRequest(BaseModel):
prompt: str
max_length: Optional[int] = 512
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
num_return_sequences: Optional[int] = 1
class GenerationResponse(BaseModel):
generated_text: List[str]
generation_time: float
token_count: int
# 全局变量用于存储模型实例
model = None
tokenizer = None
# 启动时加载模型
@asynccontextmanager
async def lifespan(app: FastAPI):
# 加载模型
global model, tokenizer
logger.info("Loading model...")
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "your-model-path" # 替换为实际的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
yield
# 清理资源
logger.info("Cleaning up...")
model = None
tokenizer = None
app = FastAPI(lifespan=lifespan)
# 健康检查端点
@app.get("/health")
async def health_check():
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "healthy"}
# 文本生成端点
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
start_time = time.time()
try:
# 记录请求
logger.info(f"Received generation request: {request}")
# 输入处理
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True)
inputs = inputs.to(model.device)
# 生成文本
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=request.max_length,
temperature=request.temperature,
top_p=request.top_p,
num_return_sequences=request.num_return_sequences,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 解码输出
generated_texts = []
for sequence in outputs:
text = tokenizer.decode(sequence, skip_special_tokens=True)
generated_texts.append(text)
# 计算token数量
token_count = sum(len(sequence) for sequence in outputs)
generation_time = time.time() - start_time
# 记录响应
logger.info(f"Generation completed in {generation_time:.2f}s")
return GenerationResponse(
generated_text=generated_texts,
generation_time=generation_time,
token_count=token_count
)
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 批量生成端点
@app.post("/generate_batch")
async def generate_batch(
requests: List[GenerationRequest],
background_tasks: BackgroundTasks
):
# 在后台处理批量请求
task_ids = []
for request in requests:
task_id = f"task_{time.time()}"
background_tasks.add_task(process_generation_request, request, task_id)
task_ids.append(task_id)
return {"task_ids": task_ids}
async def process_generation_request(request: GenerationRequest, task_id: str):
# 实际的生成处理逻辑
try:
result = await generate_text(request)
# 这里可以将结果存储到数据库或缓存中
logger.info(f"Task {task_id} completed successfully")
except Exception as e:
logger.error(f"Task {task_id} failed: {str(e)}")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
workers=1, # 对于大模型服务,通常使用单worker
log_level="info"
)
这个API服务实现了以下核心功能:
1. 模型加载与资源管理
• 使用异步上下文管理器处理模型的加载和卸载 • 支持GPU加速和自动设备选择 • 实现了优雅的启动和关闭流程
• /health
端点用于健康检查• /generate
端点提供同步文本生成服务• /generate_batch
端点支持异步批量处理
• 使用Pydantic模型进行请求验证 • 完整的错误处理和日志记录 • 支持可配置的生成参数
负载均衡配置
对于生产环境中的模型服务,我们需要配置负载均衡来提高系统的可靠性和性能。以下是使用Nginx实现负载均衡的配置:
# nginx.conf
events {
worker_connections 1024;
}
http {
upstream llm_backend {
least_conn; # 使用最少连接数算法
server 127.0.0.1:8000;
server 127.0.0.1:8001;
server 127.0.0.1:8002;
keepalive 32; # 保持连接池
}
# 启用gzip压缩
gzip on;
gzip_types text/plain application/json;
# 请求大小限制
client_max_body_size 10M;
server {
listen 80;
server_name api.yourdomain.com;
# SSL配置
listen 443 ssl;
ssl_certificate /path/to/cert.pem;
ssl_certificate_key /path/to/key.pem;
# 安全headers
add_header Strict-Transport-Security "max-age=31536000" always;
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-XSS-Protection "1; mode=block" always;
# 反向代理配置
location / {
proxy_pass http://llm_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_cache_bypass $http_upgrade;
# 超时设置
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
# 限流配置
limit_req zone=api burst=20 nodelay;
limit_conn addr 10;
}
# 健康检查端点
location /health {
proxy_pass http://llm_backend;
access_log off;
proxy_http_version 1.1;
proxy_set_header Connection "";
health_check interval=5s
fails=3
passes=2;
}
}
# 限流配置
limit_req_zone $binary_remote_addr zone=api:10m rate=10r/s;
limit_conn_zone $binary_remote_addr zone=addr:10m;
}
监控告警部署
为了确保服务的稳定运行,我们需要部署完善的监控和告警系统。以下是使用Prometheus和Grafana实现的监控方案:
from prometheus_client import Counter, Histogram, start_http_server
import time
from functools import wraps
# 定义监控指标
REQUESTS_TOTAL = Counter(
'llm_requests_total',
'Total number of LLM API requests',
['endpoint', 'status']
)
GENERATION_TIME = Histogram(
'llm_generation_seconds',
'Time spent generating text',
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, float('inf')]
)
TOKEN_COUNT = Counter(
'llm_tokens_generated_total',
'Total number of tokens generated'
)
def monitor_endpoint(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = await func(*args, **kwargs)
REQUESTS_TOTAL.labels(
endpoint=func.__name__,
status='success'
).inc()
if hasattr(result, 'token_count'):
TOKEN_COUNT.inc(result.token_count)
return result
except Exception as e:
REQUESTS_TOTAL.labels(
endpoint=func.__name__,
status='error'
).inc()
raise
finally:
GENERATION_TIME.observe(time.time() - start_time)
return wrapper
# 在FastAPI应用中使用装饰器
@app.post("/generate")
@monitor_endpoint
async def generate_text(request: GenerationRequest):
# 原有的生成逻辑
pass
# 启动Prometheus指标服务器
start_http_server(8001)
同时,我们需要配置相应的告警规则:
groups:
- name: LLMServiceAlerts
rules:
- alert: HighErrorRate
expr: rate(llm_requests_total{status="error"}[5m]) > 0.1
for: 5m
labels:
severity: critical
annotations:
summary: High error rate detected
description: Error rate is above 10% for the last 5 minutes
- alert: SlowGenerationTime
expr: histogram_quantile(0.95, rate(llm_generation_seconds_bucket[5m])) > 5
for: 5m
labels:
severity: warning
annotations:
summary: Slow text generation detected
description: 95th percentile of generation time is above 5 seconds
- alert: HighMemoryUsage
expr: process_resident_memory_bytes / 1024 / 1024 / 1024 > 80
for: 5m
labels:
severity: warning
annotations:
summary: High memory usage detected
description: Memory usage is above 80GB
- alert: ServiceUnhealthy
expr: up{job="llm_service"} == 0
for: 1m
labels:
severity: critical
annotations:
summary: Service is down
description: LLM service is not responding to health checks
通过以上配置,我们实现了一个功能完善、可靠的API服务系统。这个系统具有以下特点:
1. 高可用性
• 负载均衡确保请求分发的均衡性 • 健康检查自动剔除不健康的节点 • 优雅的错误处理和恢复机制
• 支持水平扩展增加服务节点 • 批量处理接口支持高并发场景 • 模块化的设计便于功能扩展
• 完整的指标收集和监控 • 多维度的告警规则 • 详细的日志记录
这套系统可以作为个人AI工作站的基础设施,为上层应用提供稳定的模型服务能力。在实际部署时,您可以根据具体需求调整配置参数,并根据硬件资源情况优化服务器配置。