Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 108 additions & 93 deletions src/murfey/server/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import secrets
import time
from logging import getLogger
from typing import Dict
from uuid import uuid4

import aiohttp
Expand All @@ -18,7 +17,7 @@
from passlib.context import CryptContext
from pydantic import BaseModel
from sqlmodel import Session, create_engine, select
from typing_extensions import Annotated
from typing_extensions import Annotated, Any

from murfey.server.murfey_db import murfey_db, url
from murfey.util.api import url_path_for
Expand All @@ -40,17 +39,19 @@
auth_url = security_config.auth_url
ALGORITHM = security_config.auth_algorithm or "HS256"
SECRET_KEY = security_config.auth_key or secrets.token_hex(32)
if security_config.auth_type == "password":
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
else:
oauth2_scheme = APIKeyCookie(name=security_config.cookie_key)
if security_config.instrument_auth_type == "token":
instrument_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
else:
instrument_oauth2_scheme = lambda *args, **kwargs: None
oauth2_scheme = (
OAuth2PasswordBearer(tokenUrl="auth/token")
if security_config.auth_type == "password"
else APIKeyCookie(name=security_config.cookie_key)
)
instrument_oauth2_scheme = (
OAuth2PasswordBearer(tokenUrl="auth/token")
if security_config.instrument_auth_type == "token"
else lambda *args, **kwargs: None
)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

instrument_server_tokens: Dict[float, dict] = {}
instrument_server_tokens: dict[float, dict] = {}

# Set up database engine
try:
Expand All @@ -66,14 +67,30 @@ def hash_password(password: str) -> str:

"""
=======================================================================================
TOKEN VALIDATION FUNCTIONS
VALIDATION FUNCTIONS
=======================================================================================

Functions and helpers used to validate incoming requests from both the client and
the frontend. 'validate_token()' and 'validate_instrument_token()' are imported
int the other FastAPI modules and attached as dependencies to the routers.
the frontend.

'validate_token()' and 'validate_instrument_token()' are imported in the other FastAPI
modules and attached as dependencies to the routers. They validate the tokens passed
around internally by Murfey to ensure that the request is valid.

'validate_instrument_server_session_access()' and 'validate_frontend_session_access()'
are used to verify the IDs of sessions ot be accessed, and are attached as dependencies
to them.

'validate_user_instrument_access()' is used to verify the instrument server name being
accessed by the frontend, and is attached as a dependency as well.
"""

# Essential headers used for authentication to forward along if present
AUTH_HEADERS = (
"authorization",
"x-auth-request-access-token",
)


def check_user(username: str) -> bool:
try:
Expand All @@ -84,6 +101,39 @@ def check_user(username: str) -> bool:
return username in [u.username for u in users]


async def submit_to_auth_endpoint(
url_subpath: str,
request: Request,
token: str,
) -> dict[str, Any]:
"""
Helper function to forward incoming requests to an authentication server
to verify that they are allowed to inspect the
"""

# Forward only essentials auth-related headers
headers = {
key: value
for key, value in dict(request.headers).items()
if key.lower() in AUTH_HEADERS
}
if security_config.auth_type == "password":
headers["authorization"] = f"Bearer {token}"
cookies = (
{security_config.cookie_key: token}
if security_config.auth_type == "cookie"
else {}
)
async with aiohttp.ClientSession(cookies=cookies) as session:
async with session.get(
f"{auth_url}/{url_subpath}",
headers=headers,
) as response:
success = response.status == 200
validation_outcome: dict[str, Any] = await response.json()
return validation_outcome if success and validation_outcome else {"valid": False}


