03-9. 測試資料庫操作

⏱️ 閱讀時間: 18 分鐘 🎯 難度: ⭐⭐⭐ (進階)


🤔 一句話解釋

測試資料庫操作需要隔離的測試環境,確保測試不會影響真實資料,且每次測試結果可預測。


🎯 測試策略

┌─────────────────────────────────────────────────────────┐
│                    測試金字塔                            │
├─────────────────────────────────────────────────────────┤
│                    ┌─────┐                              │
│                   /  E2E  \     少量端對端測試           │
│                  /─────────\                            │
│                 / Integration\   整合測試               │
│                /───────────────\                        │
│               /    Unit Tests   \  大量單元測試          │
│              /───────────────────\                      │
└─────────────────────────────────────────────────────────┘

資料庫測試類型

類型說明速度
Mock不實際連接資料庫最快
SQLite 記憶體使用記憶體資料庫
測試資料庫使用獨立的測試 DB中等
容器化Docker 啟動測試 DB較慢

📦 安裝測試工具

pip install pytest pytest-asyncio pytest-cov httpx
pip install factory-boy faker  # 測試資料生成

🔧 測試環境設定

pytest 設定

# conftest.py
import pytest
import asyncio
from typing import AsyncGenerator, Generator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from httpx import AsyncClient, ASGITransport

from app.main import app
from app.database import Base, get_db
from app.models import User, Post

# 測試用資料庫 URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"

# 建立測試 Engine
test_engine = create_async_engine(
    TEST_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,  # SQLite 記憶體需要使用 StaticPool
)

# 測試 Session 工廠
TestSessionLocal = sessionmaker(
    test_engine,
    class_=AsyncSession,
    expire_on_commit=False,
    autoflush=False,
    autocommit=False,
)


@pytest.fixture(scope="session")
def event_loop() -> Generator:
    """建立事件迴圈"""
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()


@pytest.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]:
    """每個測試函數使用獨立的資料庫 Session"""
    # 建立資料表
    async with test_engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)

    # 建立 Session
    async with TestSessionLocal() as session:
        yield session

    # 清理資料表
    async with test_engine.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)


@pytest.fixture(scope="function")
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
    """測試用 HTTP 客戶端"""

    # 覆蓋依賴項
    async def override_get_db():
        yield db_session

    app.dependency_overrides[get_db] = override_get_db

    async with AsyncClient(
        transport=ASGITransport(app=app),
        base_url="http://test"
    ) as ac:
        yield ac

    app.dependency_overrides.clear()

使用 PostgreSQL 測試

# conftest.py(PostgreSQL 版本)
import pytest
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker

# 測試資料庫
TEST_DATABASE_URL = "postgresql+asyncpg://test:test@localhost:5432/test_db"

test_engine = create_async_engine(TEST_DATABASE_URL)
TestSessionLocal = sessionmaker(
    test_engine,
    class_=AsyncSession,
    expire_on_commit=False,
)


@pytest.fixture(scope="function")
async def db_session():
    """使用交易隔離測試"""
    async with test_engine.connect() as conn:
        # 開始交易
        trans = await conn.begin()

        # 建立 Session(使用同一個連線)
        async with AsyncSession(bind=conn) as session:
            yield session

        # 回滾交易(不影響資料庫)
        await trans.rollback()

📝 Repository 測試

測試基本 CRUD

# tests/test_user_repository.py
import pytest
from sqlalchemy.ext.asyncio import AsyncSession

from app.models import User
from app.repositories.user_repository import UserRepository


