diff options
Diffstat (limited to 'src/collector/main.py')
-rwxr-xr-x | src/collector/main.py | 267 |
1 files changed, 267 insertions, 0 deletions
diff --git a/src/collector/main.py b/src/collector/main.py new file mode 100755 index 0000000..c363885 --- /dev/null +++ b/src/collector/main.py @@ -0,0 +1,267 @@ +from typing import Dict, Union, List, Callable, Awaitable, Any +import json +import os +import sys +import time + +import uvicorn +from fastapi import Depends, FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.auth_config import AuthConfig +from fastapi_jwt_auth.exceptions import AuthJWTException +from pydantic import BaseModel + +from .db import DictDB +from .schema import get_index_keys, validate_collector_data + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:8001"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["X-Total-Count"], +) + +# TODO: X-Total-Count + + +@app.middleware("http") +async def mock_x_total_count_header(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + + print(type(call_next)) + + response: Response = await call_next(request) + response.headers["X-Total-Count"] = "100" + return response + + +for i in range(10): + try: + db = DictDB() + except Exception as e: + print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.") + else: + break + time.sleep(1) +else: + print("Database did not respond after 10 attempts, quitting.") + sys.exit(-1) + + +def get_pubkey() -> str: + try: + if "JWT_PUBKEY_PATH" in os.environ: + keypath = os.environ["JWT_PUBKEY_PATH"] + else: + keypath = "/opt/certs/public.pem" + + with open(keypath, "r") as fd: + pubkey = fd.read() + except FileNotFoundError: + print(f"Could not find JWT certificate in {keypath}") + sys.exit(-1) + + return pubkey + + +def get_data( + key: Union[int, None] = None, + limit: int = 25, + skip: int = 0, + ip: Union[str, None] = None, + port: Union[int, None] = None, + asn: Union[str, None] = None, + domain: Union[str, None] = None, +) -> List[Dict[str, Any]]: + if key: + return [db.get(key)] + + selectors: Dict[str, Any] = {} + indexes = get_index_keys() + selectors["domain"] = domain + + if ip and "ip" in indexes: + selectors["ip"] = ip + if port and "port" in indexes: + selectors["port"] = port + if asn and "asn" in indexes: + selectors["asn"] = asn + + data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip) + + return data + + +class JWTConfig(BaseModel): + authjwt_algorithm: str = "ES256" + authjwt_public_key: str = get_pubkey() + + +@AuthJWT.load_config # type: ignore +def jwt_config(): + return JWTConfig() + + +@app.exception_handler(AuthJWTException) +def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse: + return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) + + +@app.exception_handler(RuntimeError) +def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse: + return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) + + +@app.get("/sc/v0/get") +async def get( + key: Union[int, None] = None, + limit: int = 25, + skip: int = 0, + ip: Union[str, None] = None, + port: Union[int, None] = None, + asn: Union[str, None] = None, + Authorize: AuthJWT = Depends(), +) -> JSONResponse: + + Authorize.jwt_required() + + data = [] + raw_jwt = Authorize.get_raw_jwt() + + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) + else: + domains = raw_jwt["read"] + + for domain in domains: + data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@app.get("/sc/v0/get/{key}") +async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: + + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["read"] + + data_list = get_data(key) + + # Handle if missing + data = data_list[0] + + if data and data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to view this object", + }, + status_code=400, + ) + + return JSONResponse(content={"status": "success", "docs": data}) + + +# WHY IS AUTH OUTCOMMENTED??? +@app.post("/sc/v0/add") +async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse: + # Authorize.jwt_required() + + try: + json_data = await data.json() + except json.decoder.JSONDecodeError: + return JSONResponse( + content={ + "status": "error", + "message": "Invalid JSON.", + }, + status_code=400, + ) + + key = db.add(json_data) + + if isinstance(key, str): + return JSONResponse( + content={ + "status": "error", + "message": key, + }, + status_code=400, + ) + + return JSONResponse(content={"status": "success", "docs": key}) + + +@app.delete("/sc/v0/delete/{key}") +async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse: + + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if "write" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find write claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["write"] + + data_list = get_data(key) + + # Handle if missing + data = data_list[0] + + if data and data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to delete this object", + }, + status_code=400, + ) + + if db.delete(key) is None: + return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) + + return JSONResponse(content={"status": "success", "docs": data}) + + +# def main(standalone: bool = False): +# print(type(app)) +# if not standalone: +# return app + +# uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") + + +# if __name__ == "__main__": +# main(standalone=True) +# else: +# app = main() |