diff --git a/api/serializers.py b/api/serializers.py index 83b9da1a..1408640a 100644 --- a/api/serializers.py +++ b/api/serializers.py @@ -118,6 +118,34 @@ def validate_ordering(self, ordering): return ordering_validation(ordering) +# NOTE: While the FeedsRequestSerializer enforces ordering on db +# Model fields, aggregation requires ordering by non-model, which is annotated fields +class ASNFeedsOrderingSerializer(FeedsRequestSerializer): + ALLOWED_ORDERING_FIELDS = frozenset( + { + "asn", + "ioc_count", + "total_attack_count", + "total_interaction_count", + "total_login_attempts", + "expected_ioc_count", + "expected_interactions", + "first_seen", + "last_seen", + } + ) + + def validate_ordering(self, ordering): + field_name = ordering.lstrip("-").strip() + + if field_name not in self.ALLOWED_ORDERING_FIELDS: + raise serializers.ValidationError( + {f"Invalid ordering field for ASN aggregated feed: '{field_name}'. Allowed fields: {', '.join(sorted(self.ALLOWED_ORDERING_FIELDS))}"} + ) + + return ordering + + class FeedsResponseSerializer(serializers.Serializer): """ Serializer for feed response data structure. diff --git a/api/urls.py b/api/urls.py index 7202fc10..f426151e 100644 --- a/api/urls.py +++ b/api/urls.py @@ -10,6 +10,7 @@ enrichment_view, feeds, feeds_advanced, + feeds_asn, feeds_pagination, general_honeypot_list, ) @@ -22,6 +23,7 @@ urlpatterns = [ path("feeds/", feeds_pagination), path("feeds/advanced/", feeds_advanced), + path("feeds/asn/", feeds_asn), path("feeds///.", feeds), path("enrichment", enrichment_view), path("cowrie_session", cowrie_session_view), diff --git a/api/views/feeds.py b/api/views/feeds.py index 617df2ac..c6e56524 100644 --- a/api/views/feeds.py +++ b/api/views/feeds.py @@ -10,9 +10,12 @@ permission_classes, ) from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from api.serializers import ASNFeedsOrderingSerializer from api.views.utils import ( FeedRequestParams, + asn_aggregated_queryset, feeds_response, get_queryset, get_valid_feed_types, @@ -116,3 +119,45 @@ def feeds_advanced(request): resp_data = feeds_response(iocs, feed_params, valid_feed_types, dict_only=True, verbose=verbose) return paginator.get_paginated_response(resp_data) return feeds_response(iocs_queryset, feed_params, valid_feed_types, verbose=verbose) + + +@api_view(["GET"]) +@authentication_classes([CookieTokenAuthentication]) +@permission_classes([IsAuthenticated]) +def feeds_asn(request): + """ + Retrieve aggregated IOC feed data grouped by ASN (Autonomous System Number). + + Args: + request: The HTTP request object. + feed_type (str): Filter by feed type (e.g., 'cowrie', 'log4j'). Default: 'all'. + attack_type (str): Filter by attack type (e.g., 'scanner', 'payload_request'). Default: 'all'. + max_age (int): Maximum age of IOCs in days. Default: 3. + min_days_seen (int): Minimum days an IOC must have been observed. Default: 1. + exclude_reputation (str): ';'-separated reputations to exclude (e.g., 'mass scanner'). Default: none. + ordering (str): Aggregation ordering field (e.g., 'total_attack_count', 'asn'). Default: '-ioc_count'. + asn (str, optional): Filter results to a single ASN. + + Returns: + Response: HTTP response with a JSON list of ASN aggregation objects. + Each object contains: + asn (int): ASN number. + ioc_count (int): Number of IOCs for this ASN. + total_attack_count (int): Sum of attack_count for all IOCs. + total_interaction_count (int): Sum of interaction_count for all IOCs. + total_login_attempts (int): Sum of login_attempts for all IOCs. + honeypots (List[str]): Sorted list of unique honeypots that observed these IOCs. + expected_ioc_count (float): Sum of recurrence_probability for all IOCs, rounded to 4 decimals. + expected_interactions (float): Sum of expected_interactions for all IOCs, rounded to 4 decimals. + first_seen (DateTime): Earliest first_seen timestamp among IOCs. + last_seen (DateTime): Latest last_seen timestamp among IOCs. + """ + logger.info(f"request /api/feeds/asn/ with params: {request.query_params}") + feed_params = FeedRequestParams(request.query_params) + valid_feed_types = get_valid_feed_types() + + iocs_qs = get_queryset(request, feed_params, valid_feed_types, is_aggregated=True, serializer_class=ASNFeedsOrderingSerializer) + + asn_aggregates = asn_aggregated_queryset(iocs_qs, request, feed_params) + data = list(asn_aggregates) + return Response(data) diff --git a/api/views/utils.py b/api/views/utils.py index cded1e9b..d2b53da5 100644 --- a/api/views/utils.py +++ b/api/views/utils.py @@ -8,7 +8,7 @@ from django.conf import settings from django.contrib.postgres.aggregates import ArrayAgg -from django.db.models import F +from django.db.models import Count, F, Max, Min, Q, Sum from django.http import HttpResponse, HttpResponseBadRequest, StreamingHttpResponse from rest_framework import status from rest_framework.response import Response @@ -121,7 +121,7 @@ def get_valid_feed_types() -> frozenset[str]: return frozenset(feed_types) -def get_queryset(request, feed_params, valid_feed_types): +def get_queryset(request, feed_params, valid_feed_types, is_aggregated=False, serializer_class=FeedsRequestSerializer): """ Build a queryset to filter IOC data based on the request parameters. @@ -129,6 +129,15 @@ def get_queryset(request, feed_params, valid_feed_types): request: The incoming request object. feed_params: A FeedRequestParams instance. valid_feed_types (frozenset): The set of all valid feed types. + is_aggregated (bool, optional): + - If True, disables slicing (`feed_size`) and model-level ordering. + - Ensures full dataset is available for aggregation or specialized computation. + - Default: False. + serializer_class (class, optional): + - Serializer class used to validate request parameters. + - Allows injecting a custom serializer to enforce rules for specific feed types + (e.g., to restrict ordering fields or validation for specialized feeds). + - Default: `FeedsRequestSerializer`. Returns: QuerySet: The filtered queryset of IOC data. @@ -139,7 +148,7 @@ def get_queryset(request, feed_params, valid_feed_types): f"Age: {feed_params.max_age}, format: {feed_params.format}" ) - serializer = FeedsRequestSerializer( + serializer = serializer_class( data=vars(feed_params), context={"valid_feed_types": valid_feed_types}, ) @@ -168,9 +177,14 @@ def get_queryset(request, feed_params, valid_feed_types): .annotate(value=F("name")) .annotate(honeypots=ArrayAgg("general_honeypot__name")) .distinct() - .order_by(feed_params.ordering)[: int(feed_params.feed_size)] ) + # aggregated endpoints should operate on the full queryset + # to compute sums, counts, and other metrics correctly. + if not is_aggregated: + iocs = iocs.order_by(feed_params.ordering) + iocs = iocs[: int(feed_params.feed_size)] + # save request source for statistics source_ip = str(request.META["REMOTE_ADDR"]) request_source = Statistics(source=source_ip) @@ -317,3 +331,74 @@ def is_sha256hash(string: str) -> bool: bool: True if the string is a valid SHA-256 hash, False otherwise """ return bool(re.fullmatch(r"^[A-Fa-f0-9]{64}$", string)) + + +def resolve_aggregation_ordering(ordering, *, default, fallback_fields=None): + """ + Resolve effective ordering for aggregated endpoints. + + Args + ordering (str or None): The user-provided ordering string from query params. + default (str): The default ordering to use if `ordering` is None or in fallback_fields. + fallback_fields (set[str], optional): A set of orderings that are allowed in other + contexts but should be overridden here. Defaults to None. + + Returns + str: A safe ordering string to use directly in the aggregation query. + """ + fallback_fields = fallback_fields or set() + + if not ordering or ordering in fallback_fields: + return default + + return ordering + + +def asn_aggregated_queryset(iocs_qs, request, feed_params): + """ + Perform DB-level aggregation grouped by ASN. + + Args + iocs_qs (QuerySet): Filtered IOC queryset from get_queryset; + request (Request): The API request object; + feed_params (FeedRequestParams): Validated parameter object + + Returns: A values-grouped queryset with annotated metrics and honeypot arrays. + """ + # optional asn params for single asn filter + asn_filter = request.query_params.get("asn") + if asn_filter: + iocs_qs = iocs_qs.filter(asn=asn_filter) + + aggregated = ( + iocs_qs.exclude(asn__isnull=True) + .values("asn") + .annotate( + ioc_count=Count("id", distinct=True), + total_attack_count=Sum("attack_count", distinct=True), + total_interaction_count=Sum("interaction_count", distinct=True), + total_login_attempts=Sum("login_attempts", distinct=True), + expected_ioc_count=Sum("recurrence_probability", distinct=True), + expected_interactions=Sum("expected_interactions", distinct=True), + honeypots=ArrayAgg( + "general_honeypot__name", + filter=Q(general_honeypot__name__isnull=False), + distinct=True, + ), + first_seen=Min("first_seen"), + last_seen=Max("last_seen"), + ) + ) + + resolved_ordering = resolve_aggregation_ordering( + ordering=feed_params.ordering, + default="-ioc_count", + fallback_fields={"-last_seen"}, + ) + + direction = "-" if resolved_ordering.startswith("-") else "" + field = resolved_ordering.lstrip("-").strip() + + aggregated = aggregated.order_by(f"{direction}{field}") + + return aggregated diff --git a/tests/test_views.py b/tests/test_views.py index fe869d60..c09f01ea 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,9 +1,10 @@ from django.conf import settings from django.test import override_settings +from django.utils import timezone from rest_framework.test import APIClient from api.views.utils import is_ip_address, is_sha256hash -from greedybear.models import GeneralHoneypot, Statistics, ViewType +from greedybear.models import IOC, GeneralHoneypot, Statistics, ViewType from . import CustomTestCase @@ -271,6 +272,185 @@ def test_400_feeds_pagination(self): self.assertEqual(response.status_code, 400) +class FeedsASNViewTestCase(CustomTestCase): + """Tests for ASN aggregated feeds API""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + IOC.objects.all().delete() + cls.testpot1, _ = GeneralHoneypot.objects.get_or_create(name="testpot1", active=True) + cls.testpot2, _ = GeneralHoneypot.objects.get_or_create(name="testpot2", active=True) + + cls.high_asn = "13335" + cls.low_asn = "16276" + + cls.ioc_high1 = IOC.objects.create( + name="high1.example.com", + type="ip", + asn=cls.high_asn, + attack_count=15, + interaction_count=30, + login_attempts=5, + first_seen=timezone.now() - timezone.timedelta(days=10), + recurrence_probability=0.8, + expected_interactions=20.0, + ) + cls.ioc_high1.general_honeypot.add(cls.testpot1, cls.testpot2) + cls.ioc_high1.save() + + cls.ioc_high2 = IOC.objects.create( + name="high2.example.com", + type="ip", + asn=cls.high_asn, + attack_count=5, + interaction_count=10, + login_attempts=2, + first_seen=timezone.now() - timezone.timedelta(days=5), + recurrence_probability=0.3, + expected_interactions=8.0, + ) + cls.ioc_high2.general_honeypot.add(cls.testpot1, cls.testpot2) + cls.ioc_high2.save() + + cls.ioc_low = IOC.objects.create( + name="low.example.com", + type="ip", + asn=cls.low_asn, + attack_count=2, + interaction_count=5, + login_attempts=1, + first_seen=timezone.now(), + recurrence_probability=0.1, + expected_interactions=3.0, + ) + cls.ioc_low.general_honeypot.add(cls.testpot1, cls.testpot2) + cls.ioc_low.save() + + def setUp(self): + self.client = APIClient() + self.client.force_authenticate(user=self.superuser) + self.url = "/api/feeds/asn/" + + def _get_results(self, response): + payload = response.json() + self.assertIsInstance(payload, list) + return payload + + def test_200_asn_feed_aggregated_fields(self): + """Ensure aggregated fields are computed correctly per ASN using dynamic sums""" + response = self.client.get(self.url) + self.assertEqual(response.status_code, 200) + results = self._get_results(response) + + # filtering high ASN + high_item = next((item for item in results if str(item["asn"]) == self.high_asn), None) + self.assertIsNotNone(high_item) + + # getting all IOCs for high ASN from the DB + high_iocs = IOC.objects.filter(asn=self.high_asn) + + self.assertEqual(high_item["ioc_count"], high_iocs.count()) + self.assertEqual(high_item["total_attack_count"], sum(i.attack_count for i in high_iocs)) + self.assertEqual(high_item["total_interaction_count"], sum(i.interaction_count for i in high_iocs)) + self.assertEqual(high_item["total_login_attempts"], sum(i.login_attempts for i in high_iocs)) + self.assertAlmostEqual(high_item["expected_ioc_count"], sum(i.recurrence_probability for i in high_iocs)) + self.assertAlmostEqual(high_item["expected_interactions"], sum(i.expected_interactions for i in high_iocs)) + + # validating first_seen / last_seen dynamically + self.assertEqual(high_item["first_seen"], min(i.first_seen for i in high_iocs).isoformat()) + self.assertEqual(high_item["last_seen"], max(i.last_seen for i in high_iocs).isoformat()) + + # validating honeypots dynamically + expected_honeypots = sorted({hp.name for i in high_iocs for hp in i.general_honeypot.all()}) + self.assertEqual(sorted(high_item["honeypots"]), expected_honeypots) + + def test_200_asn_feed_default_ordering(self): + response = self.client.get(self.url) + self.assertEqual(response.status_code, 200) + results = self._get_results(response) + + # high_asn has ioc_count=2 > low_asn ioc_count=1 + self.assertEqual(str(results[0]["asn"]), self.high_asn) + self.assertEqual(str(results[1]["asn"]), self.low_asn) + + def test_200_asn_feed_ordering_desc_ioc_count(self): + response = self.client.get(self.url + "?ordering=-ioc_count") + self.assertEqual(response.status_code, 200) + results = self._get_results(response) + + self.assertEqual(str(results[0]["asn"]), self.high_asn) + + def test_200_asn_feed_ordering_asc_ioc_count(self): + response = self.client.get(self.url + "?ordering=ioc_count") + self.assertEqual(response.status_code, 200) + results = self._get_results(response) + self.assertEqual(str(results[0]["asn"]), self.low_asn) + + def test_200_asn_feed_ordering_desc_interaction_count(self): + response = self.client.get(self.url + "?ordering=-total_interaction_count") + self.assertEqual(response.status_code, 200) + results = self._get_results(response) + self.assertEqual(str(results[0]["asn"]), self.high_asn) + + def test_200_asn_feed_with_asn_filter(self): + response = self.client.get(self.url + f"?asn={self.high_asn}") + self.assertEqual(response.status_code, 200) + + results = self._get_results(response) + self.assertEqual(len(results), 1) + self.assertEqual(str(results[0]["asn"]), self.high_asn) + + def test_400_asn_feed_invalid_ordering_honeypots(self): + response = self.client.get(self.url + "?ordering=honeypots") + self.assertEqual(response.status_code, 400) + data = response.json() + errors_container = data.get("errors", data) + error_list = errors_container.get("ordering", []) + self.assertTrue(error_list) + error_msg = error_list[0].lower() + self.assertIn("honeypots", error_msg) + self.assertIn("invalid", error_msg) + + def test_400_asn_feed_invalid_ordering_random(self): + response = self.client.get(self.url + "?ordering=xyz123") + self.assertEqual(response.status_code, 400) + data = response.json() + errors_container = data.get("errors", data) + error_list = errors_container.get("ordering", []) + self.assertTrue(error_list) + error_msg = error_list[0].lower() + self.assertIn("xyz123", error_msg) + self.assertIn("invalid", error_msg) + + def test_400_asn_feed_invalid_ordering_model_field_not_in_agg(self): + response = self.client.get(self.url + "?ordering=attack_count") + self.assertEqual(response.status_code, 400) + data = response.json() + errors_container = data.get("errors", data) + error_list = errors_container.get("ordering", []) + self.assertTrue(error_list) + error_msg = error_list[0].lower() + self.assertIn("attack_count", error_msg) + self.assertIn("invalid", error_msg) + + def test_400_asn_feed_ordering_empty_param(self): + response = self.client.get(self.url + "?ordering=") + self.assertEqual(response.status_code, 400) + data = response.json() + errors_container = data.get("errors", data) + error_list = errors_container.get("ordering", []) + self.assertTrue(error_list) + error_msg = error_list[0].lower() + self.assertIn("blank", error_msg) + + def test_asn_feed_ignores_feed_size(self): + response = self.client.get(self.url + "?feed_size=1") + results = response.json() + # aggregation should return all ASNs regardless of feed_size + self.assertEqual(len(results), 2) + + class StatisticsViewTestCase(CustomTestCase): @classmethod def setUpClass(cls):