Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 79 additions & 71 deletions smarttub/api.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,113 @@
import asyncio
import base64
import datetime
from enum import Enum
import json
import logging
import time
from typing import List

import aiohttp
import dateutil.parser
from inflection import underscore
import jwt

logger = logging.getLogger(__name__)


class SmartTub:
"""Interface to the SmartTub API"""

AUTH_AUDIENCE = "https://api.operation-link.com/"
AUTH_URL = "https://smarttub.auth0.com/oauth/token"
AUTH_CLIENT_ID = "dB7Rcp3rfKKh0vHw2uqkwOZmRb5WNjQC"
AUTH_REALM = "Username-Password-Authentication"
AUTH_ACCOUNT_ID_KEY = "http://operation-link.com/account_id"
AUTH_GRANT_TYPE = "http://auth0.com/oauth/grant-type/password-realm"
AUTH_SCOPE = "openid email offline_access User Admin"
"""Interface to the SmartTub API."""

AUTH_URL = "https://api.smarttub.io/idp/signin"
API_BASE = "https://api.smarttub.io"

def __init__(self, session: aiohttp.ClientSession = None):
self.logged_in = False
self._session = session or aiohttp.ClientSession()

async def login(self, username: str, password: str):
"""Authenticate to SmartTub
self._access_token: str | None = None
self._refresh_token: str | None = None
self._id_token: str | None = None
self._token_expires_at: datetime.datetime | None = None
self.account_id: str | None = None
# Store credentials for re-authentication (no refresh endpoint available)
self._username: str | None = None
self._password: str | None = None

async def login(self, username: str, password: str) -> None:
"""Authenticate to SmartTub.

This method must be called before any useful work can be done.

username -- the email address for the SmartTub account
password -- the password for the SmartTub account
"""

# https://auth0.com/docs/api-auth/tutorials/password-grant
r = await self._session.post(
self.AUTH_URL,
json={
"audience": self.AUTH_AUDIENCE,
"client_id": self.AUTH_CLIENT_ID,
"grant_type": self.AUTH_GRANT_TYPE,
"realm": self.AUTH_REALM,
"scope": self.AUTH_SCOPE,
"username": username,
"password": password,
},
)
if r.status == 403:
raise LoginFailed(r.text)

r.raise_for_status()
j = await r.json()

self._set_access_token(j["access_token"])
self.refresh_token = j["refresh_token"]
assert j["token_type"] == "Bearer"

self.account_id = self.access_token_data[self.AUTH_ACCOUNT_ID_KEY]
self.logged_in = True

logger.debug(f"login successful, username={username}")
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
body = {"username": username, "password": password}

async with self._session.post(
self.AUTH_URL, json=body, headers=headers
) as response:
try:
data = await response.json()
except Exception:
text = await response.text()
raise LoginFailed(f"Login failed: {response.status} - {text}")

if response.status != 201:
if isinstance(data, list):
error_msg = ", ".join(str(x) for x in data)
else:
error_msg = data.get("message", "Unknown error")
raise LoginFailed(f"Login failed ({response.status}): {error_msg}")

try:
token_data = data["token"]
self._access_token = token_data["access_token"]
self._refresh_token = token_data.get("refresh_token")
self._id_token = token_data.get("id_token")

# Extract account_id from ID token
if self._id_token:
parts = self._id_token.split(".")
if len(parts) > 1:
payload_b64 = parts[1]
# Fix Base64 padding
padded = payload_b64 + "=" * (-len(payload_b64) % 4)
decoded_bytes = base64.b64decode(padded)
jwt_data = json.loads(decoded_bytes)
self.account_id = jwt_data.get("custom:account_id")

expires_in = token_data.get("expires_in", 86400)
self._token_expires_at = datetime.datetime.now() + datetime.timedelta(
seconds=expires_in
)

# Store credentials for re-authentication when token expires
self._username = username
self._password = password

logger.debug(f"login successful, username={username}")

except KeyError as exc:
raise LoginFailed(
"Login successful but response format was unexpected"
) from exc

@property
def _headers(self):
return {"Authorization": f"Bearer {self.access_token}"}
return {"Authorization": f"Bearer {self._access_token}"}

async def _require_login(self):
if not self.logged_in:
"""Ensure we have a valid access token, re-authenticating if needed."""
if not self._access_token:
raise RuntimeError("not logged in")
if self.token_expires_at <= time.time():
await self._refresh_token()

def _set_access_token(self, token):
self.access_token = token
self.access_token_data = jwt.decode(
self.access_token,
algorithms=["HS256"],
options={"verify_signature": False, "verify": False},
)
self.token_expires_at = self.access_token_data["exp"]

async def _refresh_token(self):
# https://auth0.com/docs/tokens/guides/use-refresh-tokens
r = await self._session.post(
self.AUTH_URL,
json={
"grant_type": "refresh_token",
"client_id": self.AUTH_CLIENT_ID,
"refresh_token": self.refresh_token,
},
)
r.raise_for_status()
j = await r.json()
self._set_access_token(j["access_token"])
logger.debug("token refresh successful")
if self._token_expires_at and datetime.datetime.now() > self._token_expires_at:
# Token expired - re-authenticate using stored credentials
if self._username and self._password:
logger.debug("token expired, re-authenticating")
await self.login(self._username, self._password)
else:
raise RuntimeError("token expired and no credentials available")

async def request(self, method, path, body=None):
"""Generic method for making an authenticated request to the API
Expand Down
129 changes: 94 additions & 35 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import aiohttp
import time
import base64
import datetime
import json

