Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions armis_sdk/clients/collectors_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import contextlib
from typing import IO
from typing import AsyncIterator
from typing import Generator
from typing import Union

import httpx
import universalasync

from armis_sdk.core import response_utils
from armis_sdk.core.base_entity_client import BaseEntityClient
from armis_sdk.entities.collector_image import CollectorImage
from armis_sdk.entities.download_progress import DownloadProgress
from armis_sdk.enums.collector_image_type import CollectorImageType


@universalasync.wrap
class CollectorsClient(BaseEntityClient):
# pylint: disable=line-too-long
"""
A client for interacting with Armis collectors.

The primary entity for this client is [CollectorImage][armis_sdk.entities.collector_image.CollectorImage].
"""

async def download_image(
self,
destination: Union[str, IO[bytes]],
image_type: CollectorImageType = CollectorImageType.OVA,
) -> AsyncIterator[DownloadProgress]:
"""Download a collector image to a specified destination path / file.

Args:
destination: The file path or file-like object where the collector image will be saved.
image_type: The type of collector image to download. Defaults to "OVA".

Returns:
An (async) iterator of `DownloadProgress` object.

Example:
```python linenums="1" hl_lines="10 15"
import asyncio

from armis_sdk.clients.collectors_client import CollectorsClient


async def main():
collectors_client = CollectorsClient()

# Download to a path
async for progress in armis_sdk.collectors.download_image("/tmp/collector.ova"):
print(progress.percent)

# Download to a file
with open("/tmp/collector.ova", "wb") as file:
async for progress in armis_sdk.collectors.download_image(file):
print(progress.percent)

asyncio.run(main())
```
Will output:
```python linenums="1"
1%
2%
3%
```
etc.
"""
collector_image = await self.get_image(image_type=image_type)
async with httpx.AsyncClient() as client:
async with client.stream("GET", collector_image.url) as response:
response.raise_for_status()
total_size = int(response.headers.get("Content-Length", "0"))
# pylint: disable-next=contextmanager-generator-missing-cleanup
with self.open_file(destination) as file:
async for chunk in response.aiter_bytes():
file.write(chunk)
yield DownloadProgress(downloaded=file.tell(), total=total_size)

async def get_image(
self, image_type: CollectorImageType = CollectorImageType.OVA
) -> CollectorImage:
"""Get collector image information including download URL and credentials.

Args:
image_type: The type of collector image to retrieve. Defaults to "OVA".

Returns:
A `CollectorImage` object.

Example:
```python linenums="1" hl_lines="8"
import asyncio

from armis_sdk.clients.collectors_client import CollectorsClient


async def main():
collectors_client = CollectorsClient()
print(await collectors_client.get_image(image_type="OVA"))

asyncio.run(main())
```
Will output:
```python linenums="1"
CollectorImage(url="...", ...)
```
"""
async with self._armis_client.client() as client:
response = await client.get(
"/v3/collectors/_image", params={"image_type": image_type.value}
)
data = response_utils.get_data_dict(response)
return CollectorImage.model_validate(data)