async def validate_token(
token: Annotated[str, Depends(oauth2_scheme)],
request: Request,
Expand All @@ -94,25 +144,9 @@ async def validate_token(
try:
# Validate using auth URL if provided; will error if invalid
if auth_url:
# Extract and forward headers as-is
headers = dict(request.headers)
# Update/add authorization header if authenticating using password
if security_config.auth_type == "password":
headers["authorization"] = f"Bearer {token}"
# Forward the cookie along if authenticating using cookie
cookies = (
{security_config.cookie_key: token}
if security_config.auth_type == "cookie"
else {}
)
async with aiohttp.ClientSession(cookies=cookies) as session:
async with session.get(
f"{auth_url}/validate_token",
headers=headers,
) as response:
success = response.status == 200
validation_outcome = await response.json()
if not (success and validation_outcome.get("valid")):
if not (
await submit_to_auth_endpoint("validate_token", request, token)
).get("valid"):
raise JWTError
# If authenticating using cookies; an auth URL MUST be provided
else:
Expand Down Expand Up @@ -199,20 +233,6 @@ async def validate_instrument_token(
return None


"""
=======================================================================================
SESSION ID VALIDATION
=======================================================================================

Annotated ints are defined here that trigger validation of the session IDs in incoming
requests, verifying that the session is allowed to access the particular visit.

The 'MurfeySessionID...' types are imported and used in the type hints of the endpoint
functions in the other FastAPI routers, depending on whether requests from the frontend
or the instrument are expected.
"""


def get_visit_name(session_id: int) -> str:
with Session(engine) as murfey_db:
return (
Expand All @@ -222,46 +242,6 @@ def get_visit_name(session_id: int) -> str:
)


async def submit_to_auth_endpoint(url_subpath: str, token: str) -> None:
if auth_url:
headers = (
{}
if security_config.auth_type == "cookie"
else {"Authorization": f"Bearer {token}"}
)
cookies = (
{security_config.cookie_key: token}
if security_config.auth_type == "cookie"
else {}
)
async with aiohttp.ClientSession(cookies=cookies) as session:
async with session.get(
f"{auth_url}/{url_subpath}",
headers=headers,
) as response:
success = response.status == 200
validation_outcome: dict = await response.json()
if not (success and validation_outcome.get("valid")):
logger.warning("Unauthorised visit access request from frontend")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You do not have access to this visit",
headers={"WWW-Authenticate": "Bearer"},
)


async def validate_frontend_session_access(
session_id: int,
token: Annotated[str, Depends(oauth2_scheme)],
) -> int:
"""
Validates whether a frontend request can access information about this session
"""
visit_name = get_visit_name(session_id)
await submit_to_auth_endpoint(f"validate_visit_access/{visit_name}", token)
return session_id


async def validate_instrument_server_session_access(
session_id: int,
token: Annotated[str, Depends(instrument_oauth2_scheme)],
Expand Down Expand Up @@ -294,25 +274,60 @@ async def validate_instrument_server_session_access(
return session_id


async def validate_frontend_session_access(
session_id: int,
request: Request,
token: Annotated[str, Depends(oauth2_scheme)],
) -> int:
"""
Validates whether a frontend request can access information about this session
"""
visit_name = get_visit_name(session_id)
if auth_url:
if not (
await submit_to_auth_endpoint(
f"validate_visit_access/{visit_name}",
request,
token,
)
).get("valid"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You do not have access to this visit",
headers={"WWW-Authenticate": "Bearer"},
)
return session_id


async def validate_user_instrument_access(
instrument_name: str,
request: Request,
token: Annotated[str, Depends(oauth2_scheme)],
) -> str:
"""
Validates whether a frontend request can access information about this instrument
"""
await submit_to_auth_endpoint(
f"validate_instrument_access/{instrument_name}", token
)
if auth_url:
if not (
await submit_to_auth_endpoint(
f"validate_instrument_access/{instrument_name}",
request,
token,
)
).get("valid"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You do not have access to this instrument",
headers={"WWW-Authenticate": "Bearer"},
)
return instrument_name


# Set validation conditions for the session ID based on where the request is from
MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)]
# Create annotated session ID and instrument name for endpoints that need to verify them
MurfeySessionIDInstrument = Annotated[
int, Depends(validate_instrument_server_session_access)
]

MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)]
MurfeyInstrumentNameFrontend = Annotated[str, Depends(validate_user_instrument_access)]


Expand Down
Loading