# 

# 03-5. Repository 模式

&gt; ⏱️ **閱讀時間：** 18 分鐘
&gt; 🎯 **難度：** ⭐⭐⭐ (進階)

---

## 🤔 一句話解釋

**Repository 模式將資料存取邏輯封裝成獨立的層，讓業務邏輯不直接依賴資料庫實作。**

---

## 🎯 為什麼需要 Repository 模式？

### 沒有 Repository 的問題

```python
# ❌ 業務邏輯直接操作資料庫
@app.post(&#34;/users&#34;)
async def create_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
    # 資料庫查詢散落在各處
    stmt = select(User).where(User.email == user_data.email)
    existing = await db.execute(stmt)
    if existing.scalar():
        raise HTTPException(status_code=400, detail=&#34;Email exists&#34;)

    user = User(**user_data.dict())
    db.add(user)
    await db.commit()
    await db.refresh(user)
    return user
```

### 使用 Repository 的好處

```
┌─────────────────────────────────────────────────────────┐
│                   分層架構                               │
├─────────────────────────────────────────────────────────┤
│  API Layer          →  定義端點、驗證輸入               │
│        ↓                                                │
│  Service Layer      →  業務邏輯                         │
│        ↓                                                │
│  Repository Layer   →  資料存取（封裝 SQL）             │
│        ↓                                                │
│  Database           →  實際的資料庫                     │
└─────────────────────────────────────────────────────────┘
```

**優點：**

| 優點 | 說明 |
|------|------|
| **單一職責** | 每層專注自己的工作 |
| **易於測試** | 可以 Mock Repository |
| **可維護性** | 修改資料存取不影響業務邏輯 |
| **可重用** | Repository 方法可以被多處使用 |

---

## 🔧 基本 Repository

### 抽象基礎類別

```python
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession

T = TypeVar(&#34;T&#34;)

class AbstractRepository(ABC, Generic[T]):
    &#34;&#34;&#34;抽象 Repository 介面&#34;&#34;&#34;

    @abstractmethod
    async def get(self, id: int) -&gt; T | None:
        &#34;&#34;&#34;根據 ID 取得單筆資料&#34;&#34;&#34;
        raise NotImplementedError

    @abstractmethod
    async def get_all(self, skip: int = 0, limit: int = 100) -&gt; Sequence[T]:
        &#34;&#34;&#34;取得多筆資料&#34;&#34;&#34;
        raise NotImplementedError

    @abstractmethod
    async def create(self, obj: T) -&gt; T:
        &#34;&#34;&#34;建立資料&#34;&#34;&#34;
        raise NotImplementedError

    @abstractmethod
    async def update(self, obj: T) -&gt; T:
        &#34;&#34;&#34;更新資料&#34;&#34;&#34;
        raise NotImplementedError

    @abstractmethod
    async def delete(self, id: int) -&gt; bool:
        &#34;&#34;&#34;刪除資料&#34;&#34;&#34;
        raise NotImplementedError
```

### SQLAlchemy Repository 實作

```python
class SQLAlchemyRepository(AbstractRepository[T]):
    &#34;&#34;&#34;SQLAlchemy Repository 基礎實作&#34;&#34;&#34;

    def __init__(self, db: AsyncSession, model: Type[T]):
        self.db = db
        self.model = model

    async def get(self, id: int) -&gt; T | None:
        &#34;&#34;&#34;根據 ID 取得單筆資料&#34;&#34;&#34;
        return await self.db.get(self.model, id)

    async def get_all(self, skip: int = 0, limit: int = 100) -&gt; Sequence[T]:
        &#34;&#34;&#34;取得多筆資料&#34;&#34;&#34;
        stmt = select(self.model).offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return result.scalars().all()

    async def create(self, obj: T) -&gt; T:
        &#34;&#34;&#34;建立資料&#34;&#34;&#34;
        self.db.add(obj)
        await self.db.commit()
        await self.db.refresh(obj)
        return obj

    async def update(self, obj: T) -&gt; T:
        &#34;&#34;&#34;更新資料&#34;&#34;&#34;
        await self.db.commit()
        await self.db.refresh(obj)
        return obj

    async def delete(self, id: int) -&gt; bool:
        &#34;&#34;&#34;刪除資料&#34;&#34;&#34;
        obj = await self.get(id)
        if not obj:
            return False
        await self.db.delete(obj)
        await self.db.commit()
        return True

    async def count(self) -&gt; int:
        &#34;&#34;&#34;計算總數&#34;&#34;&#34;
        from sqlalchemy import func
        stmt = select(func.count()).select_from(self.model)
        result = await self.db.execute(stmt)
        return result.scalar() or 0
```