@classmethod
@contextlib.contextmanager
def open_file(
cls, destination: Union[str, IO[bytes]]
) -> Generator[IO[bytes], None, None]:
if isinstance(destination, str):
with open(destination, "wb") as file:
yield file
else:
yield destination
3 changes: 3 additions & 0 deletions armis_sdk/core/armis_sdk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from armis_sdk.clients.assets_client import AssetsClient
from armis_sdk.clients.collectors_client import CollectorsClient
from armis_sdk.clients.data_export_client import DataExportClient
from armis_sdk.clients.device_custom_properties_client import (
DeviceCustomPropertiesClient,
Expand All @@ -19,6 +20,7 @@ class ArmisSdk: # pylint: disable=too-few-public-methods
Attributes:
client (ArmisClient): An instance of [ArmisClient][armis_sdk.core.armis_client.ArmisClient]
assets (AssetsClient): An instance of [AssetsClient][armis_sdk.clients.assets_client.AssetsClient]
collectors (CollectorsClient): An instance of [CollectorsClient][armis_sdk.clients.collectors_client.CollectorsClient]
data_export (DataExportClient): An instance of [DataExportClient][armis_sdk.clients.data_export_client.DataExportClient]
device_custom_properties (DeviceCustomPropertiesClient): An instance of [DeviceCustomPropertiesClient][armis_sdk.clients.device_custom_properties_client.DeviceCustomPropertiesClient]
sites (SitesClient): An instance of [SitesClient][armis_sdk.clients.sites_client.SitesClient]
Expand All @@ -42,6 +44,7 @@ async def main():
def __init__(self, credentials: Optional[ClientCredentials] = None):
self.client: ArmisClient = ArmisClient(credentials=credentials)
self.assets: AssetsClient = AssetsClient(self.client)
self.collectors: CollectorsClient = CollectorsClient(self.client)
self.data_export: DataExportClient = DataExportClient(self.client)
self.device_custom_properties: DeviceCustomPropertiesClient = (
DeviceCustomPropertiesClient(self.client)
Expand Down
24 changes: 24 additions & 0 deletions armis_sdk/entities/collector_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import datetime

from pydantic import Field

from armis_sdk.core.base_entity import BaseEntity
from armis_sdk.enums.collector_image_type import CollectorImageType


class CollectorImage(BaseEntity):
"""
An entity that represents the details required to download and run a collector image.
"""

image_type: CollectorImageType = Field(strict=False)
"""The type of the image."""

image_password: str
"""The password for the OS that is encapsulated by the image."""

url: str
"""The temporary, presigned URL from which the OS image file can be downloaded."""

url_expiration_date: datetime.datetime = Field(strict=False)
"""Expiration date of the URL."""
14 changes: 14 additions & 0 deletions armis_sdk/entities/download_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from armis_sdk.core.base_entity import BaseEntity


class DownloadProgress(BaseEntity):
downloaded: int
"""How much bytes were already downloaded."""

total: int
"""Total number of bytes to download."""

@property
def percent(self) -> str:
"""Percentage of progress."""
return f"{self.downloaded/self.total:.4%}"
36 changes: 36 additions & 0 deletions armis_sdk/enums/collector_image_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from enum import Enum


class CollectorImageType(Enum):
DARWIN_AMD64_BROKER = "DARWIN_AMD64_BROKER"
""""""

DARWIN_ARM64_BROKER = "DARWIN_ARM64_BROKER"
""""""

DEB = "DEB"
""""""

LINUX_AMD64_BROKER = "LINUX_AMD64_BROKER"
""""""

LINUX_ARM64_BROKER = "LINUX_ARM64_BROKER"
""""""

OVA = "OVA"
""""""

QCOW2 = "QCOW2"
""""""

RPM = "RPM"
""""""

VHD = "VHD"
""""""

VHDX = "VHDX"
""""""

WINDOWS_BROKER = "WINDOWS_BROKER"
""""""
1 change: 1 addition & 0 deletions docs/clients/CollectorsClient.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: armis_sdk.clients.collectors_client.CollectorsClient
1 change: 1 addition & 0 deletions docs/entities/CollectorImage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: armis_sdk.entities.collector_image.CollectorImage
1 change: 1 addition & 0 deletions docs/entities/DownloadProgress.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: armis_sdk.entities.download_progress.DownloadProgress
1 change: 1 addition & 0 deletions docs/enums/CollectorImageType.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: armis_sdk.enums.collector_image_type.CollectorImageType
5 changes: 5 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@ nav:
- RiskFactor: entities/data_export/RiskFactor.md
- Vulnerability: entities/data_export/Vulnerability.md
- Boundary: entities/Boundary.md
- CollectorImage: entities/CollectorImage.md
- Device: entities/Device.md
- DeviceCustomProperty: entities/DeviceCustomProperty.md
- DownloadProgress: entities/DownloadProgress.md
- NetworkInterface: entities/NetworkInterface.md
- Site: entities/Site.md
- Clients:
- AssetsClient: clients/AssetsClient.md
- CollectorsClient: clients/CollectorsClient.md
- DataExportClient: clients/DataExportClient.md
- DeviceCustomPropertiesClient: clients/DeviceCustomPropertiesClient.md
- SitesClient: clients/SitesClient.md
- Core:
- ArmisClient: core/ArmisClient.md
- ArmisSdk: core/ArmisSdk.md
- Errors: core/errors.md
- Enums:
- CollectorImageType: enums/CollectorImageType.md
- About Armis: about.md

theme:
Expand Down
119 changes: 119 additions & 0 deletions tests/armis_sdk/clients/collectors_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import datetime
import tempfile

import pytest_httpx

from armis_sdk.clients.collectors_client import CollectorsClient
from armis_sdk.entities.collector_image import CollectorImage
from armis_sdk.entities.download_progress import DownloadProgress
from armis_sdk.enums.collector_image_type import CollectorImageType

pytest_plugins = ["tests.plugins.auto_setup_plugin"]


async def test_get_image(httpx_mock: pytest_httpx.HTTPXMock):
httpx_mock.add_response(
url="https://api.armis.com/v3/collectors/_image?image_type=OVA",
json={
"image_type": "OVA",
"image_password": "test_password",
"url": "https://example.com/collector.ova",
"url_expiration_date": "2025-12-10T00:00:00",
},
)

collectors_client = CollectorsClient()
collector_image = await collectors_client.get_image()

assert collector_image == CollectorImage(
image_type=CollectorImageType.OVA,
image_password="test_password",
url="https://example.com/collector.ova",
url_expiration_date=datetime.datetime(2025, 12, 10),
)


async def test_get_with_explicit_image_type(httpx_mock: pytest_httpx.HTTPXMock):
httpx_mock.add_response(
url="https://api.armis.com/v3/collectors/_image?image_type=QCOW2",
json={
"image_type": "QCOW2",
"image_password": "test_password_qcow2",
"url": "https://example.com/collector.qcow2",
"url_expiration_date": "2025-12-11T00:00:00",
},
)

collectors_client = CollectorsClient()
collector_image = await collectors_client.get_image(
image_type=CollectorImageType.QCOW2
)

assert collector_image == CollectorImage(
image_type=CollectorImageType.QCOW2,
image_password="test_password_qcow2",
url="https://example.com/collector.qcow2",
url_expiration_date=datetime.datetime(2025, 12, 11),
)


async def test_download_image_to_path(httpx_mock: pytest_httpx.HTTPXMock):
httpx_mock.add_response(
url="https://api.armis.com/v3/collectors/_image?image_type=OVA",
json={
"image_type": "OVA",
"image_password": "test_password",
"url": "https://example.com/collector.ova",
"url_expiration_date": "2025-12-10T00:00:00",
},
)
httpx_mock.add_response(
url="https://example.com/collector.ova",
stream=pytest_httpx.IteratorStream([b"a" * 16384, b"b" * 16384, b"c" * 16383]),
headers={"Content-Length": "49151"},
)

collectors_client = CollectorsClient()
with tempfile.NamedTemporaryFile() as temp_file:
progress_items = [
site async for site in collectors_client.download_image(temp_file.name)
]

assert progress_items == [
DownloadProgress(downloaded=16384, total=49151),
DownloadProgress(downloaded=32768, total=49151),
DownloadProgress(downloaded=49151, total=49151),
]

assert temp_file.read() == b"a" * 16384 + b"b" * 16384 + b"c" * 16383


async def test_download_image_to_file(httpx_mock: pytest_httpx.HTTPXMock):
httpx_mock.add_response(
url="https://api.armis.com/v3/collectors/_image?image_type=OVA",
json={
"image_type": "OVA",
"image_password": "test_password",
"url": "https://example.com/collector.ova",
"url_expiration_date": "2025-12-10T00:00:00",
},
)
httpx_mock.add_response(
url="https://example.com/collector.ova",
stream=pytest_httpx.IteratorStream([b"a" * 16384, b"b" * 16384, b"c" * 16383]),
headers={"Content-Length": "49151"},
)

collectors_client = CollectorsClient()
with tempfile.NamedTemporaryFile() as temp_file:
progress_items = [
site async for site in collectors_client.download_image(temp_file)
]

assert progress_items == [
DownloadProgress(downloaded=16384, total=49151),
DownloadProgress(downloaded=32768, total=49151),
DownloadProgress(downloaded=49151, total=49151),
]
temp_file.seek(0)
assert temp_file.read() == b"a" * 16384 + b"b" * 16384 + b"c" * 16383