@pytest.mark.asyncio
class TestUserRepository:
    """User Repository 測試"""

    async def test_create_user(self, db_session: AsyncSession):
        """測試建立使用者"""
        repo = UserRepository(db_session)

        user = User(
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )
        created = await repo.create(user)

        assert created.id is not None
        assert created.username == "testuser"
        assert created.email == "test@example.com"

    async def test_get_by_id(self, db_session: AsyncSession):
        """測試根據 ID 查詢"""
        repo = UserRepository(db_session)

        # 先建立
        user = User(
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )
        created = await repo.create(user)

        # 再查詢
        found = await repo.get_by_id(created.id)

        assert found is not None
        assert found.id == created.id
        assert found.username == "testuser"

    async def test_get_by_id_not_found(self, db_session: AsyncSession):
        """測試查詢不存在的使用者"""
        repo = UserRepository(db_session)

        found = await repo.get_by_id(99999)

        assert found is None

    async def test_get_by_email(self, db_session: AsyncSession):
        """測試根據 email 查詢"""
        repo = UserRepository(db_session)

        user = User(
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )
        await repo.create(user)

        found = await repo.get_by_email("test@example.com")

        assert found is not None
        assert found.email == "test@example.com"

    async def test_exists_by_email(self, db_session: AsyncSession):
        """測試 email 是否存在"""
        repo = UserRepository(db_session)

        # 不存在
        assert await repo.exists_by_email("test@example.com") is False

        # 建立後存在
        user = User(
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )
        await repo.create(user)

        assert await repo.exists_by_email("test@example.com") is True

    async def test_update_user(self, db_session: AsyncSession):
        """測試更新使用者"""
        repo = UserRepository(db_session)

        user = User(
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )
        created = await repo.create(user)

        # 更新
        created.username = "newname"
        updated = await repo.update(created)

        assert updated.username == "newname"

        # 重新查詢確認
        found = await repo.get_by_id(created.id)
        assert found.username == "newname"

    async def test_delete_user(self, db_session: AsyncSession):
        """測試刪除使用者"""
        repo = UserRepository(db_session)

        user = User(
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )
        created = await repo.create(user)

        # 刪除
        result = await repo.delete(created.id)
        assert result is True

        # 確認已刪除
        found = await repo.get_by_id(created.id)
        assert found is None

    async def test_delete_nonexistent_user(self, db_session: AsyncSession):
        """測試刪除不存在的使用者"""
        repo = UserRepository(db_session)

        result = await repo.delete(99999)
        assert result is False

    async def test_get_all_with_pagination(self, db_session: AsyncSession):
        """測試分頁查詢"""
        repo = UserRepository(db_session)

        # 建立多個使用者
        for i in range(10):
            user = User(
                username=f"user{i}",
                email=f"user{i}@example.com",
                hashed_password="hashed"
            )
            await repo.create(user)

        # 測試分頁
        page1 = await repo.get_all(skip=0, limit=5)
        assert len(page1) == 5

        page2 = await repo.get_all(skip=5, limit=5)
        assert len(page2) == 5

        # 確保不重複
        page1_ids = {u.id for u in page1}
        page2_ids = {u.id for u in page2}
        assert page1_ids.isdisjoint(page2_ids)

🏗️ Service 測試

使用 Mock Repository

# tests/test_user_service.py
import pytest
from unittest.mock import AsyncMock, MagicMock

from app.services.user_service import UserService, CreateUserDTO
from app.models import User


@pytest.mark.asyncio
class TestUserService:
    """User Service 測試"""

    async def test_create_user_success(self):
        """測試成功建立使用者"""
        # Mock Repository
        mock_repo = AsyncMock()
        mock_repo.exists_by_email.return_value = False
        mock_repo.exists_by_username.return_value = False
        mock_repo.create.return_value = User(
            id=1,
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )

        service = UserService(mock_repo)
        data = CreateUserDTO(
            username="testuser",
            email="test@example.com",
            password="password123"
        )

        user = await service.create_user(data)

        assert user.id == 1
        assert user.username == "testuser"
        mock_repo.exists_by_email.assert_called_once_with("test@example.com")
        mock_repo.exists_by_username.assert_called_once_with("testuser")
        mock_repo.create.assert_called_once()

    async def test_create_user_email_exists(self):
        """測試 email 已存在"""
        mock_repo = AsyncMock()
        mock_repo.exists_by_email.return_value = True

        service = UserService(mock_repo)
        data = CreateUserDTO(
            username="testuser",
            email="existing@example.com",
            password="password123"
        )

        with pytest.raises(ValueError, match="Email already registered"):
            await service.create_user(data)

    async def test_create_user_username_exists(self):
        """測試 username 已存在"""
        mock_repo = AsyncMock()
        mock_repo.exists_by_email.return_value = False
        mock_repo.exists_by_username.return_value = True

        service = UserService(mock_repo)
        data = CreateUserDTO(
            username="existing",
            email="test@example.com",
            password="password123"
        )

        with pytest.raises(ValueError, match="Username already taken"):
            await service.create_user(data)

    async def test_get_user_found(self):
        """測試取得存在的使用者"""
        mock_repo = AsyncMock()
        mock_repo.get_by_id.return_value = User(
            id=1,
            username="testuser",
            email="test@example.com",
            hashed_password="hashed"
        )

        service = UserService(mock_repo)
        user = await service.get_user(1)

        assert user is not None
        assert user.id == 1
        mock_repo.get_by_id.assert_called_once_with(1)

    async def test_get_user_not_found(self):
        """測試取得不存在的使用者"""
        mock_repo = AsyncMock()
        mock_repo.get_by_id.return_value = None

        service = UserService(mock_repo)
        user = await service.get_user(99999)

        assert user is None

