SQL (Relational) Databases with Peewee

Warning

If you are just starting, the SQLAlchemy tutorial should be enough.

Feel free to skip this.

If you are starting a project from scratch, you are probably better off with SQLAlchemy ORM, or any other async ORM.

If you already have a code base that uses Peewee ORM, you can check here how to use it with FastAPI.

Python 3.7+ required

You will need Python 3.7 or above to safely use Peewee with FastAPI.

Peewee for async

Peewee was not designed for async frameworks, or with them in mind.

Peewee has some heavy assumptions about its defaults and about how it should be used.

If you are developing an application with an older non-async framework, and can work with all its defaults, it can be a great tool.

But if you need to change some of the defaults, support more than one predefined database, work with an async framework (like FastAPI), etc, you will need to add quite some complex extra code to override those defaults.

Nevertheless, it's possible to do it, and here you'll see exactly what code you have to add to be able to use Peewee with FastAPI.

Technical Details

You can read more about Peewee's stand about async in Python in the docs, an issue, a PR.

The same app

We are going to create the same application as in the SQLAlchemy tutorial.

Most of the code is actually the same.

So, we are going to focus only on the differences.

File structure

Let's say you have a directory named my_super_project that contains a sub-directory called sql_app with a structure like this:

.
└── sql_app
    ├── __init__.py
    ├── crud.py
    ├── database.py
    ├── main.py
    └── schemas.py

This is almost the same structure as we had for the SQLAlchemy tutorial.

Now let's see what each file/module does.

Create the Peewee parts

Let's refer to the file sql_app/database.py.

The standard Peewee code

Let's first check all the normal Peewee code, create a Peewee database:

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

Tip

Have in mind that if you wanted to use a different database, like PostgreSQL, you couldn't just change the string. You would need to use a different Peewee database class.

Note

The argument:

check_same_thread=False

is equivalent to the one in the SQLAlchemy tutorial:

connect_args={"check_same_thread": False}

...it is needed only for SQLite.

Technical Details

Exactly the same technical details as in the SQLAlchemy tutorial apply.

Make Peewee async-compatible PeeweeConnectionState

The main issue with Peewee and FastAPI is that Peewee relies heavily on Python's threading.local, and it doesn't have a direct way to override it or let you handle connections/sessions directly (as is done in the SQLAlchemy tutorial).

And threading.local is not compatible with the new async features of modern Python.

Technical Details

threading.local is used to have a "magic" variable that has a different value for each thread.

This was useful in older frameworks designed to have one single thread per request, no more, no less.

Using this, each request would have its own database connection/session, which is the actual final goal.

But FastAPI, using the new async features, could handle more than one request on the same thread. And at the same time, for a single request, it could run multiple things in different threads (in a threadpool), depending on if you use async def or normal def. This is what gives all the performance improvements to FastAPI.

But Python 3.7 and above provide a more advanced alternative to threading.local, that can also be used in the places where threading.local would be used, but is compatible with the new async features.

We are going to use that. It's called contextvars.

We are going to override the internal parts of Peewee that use threading.local and replace them with contextvars, with the corresponding updates.

This might seem a bit complex (and it actually is), you don't really need to completely understand how it works to use it.

We will create a PeeweeConnectionState:

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

This class inherits from a special internal class used by Peewee.

It has all the logic to make Peewee use contextvars instead of threading.local.

contextvars works a bit differently than threading.local. But the rest of Peewee's internal code assumes that this class works with threading.local.

So, we need to do some extra tricks to make it work as if it was just using threading.local. The __init__, __setattr__, and __getattr__ implement all the required tricks for this to be used by Peewee without knowing that it is now compatible with FastAPI.

Tip

This will just make Peewee behave correctly when used with FastAPI. Not randomly opening or closing connections that are being used, creating errors, etc.

But it doesn't give Peewee async super-powers. You should still use normal def functions and not async def.

Use the custom PeeweeConnectionState class

Now, overwrite the ._state internal attribute in the Peewee database db object using the new PeeweeConnectionState:

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

Tip

Make sure you overwrite db._state after creating db.

Tip

You would do the same for any other Peewee database, including PostgresqlDatabase, MySQLDatabase, etc.

Create the database models

Let's now see the file sql_app/models.py.

Create Peewee models for our data

Now create the Peewee models (classes) for User and Item.

This is the same you would do if you followed the Peewee tutorial and updated the models to have the same data as in the SQLAlchemy tutorial.

Tip

Peewee also uses the term "model" to refer to these classes and instances that interact with the database.

But Pydantic also uses the term "model" to refer to something different, the data validation, conversion, and documentation classes and instances.

Import db from database (the file database.py from above) and use it here.

import peewee

from .database import db


