Added auth
This commit is contained in:
		
							
								
								
									
										0
									
								
								api/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								api/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										14
									
								
								api/database.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								api/database.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| from sqlalchemy import create_engine | ||||
| from sqlalchemy.ext.declarative import declarative_base | ||||
| from sqlalchemy.orm import sessionmaker | ||||
|  | ||||
|  | ||||
| SQLALCHEMY_DATABASE_URL = "sqlite:///./db.sqlite3" | ||||
|  | ||||
| engine = create_engine( | ||||
|     SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}, future=True | ||||
| ) | ||||
|  | ||||
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, future=True) | ||||
|  | ||||
| Base = declarative_base() | ||||
							
								
								
									
										13
									
								
								api/hasher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								api/hasher.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| import argon2 | ||||
|  | ||||
|  | ||||
| class Argon2PasswordHasher(argon2.PasswordHasher): | ||||
|     def verify(self, password_hash: str | bytes, password: str | bytes) -> bool: | ||||
|         try: | ||||
|             super().verify(password_hash, password) | ||||
|             return True | ||||
|         except argon2.exceptions.VerifyMismatchError: | ||||
|             return False | ||||
|  | ||||
|  | ||||
| argon2_hasher = Argon2PasswordHasher() | ||||
							
								
								
									
										17
									
								
								api/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								api/models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| from sqlalchemy import Column | ||||
| from sqlalchemy import Integer | ||||
| from sqlalchemy import VARCHAR | ||||
|  | ||||
| from api.database import Base | ||||
| from api.database import engine | ||||
|  | ||||
|  | ||||
| class UserModel(Base): | ||||
|     __tablename__ = "user" | ||||
|  | ||||
|     id = Column(Integer, primary_key=True, index=True, nullable=False) | ||||
|     name = Column(VARCHAR(length=32), nullable=False, unique=True) | ||||
|     password_hash = Column(VARCHAR, nullable=False) | ||||
|  | ||||
|  | ||||
| Base.metadata.create_all(engine) | ||||
							
								
								
									
										24
									
								
								api/schema/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								api/schema/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| import strawberry | ||||
|  | ||||
| from api.schema.definitions.auth import AuthResult | ||||
| from api.schema.definitions.auth import login | ||||
| from api.schema.definitions.auth import update_me | ||||
| from api.schema.definitions.common import CommonMessage | ||||
| from api.schema.extensions import extensions | ||||
| from api.schema.permissions import IsAuthenticated | ||||
|  | ||||
|  | ||||
| @strawberry.type | ||||
| class Query: | ||||
|     hello: str | ||||
|  | ||||
|  | ||||
| @strawberry.type | ||||
| class Mutation: | ||||
|     login: AuthResult = strawberry.field(resolver=login) | ||||
|     update_me: CommonMessage = strawberry.field( | ||||
|         resolver=update_me, permission_classes=[IsAuthenticated] | ||||
|     ) | ||||
|  | ||||
|  | ||||
| schema = strawberry.Schema(query=Query, mutation=Mutation, extensions=extensions) | ||||
							
								
								
									
										83
									
								
								api/schema/definitions/auth.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								api/schema/definitions/auth.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| import strawberry | ||||
| from fastapi import Request | ||||
| from sqlalchemy import true | ||||
| from sqlalchemy.orm import Session | ||||
|  | ||||
| from api.hasher import argon2_hasher | ||||
| from api.models import UserModel | ||||
| from api.schema.definitions.common import CommonError | ||||
| from api.schema.definitions.common import CommonMessage | ||||
| from api.schema.definitions.user import User | ||||
| from api.token import decode_user_token | ||||
| from api.token import encode_user_token | ||||
| from api.token import token_from_headers | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from strawberry.types import Info | ||||
|     from api.schema import Query | ||||
|  | ||||
|  | ||||
| @strawberry.input | ||||
| class LoginInput: | ||||
|     name: str | ||||
|     password: str | ||||
|  | ||||
|  | ||||
| @strawberry.input | ||||
| class UpdateUserInput: | ||||
|     name: str | ||||
|     password: str | ||||
|  | ||||
|  | ||||
| @strawberry.type | ||||
| class AuthSuccess: | ||||
|     user: User | ||||
|     token: str | ||||
|  | ||||
|  | ||||
| AuthResult = strawberry.union("AuthResult", (AuthSuccess, CommonError)) | ||||
| LogoutResult = strawberry.union("LogoutResult", (CommonMessage, CommonError)) | ||||
|  | ||||
|  | ||||
| async def login(root: "Query", info: "Info", body: LoginInput) -> AuthResult: | ||||
|     db: Session = info.context["db"] | ||||
|  | ||||
|     stmt = ( | ||||
|         db.query(UserModel.id, UserModel.name, UserModel.password_hash) | ||||
|         .filter(UserModel.name == body.name) | ||||
|         .limit(1) | ||||
|     ) | ||||
|  | ||||
|     user: UserModel | None = db.execute(stmt).first()  # type: ignore | ||||
|     if not user: | ||||
|         return CommonError(message="Invalid credentials") | ||||
|  | ||||
|     if not argon2_hasher.verify(user.password_hash, body.password): | ||||
|         return CommonError(message="Invalid credentials") | ||||
|  | ||||
|     if argon2_hasher.check_needs_rehash(user.password_hash): | ||||
|         user.password_hash = argon2_hasher.hash(body.password) | ||||
|         db.commit() | ||||
|  | ||||
|     return AuthSuccess(user=User.from_instance(user), token=encode_user_token(user)) | ||||
|  | ||||
|  | ||||
| async def update_me( | ||||
|     root: "Query", info: "Info", body: UpdateUserInput | ||||
| ) -> CommonMessage: | ||||
|     db: Session = info.context["db"] | ||||
|     req: Request = info.context["request"] | ||||
|  | ||||
|     _, auth_token = token_from_headers(req.headers) | ||||
|     token = decode_user_token(auth_token) | ||||
|  | ||||
|     updated_user = { | ||||
|         UserModel.name: body.name, | ||||
|         UserModel.password_hash: argon2_hasher.hash(body.password), | ||||
|     } | ||||
|     db.query(UserModel).filter(UserModel.id == token["id"]).update(updated_user) | ||||
|     db.commit() | ||||
|  | ||||
|     return CommonMessage(message="Succesfully updated credentials") | ||||
							
								
								
									
										11
									
								
								api/schema/definitions/common.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								api/schema/definitions/common.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,11 @@ | ||||
