83 lines
2.2 KiB
Python
83 lines
2.2 KiB
Python
from typing import TYPE_CHECKING
|
|
|
|
import strawberry
|
|
from fastapi import Request
|
|
from sqlalchemy.orm import Session
|
|
|
|
from api.hasher import argon2_hasher
|
|
from api.models import User as UserModel
|
|
from api.schema.definitions.common import CommonError
|
|
from api.schema.definitions.common import CommonMessage
|
|
from api.schema.definitions.user import User
|
|
from api.token import decode_user_token
|
|
from api.token import encode_user_token
|
|
from api.token import token_from_headers
|
|
|
|
if TYPE_CHECKING:
|
|
from strawberry.types import Info
|
|
from api.schema import Query
|
|
|
|
|
|
@strawberry.input
|
|
class LoginInput:
|
|
name: str
|
|
password: str
|
|
|
|
|
|
@strawberry.input
|
|
class UpdateUserInput:
|
|
name: str
|
|
password: str
|
|
|
|
|
|
@strawberry.type
|
|
class AuthSuccess:
|
|
user: User
|
|
token: str
|
|
|
|
|
|
AuthResult = strawberry.union("AuthResult", (AuthSuccess, CommonError))
|
|
LogoutResult = strawberry.union("LogoutResult", (CommonMessage, CommonError))
|
|
|
|
|
|
async def login(root: "Query", info: "Info", body: LoginInput) -> AuthResult:
|
|
db: Session = info.context["db"]
|
|
|
|
stmt = (
|
|
db.query(UserModel.id, UserModel.name, UserModel.password_hash)
|
|
.filter(UserModel.name == body.name)
|
|
.limit(1)
|
|
)
|
|
|
|
user: UserModel | None = db.execute(stmt).first() # type: ignore
|
|
if not user:
|
|
return CommonError(message="Invalid credentials")
|
|
|
|
if not argon2_hasher.verify(user.password_hash, body.password):
|
|
return CommonError(message="Invalid credentials")
|
|
|
|
if argon2_hasher.check_needs_rehash(user.password_hash):
|
|
user.password_hash = argon2_hasher.hash(body.password)
|
|
db.commit()
|
|
|
|
return AuthSuccess(user=User.from_instance(user), token=encode_user_token(user))
|
|
|
|
|
|
async def update_me(
|
|
root: "Query", info: "Info", body: UpdateUserInput
|
|
) -> CommonMessage:
|
|
db: Session = info.context["db"]
|
|
req: Request = info.context["request"]
|
|
|
|
_, auth_token = token_from_headers(req.headers)
|
|
token = decode_user_token(auth_token)
|
|
|
|
updated_user = {
|
|
UserModel.name: body.name,
|
|
UserModel.password_hash: argon2_hasher.hash(body.password),
|
|
}
|
|
db.query(UserModel).filter(UserModel.id == token["id"]).update(updated_user)
|
|
db.commit()
|
|
|
|
return CommonMessage(message="Succesfully updated credentials")
|