Last week we got pytest to run asynchronous test methods. That was the preparation step for this post where we switch to asynchronous SQLAlchemy for our to-do application. As it turns out, switching to asynchronous methods for SQLAlchemy takes a lot of work. Let us get through the different changes we need to make.
This post is part of my journey to learn Python. You find the code for this post in my PythonFriday repository on GitHub.
Install the asynchronous SQLite driver
The default driver for SQLite only works with synchronous commands. To access SQLite with the asynchronous engine, we need to install this package:
1 |
pip install aiosqlite |
Update the data store tests
We start our change by turning our tests for the data store into asynchronous methods. As we learned last week, this consists of these 4 main steps:
- Add the async keyword in front of our test methods.
- Add the async decorator to our test methods.
- Add an await in front of all method calls to our data store.
- Add the async fixture decorator to our fixture.
In our fixture we call the function create_async_session_factory() to get the asynchronous session factory. While we could work with the session in our data store in the synchronous world, we cannot do the same with the asynchronous engine. If we try, we get an endless list of errors about closed transactions and other problems. Therefore, we will put the factory inside our data store.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os import pytest_asyncio from ..data.database import create_async_session_factory, create_session_factory from ..data.datastore_db import DataStoreDb from ..models.todo import TaskInput from datetime import date, datetime, timedelta import pytest @pytest_asyncio.fixture() async def with_db(): db_file = os.path.join( os.path.dirname(__file__), '..', 'db', 'test_db.sqlite') factory = await create_async_session_factory(db_file) yield factory @pytest.mark.asyncio async def test_can_add_entry(with_db): current_time = datetime.now() entry = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False) store = DataStoreDb(with_db) data = await store.add(entry) assert data.name == "a simple task" assert data.priority == 1 assert data.due_date == date.today() assert data.done == False assert data.created_at == date.today() assert data.id >= 1 @pytest.mark.asyncio async def test_can_add_multiple_entries(with_db): entry_a = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False) entry_b = TaskInput(name="b simple task", priority=2, due_date=date.today(), done=False) store = DataStoreDb(with_db) data_a = await store.add(entry_a) data_b = await store.add(entry_b) assert data_a.id < data_b.id @pytest.mark.asyncio async def test_can_get_specific_entry_back(with_db): entry_a = TaskInput(name="Find a specific task", priority=1, due_date=date.today(), done=False) store = DataStoreDb(with_db) saved = await store.add(entry_a) entry = await store.get(saved.id) assert saved == entry @pytest.mark.asyncio async def test_missing_entry_gets_None_back(with_db): store = DataStoreDb(with_db) entry = await store.get(-1) assert entry == None @pytest.mark.asyncio async def test_can_get_all_entrries_back(with_db): entry_a = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False) entry_b = TaskInput(name="b simple task", priority=2, due_date=date.today(), done=False) entry_c = TaskInput(name="b simple task", priority=2, due_date=date.today(), done=False) store = DataStoreDb(with_db) await store.add(entry_a) await store.add(entry_b) await store.add(entry_c) entries = await store.all() assert len(entries) >= 3 @pytest.mark.asyncio async def test_can_delete_entry(with_db): entry_a = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False) store = DataStoreDb(with_db) saved = await store.add(entry_a) id = saved.id print(id) await store.delete(id) result = await store.get(id) assert result == None @pytest.mark.asyncio async def test_can_update_entry(with_db): old = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False) store = DataStoreDb(with_db) old_saved = await store.add(old) new = TaskInput(name="b simple task", priority=2, due_date=date.today() + timedelta(days=2), done=True) await store.update(old_saved.id, new) entry = await store.get(old_saved.id) assert entry.name == "b simple task" assert entry.priority == 2 assert entry.due_date == date.today() + timedelta(days=2) assert entry.done == True @pytest.mark.asyncio async def test_non_existing_entry_cannot_be_updated(with_db): store = DataStoreDb(with_db) new = TaskInput(name="b simple task", priority=2, due_date=date.today() + timedelta(days=2), done=True) with pytest.raises(ValueError) as e_info: await store.update(-123, new) assert str(e_info.value) == "no taks known with id '-123'" @pytest.mark.asyncio async def test_fetches_statistics(with_db): store = DataStoreDb(with_db) await store.add(TaskInput(name="counter", priority=1, due_date=date.today(), done=False)) stats = await store.get_statistics() assert stats.total_tasks == stats.total_open + stats.total_done assert stats.total_tasks >= 1 |
Create the asynchronous session factory
In our data/database.py file, we can add the new method to create an asynchronous session factory:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from typing import AsyncIterator from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine def create_session_factory(db_file: str) -> sessionmaker: engine = create_engine( 'sqlite:///' + db_file, connect_args={"check_same_thread": False}, echo=False ) factory = sessionmaker(autocommit=False, autoflush=False, bind=engine) return factory async def create_async_session_factory(db_file: str) -> AsyncIterator[AsyncSession]: engine = create_async_engine( 'sqlite+aiosqlite:///' + db_file, connect_args={"check_same_thread": False}, echo=False ) factory = async_sessionmaker(engine, autocommit=False, autoflush=False, expire_on_commit=False, class_=AsyncSession) return factory |
We now use our new driver to access the SQLite database, but otherwise the method works as before.
Change the DataStoreDb
For our DataStoreDb class, we need to make these adjustments:
- Use AsyncIterator[AsyncSession] instead of the session in the constructor.
- Add the async keyword in front of all public methods.
- Wrap an async with block around everything that uses the database.
- Rewrite the queries to run asynchronously.
- In the statistics method, we turn the result cursor into a list to prevent a cursor error.
With all those changes, our DataStoreDb class now looks like this:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from typing import AsyncIterator from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func, select from datetime import date, datetime from ..models.statistics import StatisticOverview from ..models.todo import TaskInput, TaskOutput from .entities import Task class DataStoreDb: def __init__(self, db: AsyncIterator[AsyncSession]): self.db = db async def add(self, entry: TaskInput) -> TaskOutput: async with self.db() as session: task = Task(id=None, created_at=datetime.now(), **dict(entry)) session.add(task) await session.commit() return self.__to_output(task) async def get(self, id: int) -> TaskOutput: async with self.db() as session: query = select(Task).where(Task.id == id) result = await session.scalar(query) if result: return self.__to_output(result) else: return None async def all(self): async with self.db() as session: query = select(Task) entries = await session.scalars(query) results = [] for entry in entries: results.append(self.__to_output(entry)) return results async def delete(self, id: int) -> None: async with self.db() as session: query = select(Task).where(Task.id == id) entry = await session.scalar(query) if entry: await session.delete(entry) await session.commit() async def update(self, id: int, update: TaskInput) -> TaskOutput: async with self.db() as session: query = select(Task).where(Task.id == id) entry = await session.scalar(query) if entry: entry.name = update.name entry.priority = update.priority entry.due_date = update.due_date entry.done = update.done await session.commit() return self.__to_output(entry) else: raise ValueError(f"no taks known with id '{id}'") async def get_statistics(self) -> StatisticOverview: async with self.db() as session: query = ( select( func.count("*").label("total"), func.count("*").filter(Task.done==True).label("done"), func.count("*").filter(Task.done==False).label("open") , ) ) result_db = await session.execute(query) #https://stackoverflow.com/questions/36515882/command-cursor-object-is-not-subscriptable result = list(result_db)[0] return StatisticOverview(total_tasks=result[0], total_done=result[1], total_open=result[2]) def __to_output(self, entity: Task) -> TaskOutput: return TaskOutput(id=entity.id, name=entity.name, priority=entity.priority, due_date=entity.due_date, done=entity.done, created_at=date.today()) |
We can now run our tests for the datastore, and they all should pass. However, the tests for FastAPI application will fail. Let us fix that.
Fix the endpoint tests
The tests for our FastAPI endpoints need an adjustment for the override_get_db() function. All the methods need the await keyword and the marker, and we need to await the call to the prepare_task() function:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from datetime import date, timedelta import os import re from bs4 import BeautifulSoup from fastapi.testclient import TestClient import pytest from ..dependencies import get_db from ..data.datastore_db import DataStoreDb from ..data.database import create_async_session_factory from ..main import app import logging logging.getLogger("httpx").setLevel(logging.WARNING) async def override_get_db(): db_file = os.path.join( os.path.dirname(__file__), '..', 'db', 'test_db.sqlite') factory = await create_async_session_factory(db_file) db = DataStoreDb(factory) yield db client = TestClient(app) app.dependency_overrides[get_db] = override_get_db @pytest.mark.asyncio async def test_create_task(): data = { "name": "A first task", "priority": 5, "due_date": str(date.today() + timedelta(days=1)), "done": False } response = client.post("/api/todo/", json=data) assert response.status_code == 201 result = response.json() assert result['id'] > 0 assert result['done'] == False assert result['created_at'] == str(date.today()) assert result['name'] == data['name'] assert result['priority'] == data['priority'] assert result['due_date'] == data['due_date'] assert f"http://testserver/api/todo/{result['id']}" == response.headers['location'] @pytest.mark.asyncio async def prepare_task(name, priority=4, due_date=None, done=False): if due_date == None: due_date = date.today() + timedelta(days=1) data = { "name": name, "priority": priority, "due_date": str(due_date), "done": done } prepare_response = client.post("/api/todo/", json=data) assert prepare_response.status_code == 201 return prepare_response.json()['id'] @pytest.mark.asyncio async def test_show_task(): name = "A second task" id = await prepare_task(name) response = client.get(f"/api/todo/{id}") assert response.status_code == 200 details = response.json() assert details['name'] == name @pytest.mark.asyncio async def test_show_task_where_task_is_unknown(): response = client.get(f"/api/todo/-1") assert response.status_code == 404 assert response.json()['detail'] == "Task not found" @pytest.mark.asyncio async def test_update_task(): id = await prepare_task("original") update = { "name": "An updated task", "priority": 5, "due_date": str(date.today() + timedelta(days=2)), "done": False } response = client.put(f"/api/todo/{id}", json=update) assert response.status_code == 200 assert response.json()['name'] == "An updated task" check = client.get(f"/api/todo/{id}") assert check.json()['name'] == "An updated task" @pytest.mark.asyncio async def test_delete_task(): id = await prepare_task("to delete") response = client.delete(f"/api/todo/{id}") assert response.status_code == 204 check = client.get(f"/api/todo/{id}") assert check.status_code == 404 @pytest.mark.asyncio async def test_main_page_shows_info_message(): response = client.get("/") assert response.status_code == 200 assert response.json()['message'] == "The minimalistic ToDo API" @pytest.mark.asyncio async def test_show_all_tasks(): await prepare_task("a first task") await prepare_task("a second task") await prepare_task("a third task") response = client.get("/api/todo") assert response.status_code == 200 tasks = response.json() assert len(tasks) >= 3 @pytest.mark.asyncio async def test_show_all_tasks_that_are_not_done(): await prepare_task("a finished task", done=True) await prepare_task("an open task", done=False) response = client.get("/api/todo?include_done=false") assert response.status_code == 200 tasks = response.json() done = [task for task in tasks if task['done'] == True] assert len(done) == 0 @pytest.mark.asyncio async def test_show_all_tasks_that_are_due_within_five_days(): await prepare_task("in 10 days", due_date=date.today() + timedelta(days=10)) response = client.get(f"/api/todo?due_before={date.today() + timedelta(days=5)}") assert response.status_code == 200 tasks = response.json() done = [task for task in tasks if date.fromisoformat(task['due_date']) > date.today() + timedelta(days=5)] assert len(done) == 0 @pytest.mark.asyncio async def test_docs_endpoint_works(): response = client.get("/openapi.json") # No exception -> test passes @pytest.mark.asyncio async def test_about_page(): response = client.get("/about") assert response.status_code == 200 soup = BeautifulSoup(response.text, 'html.parser') assert soup.title.text == "About To-Do Task API" assert soup.body.h1.text == "About" @pytest.mark.asyncio async def test_dashboard(): response = client.get("/dashboard") assert response.status_code == 200 soup = BeautifulSoup(response.text, 'html.parser') assert soup.title.text == "Dashboard To-Do Task API" assert soup.body.h1.text == "Dashboard" numbers = re.findall(r"\d+", soup.body.p.text) assert int(numbers[0]) == int(numbers[1]) + int(numbers[2]) |
Fix the dependencies.py code
We need to switch to the create_async_session_factory() function and change how we initialise our data store:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import os from .data.database import create_async_session_factory from .data.datastore_db import DataStoreDb async def get_db(): """ Creates the datastore """ db_file = os.path.join( os.path.dirname(__file__), '.', 'db', 'todo_api.sqlite') factory = await create_async_session_factory(db_file) db = DataStoreDb(factory) yield db |
Fix the FastAPI endpoints
In all our endpoints we need to await the call to the db.* methods, the rest of the files can stay the same.
Our main.py file has one place to await the call to the db.get_statistics() method:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from fastapi import Depends, FastAPI, Request from fastapi.encoders import jsonable_encoder from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from .data.datastore_db import DataStoreDb from .dependencies import get_db from .routers import todo from pathlib import Path BASE_DIR = Path(__file__).resolve().parent app = FastAPI() app.include_router(todo.router, prefix="/api/todo") app.mount("/static", StaticFiles(directory=str(Path(BASE_DIR, 'static'))), name="static") templates = Jinja2Templates(directory=str(Path(BASE_DIR, 'templates'))) @app.get("/about", response_class=HTMLResponse) async def about(request: Request): return templates.TemplateResponse( request=request, name="about.html" ) @app.get("/", include_in_schema=False) async def main(): return {'message':'The minimalistic ToDo API'} @app.get("/dashboard", include_in_schema=False) async def dashboard(request: Request, db: DataStoreDb = Depends(get_db)): stats = await db.get_statistics() return templates.TemplateResponse( request=request, name="dashboard.html", context=jsonable_encoder(stats) ) |
In the routers/todo.py file, we need an await in every method:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from datetime import date, timedelta from typing import Annotated, List from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from ..dependencies import get_db from ..models.todo import TaskOutput, TaskInput from ..data.datastore_db import DataStoreDb router = APIRouter() async def filter_parameters(q: str | None = None, include_done: bool = True, due_before: date = date.today() + timedelta(days=365)): return {"q": q, "include_done": include_done, "due_before": due_before } @router.get("/") async def show_all_tasks(filter: Annotated[dict, Depends(filter_parameters)], db: DataStoreDb = Depends(get_db)) -> List[TaskOutput]: result = await db.all() if not filter["include_done"]: result = [item for item in result if item.done == False ] result = [item for item in result if item.due_date <= filter["due_before"] ] return result @router.post("/", status_code=status.HTTP_201_CREATED) async def create_task(task: TaskInput, request: Request, db: DataStoreDb = Depends(get_db)) -> TaskOutput: result = await db.add(task) headers = {"Location": f"{request.base_url}api/todo/{result.id}"} return JSONResponse(content=jsonable_encoder(result), status_code=status.HTTP_201_CREATED, headers=headers) @router.get("/{id}") async def show_task(id: int, db: DataStoreDb = Depends(get_db)) -> TaskOutput: result = await db.get(id) if result: return result else: raise HTTPException(status_code=404, detail="Task not found") @router.put("/{id}") async def update_task(id: int, task: TaskInput, db: DataStoreDb = Depends(get_db)) -> TaskOutput: try: result = await db.update(id, task) return result except ValueError: raise HTTPException(status_code=404, detail="Task not found") @router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_task(id: int, db: DataStoreDb = Depends(get_db)) -> None: await db.delete(id) return Response(status_code=status.HTTP_204_NO_CONTENT) |
We can now run all tests and they should pass. Without those tests we would need to spend an awful amount of time to run the application by hand. Therefore, make sure that you have a good test coverage for your code before you start such a massive change.
Next
With all these changes we can now run our to-do application with asynchronous SQLAlchemy. As you can see in this post, even for a small application that is a massive change. It took me a few rounds to get everything back into a working state and I would not do it without a reliable test suite.
Next week we explore a way to get rid of our hand-written filter and replace it with something more powerful.
2 thoughts on “Python Friday #240: Asynchronous SQLAlchemy With FastAPI”