---

## 📝 實作 User Repository

### 定義介面

```python
from abc import ABC, abstractmethod
from typing import Sequence

class IUserRepository(ABC):
    &#34;&#34;&#34;User Repository 介面&#34;&#34;&#34;

    @abstractmethod
    async def get_by_id(self, user_id: int) -&gt; User | None:
        pass

    @abstractmethod
    async def get_by_email(self, email: str) -&gt; User | None:
        pass

    @abstractmethod
    async def get_by_username(self, username: str) -&gt; User | None:
        pass

    @abstractmethod
    async def get_all(self, skip: int = 0, limit: int = 100) -&gt; Sequence[User]:
        pass

    @abstractmethod
    async def create(self, user: User) -&gt; User:
        pass

    @abstractmethod
    async def update(self, user: User) -&gt; User:
        pass

    @abstractmethod
    async def delete(self, user_id: int) -&gt; bool:
        pass

    @abstractmethod
    async def exists_by_email(self, email: str) -&gt; bool:
        pass
```

### 具體實作

```python
from sqlalchemy import select, func, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

class UserRepository(IUserRepository):
    &#34;&#34;&#34;User Repository SQLAlchemy 實作&#34;&#34;&#34;

    def __init__(self, db: AsyncSession):
        self.db = db

    async def get_by_id(self, user_id: int) -&gt; User | None:
        &#34;&#34;&#34;根據 ID 查詢使用者&#34;&#34;&#34;
        return await self.db.get(User, user_id)

    async def get_by_id_with_posts(self, user_id: int) -&gt; User | None:
        &#34;&#34;&#34;根據 ID 查詢使用者（含文章）&#34;&#34;&#34;
        stmt = (
            select(User)
            .where(User.id == user_id)
            .options(selectinload(User.posts))
        )
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_by_email(self, email: str) -&gt; User | None:
        &#34;&#34;&#34;根據 email 查詢使用者&#34;&#34;&#34;
        stmt = select(User).where(User.email == email)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_by_username(self, username: str) -&gt; User | None:
        &#34;&#34;&#34;根據使用者名稱查詢&#34;&#34;&#34;
        stmt = select(User).where(User.username == username)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_all(
        self,
        skip: int = 0,
        limit: int = 100,
        is_active: bool | None = None
    ) -&gt; Sequence[User]:
        &#34;&#34;&#34;查詢使用者列表&#34;&#34;&#34;
        stmt = select(User)

        if is_active is not None:
            stmt = stmt.where(User.is_active == is_active)

        stmt = stmt.offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return result.scalars().all()

    async def search(
        self,
        *,
        keyword: str | None = None,
        is_active: bool | None = None,
        skip: int = 0,
        limit: int = 100
    ) -&gt; Sequence[User]:
        &#34;&#34;&#34;搜尋使用者&#34;&#34;&#34;
        stmt = select(User)

        conditions = []
        if keyword:
            conditions.append(
                or_(
                    User.username.ilike(f&#34;%{keyword}%&#34;),
                    User.email.ilike(f&#34;%{keyword}%&#34;)
                )
            )
        if is_active is not None:
            conditions.append(User.is_active == is_active)

        if conditions:
            stmt = stmt.where(and_(*conditions))

        stmt = stmt.offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return result.scalars().all()

    async def create(self, user: User) -&gt; User:
        &#34;&#34;&#34;建立使用者&#34;&#34;&#34;
        self.db.add(user)
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def update(self, user: User) -&gt; User:
        &#34;&#34;&#34;更新使用者&#34;&#34;&#34;
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def delete(self, user_id: int) -&gt; bool:
        &#34;&#34;&#34;刪除使用者&#34;&#34;&#34;
        user = await self.get_by_id(user_id)
        if not user:
            return False
        await self.db.delete(user)
        await self.db.commit()
        return True

    async def exists_by_email(self, email: str) -&gt; bool:
        &#34;&#34;&#34;檢查 email 是否已存在&#34;&#34;&#34;
        stmt = select(func.count()).select_from(User).where(User.email == email)
        result = await self.db.execute(stmt)
        return (result.scalar() or 0) &gt; 0

    async def exists_by_username(self, username: str) -&gt; bool:
        &#34;&#34;&#34;檢查使用者名稱是否已存在&#34;&#34;&#34;
        stmt = select(func.count()).select_from(User).where(User.username == username)
        result = await self.db.execute(stmt)
        return (result.scalar() or 0) &gt; 0

    async def count(self, is_active: bool | None = None) -&gt; int:
        &#34;&#34;&#34;計算使用者數量&#34;&#34;&#34;
        stmt = select(func.count()).select_from(User)
        if is_active is not None:
            stmt = stmt.where(User.is_active == is_active)
        result = await self.db.execute(stmt)
        return result.scalar() or 0
```

