from typing import Dict, Union, List, Callable, Awaitable, Any import json import os import sys import time import uvicorn from fastapi import Depends, FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.auth_config import AuthConfig from fastapi_jwt_auth.exceptions import AuthJWTException from pydantic import BaseModel from .db import DictDB from .schema import get_index_keys, validate_collector_data app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8001"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["X-Total-Count"], ) # TODO: X-Total-Count @app.middleware("http") async def mock_x_total_count_header(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: print(type(call_next)) response: Response = await call_next(request) response.headers["X-Total-Count"] = "100" return response for i in range(10): try: db = DictDB() except Exception as e: print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.") else: break time.sleep(1) else: print("Database did not respond after 10 attempts, quitting.") sys.exit(-1) def get_pubkey() -> str: try: if "JWT_PUBKEY_PATH" in os.environ: keypath = os.environ["JWT_PUBKEY_PATH"] else: keypath = "/opt/certs/public.pem" with open(keypath, "r") as fd: pubkey = fd.read() except FileNotFoundError: print(f"Could not find JWT certificate in {keypath}") sys.exit(-1) return pubkey def get_data( key: Union[int, None] = None, limit: int = 25, skip: int = 0, ip: Union[str, None] = None, port: Union[int, None] = None, asn: Union[str, None] = None, domain: Union[str, None] = None, ) -> List[Dict[str, Any]]: if key: return [db.get(key)] selectors: Dict[str, Any] = {} indexes = get_index_keys() selectors["domain"] = domain if ip and "ip" in indexes: selectors["ip"] = ip if port and "port" in indexes: selectors["port"] = port if asn and "asn" in indexes: selectors["asn"] = asn data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip) return data class JWTConfig(BaseModel): authjwt_algorithm: str = "ES256" authjwt_public_key: str = get_pubkey() @AuthJWT.load_config # type: ignore def jwt_config(): return JWTConfig() @app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse: return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) @app.exception_handler(RuntimeError) def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse: return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) @app.get("/sc/v0/get") async def get( key: Union[int, None] = None, limit: int = 25, skip: int = 0, ip: Union[str, None] = None, port: Union[int, None] = None, asn: Union[str, None] = None, Authorize: AuthJWT = Depends(), ) -> JSONResponse: Authorize.jwt_required() data = [] raw_jwt = Authorize.get_raw_jwt() if "read" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find read claim in JWT token", }, status_code=400, ) else: domains = raw_jwt["read"] for domain in domains: data.extend(get_data(key, limit, skip, ip, port, asn, domain)) return JSONResponse(content={"status": "success", "docs": data}) @app.get("/sc/v0/get/{key}") async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: Authorize.jwt_required() raw_jwt = Authorize.get_raw_jwt() if "read" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find read claim in JWT token", }, status_code=400, ) else: allowed_domains = raw_jwt["read"] data_list = get_data(key) # Handle if missing data = data_list[0] if data and data["domain"] not in allowed_domains: return JSONResponse( content={ "status": "error", "message": "User not authorized to view this object", }, status_code=400, ) return JSONResponse(content={"status": "success", "docs": data}) # WHY IS AUTH OUTCOMMENTED??? @app.post("/sc/v0/add") async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse: # Authorize.jwt_required() try: json_data = await data.json() except json.decoder.JSONDecodeError: return JSONResponse( content={ "status": "error", "message": "Invalid JSON.", }, status_code=400, ) key = db.add(json_data) if isinstance(key, str): return JSONResponse( content={ "status": "error", "message": key, }, status_code=400, ) return JSONResponse(content={"status": "success", "docs": key}) @app.delete("/sc/v0/delete/{key}") async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse: Authorize.jwt_required() raw_jwt = Authorize.get_raw_jwt() if "write" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find write claim in JWT token", }, status_code=400, ) else: allowed_domains = raw_jwt["write"] data_list = get_data(key) # Handle if missing data = data_list[0] if data and data["domain"] not in allowed_domains: return JSONResponse( content={ "status": "error", "message": "User not authorized to delete this object", }, status_code=400, ) if db.delete(key) is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) return JSONResponse(content={"status": "success", "docs": data}) # def main(standalone: bool = False): # print(type(app)) # if not standalone: # return app # uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") # if __name__ == "__main__": # main(standalone=True) # else: # app = main()