🌐 API 整合測試

測試端點

# tests/test_user_api.py
import pytest
from httpx import AsyncClient


@pytest.mark.asyncio
class TestUserAPI:
    """User API 整合測試"""

    async def test_create_user(self, client: AsyncClient):
        """測試建立使用者 API"""
        response = await client.post(
            "/users",
            json={
                "username": "testuser",
                "email": "test@example.com",
                "password": "password123"
            }
        )

        assert response.status_code == 200
        data = response.json()
        assert data["username"] == "testuser"
        assert data["email"] == "test@example.com"
        assert "id" in data
        assert "password" not in data  # 密碼不應該返回

    async def test_create_user_invalid_email(self, client: AsyncClient):
        """測試無效的 email"""
        response = await client.post(
            "/users",
            json={
                "username": "testuser",
                "email": "invalid-email",
                "password": "password123"
            }
        )

        assert response.status_code == 422  # Validation error

    async def test_create_user_duplicate_email(self, client: AsyncClient):
        """測試重複的 email"""
        # 建立第一個使用者
        await client.post(
            "/users",
            json={
                "username": "user1",
                "email": "test@example.com",
                "password": "password123"
            }
        )

        # 嘗試建立重複 email 的使用者
        response = await client.post(
            "/users",
            json={
                "username": "user2",
                "email": "test@example.com",
                "password": "password123"
            }
        )

        assert response.status_code == 400
        assert "Email already registered" in response.json()["detail"]

    async def test_get_users(self, client: AsyncClient):
        """測試取得使用者列表"""
        # 建立幾個使用者
        for i in range(3):
            await client.post(
                "/users",
                json={
                    "username": f"user{i}",
                    "email": f"user{i}@example.com",
                    "password": "password123"
                }
            )

        response = await client.get("/users")

        assert response.status_code == 200
        data = response.json()
        assert len(data) == 3

    async def test_get_user_by_id(self, client: AsyncClient):
        """測試根據 ID 取得使用者"""
        # 建立使用者
        create_response = await client.post(
            "/users",
            json={
                "username": "testuser",
                "email": "test@example.com",
                "password": "password123"
            }
        )
        user_id = create_response.json()["id"]

        # 查詢使用者
        response = await client.get(f"/users/{user_id}")

        assert response.status_code == 200
        data = response.json()
        assert data["id"] == user_id
        assert data["username"] == "testuser"

    async def test_get_user_not_found(self, client: AsyncClient):
        """測試查詢不存在的使用者"""
        response = await client.get("/users/99999")

        assert response.status_code == 404
        assert "not found" in response.json()["detail"].lower()

    async def test_update_user(self, client: AsyncClient):
        """測試更新使用者"""
        # 建立使用者
        create_response = await client.post(
            "/users",
            json={
                "username": "testuser",
                "email": "test@example.com",
                "password": "password123"
            }
        )
        user_id = create_response.json()["id"]

        # 更新使用者
        response = await client.patch(
            f"/users/{user_id}",
            json={"username": "newname"}
        )

        assert response.status_code == 200
        data = response.json()
        assert data["username"] == "newname"

    async def test_delete_user(self, client: AsyncClient):
        """測試刪除使用者"""
        # 建立使用者
        create_response = await client.post(
            "/users",
            json={
                "username": "testuser",
                "email": "test@example.com",
                "password": "password123"
            }
        )
        user_id = create_response.json()["id"]

        # 刪除使用者
        response = await client.delete(f"/users/{user_id}")
        assert response.status_code == 200

        # 確認已刪除
        get_response = await client.get(f"/users/{user_id}")
        assert get_response.status_code == 404

🏭 測試資料工廠

使用 Factory Boy

# tests/factories.py
import factory
from factory import fuzzy
from faker import Faker

from app.models import User, Post
from app.database import SessionLocal

fake = Faker()


class SQLAlchemyModelFactory(factory.Factory):
    """SQLAlchemy Model 工廠基礎類別"""

    class Meta:
        abstract = True

    @classmethod
    def _create(cls, model_class, *args, **kwargs):
        """同步建立(用於 Fixture)"""
        return model_class(*args, **kwargs)


class UserFactory(SQLAlchemyModelFactory):
    """User 測試資料工廠"""

    class Meta:
        model = User

    username = factory.LazyFunction(lambda: fake.user_name()[:50])
    email = factory.LazyFunction(fake.email)
    hashed_password = factory.LazyFunction(
        lambda: fake.sha256()
    )
    is_active = True


