依赖注入(Dependency Injection)是 FastAPI 最强大的特性之一,也是理解 FastAPI 架构的关键。本文从概念到实战,系统讲解 Depends 的用法、嵌套依赖、数据库会话管理、认证鉴权和权限控制。
4.1 什么是依赖注入
4.1.1 概念解释
依赖注入的核心思想:不要自己创建依赖,而是由外部提供。
# 没有依赖注入:函数自己负责获取数据库连接
async def get_users():
db = create_database_connection() # 硬编码依赖
users = await db.execute("SELECT * FROM users")
db.close()
return users
# 有依赖注入:依赖由外部注入
async def get_users(db: Database = Depends(get_db)):
users = await db.execute("SELECT * FROM users")
return users
# db 的创建和销毁由 get_db 负责
4.1.2 为什么需要依赖注入
| 问题 | 没有 DI | 有 DI |
|---|---|---|
| 代码复用 | 每个路由重复写数据库连接逻辑 | 集中定义,多处复用 |
| 可测试性 | 难以替换真实数据库 | 注入 mock 对象即可 |
| 关注点分离 | 路由函数混杂基础设施代码 | 路由只关心业务逻辑 |
| 配置管理 | 散落在各处 | 集中管理 |
4.1.3 FastAPI 的依赖注入
FastAPI 内置了依赖注入系统,不需要第三方库(如 dependency-injector)。它的核心是 Depends 函数。
from fastapi import Depends, FastAPI
app = FastAPI()
async def get_db():
"""依赖函数:提供数据库会话"""
db = DatabaseSession()
try:
yield db # yield 之前是"进入"逻辑
finally:
db.close() # yield 之后是"退出"逻辑(清理资源)
@app.get("/users/")
async def list_users(db = Depends(get_db)):
# FastAPI 自动调用 get_db(),将返回值注入到 db 参数
return await db.get_users()
4.2 Depends 基础用法
4.2.1 简单依赖
from fastapi import Depends
# 依赖函数
async def get_pagination(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
) -> dict:
return {
"offset": (page - 1) * page_size,
"limit": page_size,
"page": page,
"page_size": page_size,
}
# 路由使用依赖
@app.get("/users/")
async def list_users(pagination: dict = Depends(get_pagination)):
return {
"offset": pagination["offset"],
"limit": pagination["limit"],
}
@app.get("/products/")
async def list_products(pagination: dict = Depends(get_pagination)):
# 复用同一个依赖
return {"offset": pagination["offset"]}
4.2.2 依赖的返回类型
from typing import Annotated
# 方式一:直接类型注解
@app.get("/users/")
async def list_users(pagination: dict = Depends(get_pagination)):
...
# 方式二:使用 Annotated(推荐,IDE 支持更好)
Pagination = Annotated[dict, Depends(get_pagination)]
@app.get("/users/")
async def list_users(pagination: Pagination):
...
@app.get("/products/")
async def list_products(pagination: Pagination):
...
4.2.3 yield 依赖(资源管理)
yield 依赖是管理数据库连接、文件句柄等资源的标准模式。
from fastapi import Depends
async def get_db():
"""数据库会话依赖"""
db = AsyncSessionLocal() # 创建会话
try:
yield db # 将会话提供给路由
except Exception:
await db.rollback() # 异常时回滚
raise
finally:
await db.close() # 无论如何都关闭会话
# 嵌套 yield 依赖
async def get_user_service(db = Depends(get_db)):
"""用户服务依赖(依赖数据库会话)"""
service = UserService(db)
try:
yield service
finally:
# service 的清理逻辑(如有)
pass
4.3 嵌套依赖
FastAPI 会自动解析依赖链,支持任意深度的嵌套。
4.3.1 依赖链
from fastapi import Depends
# 第一层:配置
async def get_settings() -> Settings:
return Settings()
# 第二层:数据库(依赖配置)
async def get_db(settings: Settings = Depends(get_settings)):
engine = create_async_engine(settings.DATABASE_URL)
async with AsyncSession(engine) as session:
yield session
# 第三层:用户服务(依赖数据库)
async def get_user_service(db = Depends(get_db)):
return UserService(db)
# 第四层:当前用户(依赖用户服务)
async def get_current_user(
token: str = Header(),
service: UserService = Depends(get_user_service),
) -> User:
return await service.verify_token(token)
# 路由使用
@app.get("/me")
async def get_me(current_user: User = Depends(get_current_user)):
return current_user
依赖链解析过程:
GET /me 请求到达
-> FastAPI 发现 get_me 需要 get_current_user
-> get_current_user 需要 get_user_service 和 token
-> get_user_service 需要 get_db
-> get_db 需要 get_settings
-> 从 get_settings 开始,逐层解析并注入
4.3.2 依赖缓存
同一个请求中,相同的依赖只会执行一次。FastAPI 会缓存依赖的结果。
async def get_db():
print("Creating DB session") # 只会打印一次
db = AsyncSessionLocal()
try:
yield db
finally:
await db.close()
async def get_user_repo(db = Depends(get_db)):
return UserRepository(db)
async def get_order_repo(db = Depends(get_db)):
# 这里的 db 和 get_user_repo 的是同一个实例
# get_db 只执行一次
return OrderRepository(db)
@app.get("/data")
async def get_data(
user_repo = Depends(get_user_repo),
order_repo = Depends(get_order_repo),
):
# 两个 repo 共享同一个数据库会话
...
4.4 数据库会话管理
这是依赖注入最常见的应用场景。
4.4.1 SQLAlchemy 异步会话
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from fastapi import Depends
# 创建引擎和会话工厂
engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost:5432/mydb",
pool_size=5,
max_overflow=10,
echo=False,
)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
# 依赖函数
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# 类型别名,简化使用
DBSession = Annotated[AsyncSession, Depends(get_db)]
# 路由使用
@app.get("/users/", response_model=list[UserResponse])
async def list_users(db: DBSession):
result = await db.execute(select(User))
return result.scalars().all()
@app.post("/users/", response_model=UserResponse, status_code=201)
async def create_user(body: UserCreate, db: DBSession):
user = User(**body.model_dump())
db.add(user)
await db.flush() # 刷新以获取 ID
await db.refresh(user) # 刷新以获取默认值
return user
4.4.2 事务管理
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit() # 成功时提交
except Exception:
await session.rollback() # 失败时回滚
raise
# 在服务层手动控制事务
class UserService:
def __init__(self, db: AsyncSession):
self.db = db
async def transfer_balance(self, from_id: int, to_id: int, amount: float):
"""转账操作(在同一事务中)"""
from_user = await self.db.get(User, from_id)
to_user = await self.db.get(User, to_id)
if from_user.balance < amount:
raise ValueError("余额不足")
from_user.balance -= amount
to_user.balance += amount
# 无需手动 commit,依赖函数会处理
4.5 认证与鉴权
4.5.1 API Key 认证
from fastapi import Depends, HTTPException, Security
from fastapi.security import APIKeyHeader
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != "my-secret-key":
raise HTTPException(status_code=403, detail="无效的 API Key")
return api_key
@app.get("/protected/", dependencies=[Depends(verify_api_key)])
async def protected_route():
return {"message": "你通过了 API Key 认证"}
4.5.2 JWT Bearer Token 认证
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from pydantic import BaseModel
# 配置
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
# 安全方案
security = HTTPBearer()
class TokenData(BaseModel):
user_id: int
username: str
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> TokenData:
"""JWT 认证依赖"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get("sub")
username = payload.get("username")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据",
)
return TokenData(user_id=int(user_id), username=username)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 已过期或无效",
headers={"WWW-Authenticate": "Bearer"},
)
# 使用
@app.get("/me")
async def get_me(current_user: TokenData = Depends(get_current_user)):
return {"user_id": current_user.user_id, "username": current_user.username}
4.5.3 权限控制
from enum import Enum
from functools import wraps
class Role(str, Enum):
ADMIN = "admin"
MODERATOR = "moderator"
USER = "user"
class PermissionChecker:
"""权限检查器(可作为依赖使用)"""
def __init__(self, required_roles: list[Role]):
self.required_roles = required_roles
async def __call__(self, current_user: TokenData = Depends(get_current_user)):
# 这里应该从数据库查询用户角色
user_role = await get_user_role(current_user.user_id)
if user_role not in self.required_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足",
)
return current_user
# 使用
@app.get("/admin/users/")
async def admin_list_users(
current_user: TokenData = Depends(PermissionChecker([Role.ADMIN])),
):
return {"message": f"管理员 {current_user.username} 访问成功"}
@app.get("/moderator/posts/")
async def mod_list_posts(
current_user: TokenData = Depends(PermissionChecker([Role.ADMIN, Role.MODERATOR])),
):
return {"message": f"版主 {current_user.username} 访问成功"}
4.6 依赖复用技巧
4.6.1 路由级依赖
在路由装饰器中添加依赖,所有该路由的请求都会执行。
from fastapi import Depends
# 单个路由的依赖
@app.get("/admin/", dependencies=[Depends(verify_api_key)])
async def admin_dashboard():
return {"data": "admin dashboard"}
# 路由组的依赖
from fastapi import APIRouter
admin_router = APIRouter(
prefix="/admin",
dependencies=[Depends(verify_api_key), Depends(require_admin)],
)
@admin_router.get("/users/")
async def list_users():
...
@admin_router.get("/settings/")
async def get_settings():
...
4.6.2 参数化依赖
通过类的 __init__ 方法实现参数化的依赖。
class RateLimiter:
"""速率限制依赖"""
def __init__(self, max_requests: int, window_seconds: int):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests: dict[str, list[float]] = {}
async def __call__(self, request: Request):
client_ip = request.client.host
now = time.time()
# 清理过期记录
if client_ip in self.requests:
self.requests[client_ip] = [
t for t in self.requests[client_ip]
if now - t < self.window_seconds
]
else:
self.requests[client_ip] = []
if len(self.requests[client_ip]) >= self.max_requests:
raise HTTPException(status_code=429, detail="请求过于频繁")
self.requests[client_ip].append(now)
# 使用
@app.post("/login/", dependencies=[Depends(RateLimiter(max_requests=5, window_seconds=60))])
async def login():
...
4.6.3 依赖类 vs 依赖函数
# 函数式依赖(简单场景)
async def get_db():
async with AsyncSessionLocal() as session:
yield session
# 类依赖(需要参数化或复杂逻辑)
class DatabaseDependency:
def __init__(self, url: str):
self.url = url
self.engine = create_async_engine(url)
self.session_factory = async_sessionmaker(self.engine)
async def __call__(self) -> AsyncSession:
async with self.session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# 不同数据库的依赖
get_main_db = DatabaseDependency("postgresql+asyncpg://main...")
get_analytics_db = DatabaseDependency("postgresql+asyncpg://analytics...")
@app.get("/users/")
async def list_users(db = Depends(get_main_db)):
...
@app.get("/stats/")
async def get_stats(db = Depends(get_analytics_db)):
...
4.7 测试中的依赖替换
依赖注入最大的好处之一是易于测试。
4.7.1 使用 app.dependency_overrides
# main.py
from fastapi import Depends
async def get_db():
async with AsyncSessionLocal() as session:
yield session
@app.get("/users/")
async def list_users(db = Depends(get_db)):
result = await db.execute(select(User))
return result.scalars().all()
# test_main.py
from fastapi.testclient import TestClient
from main import app
# Mock 数据库
class FakeDB:
async def execute(self, *args, **kwargs):
return FakeResult()
# 覆盖依赖
async def override_get_db():
return FakeDB()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_list_users():
response = client.get("/users/")
assert response.status_code == 200
4.7.2 使用 pytest fixtures
# conftest.py
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from main import app
from app.api.deps import get_db
from app.models import Base
# 测试数据库
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"
test_engine = create_async_engine(TEST_DATABASE_URL)
TestSessionLocal = async_sessionmaker(test_engine, class_=AsyncSession)
@pytest.fixture(autouse=True)
async def setup_db():
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db():
async with TestSessionLocal() as session:
yield session
@pytest.fixture
async def client(db):
async def override_get_db():
yield db
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
app.dependency_overrides.clear()
# test_users.py
async def test_create_user(client: AsyncClient):
response = await client.post("/api/v1/users/", json={
"name": "Alice",
"email": "alice@example.com",
})
assert response.status_code == 201
data = response.json()
assert data["name"] == "Alice"
4.8 实战:完整的认证鉴权系统
from fastapi import Depends, HTTPException, status, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta
# ========== 配置 ==========
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# ========== 工具 ==========
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def hash_password(password: str) -> str:
return pwd_context.hash(password)
# ========== 依赖 ==========
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Security(security),
db: AsyncSession = Depends(get_db),
) -> User:
"""获取当前认证用户"""
try:
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
user_id: int = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="无效的 Token")
except JWTError:
raise HTTPException(status_code=401, detail="Token 已过期或无效")
user = await db.get(User, user_id)
if user is None:
raise HTTPException(status_code=401, detail="用户不存在")
if not user.is_active:
raise HTTPException(status_code=403, detail="用户已被停用")
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""确保用户处于活跃状态"""
if not current_user.is_active:
raise HTTPException(status_code=403, detail="用户已被停用")
return current_user
class require_role:
"""角色权限检查器"""
def __init__(self, *roles: str):
self.roles = roles
async def __call__(self, current_user: User = Depends(get_current_user)) -> User:
if current_user.role not in self.roles:
raise HTTPException(status_code=403, detail="权限不足")
return current_user
# ========== 路由 ==========
@app.post("/auth/login")
async def login(body: LoginRequest, db: DBSession):
"""登录接口"""
user = await db.execute(select(User).where(User.email == body.email))
user = user.scalar_one_or_none()
if not user or not verify_password(body.password, user.hashed_password):
raise HTTPException(status_code=401, detail="邮箱或密码错误")
token = create_access_token(
{"sub": str(user.id), "username": user.username},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
)
return {"access_token": token, "token_type": "bearer"}
@app.get("/users/me", response_model=UserResponse)
async def get_me(current_user: User = Depends(get_current_active_user)):
"""获取当前用户信息"""
return current_user
@app.get("/admin/users/", response_model=list[UserResponse])
async def admin_list_users(
current_user: User = Depends(require_role("admin")),
db: DBSession,
):
"""管理员获取用户列表"""
result = await db.execute(select(User))
return result.scalars().all()
4.9 小结
| 知识点 | 关键 API |
|---|---|
| 基础依赖 | Depends(dependency_function) |
| yield 依赖 | yield resource(自动清理) |
| 嵌套依赖 | 依赖函数中再用 Depends |
| 类型简化 | Annotated[Type, Depends(...)] |
| 认证 | HTTPBearer + Security |
| 权限控制 | 可调用类 + Depends |
| 路由级依赖 | dependencies=[Depends(...)] |
| 测试覆盖 | app.dependency_overrides[dep] = mock |

