diff --git a/bot.py b/bot.py index 96c4444..a973341 100644 --- a/bot.py +++ b/bot.py @@ -178,13 +178,29 @@ async def _stats(self, update: Update, context: CallbackContext): return activity_summary_text = [] - activity_summary_text.append('*Activity summary*') + + # Query activity summary for each interval + activity_summary_text.append('*Activity*') for interval in VALID_SUMMARY_INTERVALS: activity_summary = self._db.get_activity_summary(interval) activity_summary_text.append(f"_{interval.lower()}_") for activity in activity_summary: activity_summary_text.append(f"- {activity}") activity_summary_text.append('') + + # Query subscription summary + activity_summary_text.append('*Subscriptions*') + subscription_summary = self._db.get_subscription_summary() + for station in subscription_summary: + activity_summary_text.append(f"- {station}") + + # Query unique subscribers + unique_subscribers = self._db.count_unique_subscribers() + activity_summary_text.append('') + activity_summary_text.append( + f"_Unique subscribers: {unique_subscribers}_") + activity_summary_text.append('') + activity_summary_text = "\n".join(activity_summary_text) await update.message.reply_markdown(activity_summary_text) diff --git a/db.py b/db.py index 8e067b1..a44a78a 100644 --- a/db.py +++ b/db.py @@ -23,34 +23,6 @@ def _get_db_connection(self): cursor_factory=DictCursor) return connection - def _create_tables(self): - connection = self._get_db_connection() - try: - with connection.cursor() as cursor: - cursor.execute(f""" - CREATE TABLE IF NOT EXISTS activity_{self._table_suffix} ( - id SERIAL PRIMARY KEY, - activity_type VARCHAR(50) NOT NULL, - user_id VARCHAR(50) NOT NULL, - station TEXT, - timestamp TIMESTAMP NOT NULL - ) - """) - - # subscriptions table - cursor.execute(f""" - CREATE TABLE IF NOT EXISTS subscriptions_{self._table_suffix} ( - id SERIAL PRIMARY KEY, - station TEXT NOT NULL, - user_id VARCHAR(50) NOT NULL - ) - """) - connection.commit() - except Exception as e: - logger.error(f"{e} while creating tables") - finally: - connection.close() - def _create_tables(self): connection = self._get_db_connection() try: @@ -129,6 +101,23 @@ def get_subscriptions_by_station(self, station) -> list[int]: else: return [] + def count_unique_subscribers(self) -> list[int]: + sql = f""" + SELECT DISTINCT user_id + FROM subscriptions_{self._table_suffix} + """ + subscribers = self._select(sql) + return len([subscriber['user_id'] for subscriber in subscribers]) + + def get_subscription_summary(self) -> list[str]: + stations = self.stations_with_subscribers() + summary = [] + for station in stations: + summary.append( + f"{station}: {len(self.get_subscriptions_by_station(station))}" + ) + return summary + def _select_with_values(self, sql, values): connection = self._get_db_connection() try: diff --git a/ecmwf.py b/ecmwf.py index de97ef8..4b456e4 100644 --- a/ecmwf.py +++ b/ecmwf.py @@ -122,13 +122,10 @@ def _get_with_request(self, link, raise_on_error=True): if not result.ok and raise_on_error: raise ValueError('Request failed for {}'.format(get)) else: - if result.status_code == 403: - raise ValueError('403 Forbidden for {}'.format(get)) - else: - try: - return result.json() - except json.decoder.JSONDecodeError: - raise ValueError('JSONDecodeError for {}'.format(get)) + try: + return result.json() + except json.decoder.JSONDecodeError: + raise ValueError('JSONDecodeError for {}'.format(get)) def _get_API_data_for_epsgram(self, station, diff --git a/test/test_db.py b/test/test_db.py index a3d472e..30fda1c 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -107,3 +107,25 @@ def test_get_subscriptions_by_station(db_instance): # Test for a station with no subscriptions users = db_instance.get_subscriptions_by_station("station3") assert users == [] + + +def test_subscription_summary(db_instance): + # Add test data + db_instance.add_subscription("station1", 12345) + db_instance.add_subscription("station2", 67890) + db_instance.add_subscription("station1", 54321) + + summary = db_instance.get_subscription_summary() + + assert summary == ["station1: 2", "station2: 1"] + + +def test_get_unique_subscribers(db_instance): + # Add test data + db_instance.add_subscription("station1", 12345) + db_instance.add_subscription("station2", 67890) + db_instance.add_subscription("station1", 54321) + + unique_subscribers = db_instance.count_unique_subscribers() + + assert unique_subscribers == 3