diff options
Diffstat (limited to 'src/collector')
-rw-r--r-- | src/collector/__init__.py | 4 | ||||
-rwxr-xr-x | src/collector/db.py | 148 | ||||
-rwxr-xr-x | src/collector/main.py | 267 | ||||
-rw-r--r-- | src/collector/py.typed | 0 | ||||
-rw-r--r-- | src/collector/schema.py | 136 |
5 files changed, 555 insertions, 0 deletions
diff --git a/src/collector/__init__.py b/src/collector/__init__.py new file mode 100644 index 0000000..6530fdd --- /dev/null +++ b/src/collector/__init__.py @@ -0,0 +1,4 @@ +"""Collector +""" + +__version__ = "1.03" diff --git a/src/collector/db.py b/src/collector/db.py new file mode 100755 index 0000000..0bfa014 --- /dev/null +++ b/src/collector/db.py @@ -0,0 +1,148 @@ +# A database storing dictionaries, keyed on a timestamp. value = A +# dict which will be stored as a JSON object encoded in UTF-8. Note +# that dict keys of type integer or float will become strings while +# values will keep their type. + +# Note that there's a (slim) chance that you'd stomp on the previous +# value if you're too quick with generating the timestamps, ie +# invoking time.time() several times quickly enough. + +from typing import Dict, List, Tuple, Union, Any +import os +import sys +import time + +from src import couch +from .schema import as_index_list, validate_collector_data + + +class DictDB: + def __init__(self) -> None: + """ + Check if the database exists, otherwise we will create it together + with the indexes specified in index.py. + """ + + print(os.environ) + + try: + self.database = os.environ["COUCHDB_NAME"] + self.hostname = os.environ["COUCHDB_HOSTNAME"] + self.username = os.environ["COUCHDB_USER"] + self.password = os.environ["COUCHDB_PASSWORD"] + except KeyError: + print( + "The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME," + + " COUCHDB_USER and COUCHDB_PASSWORD must be set." + ) + sys.exit(-1) + + if "COUCHDB_PORT" in os.environ: + couchdb_port = os.environ["COUCHDB_PORT"] + else: + couchdb_port = "5984" + + self.server = couch.client.Server(f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/") + + try: + self.couchdb = self.server.database(self.database) + print("Database already exists") + except couch.exceptions.NotFound: + print("Creating database and indexes.") + self.couchdb = self.server.create(self.database) + + for i in as_index_list(): + self.couchdb.index(i) + + self._ts = time.time() + + def unique_key(self) -> int: + """ + Create a unique key based on the current time. We will use this as + the ID for any new documents we store in CouchDB. + """ + + ts = time.time() + while round(ts * 1000) == self._ts: + ts = time.time() + self._ts = round(ts * 1000) + + return self._ts + + # Why batch_write??? + def add(self, data: Union[List[Dict[str, Any]], Dict[str, Any]]) -> Union[str, Tuple[str, str]]: + """ + Store a document in CouchDB. + """ + + if isinstance(data, List): + for item in data: + error = validate_collector_data(item) + if error != "": + return error + item["_id"] = str(self.unique_key()) + ret: Tuple[str, str] = self.couchdb.save_bulk(data) + else: + error = validate_collector_data(data) + if error != "": + return error + data["_id"] = str(self.unique_key()) + ret = self.couchdb.save(data) + + return ret + + def get(self, key: int) -> Dict[str, Any]: + """ + Get a document based on its ID, return an empty dict if not found. + """ + + try: + doc: Dict[str, Any] = self.couchdb.get(key) + except couch.exceptions.NotFound: + doc = {} + + return doc + + # + # def slice(self, key_from=None, key_to=None): + # pass + + def search(self, limit: int = 25, skip: int = 0, **kwargs: Any) -> List[Dict[str, Any]]: + """ + Execute a Mango query, ideally we should have an index matching + the query otherwise things will be slow. + """ + + data: List[Dict[str, Any]] = [] + selector: Dict[str, Any] = {} + + try: + limit = int(limit) + skip = int(skip) + except ValueError: + limit = 25 + skip = 0 + + if kwargs: + selector = {"limit": limit, "skip": skip, "selector": {}} + + for key in kwargs: + if kwargs[key] and kwargs[key].isnumeric(): + kwargs[key] = int(kwargs[key]) + selector["selector"][key] = {"$eq": kwargs[key]} + + for doc in self.couchdb.find(selector, wrapper=None, limit=5): + data.append(doc) + + return data + + def delete(self, key: int) -> Union[int, None]: + """ + Delete a document based on its ID. + """ + try: + self.couchdb.delete(key) + except couch.exceptions.NotFound: + return None + + return key diff --git a/src/collector/main.py b/src/collector/main.py new file mode 100755 index 0000000..c363885 --- /dev/null +++ b/src/collector/main.py @@ -0,0 +1,267 @@ +from typing import Dict, Union, List, Callable, Awaitable, Any +import json +import os +import sys +import time + +import uvicorn +from fastapi import Depends, FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.auth_config import AuthConfig +from fastapi_jwt_auth.exceptions import AuthJWTException +from pydantic import BaseModel + +from .db import DictDB +from .schema import get_index_keys, validate_collector_data + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:8001"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["X-Total-Count"], +) + +# TODO: X-Total-Count + + +@app.middleware("http") +async def mock_x_total_count_header(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + + print(type(call_next)) + + response: Response = await call_next(request) + response.headers["X-Total-Count"] = "100" + return response + + +for i in range(10): + try: + db = DictDB() + except Exception as e: + print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.") + else: + break + time.sleep(1) +else: + print("Database did not respond after 10 attempts, quitting.") + sys.exit(-1) + + +def get_pubkey() -> str: + try: + if "JWT_PUBKEY_PATH" in os.environ: + keypath = os.environ["JWT_PUBKEY_PATH"] + else: + keypath = "/opt/certs/public.pem" + + with open(keypath, "r") as fd: + pubkey = fd.read() + except FileNotFoundError: + print(f"Could not find JWT certificate in {keypath}") + sys.exit(-1) + + return pubkey + + +def get_data( + key: Union[int, None] = None, + limit: int = 25, + skip: int = 0, + ip: Union[str, None] = None, + port: Union[int, None] = None, + asn: Union[str, None] = None, + domain: Union[str, None] = None, +) -> List[Dict[str, Any]]: + if key: + return [db.get(key)] + + selectors: Dict[str, Any] = {} + indexes = get_index_keys() + selectors["domain"] = domain + + if ip and "ip" in indexes: + selectors["ip"] = ip + if port and "port" in indexes: + selectors["port"] = port + if asn and "asn" in indexes: + selectors["asn"] = asn + + data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip) + + return data + + +class JWTConfig(BaseModel): + authjwt_algorithm: str = "ES256" + authjwt_public_key: str = get_pubkey() + + +@AuthJWT.load_config # type: ignore +def jwt_config(): + return JWTConfig() + + +@app.exception_handler(AuthJWTException) +def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse: + return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) + + +@app.exception_handler(RuntimeError) +def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse: + return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) + + +@app.get("/sc/v0/get") +async def get( + key: Union[int, None] = None, + limit: int = 25, + skip: int = 0, + ip: Union[str, None] = None, + port: Union[int, None] = None, + asn: Union[str, None] = None, + Authorize: AuthJWT = Depends(), +) -> JSONResponse: + + Authorize.jwt_required() + + data = [] + raw_jwt = Authorize.get_raw_jwt() + + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) + else: + domains = raw_jwt["read"] + + for domain in domains: + data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@app.get("/sc/v0/get/{key}") +async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: + + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["read"] + + data_list = get_data(key) + + # Handle if missing + data = data_list[0] + + if data and data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to view this object", + }, + status_code=400, + ) + + return JSONResponse(content={"status": "success", "docs": data}) + + +# WHY IS AUTH OUTCOMMENTED??? +@app.post("/sc/v0/add") +async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse: + # Authorize.jwt_required() + + try: + json_data = await data.json() + except json.decoder.JSONDecodeError: + return JSONResponse( + content={ + "status": "error", + "message": "Invalid JSON.", + }, + status_code=400, + ) + + key = db.add(json_data) + + if isinstance(key, str): + return JSONResponse( + content={ + "status": "error", + "message": key, + }, + status_code=400, + ) + + return JSONResponse(content={"status": "success", "docs": key}) + + +@app.delete("/sc/v0/delete/{key}") +async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse: + + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if "write" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find write claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["write"] + + data_list = get_data(key) + + # Handle if missing + data = data_list[0] + + if data and data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to delete this object", + }, + status_code=400, + ) + + if db.delete(key) is None: + return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) + + return JSONResponse(content={"status": "success", "docs": data}) + + +# def main(standalone: bool = False): +# print(type(app)) +# if not standalone: +# return app + +# uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") + + +# if __name__ == "__main__": +# main(standalone=True) +# else: +# app = main() diff --git a/src/collector/py.typed b/src/collector/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/collector/py.typed diff --git a/src/collector/schema.py b/src/collector/schema.py new file mode 100644 index 0000000..e291f10 --- /dev/null +++ b/src/collector/schema.py @@ -0,0 +1,136 @@ +from typing import List, Any, Dict +import json +import sys +import traceback + +import jsonschema + +# fmt:off +# NOTE: Commented out properties are left intentionally, so it is easier to see +# what properties are optional. +schema = { + "$schema": "http://json-schema.org/schema#", + "type": "object", + "properties": { + "document_version": {"type": "integer"}, + "ip": {"type": "string"}, + "port": {"type": "integer"}, + "whois_description": {"type": "string"}, + "asn": {"type": "string"}, + "asn_country_code": {"type": "string"}, + "ptr": {"type": "string"}, + "abuse_mail": {"type": "string"}, + "domain": {"type": "string"}, + "timestamp": {"type": "string", "format": "date-time"}, + "display_name": {"type": "string"}, + "description": {"type": "string"}, + "custom_data": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "display_name": {"type": "string"}, + "data": {"type": ["string", "boolean", "integer"]}, + "description": {"type": "string"}, + }, + "required": [ + "display_name", + "data", + # "description" + ] + }, + }, + }, + "result": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "display_name": {"type": "string"}, + "vulnerable": {"type": "boolean"}, + "investigation_needed": {"type": "boolean"}, + "reliability": {"type": "integer"}, + "description": {"type": "string"}, + }, + "oneOf": [ + { + "required": [ + "display_name", + "vulnerable", + # "reliability", # TODO: reliability is required if vulnerable = true + # "description", + ] + }, + { + "required": [ + "display_name", + "investigation_needed", + # "reliability", # TODO: reliability is required if investigation_needed = true + # "description", + ] + }, + ] + }, + }, + }, + }, + "required": [ + "document_version", + "ip", + "port", + "whois_description", + "asn", + "asn_country_code", + "ptr", + "abuse_mail", + "domain", + "timestamp", + "display_name", + # "description", + # "custom_data", + "result", + ], +} +# fmt:on + + +def get_index_keys() -> List[str]: + keys: List[str] = [] + for key in schema["properties"]: + keys.append(key) + return keys + + +def as_index_list() -> List[Dict[str, Any]]: + index_list: List[Dict[str, Any]] = [] + for key in schema["properties"]: + name = f"{key}-json-index" + index = { + "index": { + "fields": [ + key, + ] + }, + "name": name, + "type": "json", + } + index_list.append(index) + + return index_list + + +def validate_collector_data(json_blob: Dict[str, Any]) -> str: + try: + jsonschema.validate(json_blob, schema, format_checker=jsonschema.FormatChecker()) + except jsonschema.exceptions.ValidationError as e: + return f"Validation failed with error: {e.message}" + return "" + + +if __name__ == "__main__": + with open(sys.argv[1]) as fd: + json_data = json.loads(fd.read()) + + print(validate_collector_data(json_data)) |