---

## 🏗️ Service 層

### User Service

```python
from dataclasses import dataclass
from typing import Sequence

@dataclass
class CreateUserDTO:
    &#34;&#34;&#34;建立使用者的資料傳輸物件&#34;&#34;&#34;
    username: str
    email: str
    password: str

@dataclass
class UpdateUserDTO:
    &#34;&#34;&#34;更新使用者的資料傳輸物件&#34;&#34;&#34;
    username: str | None = None
    email: str | None = None
    is_active: bool | None = None

class UserService:
    &#34;&#34;&#34;使用者業務邏輯層&#34;&#34;&#34;

    def __init__(self, user_repo: IUserRepository):
        self.user_repo = user_repo

    async def get_user(self, user_id: int) -&gt; User | None:
        &#34;&#34;&#34;取得使用者&#34;&#34;&#34;
        return await self.user_repo.get_by_id(user_id)

    async def get_users(
        self,
        skip: int = 0,
        limit: int = 100,
        is_active: bool | None = None
    ) -&gt; Sequence[User]:
        &#34;&#34;&#34;取得使用者列表&#34;&#34;&#34;
        return await self.user_repo.get_all(skip=skip, limit=limit, is_active=is_active)

    async def create_user(self, data: CreateUserDTO) -&gt; User:
        &#34;&#34;&#34;建立使用者&#34;&#34;&#34;
        # 業務邏輯：檢查是否已存在
        if await self.user_repo.exists_by_email(data.email):
            raise ValueError(&#34;Email already registered&#34;)

        if await self.user_repo.exists_by_username(data.username):
            raise ValueError(&#34;Username already taken&#34;)

        # 業務邏輯：密碼雜湊
        hashed_password = self._hash_password(data.password)

        user = User(
            username=data.username,
            email=data.email,
            hashed_password=hashed_password
        )

        return await self.user_repo.create(user)

    async def update_user(self, user_id: int, data: UpdateUserDTO) -&gt; User | None:
        &#34;&#34;&#34;更新使用者&#34;&#34;&#34;
        user = await self.user_repo.get_by_id(user_id)
        if not user:
            return None

        # 業務邏輯：檢查 email 是否重複
        if data.email and data.email != user.email:
            if await self.user_repo.exists_by_email(data.email):
                raise ValueError(&#34;Email already registered&#34;)
            user.email = data.email

        # 業務邏輯：檢查 username 是否重複
        if data.username and data.username != user.username:
            if await self.user_repo.exists_by_username(data.username):
                raise ValueError(&#34;Username already taken&#34;)
            user.username = data.username

        if data.is_active is not None:
            user.is_active = data.is_active

        return await self.user_repo.update(user)

    async def delete_user(self, user_id: int) -&gt; bool:
        &#34;&#34;&#34;刪除使用者&#34;&#34;&#34;
        return await self.user_repo.delete(user_id)

    async def deactivate_user(self, user_id: int) -&gt; User | None:
        &#34;&#34;&#34;停用使用者&#34;&#34;&#34;
        user = await self.user_repo.get_by_id(user_id)
        if not user:
            return None

        user.is_active = False
        return await self.user_repo.update(user)

    def _hash_password(self, password: str) -&gt; str:
        &#34;&#34;&#34;密碼雜湊（簡化版本）&#34;&#34;&#34;
        # 實際應該使用 bcrypt 或 argon2
        import hashlib
        return hashlib.sha256(password.encode()).hexdigest()
```