class User(peewee.Model):
    email = peewee.CharField(unique=True, index=True)
    hashed_password = peewee.CharField()
    is_active = peewee.BooleanField(default=True)

    class Meta:
        database = db


class Item(peewee.Model):
    title = peewee.CharField(index=True)
    description = peewee.CharField(index=True)
    owner = peewee.ForeignKeyField(User, backref="items")

    class Meta:
        database = db

Tip

Peewee creates several magic attributes.

It will automatically add an id attribute as an integer to be the primary key.

It will chose the name of the tables based on the class names.

For the Item, it will create an attribute owner_id with the integer ID of the User. But we don't declare it anywhere.

Create the Pydantic models

Now let's check the file sql_app/schemas.py.

Tip

To avoid confusion between the Peewee models and the Pydantic models, we will have the file models.py with the Peewee models, and the file schemas.py with the Pydantic models.

These Pydantic models define more or less a "schema" (a valid data shape).

So this will help us avoiding confusion while using both.

Create the Pydantic models / schemas

Create all the same Pydantic models as in the SQLAlchemy tutorial:

from typing import Any, List

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: str = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict

Tip

Here we are creating the models with an id.

We didn't explicitly specify an id attribute in the Peewee models, but Peewee adds one automatically.

We are also adding the magic owner_id attribute to Item.

Create a PeeweeGetterDict for the Pydantic models / schemas

When you access a relationship in a Peewee object, like in some_user.items, Peewee doesn't provide a list of Item.

It provides a special custom object of class ModelSelect.

It's possible to create a list of its items with list(some_user.items).

But the object itself is not a list. And it's also not an actual Python generator. Because of this, Pydantic doesn't know by default how to convert it to a list of Pydantic models / schemas.

But recent versions of Pydantic allow providing a custom class that inherits from pydantic.utils.GetterDict, to provide the functionality used when using the orm_mode = True to retrieve the values for ORM model attributes.

We are going to create a custom PeeweeGetterDict class and use it in all the same Pydantic models / schemas that use orm_mode:

from typing import Any, List

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: str = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict

Here we are checking if the attribute that is being accessed (e.g. .items in some_user.items) is an instance of peewee.ModelSelect.

And if that's the case, just return a list with it.

And then we use it in the Pydantic models / schemas that use orm_mode = True, with the configuration variable getter_dict = PeeweeGetterDict.

Tip

We only need to create one PeeweeGetterDict class, and we can use it in all the Pydantic models / schemas.

CRUD utils

Now let's see the file sql_app/crud.py.

Create all the CRUD utils

Create all the same CRUD utils as in the SQLAlchemy tutorial, all the code is very similar:

from . import models, schemas


def get_user(user_id: int):
    return models.User.filter(models.User.id == user_id).first()


def get_user_by_email(email: str):
    return models.User.filter(models.User.email == email).first()


def get_users(skip: int = 0, limit: int = 100):
    return list(models.User.select().offset(skip).limit(limit))


def create_user(user: schemas.UserCreate):
    fake_hashed_password = user.password + "notreallyhashed"
    db_user = models.User(email=user.email, hashed_password=fake_hashed_password)
    db_user.save()
    return db_user


def get_items(skip: int = 0, limit: int = 100):
    return list(models.Item.select().offset(skip).limit(limit))


def create_user_item(item: schemas.ItemCreate, user_id: int):
    db_item = models.Item(**item.dict(), owner_id=user_id)
    db_item.save()
    return db_item

There are some differences with the code for the SQLAlchemy tutorial.

We don't pass a db attribute around. Instead we use the models directly. This is because the db object is a global object, that includes all the connection logic. That's why we had to do all the contextvars updates above.

Aso, when returning several objects, like in get_users, we directly call list, like in:

list(models.User.select())