import jwt
import aiohttp
import pytest

import smarttub
Expand All @@ -11,8 +12,32 @@
pytestmark = pytest.mark.asyncio


def make_id_token(account_id: str) -> str:
"""Create a mock ID token with the account_id claim."""
header = base64.urlsafe_b64encode(json.dumps({"alg": "HS256"}).encode()).rstrip(
b"="
)
payload = base64.urlsafe_b64encode(
json.dumps({"custom:account_id": account_id}).encode()
).rstrip(b"=")
signature = base64.urlsafe_b64encode(b"fakesignature").rstrip(b"=")
return f"{header.decode()}.{payload.decode()}.{signature.decode()}"


def make_login_response(account_id: str, expires_in: int = 86400) -> dict:
"""Create a mock login response."""
return {
"token": {
"access_token": "access_token_123",
"refresh_token": "refresh_token_123",
"id_token": make_id_token(account_id),
"expires_in": expires_in,
}
}


@pytest.fixture(name="unauthenticated_api")
async def unauthenticated_api(aresponses):
async def unauthenticated_api():
async with aiohttp.ClientSession() as session:
yield smarttub.SmartTub(session)

Expand All @@ -21,53 +46,81 @@ async def unauthenticated_api(aresponses):
async def api(unauthenticated_api, aresponses):
api = unauthenticated_api
aresponses.add(
response={
"access_token": jwt.encode(
{api.AUTH_ACCOUNT_ID_KEY: ACCOUNT_ID, "exp": time.time() + 3600},
"secret",
),
"token_type": "Bearer",
"refresh_token": "refresh1",
}
response=aresponses.Response(
body=json.dumps(make_login_response(ACCOUNT_ID)),
status=201,
content_type="application/json",
)
)
await api.login("username1", "password1")
return api


async def test_login(api, aresponses):
async def test_login(api):
assert api.account_id == ACCOUNT_ID
assert api.logged_in is True
assert api._access_token == "access_token_123"
assert api._username == "username1"
assert api._password == "password1"


async def test_login_failed(api, aresponses):
aresponses.add(response=aresponses.Response(status=403))
async def test_login_failed_400(unauthenticated_api, aresponses):
aresponses.add(
response=aresponses.Response(
body=json.dumps({"message": "Invalid credentials"}),
status=400,
content_type="application/json",
)
)
with pytest.raises(smarttub.LoginFailed):
await api.login("username", "password")
await unauthenticated_api.login("username", "password")


async def test_refresh_token(api, aresponses):
now = time.time()
api.token_expires_at = now
async def test_login_failed_401(unauthenticated_api, aresponses):
aresponses.add(
response={
"access_token": jwt.encode(
{api.AUTH_ACCOUNT_ID_KEY: ACCOUNT_ID, "exp": now + 3601},
"secret",
),
}
response=aresponses.Response(
body=json.dumps([{"description": "Bad request", "type": "ERROR"}]),
status=401,
content_type="application/json",
)
)
with pytest.raises(smarttub.LoginFailed):
await unauthenticated_api.login("username", "password")


async def test_token_reauth_on_expiry(api, aresponses):
"""Test that we re-authenticate when the token expires."""
# Expire the token
api._token_expires_at = datetime.datetime.now() - datetime.timedelta(seconds=1)

# Mock the re-login response
aresponses.add(
response=aresponses.Response(
body=json.dumps(make_login_response(ACCOUNT_ID)),
status=201,
content_type="application/json",
)
)
aresponses.add(response={"status": "OK"})
# Mock the actual API request
aresponses.add(
response=aresponses.Response(
body=json.dumps({"status": "OK"}),
status=200,
content_type="application/json",
)
)

response = await api.request("GET", "/")
assert api.token_expires_at > now
assert api._token_expires_at > datetime.datetime.now()
assert response.get("status") == "OK"


async def test_get_account(api, aresponses):
aresponses.add(
response={
"id": "id1",
"email": "email1",
}
response=aresponses.Response(
body=json.dumps({"id": "id1", "email": "email1"}),
status=200,
content_type="application/json",
)
)

account = await api.get_account()
Expand All @@ -81,12 +134,18 @@ async def test_api_error(api, aresponses):
await api.get_account()


async def test_not_logged_in(unauthenticated_api, aresponses):
async def test_not_logged_in(unauthenticated_api):
with pytest.raises(RuntimeError):
await unauthenticated_api.request("GET", "/")


async def test_request(api, aresponses):
aresponses.add(response=aresponses.Response(text=None, status=200))
async def test_request_empty_response(api, aresponses):
aresponses.add(
response=aresponses.Response(
body="",
status=200,
headers={"content-length": "0"},
)
)
response = await api.request("GET", "/")
assert response is None