diff --git a/README.md b/README.md index 70a728f..fd4f2e9 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,10 @@ To run the bot, follow these steps: ```sh pip install -r requirements.txt ``` -3. Run the bot: +3. Create a config file __config.yml__ for secrets like database access and bot token, see [config_example.yml](config_example.yml) as reference. +4. Run the bot: ```sh - python main.py --bot_token YOUR_SECRET_BOT_TOKEN --bot_backup $(pwd)/backup --log_level 10 + python main.py --log_level 10 ``` ## Adding a new location diff --git a/bot.py b/bot.py index 6232fce..96c4444 100644 --- a/bot.py +++ b/bot.py @@ -1,5 +1,4 @@ -import asyncio - +import yaml from telegram import ReplyKeyboardMarkup, Update, ReplyKeyboardRemove from telegram.ext import (CommandHandler, MessageHandler, Application, filters, ConversationHandler, CallbackContext, ContextTypes) @@ -15,15 +14,12 @@ class PlotBot: - def __init__(self, - token, - station_config, - db=None, - admin_id=None, - ecmwf=None): + def __init__(self, config_file, station_config, db=None, ecmwf=None): - self._admin_id = admin_id - self.app = Application.builder().token(token).build() + self._config = yaml.safe_load(open(config_file)) + self._admin_ids = self._config['bot'].get('admin_ids', []) + self.app = Application.builder().token( + self._config['bot']['token']).build() self._db = db self._ecmwf = ecmwf self._station_names = sorted( @@ -176,7 +172,7 @@ async def _error(self, update: Update, context: CallbackContext): async def _stats(self, update: Update, context: CallbackContext): user_id = update.message.chat_id - if user_id != self._admin_id: + if user_id not in self._admin_ids: await update.message.reply_text( "You are not authorized to view stats.") return diff --git a/config_example.yml b/config_example.yml new file mode 100644 index 0000000..cb145d9 --- /dev/null +++ b/config_example.yml @@ -0,0 +1,11 @@ +# store this file as config.yml in the same directory as your bot script +db: + host: "your-database-host.com" + user: "your_database_user" + password: "your_secure_password" + database: "your_database_name" + port: 5432 + table_suffix: "your_table_suffix" # e.g. "dev", "prod", etc., to differentiate environments within the same database +bot: + token: "123456789:ABCDEF1234567890abcdef1234567890" + admin_ids: [123456789, 987654321] # List of admin user IDs, can use admin-only /stats command \ No newline at end of file diff --git a/main.py b/main.py index 06e0fa0..3e70316 100644 --- a/main.py +++ b/main.py @@ -13,16 +13,6 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument('--bot_token', \ - dest='bot_token', \ - type=str, \ - help='unique token of bot (KEEP PRIVATE!)') - - parser.add_argument('--admin_id', - dest='admin_id', \ - type=int, \ - help='Telegram ID of the admin') - parser.add_argument( '--log_level', dest='log_level', @@ -41,16 +31,14 @@ def main(): ecmwf = EcmwfApi(station_config) - db = Database('config.yml') + config_file = 'config.yml' + + db = Database(config_file) - bot = PlotBot(args.bot_token, - station_config, - admin_id=args.admin_id, - db=db, - ecmwf=ecmwf) + bot = PlotBot(config_file, station_config, db=db, ecmwf=ecmwf) bot.start() - # we should not be here + # we only end up here if the bot had an error sys.exit(1) diff --git a/test/test_bot.py b/test/test_bot.py index 4695781..ea355f0 100644 --- a/test/test_bot.py +++ b/test/test_bot.py @@ -9,14 +9,26 @@ from bot import PlotBot -@pytest.fixture +@pytest.fixture(scope="module") def bot(station_config): - from bot import PlotBot - token = '9999999999:BBBBBBBRBBBBBBBBBBBBBBBBBBBBBBBBBBB' - return PlotBot(token, station_config) + config_file = "test_config.yml" + config = { + "bot": { + "token": '9999999999:BBBBBBBRBBBBBBBBBBBBBBBBBBBBBBBBBBB', + "admin_ids": [123456789, 987654321], + } + } + # Write the config to a file + with open(config_file, "w") as f: + yaml.dump(config, f) + yield PlotBot(config_file, station_config) -@pytest.fixture + # Clear the test config file after the test + os.remove(config_file) + + +@pytest.fixture(scope="module") def station_config(): stations = """ - name: Zürich