03-5. Repository 模式

⏱️ 閱讀時間: 18 分鐘 🎯 難度: ⭐⭐⭐ (進階)


🤔 一句話解釋

Repository 模式將資料存取邏輯封裝成獨立的層,讓業務邏輯不直接依賴資料庫實作。


🎯 為什麼需要 Repository 模式?

沒有 Repository 的問題

# ❌ 業務邏輯直接操作資料庫
@app.post("/users")
async def create_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
    # 資料庫查詢散落在各處
    stmt = select(User).where(User.email == user_data.email)
    existing = await db.execute(stmt)
    if existing.scalar():
        raise HTTPException(status_code=400, detail="Email exists")

    user = User(**user_data.dict())
    db.add(user)
    await db.commit()
    await db.refresh(user)
    return user

使用 Repository 的好處

┌─────────────────────────────────────────────────────────┐
│                   分層架構                               │
├─────────────────────────────────────────────────────────┤
│  API Layer          →  定義端點、驗證輸入               │
│        ↓                                                │
│  Service Layer      →  業務邏輯                         │
│        ↓                                                │
│  Repository Layer   →  資料存取(封裝 SQL)             │
│        ↓                                                │
│  Database           →  實際的資料庫                     │
└─────────────────────────────────────────────────────────┘

優點:

優點說明
單一職責每層專注自己的工作
易於測試可以 Mock Repository
可維護性修改資料存取不影響業務邏輯
可重用Repository 方法可以被多處使用

🔧 基本 Repository

抽象基礎類別

from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession

T = TypeVar("T")

class AbstractRepository(ABC, Generic[T]):
    """抽象 Repository 介面"""

    @abstractmethod
    async def get(self, id: int) -> T | None:
        """根據 ID 取得單筆資料"""
        raise NotImplementedError

    @abstractmethod
    async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[T]:
        """取得多筆資料"""
        raise NotImplementedError

    @abstractmethod
    async def create(self, obj: T) -> T:
        """建立資料"""
        raise NotImplementedError

    @abstractmethod
    async def update(self, obj: T) -> T:
        """更新資料"""
        raise NotImplementedError

    @abstractmethod
    async def delete(self, id: int) -> bool:
        """刪除資料"""
        raise NotImplementedError

SQLAlchemy Repository 實作

class SQLAlchemyRepository(AbstractRepository[T]):
    """SQLAlchemy Repository 基礎實作"""

    def __init__(self, db: AsyncSession, model: Type[T]):
        self.db = db
        self.model = model

    async def get(self, id: int) -> T | None:
        """根據 ID 取得單筆資料"""
        return await self.db.get(self.model, id)

    async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[T]:
        """取得多筆資料"""
        stmt = select(self.model).offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return result.scalars().all()

    async def create(self, obj: T) -> T:
        """建立資料"""
        self.db.add(obj)
        await self.db.commit()
        await self.db.refresh(obj)
        return obj

    async def update(self, obj: T) -> T:
        """更新資料"""
        await self.db.commit()
        await self.db.refresh(obj)
        return obj

    async def delete(self, id: int) -> bool:
        """刪除資料"""
        obj = await self.get(id)
        if not obj:
            return False
        await self.db.delete(obj)
        await self.db.commit()
        return True

    async def count(self) -> int:
        """計算總數"""
        from sqlalchemy import func
        stmt = select(func.count()).select_from(self.model)
        result = await self.db.execute(stmt)
        return result.scalar() or 0

📝 實作 User Repository

定義介面

from abc import ABC, abstractmethod
from typing import Sequence

class IUserRepository(ABC):
    """User Repository 介面"""

    @abstractmethod
    async def get_by_id(self, user_id: int) -> User | None:
        pass

    @abstractmethod
    async def get_by_email(self, email: str) -> User | None:
        pass

    @abstractmethod
    async def get_by_username(self, username: str) -> User | None:
        pass

    @abstractmethod
    async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[User]:
        pass

    @abstractmethod
    async def create(self, user: User) -> User:
        pass

    @abstractmethod
    async def update(self, user: User) -> User:
        pass

    @abstractmethod
    async def delete(self, user_id: int) -> bool:
        pass

    @abstractmethod
    async def exists_by_email(self, email: str) -> bool:
        pass