---

## 🚀 FastAPI 整合

### 依賴注入

```python
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession

app = FastAPI()

# 取得資料庫 Session
async def get_db() -&gt; AsyncSession:
    async with AsyncSessionLocal() as session:
        yield session

# 取得 UserRepository
async def get_user_repo(db: AsyncSession = Depends(get_db)) -&gt; UserRepository:
    return UserRepository(db)

# 取得 UserService
async def get_user_service(
    user_repo: UserRepository = Depends(get_user_repo)
) -&gt; UserService:
    return UserService(user_repo)
```

### API 端點

```python
from pydantic import BaseModel, EmailStr

# Pydantic Schemas
class UserCreateRequest(BaseModel):
    username: str
    email: EmailStr
    password: str

class UserUpdateRequest(BaseModel):
    username: str | None = None
    email: EmailStr | None = None
    is_active: bool | None = None

class UserResponse(BaseModel):
    id: int
    username: str
    email: str
    is_active: bool

    class Config:
        from_attributes = True

class PaginatedResponse(BaseModel):
    items: list[UserResponse]
    total: int
    skip: int
    limit: int

# API 端點
@app.post(&#34;/users&#34;, response_model=UserResponse)
async def create_user(
    request: UserCreateRequest,
    service: UserService = Depends(get_user_service)
):
    try:
        data = CreateUserDTO(
            username=request.username,
            email=request.email,
            password=request.password
        )
        user = await service.create_user(data)
        return user
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.get(&#34;/users&#34;, response_model=PaginatedResponse)
async def get_users(
    skip: int = 0,
    limit: int = 100,
    is_active: bool | None = None,
    service: UserService = Depends(get_user_service),
    user_repo: UserRepository = Depends(get_user_repo)
):
    users = await service.get_users(skip=skip, limit=limit, is_active=is_active)
    total = await user_repo.count(is_active=is_active)
    return PaginatedResponse(
        items=users,
        total=total,
        skip=skip,
        limit=limit
    )

@app.get(&#34;/users/{user_id}&#34;, response_model=UserResponse)
async def get_user(
    user_id: int,
    service: UserService = Depends(get_user_service)
):
    user = await service.get_user(user_id)
    if not user:
        raise HTTPException(status_code=404, detail=&#34;User not found&#34;)
    return user

@app.patch(&#34;/users/{user_id}&#34;, response_model=UserResponse)
async def update_user(
    user_id: int,
    request: UserUpdateRequest,
    service: UserService = Depends(get_user_service)
):
    try:
        data = UpdateUserDTO(
            username=request.username,
            email=request.email,
            is_active=request.is_active
        )
        user = await service.update_user(user_id, data)
        if not user:
            raise HTTPException(status_code=404, detail=&#34;User not found&#34;)
        return user
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.delete(&#34;/users/{user_id}&#34;)
async def delete_user(
    user_id: int,
    service: UserService = Depends(get_user_service)
):
    success = await service.delete_user(user_id)
    if not success:
        raise HTTPException(status_code=404, detail=&#34;User not found&#34;)
    return {&#34;message&#34;: &#34;User deleted&#34;}
```

---

## 🧪 測試

### Mock Repository

