from strawberry.extensions import Extension

from api.database import SessionLocal


class SQLAlchemySession(Extension):
    def on_request_start(self):
        self.execution_context.context["db"] = SessionLocal()

    def on_request_end(self):
        self.execution_context.context["db"].close()


extensions = (SQLAlchemySession,)