summaryrefslogtreecommitdiff
path: root/src/collector/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/collector/main.py')
-rwxr-xr-xsrc/collector/main.py267
1 files changed, 267 insertions, 0 deletions
diff --git a/src/collector/main.py b/src/collector/main.py
new file mode 100755
index 0000000..c363885
--- /dev/null
+++ b/src/collector/main.py
@@ -0,0 +1,267 @@
+from typing import Dict, Union, List, Callable, Awaitable, Any
+import json
+import os
+import sys
+import time
+
+import uvicorn
+from fastapi import Depends, FastAPI, Request, Response
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+from fastapi_jwt_auth import AuthJWT
+from fastapi_jwt_auth.auth_config import AuthConfig
+from fastapi_jwt_auth.exceptions import AuthJWTException
+from pydantic import BaseModel
+
+from .db import DictDB
+from .schema import get_index_keys, validate_collector_data
+
+app = FastAPI()
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["http://localhost:8001"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ expose_headers=["X-Total-Count"],
+)
+
+# TODO: X-Total-Count
+
+
+@app.middleware("http")
+async def mock_x_total_count_header(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
+
+ print(type(call_next))
+
+ response: Response = await call_next(request)
+ response.headers["X-Total-Count"] = "100"
+ return response
+
+
+for i in range(10):
+ try:
+ db = DictDB()
+ except Exception as e:
+ print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.")
+ else:
+ break
+ time.sleep(1)
+else:
+ print("Database did not respond after 10 attempts, quitting.")
+ sys.exit(-1)
+
+
+def get_pubkey() -> str:
+ try:
+ if "JWT_PUBKEY_PATH" in os.environ:
+ keypath = os.environ["JWT_PUBKEY_PATH"]
+ else:
+ keypath = "/opt/certs/public.pem"
+
+ with open(keypath, "r") as fd:
+ pubkey = fd.read()
+ except FileNotFoundError:
+ print(f"Could not find JWT certificate in {keypath}")
+ sys.exit(-1)
+
+ return pubkey
+
+
+def get_data(
+ key: Union[int, None] = None,
+ limit: int = 25,
+ skip: int = 0,
+ ip: Union[str, None] = None,
+ port: Union[int, None] = None,
+ asn: Union[str, None] = None,
+ domain: Union[str, None] = None,
+) -> List[Dict[str, Any]]:
+ if key:
+ return [db.get(key)]
+
+ selectors: Dict[str, Any] = {}
+ indexes = get_index_keys()
+ selectors["domain"] = domain
+
+ if ip and "ip" in indexes:
+ selectors["ip"] = ip
+ if port and "port" in indexes:
+ selectors["port"] = port
+ if asn and "asn" in indexes:
+ selectors["asn"] = asn
+
+ data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip)
+
+ return data
+
+
+class JWTConfig(BaseModel):
+ authjwt_algorithm: str = "ES256"
+ authjwt_public_key: str = get_pubkey()
+
+
+@AuthJWT.load_config # type: ignore
+def jwt_config():
+ return JWTConfig()
+
+
+@app.exception_handler(AuthJWTException)
+def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse:
+ return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400)
+
+
+@app.exception_handler(RuntimeError)
+def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse:
+ return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400)
+
+
+@app.get("/sc/v0/get")
+async def get(
+ key: Union[int, None] = None,
+ limit: int = 25,
+ skip: int = 0,
+ ip: Union[str, None] = None,
+ port: Union[int, None] = None,
+ asn: Union[str, None] = None,
+ Authorize: AuthJWT = Depends(),
+) -> JSONResponse:
+
+ Authorize.jwt_required()
+
+ data = []
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if "read" not in raw_jwt:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "Could not find read claim in JWT token",
+ },
+ status_code=400,
+ )
+ else:
+ domains = raw_jwt["read"]
+
+ for domain in domains:
+ data.extend(get_data(key, limit, skip, ip, port, asn, domain))
+
+ return JSONResponse(content={"status": "success", "docs": data})
+
+
+@app.get("/sc/v0/get/{key}")
+async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse:
+
+ Authorize.jwt_required()
+
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if "read" not in raw_jwt:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "Could not find read claim in JWT token",
+ },
+ status_code=400,
+ )
+ else:
+ allowed_domains = raw_jwt["read"]
+
+ data_list = get_data(key)
+
+ # Handle if missing
+ data = data_list[0]
+
+ if data and data["domain"] not in allowed_domains:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "User not authorized to view this object",
+ },
+ status_code=400,
+ )
+
+ return JSONResponse(content={"status": "success", "docs": data})
+
+
+# WHY IS AUTH OUTCOMMENTED???
+@app.post("/sc/v0/add")
+async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse:
+ # Authorize.jwt_required()
+
+ try:
+ json_data = await data.json()
+ except json.decoder.JSONDecodeError:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "Invalid JSON.",
+ },
+ status_code=400,
+ )
+
+ key = db.add(json_data)
+
+ if isinstance(key, str):
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": key,
+ },
+ status_code=400,
+ )
+
+ return JSONResponse(content={"status": "success", "docs": key})
+
+
+@app.delete("/sc/v0/delete/{key}")
+async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse:
+
+ Authorize.jwt_required()
+
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if "write" not in raw_jwt:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "Could not find write claim in JWT token",
+ },
+ status_code=400,
+ )
+ else:
+ allowed_domains = raw_jwt["write"]
+
+ data_list = get_data(key)
+
+ # Handle if missing
+ data = data_list[0]
+
+ if data and data["domain"] not in allowed_domains:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "User not authorized to delete this object",
+ },
+ status_code=400,
+ )
+
+ if db.delete(key) is None:
+ return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400)
+
+ return JSONResponse(content={"status": "success", "docs": data})
+
+
+# def main(standalone: bool = False):
+# print(type(app))
+# if not standalone:
+# return app
+
+# uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")
+
+
+# if __name__ == "__main__":
+# main(standalone=True)
+# else:
+# app = main()