```python
import pytest
from unittest.mock import AsyncMock, MagicMock

class MockUserRepository(IUserRepository):
    &#34;&#34;&#34;Mock User Repository 用於測試&#34;&#34;&#34;

    def __init__(self):
        self.users: dict[int, User] = {}
        self._id_counter = 1

    async def get_by_id(self, user_id: int) -&gt; User | None:
        return self.users.get(user_id)

    async def get_by_email(self, email: str) -&gt; User | None:
        for user in self.users.values():
            if user.email == email:
                return user
        return None

    async def get_by_username(self, username: str) -&gt; User | None:
        for user in self.users.values():
            if user.username == username:
                return user
        return None

    async def get_all(self, skip: int = 0, limit: int = 100) -&gt; list[User]:
        users = list(self.users.values())
        return users[skip:skip &#43; limit]

    async def create(self, user: User) -&gt; User:
        user.id = self._id_counter
        self._id_counter &#43;= 1
        self.users[user.id] = user
        return user

    async def update(self, user: User) -&gt; User:
        self.users[user.id] = user
        return user

    async def delete(self, user_id: int) -&gt; bool:
        if user_id in self.users:
            del self.users[user_id]
            return True
        return False

    async def exists_by_email(self, email: str) -&gt; bool:
        return await self.get_by_email(email) is not None

    async def exists_by_username(self, username: str) -&gt; bool:
        return await self.get_by_username(username) is not None


# 測試 UserService
@pytest.mark.asyncio
async def test_create_user():
    repo = MockUserRepository()
    service = UserService(repo)

    data = CreateUserDTO(
        username=&#34;john&#34;,
        email=&#34;john@example.com&#34;,
        password=&#34;password123&#34;
    )

    user = await service.create_user(data)

    assert user.id == 1
    assert user.username == &#34;john&#34;
    assert user.email == &#34;john@example.com&#34;

@pytest.mark.asyncio
async def test_create_user_duplicate_email():
    repo = MockUserRepository()
    service = UserService(repo)

    # 先建立一個使用者
    data1 = CreateUserDTO(username=&#34;john&#34;, email=&#34;john@example.com&#34;, password=&#34;pass&#34;)
    await service.create_user(data1)

    # 嘗試建立相同 email 的使用者
    data2 = CreateUserDTO(username=&#34;jane&#34;, email=&#34;john@example.com&#34;, password=&#34;pass&#34;)

    with pytest.raises(ValueError, match=&#34;Email already registered&#34;):
        await service.create_user(data2)
```

---

## ✅ 重點總結

### Repository 模式結構

```
API Layer
    ↓
Service Layer (業務邏輯)
    ↓
Repository Layer (資料存取)
    ↓
Database
```

### 職責分離

| 層級 | 職責 |
|------|------|
| API | 處理請求、驗證輸入、回傳響應 |
| Service | 業務邏輯、跨 Repository 操作 |
| Repository | 資料存取、SQL 查詢 |

### 優點

1. **可測試性**：可以 Mock Repository 測試 Service
2. **可維護性**：修改資料庫不影響業務邏輯
3. **可重用性**：Repository 方法可以被多處使用
4. **關注點分離**：每層專注自己的工作

---

## 🎤 面試這樣答

### Q: 什麼是 Repository 模式？為什麼要用？

**答案：**

&gt; Repository 模式是將資料存取邏輯封裝成獨立的層，讓業務邏輯不直接依賴資料庫實作。
&gt;
&gt; **好處：**
&gt; 1. **可測試性**：可以用 Mock Repository 測試業務邏輯
&gt; 2. **關注點分離**：Service 專注業務邏輯，Repository 專注資料存取
&gt; 3. **可替換性**：可以輕鬆更換資料庫實作
&gt;
&gt; ```python
&gt; class UserRepository:
&gt;     async def get_by_id(self, id: int) -&gt; User | None: ...
&gt;     async def create(self, user: User) -&gt; User: ...
&gt;
&gt; class UserService:
&gt;     def __init__(self, repo: UserRepository):
&gt;         self.repo = repo
&gt;
&gt;     async def create_user(self, data) -&gt; User:
&gt;         # 業務邏輯在這裡
&gt;         if await self.repo.exists_by_email(data.email):
&gt;             raise ValueError(&#34;Email exists&#34;)
&gt;         return await self.repo.create(User(**data))
&gt; ```

---

**上一篇：** [03-4. 資料庫遷移 Alembic](./03-4)
**下一篇：** [03-6. 交易管理](./03-6)

---

最後更新：2025-12-17


---

> 作者: luk  
> URL: https://yoru-karu-blog-lalaluk-52581ac5e0cef170a3c8922c19182ecb6f7bd604.gitlab.io/posts/tutorial/fastapi/03-5/  