| import strawberry | ||||
|  | ||||
|  | ||||
| @strawberry.type | ||||
| class CommonError: | ||||
|     message: str | ||||
|  | ||||
|  | ||||
| @strawberry.type | ||||
| class CommonMessage: | ||||
|     message: str | ||||
							
								
								
									
										13
									
								
								api/schema/definitions/user.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								api/schema/definitions/user.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| import strawberry | ||||
|  | ||||
| from api.models import UserModel | ||||
|  | ||||
|  | ||||
| @strawberry.type | ||||
| class User: | ||||
|     id: int | ||||
|     name: str | ||||
|  | ||||
|     @classmethod | ||||
|     def from_instance(cls, instance: UserModel): | ||||
|         return cls(id=instance.id, name=instance.name) | ||||
							
								
								
									
										14
									
								
								api/schema/extensions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								api/schema/extensions.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| from strawberry.extensions import Extension | ||||
|  | ||||
| from api.database import SessionLocal | ||||
|  | ||||
|  | ||||
| class SQLAlchemySession(Extension): | ||||
|     def on_request_start(self): | ||||
|         self.execution_context.context["db"] = SessionLocal() | ||||
|  | ||||
|     def on_request_end(self): | ||||
|         self.execution_context.context["db"].close() | ||||
|  | ||||
|  | ||||
| extensions = (SQLAlchemySession,) | ||||
							
								
								
									
										27
									
								
								api/schema/permissions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								api/schema/permissions.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| import typing | ||||
|  | ||||
| from fastapi import Request | ||||
| from strawberry.permission import BasePermission | ||||
|  | ||||
| from api.token import decode_user_token | ||||
| from api.token import token_from_headers | ||||
|  | ||||
| if typing.TYPE_CHECKING: | ||||
|     from strawberry.types import Info | ||||
|  | ||||
|  | ||||
| class IsAuthenticated(BasePermission): | ||||
|     message = "User is not authenticated" | ||||
|  | ||||
|     async def has_permission(self, source: typing.Any, info: "Info", **kwargs) -> bool: | ||||
|         req: Request = info.context["request"] | ||||
|         _, auth_token = token_from_headers(req.headers) | ||||
|  | ||||
|         if len(auth_token) == 0: | ||||
|             return False | ||||
|  | ||||
|         try: | ||||
|             decode_user_token(auth_token) | ||||
|             return True | ||||
|         except Exception: | ||||
|             return False | ||||
							
								
								
									
										14
									
								
								api/seed.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								api/seed.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| from sqlalchemy.orm import Session | ||||
|  | ||||
| from api.database import SessionLocal | ||||
| from api.hasher import argon2_hasher | ||||
| from api.models import UserModel | ||||
|  | ||||
|  | ||||
| def seed(): | ||||
|     db: Session = SessionLocal() | ||||
|  | ||||
|     if db.query(UserModel).count() == 0: | ||||
|         admin = UserModel(name="admin", password_hash=argon2_hasher.hash("admin")) | ||||
|         db.add(admin) | ||||
|         db.commit() | ||||
							
								
								
									
										35
									
								
								api/token.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								api/token.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| import os | ||||
| import typing | ||||
| from datetime import datetime | ||||
| from datetime import timedelta | ||||
| from datetime import timezone | ||||
|  | ||||
| import jwt | ||||
| from fastapi import Request | ||||
| from starlette.datastructures import Headers | ||||
|  | ||||
| from api.models import UserModel | ||||
|  | ||||
| if typing.TYPE_CHECKING: | ||||
|     UserTokenData = dict[str, typing.Any] | ||||
|  | ||||
| JWT_SECRET = os.getenv("JWT_SECRET", "") | ||||
|  | ||||
|  | ||||
| def encode_user_token( | ||||
|     user: UserModel, expire_in: timedelta = timedelta(hours=6) | ||||
| ) -> str: | ||||
|     payload = {} | ||||
|     payload["id"] = user.id | ||||
|     payload["exp"] = datetime.now(tz=timezone.utc) + expire_in | ||||
|     return jwt.encode(payload, JWT_SECRET) | ||||
|  | ||||
|  | ||||
| def decode_user_token(token: str) -> "UserTokenData": | ||||
|     return jwt.decode(token, JWT_SECRET, ["HS256"]) | ||||
|  | ||||
|  | ||||
| def token_from_headers(headers: Headers) -> tuple[str, str]: | ||||
|     auth_header: str = headers.get("authorization", "") | ||||
|     scheme, _, auth_token = auth_header.partition(" ") | ||||
|     return scheme, auth_token | ||||
		Reference in New Issue
	
	Block a user