83 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			83 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import TYPE_CHECKING
 | |
| 
 | |
| import strawberry
 | |
| from fastapi import Request
 | |
| from sqlalchemy.orm import Session
 | |
| 
 | |
| from api.hasher import argon2_hasher
 | |
| from api.models import User as UserModel
 | |
| from api.schema.definitions.common import CommonError
 | |
| from api.schema.definitions.common import CommonMessage
 | |
| from api.schema.definitions.user import User
 | |
| from api.token import decode_user_token
 | |
| from api.token import encode_user_token
 | |
| from api.token import token_from_headers
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from strawberry.types import Info
 | |
|     from api.schema import Query
 | |
| 
 | |
| 
 | |
| @strawberry.input
 | |
| class LoginInput:
 | |
|     name: str
 | |
|     password: str
 | |
| 
 | |
| 
 | |
| @strawberry.input
 | |
| class UpdateUserInput:
 | |
|     name: str
 | |
|     password: str
 | |
| 
 | |
| 
 | |
| @strawberry.type
 | |
| class AuthSuccess:
 | |
|     user: User
 | |
|     token: str
 | |
| 
 | |
| 
 | |
| AuthResult = strawberry.union("AuthResult", (AuthSuccess, CommonError))
 | |
| LogoutResult = strawberry.union("LogoutResult", (CommonMessage, CommonError))
 | |
| 
 | |
| 
 | |
| async def login(root: "Query", info: "Info", body: LoginInput) -> AuthResult:
 | |
|     db: Session = info.context["db"]
 | |
| 
 | |
|     stmt = (
 | |
|         db.query(UserModel.id, UserModel.name, UserModel.password_hash)
 | |
|         .filter(UserModel.name == body.name)
 | |
|         .limit(1)
 | |
|     )
 | |
| 
 | |
|     user: UserModel | None = db.execute(stmt).first()  # type: ignore
 | |
|     if not user:
 | |
|         return CommonError(message="Invalid credentials")
 | |
| 
 | |
|     if not argon2_hasher.verify(user.password_hash, body.password):
 | |
|         return CommonError(message="Invalid credentials")
 | |
| 
 | |
|     if argon2_hasher.check_needs_rehash(user.password_hash):
 | |
|         user.password_hash = argon2_hasher.hash(body.password)
 | |
|         db.commit()
 | |
| 
 | |
|     return AuthSuccess(user=User.from_instance(user), token=encode_user_token(user))
 | |
| 
 | |
| 
 | |
| async def update_me(
 | |
|     root: "Query", info: "Info", body: UpdateUserInput
 | |
| ) -> CommonMessage:
 | |
|     db: Session = info.context["db"]
 | |
|     req: Request = info.context["request"]
 | |
| 
 | |
|     _, auth_token = token_from_headers(req.headers)
 | |
|     token = decode_user_token(auth_token)
 | |
| 
 | |
|     updated_user = {
 | |
|         UserModel.name: body.name,
 | |
|         UserModel.password_hash: argon2_hasher.hash(body.password),
 | |
|     }
 | |
|     db.query(UserModel).filter(UserModel.id == token["id"]).update(updated_user)
 | |
|     db.commit()
 | |
| 
 | |
|     return CommonMessage(message="Succesfully updated credentials")
 |