本地大模型API服务搭建

文摘   2024-12-27 09:00   湖北  

 

本地大模型API服务搭建

在完成大语言模型的本地部署后,下一个重要步骤是将模型能力包装成标准的API服务,以便其他应用程序调用。本节我们将详细介绍如何使用FastAPI框架搭建API服务,实现负载均衡,以及部署监控告警系统。

FastAPI接口开发

FastAPI是一个现代化、高性能的Python Web框架,特别适合构建AI模型的API服务。我们来实现一个完整的API服务:

from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import ListOptional
import torch
import logging
import time
from contextlib import asynccontextmanager
import uvicorn

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('api_service.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# 定义请求模型
class GenerationRequest(BaseModel):
    prompt: str
    max_length: Optional[int] = 512
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9
    num_return_sequences: Optional[int] = 1

class GenerationResponse(BaseModel):
    generated_text: List[str]
    generation_time: float
    token_count: int

# 全局变量用于存储模型实例
model = None
tokenizer = None

# 启动时加载模型
@asynccontextmanager
async def lifespan(app: FastAPI):
    # 加载模型
    global model, tokenizer
    logger.info("Loading model...")
    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer
        model_name = "your-model-path"  # 替换为实际的模型路径
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")
        raise
    yield
    # 清理资源
    logger.info("Cleaning up...")
    model = None
    tokenizer = None

app = FastAPI(lifespan=lifespan)

# 健康检查端点
@app.get("/health")
async def health_check():
    if model is None or tokenizer is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    return {"status""healthy"}

# 文本生成端点
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
    start_time = time.time()
    
    try:
        # 记录请求
        logger.info(f"Received generation request: {request}")
        
        # 输入处理
        inputs = tokenizer(request.prompt, return_tensors="pt", padding=True)
        inputs = inputs.to(model.device)
        
        # 生成文本
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=request.max_length,
                temperature=request.temperature,
                top_p=request.top_p,
                num_return_sequences=request.num_return_sequences,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        
        # 解码输出
        generated_texts = []
        for sequence in outputs:
            text = tokenizer.decode(sequence, skip_special_tokens=True)
            generated_texts.append(text)
        
        # 计算token数量
        token_count = sum(len(sequence) for sequence in outputs)
        
        generation_time = time.time() - start_time
        
        # 记录响应
        logger.info(f"Generation completed in {generation_time:.2f}s")
        
        return GenerationResponse(
            generated_text=generated_texts,
            generation_time=generation_time,
            token_count=token_count
        )
        
    except Exception as e:
        logger.error(f"Generation error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# 批量生成端点
@app.post("/generate_batch")
async def generate_batch(
    requests: List[GenerationRequest],
    background_tasks: BackgroundTasks
):
    # 在后台处理批量请求
    task_ids = []
    for request in requests:
        task_id = f"task_{time.time()}"
        background_tasks.add_task(process_generation_request, request, task_id)
        task_ids.append(task_id)
    
    return {"task_ids": task_ids}

async def process_generation_request(request: GenerationRequest, task_id: str):
    # 实际的生成处理逻辑
    try:
        result = await generate_text(request)
        # 这里可以将结果存储到数据库或缓存中
        logger.info(f"Task {task_id} completed successfully")
    except Exception as e:
        logger.error(f"Task {task_id} failed: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=8000,
        workers=1,  # 对于大模型服务,通常使用单worker
        log_level="info"
    )

这个API服务实现了以下核心功能:

  1. 1. 模型加载与资源管理
  • • 使用异步上下文管理器处理模型的加载和卸载
  • • 支持GPU加速和自动设备选择
  • • 实现了优雅的启动和关闭流程
  • 2. API端点设计
    • • /health端点用于健康检查
    • • /generate端点提供同步文本生成服务
    • • /generate_batch端点支持异步批量处理
  • 3. 请求处理和错误处理
    • • 使用Pydantic模型进行请求验证
    • • 完整的错误处理和日志记录
    • • 支持可配置的生成参数

    负载均衡配置

    对于生产环境中的模型服务,我们需要配置负载均衡来提高系统的可靠性和性能。以下是使用Nginx实现负载均衡的配置:

    # nginx.conf

    events {
        worker_connections 1024;
    }

    http {
        upstream llm_backend {
            least_conn;  # 使用最少连接数算法
            server 127.0.0.1:8000;
            server 127.0.0.1:8001;
            server 127.0.0.1:8002;
            
            keepalive 32;  # 保持连接池
        }

        # 启用gzip压缩
        gzip on;
        gzip_types text/plain application/json;
        
        # 请求大小限制
        client_max_body_size 10M;
        
        server {
            listen 80;
            server_name api.yourdomain.com;
            
            # SSL配置
            listen 443 ssl;
            ssl_certificate /path/to/cert.pem;
            ssl_certificate_key /path/to/key.pem;
            
            # 安全headers
            add_header Strict-Transport-Security "max-age=31536000" always;
            add_header X-Frame-Options "SAMEORIGIN" always;
            add_header X-XSS-Protection "1; mode=block" always;
            
            # 反向代理配置
            location / {
                proxy_pass http://llm_backend;
                proxy_http_version 1.1;
                proxy_set_header Upgrade $http_upgrade;
                proxy_set_header Connection 'upgrade';
                proxy_set_header Host $host;
                proxy_cache_bypass $http_upgrade;
                
                # 超时设置
                proxy_connect_timeout 60s;
                proxy_send_timeout 60s;
                proxy_read_timeout 60s;
                
                # 限流配置
                limit_req zone=api burst=20 nodelay;
                limit_conn addr 10;
            }
            
            # 健康检查端点
            location /health {
                proxy_pass http://llm_backend;
                access_log off;
                proxy_http_version 1.1;
                proxy_set_header Connection "";
                
                health_check interval=5s
                             fails=3
                             passes=2;
            }
        }
        
        # 限流配置
        limit_req_zone $binary_remote_addr zone=api:10m rate=10r/s;
        limit_conn_zone $binary_remote_addr zone=addr:10m;
    }

    监控告警部署

    为了确保服务的稳定运行,我们需要部署完善的监控和告警系统。以下是使用Prometheus和Grafana实现的监控方案:

    from prometheus_client import Counter, Histogram, start_http_server
    import time
    from functools import wraps

    # 定义监控指标
    REQUESTS_TOTAL = Counter(
        'llm_requests_total',
        'Total number of LLM API requests',
        ['endpoint''status']
    )

    GENERATION_TIME = Histogram(
        'llm_generation_seconds',
        'Time spent generating text',
        buckets=[0.10.51.02.05.010.0float('inf')]
    )

    TOKEN_COUNT = Counter(
        'llm_tokens_generated_total',
        'Total number of tokens generated'
    )

    def monitor_endpoint(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            start_time = time.time()
            try:
                result = await func(*args, **kwargs)
                REQUESTS_TOTAL.labels(
                    endpoint=func.__name__,
                    status='success'
                ).inc()
                
                if hasattr(result, 'token_count'):
                    TOKEN_COUNT.inc(result.token_count)
                    
                return result
            except Exception as e:
                REQUESTS_TOTAL.labels(
                    endpoint=func.__name__,
                    status='error'
                ).inc()
                raise
            finally:
                GENERATION_TIME.observe(time.time() - start_time)
        
        return wrapper

    # 在FastAPI应用中使用装饰器
    @app.post("/generate")
    @monitor_endpoint
    async def generate_text(request: GenerationRequest):
        # 原有的生成逻辑
        pass

    # 启动Prometheus指标服务器
    start_http_server(8001)

    同时,我们需要配置相应的告警规则:

    groups:
    - name: LLMServiceAlerts
      rules:
      - alert: HighErrorRate
        expr: rate(llm_requests_total{status="error"}[5m]) > 0.1
        for: 5m
        labels:
          severity: critical
        annotations:
          summary: High error rate detected
          description: Error rate is above 10% for the last 5 minutes

      - alert: SlowGenerationTime
        expr: histogram_quantile(0.95, rate(llm_generation_seconds_bucket[5m])) > 5
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: Slow text generation detected
          description: 95th percentile of generation time is above 5 seconds

      - alert: HighMemoryUsage
        expr: process_resident_memory_bytes / 1024 / 1024 / 1024 > 80
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: High memory usage detected
          description: Memory usage is above 80GB

      - alert: ServiceUnhealthy
        expr: up{job="llm_service"} == 0
        for: 1m
        labels:
          severity: critical
        annotations:
          summary: Service is down
          description: LLM service is not responding to health checks

    通过以上配置,我们实现了一个功能完善、可靠的API服务系统。这个系统具有以下特点:

    1. 1. 高可用性
    • • 负载均衡确保请求分发的均衡性
    • • 健康检查自动剔除不健康的节点
    • • 优雅的错误处理和恢复机制
  • 2. 可扩展性
    • • 支持水平扩展增加服务节点
    • • 批量处理接口支持高并发场景
    • • 模块化的设计便于功能扩展
  • 3. 可监控性
    • • 完整的指标收集和监控
    • • 多维度的告警规则
    • • 详细的日志记录

    这套系统可以作为个人AI工作站的基础设施,为上层应用提供稳定的模型服务能力。在实际部署时,您可以根据具体需求调整配置参数,并根据硬件资源情况优化服务器配置。

     


    前端道萌
    魔界如,佛界如,一如,无二如。
     最新文章