diff --git a/API/app/application/users_repository.py b/API/app/application/users_repository.py index 4f08e5a..f2841a5 100644 --- a/API/app/application/users_repository.py +++ b/API/app/application/users_repository.py @@ -13,6 +13,9 @@ class UsersRepository: def get_by_id(self, user_id: int): return self.db.query(User).filter(User.id == user_id).first() + def get_by_id_with_role(self, user_id: int): + return self.db.query(User).filter(User.id == user_id).join(User.role).first() + def get_by_login(self, login: str): return self.db.query(User).filter(User.login == login).first() diff --git a/API/app/controllers/answer_files_entity.py b/API/app/controllers/answer_files_entity.py index 3f95da5..cecf710 100644 --- a/API/app/controllers/answer_files_entity.py +++ b/API/app/controllers/answer_files_entity.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from app.application.answer_files_repository import AnswerFilesRepository from app.database.dependencies import get_db from app.domain.entities.answer_files_entitity import AnswerFileEntity -from app.infrastructure.dependencies import get_current_user +from app.infrastructure.dependencies import require_admin router = APIRouter() @@ -14,8 +14,7 @@ router = APIRouter() @router.get("/answer_files/", response_model=List[AnswerFileEntity]) def get_answer_files( db: Session = Depends(get_db), - user=Depends(get_current_user), + user=Depends(require_admin), ): answer_files_service = AnswerFilesRepository(db) return answer_files_service.get_all() - diff --git a/API/app/controllers/auth_router.py b/API/app/controllers/auth_router.py index 7068f1e..24fd530 100644 --- a/API/app/controllers/auth_router.py +++ b/API/app/controllers/auth_router.py @@ -4,12 +4,13 @@ from sqlalchemy.orm import Session from app.database.dependencies import get_db from app.domain.entities.auth_entity import AuthEntity +from app.domain.entities.token_entity import TokenEntity from app.infrastructure.auth_service import AuthService router = APIRouter() -@router.get("/login/", response_model=dict) +@router.post("/login/", response_model=TokenEntity) def login( auth_data: AuthEntity, db: Session = Depends(get_db) diff --git a/API/app/domain/entities/token_entity.py b/API/app/domain/entities/token_entity.py new file mode 100644 index 0000000..cc190c0 --- /dev/null +++ b/API/app/domain/entities/token_entity.py @@ -0,0 +1,11 @@ +from typing import Optional + +from pydantic import BaseModel + + +class TokenEntity(BaseModel): + access_token: str + user_id: int + + class Config: + from_attributes = True diff --git a/API/app/infrastructure/dependencies.py b/API/app/infrastructure/dependencies.py index d542ab7..a3939d3 100644 --- a/API/app/infrastructure/dependencies.py +++ b/API/app/infrastructure/dependencies.py @@ -1,11 +1,12 @@ -import jwt from fastapi import Depends, HTTPException, Security from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from sqlalchemy.orm import Session +import jwt +from app.domain.models.users import User +from app.settings import get_auth_data from app.application.users_repository import UsersRepository from app.database.dependencies import get_db -from app.settings import get_auth_data +from sqlalchemy.orm import Session security = HTTPBearer() @@ -14,23 +15,28 @@ def get_current_user( credentials: HTTPAuthorizationCredentials = Security(security), db: Session = Depends(get_db) ): - token = credentials.credentials auth_data = get_auth_data() try: - payload = jwt.decode(token, auth_data["secret_key"], algorithms=[auth_data["algorithm"]]) - user_id = payload.get("user_id") - - if user_id is None: - raise HTTPException(status_code=401, detail="Invalid token") - - user = UsersRepository(db).get_by_id(user_id) - if user is None: - raise HTTPException(status_code=401, detail="User not found") - - return user - + payload = jwt.decode(credentials.credentials, auth_data["secret_key"], algorithms=[auth_data["algorithm"]]) except jwt.ExpiredSignatureError: - raise HTTPException(status_code=401, detail="Token expired") + raise HTTPException(status_code=401, detail="Token has expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") + + user_id = payload.get("user_id") + if user_id is None: + raise HTTPException(status_code=401, detail="Invalid token") + + user = UsersRepository(db).get_by_id_with_role(user_id) + if user is None: + raise HTTPException(status_code=401, detail="User not found") + + return user + + +def require_admin(user: User = Depends(get_current_user)): + if user.role.title != "Администратор": + raise HTTPException(status_code=403, detail="Access denied") + + return user