class PostFactory(SQLAlchemyModelFactory):
    """Post 測試資料工廠"""

    class Meta:
        model = Post

    title = factory.LazyFunction(lambda: fake.sentence()[:200])
    content = factory.LazyFunction(fake.text)
    author = factory.SubFactory(UserFactory)

在測試中使用工廠

# tests/test_with_factories.py
import pytest
from sqlalchemy.ext.asyncio import AsyncSession

from tests.factories import UserFactory, PostFactory
from app.repositories.user_repository import UserRepository


@pytest.mark.asyncio
class TestWithFactories:

    async def test_user_factory(self, db_session: AsyncSession):
        """使用工廠建立測試資料"""
        # 建立使用者
        user = UserFactory()
        db_session.add(user)
        await db_session.commit()

        assert user.id is not None
        assert user.username is not None
        assert user.email is not None

    async def test_create_multiple_users(self, db_session: AsyncSession):
        """建立多個測試使用者"""
        users = [UserFactory() for _ in range(5)]
        db_session.add_all(users)
        await db_session.commit()

        repo = UserRepository(db_session)
        all_users = await repo.get_all()

        assert len(all_users) == 5

    async def test_user_with_posts(self, db_session: AsyncSession):
        """建立使用者和文章"""
        user = UserFactory()
        posts = [PostFactory(author=user) for _ in range(3)]

        db_session.add(user)
        db_session.add_all(posts)
        await db_session.commit()

        # 驗證關聯
        await db_session.refresh(user)
        assert len(user.posts) == 3

🐳 Docker 測試環境

docker-compose.test.yml

version: '3.8'

services:
  test-db:
    image: postgres:15
    environment:
      POSTGRES_USER: test
      POSTGRES_PASSWORD: test
      POSTGRES_DB: test_db
    ports:
      - "5433:5432"
    tmpfs:
      - /var/lib/postgresql/data  # 使用記憶體加速

測試設定

# conftest.py(Docker 版本)
import pytest
import subprocess
import time

@pytest.fixture(scope="session", autouse=True)
def docker_compose():
    """啟動測試資料庫容器"""
    subprocess.run(
        ["docker-compose", "-f", "docker-compose.test.yml", "up", "-d"],
        check=True
    )

    # 等待資料庫就緒
    time.sleep(3)

    yield

    # 清理
    subprocess.run(
        ["docker-compose", "-f", "docker-compose.test.yml", "down", "-v"],
        check=True
    )

📊 測試覆蓋率

執行測試

# 執行所有測試
pytest

# 執行特定測試檔案
pytest tests/test_user_repository.py

# 執行特定測試類別
pytest tests/test_user_repository.py::TestUserRepository

# 執行特定測試函數
pytest tests/test_user_repository.py::TestUserRepository::test_create_user

# 顯示詳細輸出
pytest -v

# 顯示 print 輸出
pytest -s

# 測試覆蓋率
pytest --cov=app --cov-report=html

# 只跑非同步測試
pytest -m asyncio

pytest.ini 設定

[pytest]
asyncio_mode = auto
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short
markers =
    asyncio: mark test as async
    slow: mark test as slow

✅ 重點總結

測試層級

層級測試對象資料庫
UnitRepository/ServiceMock 或 SQLite
IntegrationAPI測試資料庫
E2E完整流程測試資料庫

測試要點

  1. 隔離性:每個測試獨立,不互相影響
  2. 可重複:多次執行結果一致
  3. 快速:使用記憶體資料庫或交易回滾
  4. 完整:涵蓋正常和異常情況

常用 Fixture

@pytest.fixture
async def db_session():  # 資料庫 Session
    ...

@pytest.fixture
async def client():  # HTTP 客戶端
    ...

@pytest.fixture
def user_factory():  # 測試資料工廠
    ...

🎤 面試這樣答

Q: 如何測試資料庫操作?

答案:

測試資料庫操作有幾種策略:

  1. 使用測試資料庫:獨立的測試 DB,每次測試前清空
  2. 交易回滾:每個測試在交易中執行,結束後回滾
  3. SQLite 記憶體:快速但可能有 SQL 差異
@pytest.fixture
async def db_session():
    async with test_engine.connect() as conn:
        trans = await conn.begin()
        async with AsyncSession(bind=conn) as session:
            yield session
        await trans.rollback()  # 測試後回滾

重點是確保測試隔離可重複快速


上一篇: 03-8. 多資料庫支援 下一篇: 04-1. 認證基礎


最後更新:2025-12-17

0%