具體實作

from sqlalchemy import select, func, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

class UserRepository(IUserRepository):
    """User Repository SQLAlchemy 實作"""

    def __init__(self, db: AsyncSession):
        self.db = db

    async def get_by_id(self, user_id: int) -> User | None:
        """根據 ID 查詢使用者"""
        return await self.db.get(User, user_id)

    async def get_by_id_with_posts(self, user_id: int) -> User | None:
        """根據 ID 查詢使用者(含文章)"""
        stmt = (
            select(User)
            .where(User.id == user_id)
            .options(selectinload(User.posts))
        )
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_by_email(self, email: str) -> User | None:
        """根據 email 查詢使用者"""
        stmt = select(User).where(User.email == email)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_by_username(self, username: str) -> User | None:
        """根據使用者名稱查詢"""
        stmt = select(User).where(User.username == username)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_all(
        self,
        skip: int = 0,
        limit: int = 100,
        is_active: bool | None = None
    ) -> Sequence[User]:
        """查詢使用者列表"""
        stmt = select(User)

        if is_active is not None:
            stmt = stmt.where(User.is_active == is_active)

        stmt = stmt.offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return result.scalars().all()

    async def search(
        self,
        *,
        keyword: str | None = None,
        is_active: bool | None = None,
        skip: int = 0,
        limit: int = 100
    ) -> Sequence[User]:
        """搜尋使用者"""
        stmt = select(User)

        conditions = []
        if keyword:
            conditions.append(
                or_(
                    User.username.ilike(f"%{keyword}%"),
                    User.email.ilike(f"%{keyword}%")
                )
            )
        if is_active is not None:
            conditions.append(User.is_active == is_active)

        if conditions:
            stmt = stmt.where(and_(*conditions))

        stmt = stmt.offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return result.scalars().all()

    async def create(self, user: User) -> User:
        """建立使用者"""
        self.db.add(user)
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def update(self, user: User) -> User:
        """更新使用者"""
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def delete(self, user_id: int) -> bool:
        """刪除使用者"""
        user = await self.get_by_id(user_id)
        if not user:
            return False
        await self.db.delete(user)
        await self.db.commit()
        return True

    async def exists_by_email(self, email: str) -> bool:
        """檢查 email 是否已存在"""
        stmt = select(func.count()).select_from(User).where(User.email == email)
        result = await self.db.execute(stmt)
        return (result.scalar() or 0) > 0

    async def exists_by_username(self, username: str) -> bool:
        """檢查使用者名稱是否已存在"""
        stmt = select(func.count()).select_from(User).where(User.username == username)
        result = await self.db.execute(stmt)
        return (result.scalar() or 0) > 0

    async def count(self, is_active: bool | None = None) -> int:
        """計算使用者數量"""
        stmt = select(func.count()).select_from(User)
        if is_active is not None:
            stmt = stmt.where(User.is_active == is_active)
        result = await self.db.execute(stmt)
        return result.scalar() or 0

🏗️ Service 層

User Service

from dataclasses import dataclass
from typing import Sequence

@dataclass
class CreateUserDTO:
    """建立使用者的資料傳輸物件"""
    username: str
    email: str
    password: str

@dataclass
class UpdateUserDTO:
    """更新使用者的資料傳輸物件"""
    username: str | None = None
    email: str | None = None
    is_active: bool | None = None

