目錄
03-5. Repository 模式
⏱️ 閱讀時間: 18 分鐘 🎯 難度: ⭐⭐⭐ (進階)
🤔 一句話解釋
Repository 模式將資料存取邏輯封裝成獨立的層,讓業務邏輯不直接依賴資料庫實作。
🎯 為什麼需要 Repository 模式?
沒有 Repository 的問題
# ❌ 業務邏輯直接操作資料庫
@app.post("/users")
async def create_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
# 資料庫查詢散落在各處
stmt = select(User).where(User.email == user_data.email)
existing = await db.execute(stmt)
if existing.scalar():
raise HTTPException(status_code=400, detail="Email exists")
user = User(**user_data.dict())
db.add(user)
await db.commit()
await db.refresh(user)
return user使用 Repository 的好處
┌─────────────────────────────────────────────────────────┐
│ 分層架構 │
├─────────────────────────────────────────────────────────┤
│ API Layer → 定義端點、驗證輸入 │
│ ↓ │
│ Service Layer → 業務邏輯 │
│ ↓ │
│ Repository Layer → 資料存取(封裝 SQL) │
│ ↓ │
│ Database → 實際的資料庫 │
└─────────────────────────────────────────────────────────┘優點:
| 優點 | 說明 |
|---|---|
| 單一職責 | 每層專注自己的工作 |
| 易於測試 | 可以 Mock Repository |
| 可維護性 | 修改資料存取不影響業務邏輯 |
| 可重用 | Repository 方法可以被多處使用 |
🔧 基本 Repository
抽象基礎類別
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
T = TypeVar("T")
class AbstractRepository(ABC, Generic[T]):
"""抽象 Repository 介面"""
@abstractmethod
async def get(self, id: int) -> T | None:
"""根據 ID 取得單筆資料"""
raise NotImplementedError
@abstractmethod
async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[T]:
"""取得多筆資料"""
raise NotImplementedError
@abstractmethod
async def create(self, obj: T) -> T:
"""建立資料"""
raise NotImplementedError
@abstractmethod
async def update(self, obj: T) -> T:
"""更新資料"""
raise NotImplementedError
@abstractmethod
async def delete(self, id: int) -> bool:
"""刪除資料"""
raise NotImplementedErrorSQLAlchemy Repository 實作
class SQLAlchemyRepository(AbstractRepository[T]):
"""SQLAlchemy Repository 基礎實作"""
def __init__(self, db: AsyncSession, model: Type[T]):
self.db = db
self.model = model
async def get(self, id: int) -> T | None:
"""根據 ID 取得單筆資料"""
return await self.db.get(self.model, id)
async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[T]:
"""取得多筆資料"""
stmt = select(self.model).offset(skip).limit(limit)
result = await self.db.execute(stmt)
return result.scalars().all()
async def create(self, obj: T) -> T:
"""建立資料"""
self.db.add(obj)
await self.db.commit()
await self.db.refresh(obj)
return obj
async def update(self, obj: T) -> T:
"""更新資料"""
await self.db.commit()
await self.db.refresh(obj)
return obj
async def delete(self, id: int) -> bool:
"""刪除資料"""
obj = await self.get(id)
if not obj:
return False
await self.db.delete(obj)
await self.db.commit()
return True
async def count(self) -> int:
"""計算總數"""
from sqlalchemy import func
stmt = select(func.count()).select_from(self.model)
result = await self.db.execute(stmt)
return result.scalar() or 0📝 實作 User Repository
定義介面
from abc import ABC, abstractmethod
from typing import Sequence
class IUserRepository(ABC):
"""User Repository 介面"""
@abstractmethod
async def get_by_id(self, user_id: int) -> User | None:
pass
@abstractmethod
async def get_by_email(self, email: str) -> User | None:
pass
@abstractmethod
async def get_by_username(self, username: str) -> User | None:
pass
@abstractmethod
async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[User]:
pass
@abstractmethod
async def create(self, user: User) -> User:
pass
@abstractmethod
async def update(self, user: User) -> User:
pass
@abstractmethod
async def delete(self, user_id: int) -> bool:
pass
@abstractmethod
async def exists_by_email(self, email: str) -> bool:
pass具體實作
from sqlalchemy import select, func, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
class UserRepository(IUserRepository):
"""User Repository SQLAlchemy 實作"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_id(self, user_id: int) -> User | None:
"""根據 ID 查詢使用者"""
return await self.db.get(User, user_id)
async def get_by_id_with_posts(self, user_id: int) -> User | None:
"""根據 ID 查詢使用者(含文章)"""
stmt = (
select(User)
.where(User.id == user_id)
.options(selectinload(User.posts))
)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_email(self, email: str) -> User | None:
"""根據 email 查詢使用者"""
stmt = select(User).where(User.email == email)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_username(self, username: str) -> User | None:
"""根據使用者名稱查詢"""
stmt = select(User).where(User.username == username)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def get_all(
self,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None
) -> Sequence[User]:
"""查詢使用者列表"""
stmt = select(User)
if is_active is not None:
stmt = stmt.where(User.is_active == is_active)
stmt = stmt.offset(skip).limit(limit)
result = await self.db.execute(stmt)
return result.scalars().all()
async def search(
self,
*,
keyword: str | None = None,
is_active: bool | None = None,
skip: int = 0,
limit: int = 100
) -> Sequence[User]:
"""搜尋使用者"""
stmt = select(User)
conditions = []
if keyword:
conditions.append(
or_(
User.username.ilike(f"%{keyword}%"),
User.email.ilike(f"%{keyword}%")
)
)
if is_active is not None:
conditions.append(User.is_active == is_active)
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = stmt.offset(skip).limit(limit)
result = await self.db.execute(stmt)
return result.scalars().all()
async def create(self, user: User) -> User:
"""建立使用者"""
self.db.add(user)
await self.db.commit()
await self.db.refresh(user)
return user
async def update(self, user: User) -> User:
"""更新使用者"""
await self.db.commit()
await self.db.refresh(user)
return user
async def delete(self, user_id: int) -> bool:
"""刪除使用者"""
user = await self.get_by_id(user_id)
if not user:
return False
await self.db.delete(user)
await self.db.commit()
return True
async def exists_by_email(self, email: str) -> bool:
"""檢查 email 是否已存在"""
stmt = select(func.count()).select_from(User).where(User.email == email)
result = await self.db.execute(stmt)
return (result.scalar() or 0) > 0
async def exists_by_username(self, username: str) -> bool:
"""檢查使用者名稱是否已存在"""
stmt = select(func.count()).select_from(User).where(User.username == username)
result = await self.db.execute(stmt)
return (result.scalar() or 0) > 0
async def count(self, is_active: bool | None = None) -> int:
"""計算使用者數量"""
stmt = select(func.count()).select_from(User)
if is_active is not None:
stmt = stmt.where(User.is_active == is_active)
result = await self.db.execute(stmt)
return result.scalar() or 0🏗️ Service 層
User Service
from dataclasses import dataclass
from typing import Sequence
@dataclass
class CreateUserDTO:
"""建立使用者的資料傳輸物件"""
username: str
email: str
password: str
@dataclass
class UpdateUserDTO:
"""更新使用者的資料傳輸物件"""
username: str | None = None
email: str | None = None
is_active: bool | None = None
class UserService:
"""使用者業務邏輯層"""
def __init__(self, user_repo: IUserRepository):
self.user_repo = user_repo
async def get_user(self, user_id: int) -> User | None:
"""取得使用者"""
return await self.user_repo.get_by_id(user_id)
async def get_users(
self,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None
) -> Sequence[User]:
"""取得使用者列表"""
return await self.user_repo.get_all(skip=skip, limit=limit, is_active=is_active)
async def create_user(self, data: CreateUserDTO) -> User:
"""建立使用者"""
# 業務邏輯:檢查是否已存在
if await self.user_repo.exists_by_email(data.email):
raise ValueError("Email already registered")
if await self.user_repo.exists_by_username(data.username):
raise ValueError("Username already taken")
# 業務邏輯:密碼雜湊
hashed_password = self._hash_password(data.password)
user = User(
username=data.username,
email=data.email,
hashed_password=hashed_password
)
return await self.user_repo.create(user)
async def update_user(self, user_id: int, data: UpdateUserDTO) -> User | None:
"""更新使用者"""
user = await self.user_repo.get_by_id(user_id)
if not user:
return None
# 業務邏輯:檢查 email 是否重複
if data.email and data.email != user.email:
if await self.user_repo.exists_by_email(data.email):
raise ValueError("Email already registered")
user.email = data.email
# 業務邏輯:檢查 username 是否重複
if data.username and data.username != user.username:
if await self.user_repo.exists_by_username(data.username):
raise ValueError("Username already taken")
user.username = data.username
if data.is_active is not None:
user.is_active = data.is_active
return await self.user_repo.update(user)
async def delete_user(self, user_id: int) -> bool:
"""刪除使用者"""
return await self.user_repo.delete(user_id)
async def deactivate_user(self, user_id: int) -> User | None:
"""停用使用者"""
user = await self.user_repo.get_by_id(user_id)
if not user:
return None
user.is_active = False
return await self.user_repo.update(user)
def _hash_password(self, password: str) -> str:
"""密碼雜湊(簡化版本)"""
# 實際應該使用 bcrypt 或 argon2
import hashlib
return hashlib.sha256(password.encode()).hexdigest()🚀 FastAPI 整合
依賴注入
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
app = FastAPI()
# 取得資料庫 Session
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
yield session
# 取得 UserRepository
async def get_user_repo(db: AsyncSession = Depends(get_db)) -> UserRepository:
return UserRepository(db)
# 取得 UserService
async def get_user_service(
user_repo: UserRepository = Depends(get_user_repo)
) -> UserService:
return UserService(user_repo)API 端點
from pydantic import BaseModel, EmailStr
# Pydantic Schemas
class UserCreateRequest(BaseModel):
username: str
email: EmailStr
password: str
class UserUpdateRequest(BaseModel):
username: str | None = None
email: EmailStr | None = None
is_active: bool | None = None
class UserResponse(BaseModel):
id: int
username: str
email: str
is_active: bool
class Config:
from_attributes = True
class PaginatedResponse(BaseModel):
items: list[UserResponse]
total: int
skip: int
limit: int
# API 端點
@app.post("/users", response_model=UserResponse)
async def create_user(
request: UserCreateRequest,
service: UserService = Depends(get_user_service)
):
try:
data = CreateUserDTO(
username=request.username,
email=request.email,
password=request.password
)
user = await service.create_user(data)
return user
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/users", response_model=PaginatedResponse)
async def get_users(
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
service: UserService = Depends(get_user_service),
user_repo: UserRepository = Depends(get_user_repo)
):
users = await service.get_users(skip=skip, limit=limit, is_active=is_active)
total = await user_repo.count(is_active=is_active)
return PaginatedResponse(
items=users,
total=total,
skip=skip,
limit=limit
)
@app.get("/users/{user_id}", response_model=UserResponse)
async def get_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
user = await service.get_user(user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
@app.patch("/users/{user_id}", response_model=UserResponse)
async def update_user(
user_id: int,
request: UserUpdateRequest,
service: UserService = Depends(get_user_service)
):
try:
data = UpdateUserDTO(
username=request.username,
email=request.email,
is_active=request.is_active
)
user = await service.update_user(user_id, data)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/users/{user_id}")
async def delete_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
success = await service.delete_user(user_id)
if not success:
raise HTTPException(status_code=404, detail="User not found")
return {"message": "User deleted"}🧪 測試
Mock Repository
import pytest
from unittest.mock import AsyncMock, MagicMock
class MockUserRepository(IUserRepository):
"""Mock User Repository 用於測試"""
def __init__(self):
self.users: dict[int, User] = {}
self._id_counter = 1
async def get_by_id(self, user_id: int) -> User | None:
return self.users.get(user_id)
async def get_by_email(self, email: str) -> User | None:
for user in self.users.values():
if user.email == email:
return user
return None
async def get_by_username(self, username: str) -> User | None:
for user in self.users.values():
if user.username == username:
return user
return None
async def get_all(self, skip: int = 0, limit: int = 100) -> list[User]:
users = list(self.users.values())
return users[skip:skip + limit]
async def create(self, user: User) -> User:
user.id = self._id_counter
self._id_counter += 1
self.users[user.id] = user
return user
async def update(self, user: User) -> User:
self.users[user.id] = user
return user
async def delete(self, user_id: int) -> bool:
if user_id in self.users:
del self.users[user_id]
return True
return False
async def exists_by_email(self, email: str) -> bool:
return await self.get_by_email(email) is not None
async def exists_by_username(self, username: str) -> bool:
return await self.get_by_username(username) is not None
# 測試 UserService
@pytest.mark.asyncio
async def test_create_user():
repo = MockUserRepository()
service = UserService(repo)
data = CreateUserDTO(
username="john",
email="john@example.com",
password="password123"
)
user = await service.create_user(data)
assert user.id == 1
assert user.username == "john"
assert user.email == "john@example.com"
@pytest.mark.asyncio
async def test_create_user_duplicate_email():
repo = MockUserRepository()
service = UserService(repo)
# 先建立一個使用者
data1 = CreateUserDTO(username="john", email="john@example.com", password="pass")
await service.create_user(data1)
# 嘗試建立相同 email 的使用者
data2 = CreateUserDTO(username="jane", email="john@example.com", password="pass")
with pytest.raises(ValueError, match="Email already registered"):
await service.create_user(data2)✅ 重點總結
Repository 模式結構
API Layer
↓
Service Layer (業務邏輯)
↓
Repository Layer (資料存取)
↓
Database職責分離
| 層級 | 職責 |
|---|---|
| API | 處理請求、驗證輸入、回傳響應 |
| Service | 業務邏輯、跨 Repository 操作 |
| Repository | 資料存取、SQL 查詢 |
優點
- 可測試性:可以 Mock Repository 測試 Service
- 可維護性:修改資料庫不影響業務邏輯
- 可重用性:Repository 方法可以被多處使用
- 關注點分離:每層專注自己的工作
🎤 面試這樣答
Q: 什麼是 Repository 模式?為什麼要用?
答案:
Repository 模式是將資料存取邏輯封裝成獨立的層,讓業務邏輯不直接依賴資料庫實作。
好處:
- 可測試性:可以用 Mock Repository 測試業務邏輯
- 關注點分離:Service 專注業務邏輯,Repository 專注資料存取
- 可替換性:可以輕鬆更換資料庫實作
class UserRepository: async def get_by_id(self, id: int) -> User | None: ... async def create(self, user: User) -> User: ... class UserService: def __init__(self, repo: UserRepository): self.repo = repo async def create_user(self, data) -> User: # 業務邏輯在這裡 if await self.repo.exists_by_email(data.email): raise ValueError("Email exists") return await self.repo.create(User(**data))
上一篇: 03-4. 資料庫遷移 Alembic 下一篇: 03-6. 交易管理
最後更新:2025-12-17