Added auth

This commit is contained in:
2022-05-29 00:31:32 +02:00
parent 36bb9eeefa
commit 23d514050d
19 changed files with 815 additions and 59 deletions

24
api/schema/__init__.py Normal file
View File

@ -0,0 +1,24 @@
import strawberry
from api.schema.definitions.auth import AuthResult
from api.schema.definitions.auth import login
from api.schema.definitions.auth import update_me
from api.schema.definitions.common import CommonMessage
from api.schema.extensions import extensions
from api.schema.permissions import IsAuthenticated
@strawberry.type
class Query:
hello: str
@strawberry.type
class Mutation:
login: AuthResult = strawberry.field(resolver=login)
update_me: CommonMessage = strawberry.field(
resolver=update_me, permission_classes=[IsAuthenticated]
)
schema = strawberry.Schema(query=Query, mutation=Mutation, extensions=extensions)

View File

@ -0,0 +1,83 @@
from typing import TYPE_CHECKING
import strawberry
from fastapi import Request
from sqlalchemy import true
from sqlalchemy.orm import Session
from api.hasher import argon2_hasher
from api.models import UserModel
from api.schema.definitions.common import CommonError
from api.schema.definitions.common import CommonMessage
from api.schema.definitions.user import User
from api.token import decode_user_token
from api.token import encode_user_token
from api.token import token_from_headers
if TYPE_CHECKING:
from strawberry.types import Info
from api.schema import Query
@strawberry.input
class LoginInput:
name: str
password: str
@strawberry.input
class UpdateUserInput:
name: str
password: str
@strawberry.type
class AuthSuccess:
user: User
token: str
AuthResult = strawberry.union("AuthResult", (AuthSuccess, CommonError))
LogoutResult = strawberry.union("LogoutResult", (CommonMessage, CommonError))
async def login(root: "Query", info: "Info", body: LoginInput) -> AuthResult:
db: Session = info.context["db"]
stmt = (
db.query(UserModel.id, UserModel.name, UserModel.password_hash)
.filter(UserModel.name == body.name)
.limit(1)
)
user: UserModel | None = db.execute(stmt).first() # type: ignore
if not user:
return CommonError(message="Invalid credentials")
if not argon2_hasher.verify(user.password_hash, body.password):
return CommonError(message="Invalid credentials")
if argon2_hasher.check_needs_rehash(user.password_hash):
user.password_hash = argon2_hasher.hash(body.password)
db.commit()
return AuthSuccess(user=User.from_instance(user), token=encode_user_token(user))
async def update_me(
root: "Query", info: "Info", body: UpdateUserInput
) -> CommonMessage:
db: Session = info.context["db"]
req: Request = info.context["request"]
_, auth_token = token_from_headers(req.headers)
token = decode_user_token(auth_token)
updated_user = {
UserModel.name: body.name,
UserModel.password_hash: argon2_hasher.hash(body.password),
}
db.query(UserModel).filter(UserModel.id == token["id"]).update(updated_user)
db.commit()
return CommonMessage(message="Succesfully updated credentials")

View File

@ -0,0 +1,11 @@
import strawberry
@strawberry.type
class CommonError:
message: str
@strawberry.type
class CommonMessage:
message: str

View File

@ -0,0 +1,13 @@
import strawberry
from api.models import UserModel
@strawberry.type
class User:
id: int
name: str
@classmethod
def from_instance(cls, instance: UserModel):
return cls(id=instance.id, name=instance.name)

14
api/schema/extensions.py Normal file
View File

@ -0,0 +1,14 @@
from strawberry.extensions import Extension
from api.database import SessionLocal
class SQLAlchemySession(Extension):
def on_request_start(self):
self.execution_context.context["db"] = SessionLocal()
def on_request_end(self):
self.execution_context.context["db"].close()
extensions = (SQLAlchemySession,)

27
api/schema/permissions.py Normal file
View File

@ -0,0 +1,27 @@
import typing
from fastapi import Request
from strawberry.permission import BasePermission
from api.token import decode_user_token
from api.token import token_from_headers
if typing.TYPE_CHECKING:
from strawberry.types import Info
class IsAuthenticated(BasePermission):
message = "User is not authenticated"
async def has_permission(self, source: typing.Any, info: "Info", **kwargs) -> bool:
req: Request = info.context["request"]
_, auth_token = token_from_headers(req.headers)
if len(auth_token) == 0:
return False
try:
decode_user_token(auth_token)
return True
except Exception:
return False