class UserService:
    """使用者業務邏輯層"""

    def __init__(self, user_repo: IUserRepository):
        self.user_repo = user_repo

    async def get_user(self, user_id: int) -> User | None:
        """取得使用者"""
        return await self.user_repo.get_by_id(user_id)

    async def get_users(
        self,
        skip: int = 0,
        limit: int = 100,
        is_active: bool | None = None
    ) -> Sequence[User]:
        """取得使用者列表"""
        return await self.user_repo.get_all(skip=skip, limit=limit, is_active=is_active)

    async def create_user(self, data: CreateUserDTO) -> User:
        """建立使用者"""
        # 業務邏輯:檢查是否已存在
        if await self.user_repo.exists_by_email(data.email):
            raise ValueError("Email already registered")

        if await self.user_repo.exists_by_username(data.username):
            raise ValueError("Username already taken")

        # 業務邏輯:密碼雜湊
        hashed_password = self._hash_password(data.password)

        user = User(
            username=data.username,
            email=data.email,
            hashed_password=hashed_password
        )

        return await self.user_repo.create(user)

    async def update_user(self, user_id: int, data: UpdateUserDTO) -> User | None:
        """更新使用者"""
        user = await self.user_repo.get_by_id(user_id)
        if not user:
            return None

        # 業務邏輯:檢查 email 是否重複
        if data.email and data.email != user.email:
            if await self.user_repo.exists_by_email(data.email):
                raise ValueError("Email already registered")
            user.email = data.email

        # 業務邏輯:檢查 username 是否重複
        if data.username and data.username != user.username:
            if await self.user_repo.exists_by_username(data.username):
                raise ValueError("Username already taken")
            user.username = data.username

        if data.is_active is not None:
            user.is_active = data.is_active

        return await self.user_repo.update(user)

    async def delete_user(self, user_id: int) -> bool:
        """刪除使用者"""
        return await self.user_repo.delete(user_id)

    async def deactivate_user(self, user_id: int) -> User | None:
        """停用使用者"""
        user = await self.user_repo.get_by_id(user_id)
        if not user:
            return None

        user.is_active = False
        return await self.user_repo.update(user)

    def _hash_password(self, password: str) -> str:
        """密碼雜湊(簡化版本)"""
        # 實際應該使用 bcrypt 或 argon2
        import hashlib
        return hashlib.sha256(password.encode()).hexdigest()

🚀 FastAPI 整合

依賴注入

from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession

app = FastAPI()

# 取得資料庫 Session
async def get_db() -> AsyncSession:
    async with AsyncSessionLocal() as session:
        yield session

# 取得 UserRepository
async def get_user_repo(db: AsyncSession = Depends(get_db)) -> UserRepository:
    return UserRepository(db)

# 取得 UserService
async def get_user_service(
    user_repo: UserRepository = Depends(get_user_repo)
) -> UserService:
    return UserService(user_repo)

API 端點

from pydantic import BaseModel, EmailStr

# Pydantic Schemas
class UserCreateRequest(BaseModel):
    username: str
    email: EmailStr
    password: str

class UserUpdateRequest(BaseModel):
    username: str | None = None
    email: EmailStr | None = None
    is_active: bool | None = None

class UserResponse(BaseModel):
    id: int
    username: str
    email: str
    is_active: bool

    class Config:
        from_attributes = True

class PaginatedResponse(BaseModel):
    items: list[UserResponse]
    total: int
    skip: int
    limit: int

# API 端點
@app.post("/users", response_model=UserResponse)
async def create_user(
    request: UserCreateRequest,
    service: UserService = Depends(get_user_service)
):
    try:
        data = CreateUserDTO(
            username=request.username,
            email=request.email,
            password=request.password
        )
        user = await service.create_user(data)
        return user
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.get("/users", response_model=PaginatedResponse)
async def get_users(
    skip: int = 0,
    limit: int = 100,
    is_active: bool | None = None,
    service: UserService = Depends(get_user_service),
    user_repo: UserRepository = Depends(get_user_repo)
):
    users = await service.get_users(skip=skip, limit=limit, is_active=is_active)
    total = await user_repo.count(is_active=is_active)
    return PaginatedResponse(
        items=users,
        total=total,
        skip=skip,
        limit=limit
    )