This is for the same reason that we had to create a custom PeeweeGetterDict. But by returning something that is already a list instead of the peewee.ModelSelect the response_model in the path operation with List[models.User] (that we'll see later) will work correctly.

Main FastAPI app

And now in the file sql_app/main.py let's integrate and use all the other parts we created before.

Create the database tables

In a very simplistic way create the database tables:

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException
from starlette.requests import Request

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


def get_db():
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()
    response = await call_next(request)
    return response


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

Create a dependency

Create a dependency that will connect the database right at the beginning of a request and disconnect it at the end:

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException
from starlette.requests import Request

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


def get_db():
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()
    response = await call_next(request)
    return response


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

Here we have an empty yield because we are actually not using the database object directly.

It is connecting to the database and storing the connection data in an internal variable that is independent for each request (using the contextvars tricks from above).

And then, in each path operation function that needs to access the database we add it as a dependency.

But we are not using the value given by this dependency (it actually doesn't give any value, as it has an empty yield). So, we don't add it to the path operation function but to the path operation decorator in the dependencies parameter:

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException
from starlette.requests import Request

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


def get_db():
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()
    response = await call_next(request)
    return response


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

Context Variable Middleware

For all the contextvars parts to work, we need to make sure there's a new "context" each time there's a new request, so that we have a specific context variable Peewee can use to save its state (database connection, transactions, etc).

For that, we need to create a middleware.

Right before the request, we are going to reset the database state. We will "set" a value to the context variable and then we will ask the Peewee database state to "reset" (this will create the default values it uses).

And then the rest of the request is processed with that new context variable we just set, all automatically and more or less "magically".

For the next request, as we will reset that context variable again in the middleware, that new request will have its own database state (connection, transactions, etc).

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException
from starlette.requests import Request

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


def get_db():
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()
    response = await call_next(request)
    return response


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

Tip

As FastAPI is an async framework, one request could start being processed, and before finishing, another request could be received and start processing as well, and it all could be processed in the same thread.

But context variables are aware of these async features, so, a Peewee database state set in the middleware will keep its own data throughout the entire request.

And at the same time, the other concurrent request will have its own database state that will be independent for the whole request.

Peewee Proxy

If you are using a Peewee Proxy, the actual database is at db.obj.

So, you would reset it with:

@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db.obj._state._state.set(db_state_default.copy())
    database.db.obj._state.reset()
    response = await call_next(request)
    return response

Create your FastAPI path operations

Now, finally, here's the standard FastAPI path operations code.

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException
from starlette.requests import Request

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


def get_db():
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()
    response = await call_next(request)
    return response


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

About def vs async def

The same as with SQLAlchemy, we are not doing something like:

user = await models.User.select().first()

...but instead we are using:

user = models.User.select().first()

So, again, we should declare the path operation functions and the dependency without async def, just with a normal def, as:

# Something goes here
def read_users(skip: int = 0, limit: int = 100):
    # Something goes here

Testing Peewee with async

This example includes an extra path operation that simulates a long processing request with time.sleep(sleep_time).

It will have the database connection open at the beginning and will just wait some seconds before replying back. And each new request will wait one second less.

This will easily let you test that your app with Peewee and FastAPI is behaving correctly with all the stuff about threads.

If you want to check how Peewee would break your app if used without modification, go the the sql_app/database.py file and comment the line:

# db._state = PeeweeConnectionState()

And in the file sql_app/main.py file, comment the middleware:

# @app.middleware("http")
# async def reset_db_middleware(request: Request, call_next):
#     database.db._state._state.set(db_state_default.copy())
#     database.db._state.reset()
#     response = await call_next(request)
#     return response

Then run your app with Uvicorn:

uvicorn sql_app.main:app --reload

Open your browser at http://127.0.0.1:8000/docs and create a couple of users.

Then open 10 tabs at http://127.0.0.1:8000/docs#/default/read_slow_users_slowusers__get at the same time.

Go to the path operation "Get /slowusers/" in all of the tabs. Use the "Try it out" button and execute the request in each tab, one right after the other.

The tabs will wait for a bit and then some of them will show Internal Server Error.

What happens

The first tab will make your app create a connection to the database and wait for some seconds before replying back and closing the connection.

Then, for the request in the next tab, your app will wait for one second less, and so on.

This means that it will end up finishing some of the last tabs' requests than some of the previous ones.

Then one the last requests that wait less seconds will try to open a database connection, but as one of those previous requests for the other tabs will probably be handled in the same thread as the first one, it will have the same database connection that is already open, and Peewee will throw an error and you will see it in the terminal, and the response will have an Internal Server Error.

This will probably happen for more than one of those tabs.

If you had multiple clients talking to your app exactly at the same time, this is what could happen.

And as your app starts to handle more and more clients at the same time, the waiting time in a single request needs to be shorter and shorter to trigger the error.

Fix Peewee with FastAPI

Now go back to the file sql_app/database.py, and uncomment the line:

db._state = PeeweeConnectionState()

And in the file sql_app/main.py file, uncomment the middleware:

@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()
    response = await call_next(request)
    return response

Terminate your running app and start it again.

Repeat the same process with the 10 tabs. This time all of them will wait and you will get all the results without errors.

...You fixed it!

Review all the files

Remember you should have a directory named my_super_project (or however you want) that contains a sub-directory called sql_app.

sql_app should have the following files:

  • sql_app/__init__.py: is an empty file.

  • sql_app/database.py:

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()
  • sql_app/models.py:
import peewee

from .database import db


class User(peewee.Model):
    email = peewee.CharField(unique=True, index=True)
    hashed_password = peewee.CharField()
    is_active = peewee.BooleanField(default=True)

    class Meta:
        database = db


class Item(peewee.Model):
    title = peewee.CharField(index=True)
    description = peewee.CharField(index=True)
    owner = peewee.ForeignKeyField(User, backref="items")

    class Meta:
        database = db
  • sql_app/schemas.py:
from typing import Any, List

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: str = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict
  • sql_app/crud.py:
from . import models, schemas


def get_user(user_id: int):
    return models.User.filter(models.User.id == user_id).first()


def get_user_by_email(email: str):
    return models.User.filter(models.User.email == email).first()


def get_users(skip: int = 0, limit: int = 100):
    return list(models.User.select().offset(skip).limit(limit))


def create_user(user: schemas.UserCreate):
    fake_hashed_password = user.password + "notreallyhashed"
    db_user = models.User(email=user.email, hashed_password=fake_hashed_password)
    db_user.save()
    return db_user


def get_items(skip: int = 0, limit: int = 100):
    return list(models.Item.select().offset(skip).limit(limit))


def create_user_item(item: schemas.ItemCreate, user_id: int):
    db_item = models.Item(**item.dict(), owner_id=user_id)
    db_item.save()
    return db_item
  • sql_app/main.py:
import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException
from starlette.requests import Request

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


def get_db():
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()
    response = await call_next(request)
    return response


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

Technical Details

Warning

These are very technical details that you probably don't need.

The problem

Peewee uses threading.local by default to store it's database "state" data (connection, transactions, etc).

threading.local creates a value exclusive to the current thread, but an async framework would run all the "tasks" (e.g. requests) in the same thread, and possibly not in order.

On top of that, an async framework could run some sync code in a threadpool (using asyncio.run_in_executor), but belonging to the same "task" (e.g. to the same request).

This means that, with Peewee's current implementation, multiple tasks could be using the same threading.local variable and end up sharing the same connection and data, and at the same time, if they execute sync IO-blocking code in a threadpool (as with normal def functions in FastAPI, in path operations and dependencies), that code won't have access to the database state variables, even while it's part of the same "task" (request) and it should be able to get access to that.

Context variables

Python 3.7 has contextvars that can create a local variable very similar to threading.local, but also supporting these async features.

There are several things to have in mind.

The ContextVar has to be created at the top of the module, like some_var = ContextVar("some_var", default="default value").

To set a value used in the current "context" (e.g. for the current request) use some_var.set("new value").

To get a value anywhere inside of the context (e.g. in any part handling the current request) use some_var.get().

Set context variables in middleware

If some part of the async code sets the value with some_var.set("updated in function") (e.g. the middleware), the rest of the code in it will see that new value.

And if it calls any other function with await some_function() (e.g. response = await call_next(request) in our middleware) that internal some_function() (or response = await call_next(request) in our example) and everything it calls inside, will see that same new value "updated in function".

So, in our case, if we set the Peewee state variable in the middleware and then call response = await call_next(request) all the rest of the internal code in our app (that is called by call_next()) will see this value we set in the middleware and will be able to reuse it.

But if the value is set in an internal function (e.g. in get_db()) that value will be seen only by that internal function and any code it calls, not by the parent function nor by any sibling function. So, we can't set the Peewee database state in get_db(), or the path operation functions wouldn't see the new Peewee database state for that "context".

But get_db is an async context manager

You might be thinking that get_db() is actually not used as a function, it's converted to a context manager.

So the path operation function is part of it.

But the code after the yield, in the finally is not executed in the same "context".

So, if you reset the state in get_db(), the path operation function would see the database connection set there. But the finally block would not see the same context variable value, and so, as the database object would not have the same context variable for its state, it would not have the same connection, so you couldn't close it in the finally in get_db() after the request is done.

In the middleware we are setting the Peewee state to a context variable that holds a dict. So, it's set for every new request.

And as the database state variables are stored inside of that dict instead of new context variables, when Peewee sets the new database state (connection, transactions, etc) in any part of the internal code, underneath, all that will be set as keys in that dict. But the dict would still be the same we set in the middleware. That's what allows the get_db() dependency to make Peewee create a new connection (that is stored in that dict) and allows the finally block to still have access to the same connection.

Because the context variable is set outside all that, in the middleware.

Connect and disconnect in dependency

Then the next question would be, why not just connect and disconnect the database in the middleware itself, instead of get_db()?

First, the middleware has to be async, and creating and closing the database connection is potentially blocking, so it could degrade performance.

But more importantly, the middleware returns a response, and this response is actually an awaitable function that will do all the work in your code, including background tasks.

If you closed the connection in the middleware right before returning the response, some of your code would not have the chance to use the database connection set in the context variable.

Because some other code will call that response with await response(...). And inside of that await response(...) is that, for example, background tasks are run. But if the connection was already closed before response is awaited, then it won't be able to access it.