diff --git a/armis_sdk/clients/collectors_client.py b/armis_sdk/clients/collectors_client.py new file mode 100644 index 0000000..2c48cd5 --- /dev/null +++ b/armis_sdk/clients/collectors_client.py @@ -0,0 +1,125 @@ +import contextlib +from typing import IO +from typing import AsyncIterator +from typing import Generator +from typing import Union + +import httpx +import universalasync + +from armis_sdk.core import response_utils +from armis_sdk.core.base_entity_client import BaseEntityClient +from armis_sdk.entities.collector_image import CollectorImage +from armis_sdk.entities.download_progress import DownloadProgress +from armis_sdk.enums.collector_image_type import CollectorImageType + + +@universalasync.wrap +class CollectorsClient(BaseEntityClient): + # pylint: disable=line-too-long + """ + A client for interacting with Armis collectors. + + The primary entity for this client is [CollectorImage][armis_sdk.entities.collector_image.CollectorImage]. + """ + + async def download_image( + self, + destination: Union[str, IO[bytes]], + image_type: CollectorImageType = CollectorImageType.OVA, + ) -> AsyncIterator[DownloadProgress]: + """Download a collector image to a specified destination path / file. + + Args: + destination: The file path or file-like object where the collector image will be saved. + image_type: The type of collector image to download. Defaults to "OVA". + + Returns: + An (async) iterator of `DownloadProgress` object. + + Example: + ```python linenums="1" hl_lines="10 15" + import asyncio + + from armis_sdk.clients.collectors_client import CollectorsClient + + + async def main(): + collectors_client = CollectorsClient() + + # Download to a path + async for progress in armis_sdk.collectors.download_image("/tmp/collector.ova"): + print(progress.percent) + + # Download to a file + with open("/tmp/collector.ova", "wb") as file: + async for progress in armis_sdk.collectors.download_image(file): + print(progress.percent) + + asyncio.run(main()) + ``` + Will output: + ```python linenums="1" + 1% + 2% + 3% + ``` + etc. + """ + collector_image = await self.get_image(image_type=image_type) + async with httpx.AsyncClient() as client: + async with client.stream("GET", collector_image.url) as response: + response.raise_for_status() + total_size = int(response.headers.get("Content-Length", "0")) + # pylint: disable-next=contextmanager-generator-missing-cleanup + with self.open_file(destination) as file: + async for chunk in response.aiter_bytes(): + file.write(chunk) + yield DownloadProgress(downloaded=file.tell(), total=total_size) + + async def get_image( + self, image_type: CollectorImageType = CollectorImageType.OVA + ) -> CollectorImage: + """Get collector image information including download URL and credentials. + + Args: + image_type: The type of collector image to retrieve. Defaults to "OVA". + + Returns: + A `CollectorImage` object. + + Example: + ```python linenums="1" hl_lines="8" + import asyncio + + from armis_sdk.clients.collectors_client import CollectorsClient + + + async def main(): + collectors_client = CollectorsClient() + print(await collectors_client.get_image(image_type="OVA")) + + asyncio.run(main()) + ``` + Will output: + ```python linenums="1" + CollectorImage(url="...", ...) + ``` + """ + async with self._armis_client.client() as client: + response = await client.get( + "/v3/collectors/_image", params={"image_type": image_type.value} + ) + data = response_utils.get_data_dict(response) + return CollectorImage.model_validate(data) + + @classmethod + @contextlib.contextmanager + def open_file( + cls, destination: Union[str, IO[bytes]] + ) -> Generator[IO[bytes], None, None]: + if isinstance(destination, str): + with open(destination, "wb") as file: + yield file + else: + yield destination diff --git a/armis_sdk/core/armis_sdk.py b/armis_sdk/core/armis_sdk.py index 7212aaa..e4653d4 100644 --- a/armis_sdk/core/armis_sdk.py +++ b/armis_sdk/core/armis_sdk.py @@ -1,6 +1,7 @@ from typing import Optional from armis_sdk.clients.assets_client import AssetsClient +from armis_sdk.clients.collectors_client import CollectorsClient from armis_sdk.clients.data_export_client import DataExportClient from armis_sdk.clients.device_custom_properties_client import ( DeviceCustomPropertiesClient, @@ -19,6 +20,7 @@ class ArmisSdk: # pylint: disable=too-few-public-methods Attributes: client (ArmisClient): An instance of [ArmisClient][armis_sdk.core.armis_client.ArmisClient] assets (AssetsClient): An instance of [AssetsClient][armis_sdk.clients.assets_client.AssetsClient] + collectors (CollectorsClient): An instance of [CollectorsClient][armis_sdk.clients.collectors_client.CollectorsClient] data_export (DataExportClient): An instance of [DataExportClient][armis_sdk.clients.data_export_client.DataExportClient] device_custom_properties (DeviceCustomPropertiesClient): An instance of [DeviceCustomPropertiesClient][armis_sdk.clients.device_custom_properties_client.DeviceCustomPropertiesClient] sites (SitesClient): An instance of [SitesClient][armis_sdk.clients.sites_client.SitesClient] @@ -42,6 +44,7 @@ async def main(): def __init__(self, credentials: Optional[ClientCredentials] = None): self.client: ArmisClient = ArmisClient(credentials=credentials) self.assets: AssetsClient = AssetsClient(self.client) + self.collectors: CollectorsClient = CollectorsClient(self.client) self.data_export: DataExportClient = DataExportClient(self.client) self.device_custom_properties: DeviceCustomPropertiesClient = ( DeviceCustomPropertiesClient(self.client) diff --git a/armis_sdk/entities/collector_image.py b/armis_sdk/entities/collector_image.py new file mode 100644 index 0000000..bbb0b3c --- /dev/null +++ b/armis_sdk/entities/collector_image.py @@ -0,0 +1,24 @@ +import datetime + +from pydantic import Field + +from armis_sdk.core.base_entity import BaseEntity +from armis_sdk.enums.collector_image_type import CollectorImageType + + +class CollectorImage(BaseEntity): + """ + An entity that represents the details required to download and run a collector image. + """ + + image_type: CollectorImageType = Field(strict=False) + """The type of the image.""" + + image_password: str + """The password for the OS that is encapsulated by the image.""" + + url: str + """The temporary, presigned URL from which the OS image file can be downloaded.""" + + url_expiration_date: datetime.datetime = Field(strict=False) + """Expiration date of the URL.""" diff --git a/armis_sdk/entities/download_progress.py b/armis_sdk/entities/download_progress.py new file mode 100644 index 0000000..88ba17a --- /dev/null +++ b/armis_sdk/entities/download_progress.py @@ -0,0 +1,14 @@ +from armis_sdk.core.base_entity import BaseEntity + + +class DownloadProgress(BaseEntity): + downloaded: int + """How much bytes were already downloaded.""" + + total: int + """Total number of bytes to download.""" + + @property + def percent(self) -> str: + """Percentage of progress.""" + return f"{self.downloaded/self.total:.4%}" diff --git a/armis_sdk/enums/collector_image_type.py b/armis_sdk/enums/collector_image_type.py new file mode 100644 index 0000000..d6c73b5 --- /dev/null +++ b/armis_sdk/enums/collector_image_type.py @@ -0,0 +1,36 @@ +from enum import Enum + + +class CollectorImageType(Enum): + DARWIN_AMD64_BROKER = "DARWIN_AMD64_BROKER" + """""" + + DARWIN_ARM64_BROKER = "DARWIN_ARM64_BROKER" + """""" + + DEB = "DEB" + """""" + + LINUX_AMD64_BROKER = "LINUX_AMD64_BROKER" + """""" + + LINUX_ARM64_BROKER = "LINUX_ARM64_BROKER" + """""" + + OVA = "OVA" + """""" + + QCOW2 = "QCOW2" + """""" + + RPM = "RPM" + """""" + + VHD = "VHD" + """""" + + VHDX = "VHDX" + """""" + + WINDOWS_BROKER = "WINDOWS_BROKER" + """""" diff --git a/docs/clients/CollectorsClient.md b/docs/clients/CollectorsClient.md new file mode 100644 index 0000000..bdb5570 --- /dev/null +++ b/docs/clients/CollectorsClient.md @@ -0,0 +1 @@ +::: armis_sdk.clients.collectors_client.CollectorsClient diff --git a/docs/entities/CollectorImage.md b/docs/entities/CollectorImage.md new file mode 100644 index 0000000..e9a6677 --- /dev/null +++ b/docs/entities/CollectorImage.md @@ -0,0 +1 @@ +::: armis_sdk.entities.collector_image.CollectorImage diff --git a/docs/entities/DownloadProgress.md b/docs/entities/DownloadProgress.md new file mode 100644 index 0000000..6f5fd36 --- /dev/null +++ b/docs/entities/DownloadProgress.md @@ -0,0 +1 @@ +::: armis_sdk.entities.download_progress.DownloadProgress diff --git a/docs/enums/CollectorImageType.md b/docs/enums/CollectorImageType.md new file mode 100644 index 0000000..6c351b3 --- /dev/null +++ b/docs/enums/CollectorImageType.md @@ -0,0 +1 @@ +::: armis_sdk.enums.collector_image_type.CollectorImageType diff --git a/mkdocs.yml b/mkdocs.yml index 66fa2d7..db56f97 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,12 +13,15 @@ nav: - RiskFactor: entities/data_export/RiskFactor.md - Vulnerability: entities/data_export/Vulnerability.md - Boundary: entities/Boundary.md + - CollectorImage: entities/CollectorImage.md - Device: entities/Device.md - DeviceCustomProperty: entities/DeviceCustomProperty.md + - DownloadProgress: entities/DownloadProgress.md - NetworkInterface: entities/NetworkInterface.md - Site: entities/Site.md - Clients: - AssetsClient: clients/AssetsClient.md + - CollectorsClient: clients/CollectorsClient.md - DataExportClient: clients/DataExportClient.md - DeviceCustomPropertiesClient: clients/DeviceCustomPropertiesClient.md - SitesClient: clients/SitesClient.md @@ -26,6 +29,8 @@ nav: - ArmisClient: core/ArmisClient.md - ArmisSdk: core/ArmisSdk.md - Errors: core/errors.md + - Enums: + - CollectorImageType: enums/CollectorImageType.md - About Armis: about.md theme: diff --git a/tests/armis_sdk/clients/collectors_client_test.py b/tests/armis_sdk/clients/collectors_client_test.py new file mode 100644 index 0000000..61ea7f9 --- /dev/null +++ b/tests/armis_sdk/clients/collectors_client_test.py @@ -0,0 +1,119 @@ +import datetime +import tempfile + +import pytest_httpx + +from armis_sdk.clients.collectors_client import CollectorsClient +from armis_sdk.entities.collector_image import CollectorImage +from armis_sdk.entities.download_progress import DownloadProgress +from armis_sdk.enums.collector_image_type import CollectorImageType + +pytest_plugins = ["tests.plugins.auto_setup_plugin"] + + +async def test_get_image(httpx_mock: pytest_httpx.HTTPXMock): + httpx_mock.add_response( + url="https://api.armis.com/v3/collectors/_image?image_type=OVA", + json={ + "image_type": "OVA", + "image_password": "test_password", + "url": "https://example.com/collector.ova", + "url_expiration_date": "2025-12-10T00:00:00", + }, + ) + + collectors_client = CollectorsClient() + collector_image = await collectors_client.get_image() + + assert collector_image == CollectorImage( + image_type=CollectorImageType.OVA, + image_password="test_password", + url="https://example.com/collector.ova", + url_expiration_date=datetime.datetime(2025, 12, 10), + ) + + +async def test_get_with_explicit_image_type(httpx_mock: pytest_httpx.HTTPXMock): + httpx_mock.add_response( + url="https://api.armis.com/v3/collectors/_image?image_type=QCOW2", + json={ + "image_type": "QCOW2", + "image_password": "test_password_qcow2", + "url": "https://example.com/collector.qcow2", + "url_expiration_date": "2025-12-11T00:00:00", + }, + ) + + collectors_client = CollectorsClient() + collector_image = await collectors_client.get_image( + image_type=CollectorImageType.QCOW2 + ) + + assert collector_image == CollectorImage( + image_type=CollectorImageType.QCOW2, + image_password="test_password_qcow2", + url="https://example.com/collector.qcow2", + url_expiration_date=datetime.datetime(2025, 12, 11), + ) + + +async def test_download_image_to_path(httpx_mock: pytest_httpx.HTTPXMock): + httpx_mock.add_response( + url="https://api.armis.com/v3/collectors/_image?image_type=OVA", + json={ + "image_type": "OVA", + "image_password": "test_password", + "url": "https://example.com/collector.ova", + "url_expiration_date": "2025-12-10T00:00:00", + }, + ) + httpx_mock.add_response( + url="https://example.com/collector.ova", + stream=pytest_httpx.IteratorStream([b"a" * 16384, b"b" * 16384, b"c" * 16383]), + headers={"Content-Length": "49151"}, + ) + + collectors_client = CollectorsClient() + with tempfile.NamedTemporaryFile() as temp_file: + progress_items = [ + site async for site in collectors_client.download_image(temp_file.name) + ] + + assert progress_items == [ + DownloadProgress(downloaded=16384, total=49151), + DownloadProgress(downloaded=32768, total=49151), + DownloadProgress(downloaded=49151, total=49151), + ] + + assert temp_file.read() == b"a" * 16384 + b"b" * 16384 + b"c" * 16383 + + +async def test_download_image_to_file(httpx_mock: pytest_httpx.HTTPXMock): + httpx_mock.add_response( + url="https://api.armis.com/v3/collectors/_image?image_type=OVA", + json={ + "image_type": "OVA", + "image_password": "test_password", + "url": "https://example.com/collector.ova", + "url_expiration_date": "2025-12-10T00:00:00", + }, + ) + httpx_mock.add_response( + url="https://example.com/collector.ova", + stream=pytest_httpx.IteratorStream([b"a" * 16384, b"b" * 16384, b"c" * 16383]), + headers={"Content-Length": "49151"}, + ) + + collectors_client = CollectorsClient() + with tempfile.NamedTemporaryFile() as temp_file: + progress_items = [ + site async for site in collectors_client.download_image(temp_file) + ] + + assert progress_items == [ + DownloadProgress(downloaded=16384, total=49151), + DownloadProgress(downloaded=32768, total=49151), + DownloadProgress(downloaded=49151, total=49151), + ] + temp_file.seek(0) + assert temp_file.read() == b"a" * 16384 + b"b" * 16384 + b"c" * 16383