@app.get("/users/{user_id}", response_model=UserResponse)
async def get_user(
    user_id: int,
    service: UserService = Depends(get_user_service)
):
    user = await service.get_user(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
    return user

@app.patch("/users/{user_id}", response_model=UserResponse)
async def update_user(
    user_id: int,
    request: UserUpdateRequest,
    service: UserService = Depends(get_user_service)
):
    try:
        data = UpdateUserDTO(
            username=request.username,
            email=request.email,
            is_active=request.is_active
        )
        user = await service.update_user(user_id, data)
        if not user:
            raise HTTPException(status_code=404, detail="User not found")
        return user
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.delete("/users/{user_id}")
async def delete_user(
    user_id: int,
    service: UserService = Depends(get_user_service)
):
    success = await service.delete_user(user_id)
    if not success:
        raise HTTPException(status_code=404, detail="User not found")
    return {"message": "User deleted"}

🧪 測試

Mock Repository

import pytest
from unittest.mock import AsyncMock, MagicMock

class MockUserRepository(IUserRepository):
    """Mock User Repository 用於測試"""

    def __init__(self):
        self.users: dict[int, User] = {}
        self._id_counter = 1

    async def get_by_id(self, user_id: int) -> User | None:
        return self.users.get(user_id)

    async def get_by_email(self, email: str) -> User | None:
        for user in self.users.values():
            if user.email == email:
                return user
        return None

    async def get_by_username(self, username: str) -> User | None:
        for user in self.users.values():
            if user.username == username:
                return user
        return None

    async def get_all(self, skip: int = 0, limit: int = 100) -> list[User]:
        users = list(self.users.values())
        return users[skip:skip + limit]

    async def create(self, user: User) -> User:
        user.id = self._id_counter
        self._id_counter += 1
        self.users[user.id] = user
        return user

    async def update(self, user: User) -> User:
        self.users[user.id] = user
        return user

    async def delete(self, user_id: int) -> bool:
        if user_id in self.users:
            del self.users[user_id]
            return True
        return False

    async def exists_by_email(self, email: str) -> bool:
        return await self.get_by_email(email) is not None

    async def exists_by_username(self, username: str) -> bool:
        return await self.get_by_username(username) is not None


# 測試 UserService
@pytest.mark.asyncio
async def test_create_user():
    repo = MockUserRepository()
    service = UserService(repo)

    data = CreateUserDTO(
        username="john",
        email="john@example.com",
        password="password123"
    )

    user = await service.create_user(data)

    assert user.id == 1
    assert user.username == "john"
    assert user.email == "john@example.com"

@pytest.mark.asyncio
async def test_create_user_duplicate_email():
    repo = MockUserRepository()
    service = UserService(repo)

    # 先建立一個使用者
    data1 = CreateUserDTO(username="john", email="john@example.com", password="pass")
    await service.create_user(data1)

    # 嘗試建立相同 email 的使用者
    data2 = CreateUserDTO(username="jane", email="john@example.com", password="pass")

    with pytest.raises(ValueError, match="Email already registered"):
        await service.create_user(data2)

✅ 重點總結

Repository 模式結構

API Layer
    ↓
Service Layer (業務邏輯)
    ↓
Repository Layer (資料存取)
    ↓
Database

職責分離

層級職責
API處理請求、驗證輸入、回傳響應
Service業務邏輯、跨 Repository 操作
Repository資料存取、SQL 查詢

優點

  1. 可測試性:可以 Mock Repository 測試 Service
  2. 可維護性:修改資料庫不影響業務邏輯
  3. 可重用性:Repository 方法可以被多處使用
  4. 關注點分離:每層專注自己的工作

🎤 面試這樣答

Q: 什麼是 Repository 模式?為什麼要用?

答案:

Repository 模式是將資料存取邏輯封裝成獨立的層,讓業務邏輯不直接依賴資料庫實作。

好處:

  1. 可測試性:可以用 Mock Repository 測試業務邏輯
  2. 關注點分離:Service 專注業務邏輯,Repository 專注資料存取
  3. 可替換性:可以輕鬆更換資料庫實作
class UserRepository:
    async def get_by_id(self, id: int) -> User | None: ...
    async def create(self, user: User) -> User: ...

class UserService:
    def __init__(self, repo: UserRepository):
        self.repo = repo

    async def create_user(self, data) -> User:
        # 業務邏輯在這裡
        if await self.repo.exists_by_email(data.email):
            raise ValueError("Email exists")
        return await self.repo.create(User(**data))

上一篇: 03-4. 資料庫遷移 Alembic 下一篇: 03-6. 交易管理


最後更新:2025-12-17

0%