Added auth
This commit is contained in:
0
api/__init__.py
Normal file
0
api/__init__.py
Normal file
14
api/database.py
Normal file
14
api/database.py
Normal file
@ -0,0 +1,14 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./db.sqlite3"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}, future=True
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, future=True)
|
||||
|
||||
Base = declarative_base()
|
13
api/hasher.py
Normal file
13
api/hasher.py
Normal file
@ -0,0 +1,13 @@
|
||||
import argon2
|
||||
|
||||
|
||||
class Argon2PasswordHasher(argon2.PasswordHasher):
|
||||
def verify(self, password_hash: str | bytes, password: str | bytes) -> bool:
|
||||
try:
|
||||
super().verify(password_hash, password)
|
||||
return True
|
||||
except argon2.exceptions.VerifyMismatchError:
|
||||
return False
|
||||
|
||||
|
||||
argon2_hasher = Argon2PasswordHasher()
|
17
api/models.py
Normal file
17
api/models.py
Normal file
@ -0,0 +1,17 @@
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import VARCHAR
|
||||
|
||||
from api.database import Base
|
||||
from api.database import engine
|
||||
|
||||
|
||||
class UserModel(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, nullable=False)
|
||||
name = Column(VARCHAR(length=32), nullable=False, unique=True)
|
||||
password_hash = Column(VARCHAR, nullable=False)
|
||||
|
||||
|
||||
Base.metadata.create_all(engine)
|
24
api/schema/__init__.py
Normal file
24
api/schema/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
import strawberry
|
||||
|
||||
from api.schema.definitions.auth import AuthResult
|
||||
from api.schema.definitions.auth import login
|
||||
from api.schema.definitions.auth import update_me
|
||||
from api.schema.definitions.common import CommonMessage
|
||||
from api.schema.extensions import extensions
|
||||
from api.schema.permissions import IsAuthenticated
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Query:
|
||||
hello: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Mutation:
|
||||
login: AuthResult = strawberry.field(resolver=login)
|
||||
update_me: CommonMessage = strawberry.field(
|
||||
resolver=update_me, permission_classes=[IsAuthenticated]
|
||||
)
|
||||
|
||||
|
||||
schema = strawberry.Schema(query=Query, mutation=Mutation, extensions=extensions)
|
83
api/schema/definitions/auth.py
Normal file
83
api/schema/definitions/auth.py
Normal file
@ -0,0 +1,83 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import strawberry
|
||||
from fastapi import Request
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.hasher import argon2_hasher
|
||||
from api.models import UserModel
|
||||
from api.schema.definitions.common import CommonError
|
||||
from api.schema.definitions.common import CommonMessage
|
||||
from api.schema.definitions.user import User
|
||||
from api.token import decode_user_token
|
||||
from api.token import encode_user_token
|
||||
from api.token import token_from_headers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from strawberry.types import Info
|
||||
from api.schema import Query
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class LoginInput:
|
||||
name: str
|
||||
password: str
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class UpdateUserInput:
|
||||
name: str
|
||||
password: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class AuthSuccess:
|
||||
user: User
|
||||
token: str
|
||||
|
||||
|
||||
AuthResult = strawberry.union("AuthResult", (AuthSuccess, CommonError))
|
||||
LogoutResult = strawberry.union("LogoutResult", (CommonMessage, CommonError))
|
||||
|
||||
|
||||
async def login(root: "Query", info: "Info", body: LoginInput) -> AuthResult:
|
||||
db: Session = info.context["db"]
|
||||
|
||||
stmt = (
|
||||
db.query(UserModel.id, UserModel.name, UserModel.password_hash)
|
||||
.filter(UserModel.name == body.name)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
user: UserModel | None = db.execute(stmt).first() # type: ignore
|
||||
if not user:
|
||||
return CommonError(message="Invalid credentials")
|
||||
|
||||
if not argon2_hasher.verify(user.password_hash, body.password):
|
||||
return CommonError(message="Invalid credentials")
|
||||
|
||||
if argon2_hasher.check_needs_rehash(user.password_hash):
|
||||
user.password_hash = argon2_hasher.hash(body.password)
|
||||
db.commit()
|
||||
|
||||
return AuthSuccess(user=User.from_instance(user), token=encode_user_token(user))
|
||||
|
||||
|
||||
async def update_me(
|
||||
root: "Query", info: "Info", body: UpdateUserInput
|
||||
) -> CommonMessage:
|
||||
db: Session = info.context["db"]
|
||||
req: Request = info.context["request"]
|
||||
|
||||
_, auth_token = token_from_headers(req.headers)
|
||||
token = decode_user_token(auth_token)
|
||||
|
||||
updated_user = {
|
||||
UserModel.name: body.name,
|
||||
UserModel.password_hash: argon2_hasher.hash(body.password),
|
||||
}
|
||||
db.query(UserModel).filter(UserModel.id == token["id"]).update(updated_user)
|
||||
db.commit()
|
||||
|
||||
return CommonMessage(message="Succesfully updated credentials")
|
11
api/schema/definitions/common.py
Normal file
11
api/schema/definitions/common.py
Normal file
@ -0,0 +1,11 @@
|
||||
import strawberry
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class CommonError:
|
||||
message: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class CommonMessage:
|
||||
message: str
|
13
api/schema/definitions/user.py
Normal file
13
api/schema/definitions/user.py
Normal file
@ -0,0 +1,13 @@
|
||||
import strawberry
|
||||
|
||||
from api.models import UserModel
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class User:
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_instance(cls, instance: UserModel):
|
||||
return cls(id=instance.id, name=instance.name)
|
14
api/schema/extensions.py
Normal file
14
api/schema/extensions.py
Normal file
@ -0,0 +1,14 @@
|
||||
from strawberry.extensions import Extension
|
||||
|
||||
from api.database import SessionLocal
|
||||
|
||||
|
||||
class SQLAlchemySession(Extension):
|
||||
def on_request_start(self):
|
||||
self.execution_context.context["db"] = SessionLocal()
|
||||
|
||||
def on_request_end(self):
|
||||
self.execution_context.context["db"].close()
|
||||
|
||||
|
||||
extensions = (SQLAlchemySession,)
|
27
api/schema/permissions.py
Normal file
27
api/schema/permissions.py
Normal file
@ -0,0 +1,27 @@
|
||||
import typing
|
||||
|
||||
from fastapi import Request
|
||||
from strawberry.permission import BasePermission
|
||||
|
||||
from api.token import decode_user_token
|
||||
from api.token import token_from_headers
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from strawberry.types import Info
|
||||
|
||||
|
||||
class IsAuthenticated(BasePermission):
|
||||
message = "User is not authenticated"
|
||||
|
||||
async def has_permission(self, source: typing.Any, info: "Info", **kwargs) -> bool:
|
||||
req: Request = info.context["request"]
|
||||
_, auth_token = token_from_headers(req.headers)
|
||||
|
||||
if len(auth_token) == 0:
|
||||
return False
|
||||
|
||||
try:
|
||||
decode_user_token(auth_token)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
14
api/seed.py
Normal file
14
api/seed.py
Normal file
@ -0,0 +1,14 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.database import SessionLocal
|
||||
from api.hasher import argon2_hasher
|
||||
from api.models import UserModel
|
||||
|
||||
|
||||
def seed():
|
||||
db: Session = SessionLocal()
|
||||
|
||||
if db.query(UserModel).count() == 0:
|
||||
admin = UserModel(name="admin", password_hash=argon2_hasher.hash("admin"))
|
||||
db.add(admin)
|
||||
db.commit()
|
35
api/token.py
Normal file
35
api/token.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
import typing
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
from api.models import UserModel
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
UserTokenData = dict[str, typing.Any]
|
||||
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "")
|
||||
|
||||
|
||||
def encode_user_token(
|
||||
user: UserModel, expire_in: timedelta = timedelta(hours=6)
|
||||
) -> str:
|
||||
payload = {}
|
||||
payload["id"] = user.id
|
||||
payload["exp"] = datetime.now(tz=timezone.utc) + expire_in
|
||||
return jwt.encode(payload, JWT_SECRET)
|
||||
|
||||
|
||||
def decode_user_token(token: str) -> "UserTokenData":
|
||||
return jwt.decode(token, JWT_SECRET, ["HS256"])
|
||||
|
||||
|
||||
def token_from_headers(headers: Headers) -> tuple[str, str]:
|
||||
auth_header: str = headers.get("authorization", "")
|
||||
scheme, _, auth_token = auth_header.partition(" ")
|
||||
return scheme, auth_token
|
Reference in New Issue
Block a user