Added auth

This commit is contained in:
2022-05-29 00:31:32 +02:00
parent 36bb9eeefa
commit 23d514050d
19 changed files with 815 additions and 59 deletions

0
api/__init__.py Normal file
View File

14
api/database.py Normal file
View File

@ -0,0 +1,14 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./db.sqlite3"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}, future=True
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, future=True)
Base = declarative_base()

13
api/hasher.py Normal file
View File

@ -0,0 +1,13 @@
import argon2
class Argon2PasswordHasher(argon2.PasswordHasher):
def verify(self, password_hash: str | bytes, password: str | bytes) -> bool:
try:
super().verify(password_hash, password)
return True
except argon2.exceptions.VerifyMismatchError:
return False
argon2_hasher = Argon2PasswordHasher()

17
api/models.py Normal file
View File

@ -0,0 +1,17 @@
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import VARCHAR
from api.database import Base
from api.database import engine
class UserModel(Base):
__tablename__ = "user"
id = Column(Integer, primary_key=True, index=True, nullable=False)
name = Column(VARCHAR(length=32), nullable=False, unique=True)
password_hash = Column(VARCHAR, nullable=False)
Base.metadata.create_all(engine)

24
api/schema/__init__.py Normal file
View File

@ -0,0 +1,24 @@
import strawberry
from api.schema.definitions.auth import AuthResult
from api.schema.definitions.auth import login
from api.schema.definitions.auth import update_me
from api.schema.definitions.common import CommonMessage
from api.schema.extensions import extensions
from api.schema.permissions import IsAuthenticated
@strawberry.type
class Query:
hello: str
@strawberry.type
class Mutation:
login: AuthResult = strawberry.field(resolver=login)
update_me: CommonMessage = strawberry.field(
resolver=update_me, permission_classes=[IsAuthenticated]
)
schema = strawberry.Schema(query=Query, mutation=Mutation, extensions=extensions)

View File

@ -0,0 +1,83 @@
from typing import TYPE_CHECKING
import strawberry
from fastapi import Request
from sqlalchemy import true
from sqlalchemy.orm import Session
from api.hasher import argon2_hasher
from api.models import UserModel
from api.schema.definitions.common import CommonError
from api.schema.definitions.common import CommonMessage
from api.schema.definitions.user import User
from api.token import decode_user_token
from api.token import encode_user_token
from api.token import token_from_headers
if TYPE_CHECKING:
from strawberry.types import Info
from api.schema import Query
@strawberry.input
class LoginInput:
name: str
password: str
@strawberry.input
class UpdateUserInput:
name: str
password: str
@strawberry.type
class AuthSuccess:
user: User
token: str
AuthResult = strawberry.union("AuthResult", (AuthSuccess, CommonError))
LogoutResult = strawberry.union("LogoutResult", (CommonMessage, CommonError))
async def login(root: "Query", info: "Info", body: LoginInput) -> AuthResult:
db: Session = info.context["db"]
stmt = (
db.query(UserModel.id, UserModel.name, UserModel.password_hash)
.filter(UserModel.name == body.name)
.limit(1)
)
user: UserModel | None = db.execute(stmt).first() # type: ignore
if not user:
return CommonError(message="Invalid credentials")
if not argon2_hasher.verify(user.password_hash, body.password):
return CommonError(message="Invalid credentials")
if argon2_hasher.check_needs_rehash(user.password_hash):
user.password_hash = argon2_hasher.hash(body.password)
db.commit()
return AuthSuccess(user=User.from_instance(user), token=encode_user_token(user))
async def update_me(
root: "Query", info: "Info", body: UpdateUserInput
) -> CommonMessage:
db: Session = info.context["db"]
req: Request = info.context["request"]
_, auth_token = token_from_headers(req.headers)
token = decode_user_token(auth_token)
updated_user = {
UserModel.name: body.name,
UserModel.password_hash: argon2_hasher.hash(body.password),
}
db.query(UserModel).filter(UserModel.id == token["id"]).update(updated_user)
db.commit()
return CommonMessage(message="Succesfully updated credentials")

View File

@ -0,0 +1,11 @@
import strawberry
@strawberry.type
class CommonError:
message: str
@strawberry.type
class CommonMessage:
message: str

View File

@ -0,0 +1,13 @@
import strawberry
from api.models import UserModel
@strawberry.type
class User:
id: int
name: str
@classmethod
def from_instance(cls, instance: UserModel):
return cls(id=instance.id, name=instance.name)

14
api/schema/extensions.py Normal file
View File

@ -0,0 +1,14 @@
from strawberry.extensions import Extension
from api.database import SessionLocal
class SQLAlchemySession(Extension):
def on_request_start(self):
self.execution_context.context["db"] = SessionLocal()
def on_request_end(self):
self.execution_context.context["db"].close()
extensions = (SQLAlchemySession,)

27
api/schema/permissions.py Normal file
View File

@ -0,0 +1,27 @@
import typing
from fastapi import Request
from strawberry.permission import BasePermission
from api.token import decode_user_token
from api.token import token_from_headers
if typing.TYPE_CHECKING:
from strawberry.types import Info
class IsAuthenticated(BasePermission):
message = "User is not authenticated"
async def has_permission(self, source: typing.Any, info: "Info", **kwargs) -> bool:
req: Request = info.context["request"]
_, auth_token = token_from_headers(req.headers)
if len(auth_token) == 0:
return False
try:
decode_user_token(auth_token)
return True
except Exception:
return False

14
api/seed.py Normal file
View File

@ -0,0 +1,14 @@
from sqlalchemy.orm import Session
from api.database import SessionLocal
from api.hasher import argon2_hasher
from api.models import UserModel
def seed():
db: Session = SessionLocal()
if db.query(UserModel).count() == 0:
admin = UserModel(name="admin", password_hash=argon2_hasher.hash("admin"))
db.add(admin)
db.commit()

35
api/token.py Normal file
View File

@ -0,0 +1,35 @@
import os
import typing
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import jwt
from fastapi import Request
from starlette.datastructures import Headers
from api.models import UserModel
if typing.TYPE_CHECKING:
UserTokenData = dict[str, typing.Any]
JWT_SECRET = os.getenv("JWT_SECRET", "")
def encode_user_token(
user: UserModel, expire_in: timedelta = timedelta(hours=6)
) -> str:
payload = {}
payload["id"] = user.id
payload["exp"] = datetime.now(tz=timezone.utc) + expire_in
return jwt.encode(payload, JWT_SECRET)
def decode_user_token(token: str) -> "UserTokenData":
return jwt.decode(token, JWT_SECRET, ["HS256"])
def token_from_headers(headers: Headers) -> tuple[str, str]:
auth_header: str = headers.get("authorization", "")
scheme, _, auth_token = auth_header.partition(" ")
return scheme, auth_token