From ea6c4add423b4bb1709d724ce9324a82010299d1 Mon Sep 17 00:00:00 2001 From: Rich Date: Thu, 4 Jan 2024 13:39:29 -0800 Subject: [PATCH 1/4] bot: Remove duplicated code. --- .../src/bot/handlers/image.rs | 32 ++++--------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/crates/stable-diffusion-bot/src/bot/handlers/image.rs b/crates/stable-diffusion-bot/src/bot/handlers/image.rs index 3133c05..370535c 100644 --- a/crates/stable-diffusion-bot/src/bot/handlers/image.rs +++ b/crates/stable-diffusion-bot/src/bot/handlers/image.rs @@ -519,19 +519,8 @@ pub(crate) fn image_schema() -> UpdateHandler { .chain(dptree::filter_map(|g: GenCommands| match g { GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => Some(s), })) - .branch( - Message::filter_photo() - .chain(filter_map_bot_state()) - .chain(case![BotState::Generate]) - .chain(filter_map_settings()) - .endpoint(handle_image), - ) - .branch( - filter_map_bot_state() - .chain(case![BotState::Generate]) - .chain(filter_map_settings()) - .endpoint(handle_prompt), - ); + .branch(Message::filter_photo().endpoint(handle_image)) + .branch(dptree::endpoint(handle_prompt)); let message_handler = Update::filter_message() .branch( @@ -549,23 +538,11 @@ pub(crate) fn image_schema() -> UpdateHandler { .branch( Message::filter_photo() .map(|msg: Message| msg.caption().map(str::to_string).unwrap_or_default()) - .chain(filter_map_bot_state()) - .chain(case![BotState::Generate]) - .chain(filter_map_settings()) .endpoint(handle_image), ) - .branch( - Message::filter_text() - .chain(filter_map_bot_state()) - .chain(case![BotState::Generate]) - .chain(filter_map_settings()) - .endpoint(handle_prompt), - ); + .branch(Message::filter_text().endpoint(handle_prompt)); let callback_handler = Update::filter_callback_query() - .chain(filter_map_bot_state()) - .chain(case![BotState::Generate]) - .chain(filter_map_settings()) .branch( dptree::filter_map(|q: CallbackQuery| { q.data @@ -580,6 +557,9 @@ pub(crate) fn image_schema() -> UpdateHandler { ); dptree::entry() + .chain(filter_map_bot_state()) + .chain(case![BotState::Generate]) + .chain(filter_map_settings()) .branch(gen_command_handler) .branch(message_handler) .branch(callback_handler) From 1100ed13632451a288e66419b7e0c4b24f664b21 Mon Sep 17 00:00:00 2001 From: Rich Date: Mon, 27 Nov 2023 07:43:23 -0800 Subject: [PATCH 2/4] wip: Refactor and add settings. --- .../src/bot/handlers/image.rs | 106 ++++++--- .../src/bot/handlers/mod.rs | 9 + .../src/bot/handlers/settings.rs | 9 +- crates/stable-diffusion-bot/src/bot/mod.rs | 216 +++++++++++++----- crates/stable-diffusion-bot/src/main.rs | 112 ++++++++- 5 files changed, 360 insertions(+), 92 deletions(-) diff --git a/crates/stable-diffusion-bot/src/bot/handlers/image.rs b/crates/stable-diffusion-bot/src/bot/handlers/image.rs index 370535c..24ebac2 100644 --- a/crates/stable-diffusion-bot/src/bot/handlers/image.rs +++ b/crates/stable-diffusion-bot/src/bot/handlers/image.rs @@ -85,15 +85,28 @@ impl Reply { }) } - pub async fn send(self, bot: &Bot, chat_id: ChatId) -> anyhow::Result<()> { + pub async fn send( + self, + bot: &Bot, + chat_id: ChatId, + cfg: ConfigParameters, + user_id: Option, + ) -> anyhow::Result<()> { match self.images { Photo::Single(image) => { - bot.send_photo(chat_id, InputFile::memory(image)) + let mut message = bot + .send_photo(chat_id, InputFile::memory(image)) .parse_mode(teloxide::types::ParseMode::MarkdownV2) .caption(self.caption) - .reply_markup(keyboard(self.seed)) - .reply_to_message_id(self.source) - .await?; + .reply_to_message_id(self.source); + if !cfg.ui.hide_all_buttons { + message = message.reply_markup(keyboard( + self.seed, + cfg, + user_id.map(Into::into).unwrap_or(chat_id), + )) + } + message.await?; } Photo::Album(images) => { let mut caption = Some(self.caption); @@ -107,13 +120,19 @@ impl Reply { bot.send_media_group(chat_id, input_media) .reply_to_message_id(self.source) .await?; - bot.send_message( - chat_id, - "What would you like to do? Select below, or enter a new prompt.", - ) - .reply_markup(keyboard(self.seed)) - .reply_to_message_id(self.source) - .await?; + if !cfg.ui.hide_all_buttons { + bot.send_message( + chat_id, + "What would you like to do? Select below, or enter a new prompt.", + ) + .reply_markup(keyboard( + self.seed, + cfg, + user_id.map(Into::into).unwrap_or(chat_id), + )) + .reply_to_message_id(self.source) + .await?; + } } } @@ -124,6 +143,11 @@ impl Reply { struct MessageText(String); impl MessageText { + pub fn new(prompt: &str) -> Self { + use teloxide::utils::markdown::escape; + Self(format!("`{}`", escape(prompt))) + } + pub fn new_with_image_params(prompt: &str, infotxt: &dyn ImageParams) -> Self { use teloxide::utils::markdown::escape; @@ -240,12 +264,16 @@ async fn handle_image( resp.params.seed().unwrap_or(-1) }; - let caption = MessageText::try_from(resp.params.as_ref()) - .context("Failed to build caption from response")?; + let caption = if cfg.messages.hide_generation_info { + MessageText::new(&resp.params.prompt().unwrap_or_default()) + } else { + MessageText::try_from(resp.params.as_ref()) + .context("Failed to build caption from response")? + }; Reply::new(caption.0, resp.images, seed, msg.id) .context("Failed to create response!")? - .send(&bot, msg.chat.id) + .send(&bot, msg.chat.id, cfg, msg.from().map(|f| f.id)) .await?; dialogue @@ -298,12 +326,16 @@ async fn handle_prompt( resp.params.seed().unwrap_or(-1) }; - let caption = MessageText::try_from(resp.params.as_ref()) - .context("Failed to build caption from response")?; + let caption = if cfg.messages.hide_generation_info { + MessageText::new(&resp.params.prompt().unwrap_or_default()) + } else { + MessageText::try_from(resp.params.as_ref()) + .context("Failed to build caption from response")? + }; Reply::new(caption.0, resp.images, seed, msg.id) .context("Failed to create response!")? - .send(&bot, msg.chat.id) + .send(&bot, msg.chat.id, cfg, msg.from().map(|f| f.id)) .await?; dialogue @@ -318,17 +350,30 @@ async fn handle_prompt( Ok(()) } -fn keyboard(seed: i64) -> InlineKeyboardMarkup { - let seed_button = if seed == -1 { - InlineKeyboardButton::callback("🎲 Seed", "reuse/-1") +fn keyboard(seed: i64, cfg: ConfigParameters, user: ChatId) -> InlineKeyboardMarkup { + let settings_button = if (!cfg.settings.disable_user_settings || cfg.user_is_admin(&user)) + && !cfg.ui.hide_settings_button + { + vec![InlineKeyboardButton::callback("⚙️ Settings", "settings")] + } else { + vec![] + }; + let seed_button = if cfg.ui.hide_reuse_button { + vec![] + } else if seed == -1 { + vec![InlineKeyboardButton::callback("🎲 Seed", "reuse/-1")] + } else { + vec![InlineKeyboardButton::callback( + "♻️ Seed", + format!("reuse/{seed}"), + )] + }; + let rerun_button = if cfg.ui.hide_rerun_button { + vec![] } else { - InlineKeyboardButton::callback("♻️ Seed", format!("reuse/{seed}")) + vec![InlineKeyboardButton::callback("🔄 Rerun", "rerun")] }; - InlineKeyboardMarkup::new([[ - InlineKeyboardButton::callback("🔄 Rerun", "rerun"), - seed_button, - InlineKeyboardButton::callback("⚙️ Settings", "settings"), - ]]) + InlineKeyboardMarkup::new([[rerun_button, seed_button, settings_button].concat()]) } #[instrument(skip_all)] @@ -437,6 +482,7 @@ async fn handle_reuse( (mut txt2img, mut img2img): (Box, Box), q: CallbackQuery, seed: i64, + cfg: ConfigParameters, ) -> anyhow::Result<()> { let message = if let Some(message) = q.message { message @@ -505,7 +551,11 @@ async fn handle_reuse( warn!("Failed to answer set seed callback query: {}", e) } bot.edit_message_reply_markup(chat_id, id) - .reply_markup(keyboard(-1)) + .reply_markup(keyboard( + -1, + cfg, + message.from().map(|f| f.id.into()).unwrap_or(chat_id), + )) .send() .await?; } diff --git a/crates/stable-diffusion-bot/src/bot/handlers/mod.rs b/crates/stable-diffusion-bot/src/bot/handlers/mod.rs index 4807156..4676cb6 100644 --- a/crates/stable-diffusion-bot/src/bot/handlers/mod.rs +++ b/crates/stable-diffusion-bot/src/bot/handlers/mod.rs @@ -92,6 +92,14 @@ pub(crate) fn filter_map_settings() -> UpdateHandler { }) } +pub(crate) fn admin_filter() -> UpdateHandler { + dptree::filter(|cfg: ConfigParameters, upd: Update| { + upd.user() + .map(|user| cfg.user_is_admin(&user.id.into())) + .unwrap_or_default() + }) +} + pub(crate) fn auth_filter() -> UpdateHandler { dptree::filter(|cfg: ConfigParameters, upd: Update| { upd.chat() @@ -135,6 +143,7 @@ pub(crate) fn authenticated_command_handler() -> UpdateHandler { .branch(image_schema()) } +#[cfg(any())] #[cfg(test)] mod tests { use super::*; diff --git a/crates/stable-diffusion-bot/src/bot/handlers/settings.rs b/crates/stable-diffusion-bot/src/bot/handlers/settings.rs index 51bd190..e73bfa2 100644 --- a/crates/stable-diffusion-bot/src/bot/handlers/settings.rs +++ b/crates/stable-diffusion-bot/src/bot/handlers/settings.rs @@ -13,7 +13,7 @@ use tracing::{error, warn}; use crate::{bot::ConfigParameters, BotState}; -use super::{filter_map_bot_state, filter_map_settings, DiffusionDialogue, State}; +use super::{admin_filter, filter_map_bot_state, filter_map_settings, DiffusionDialogue, State}; /// BotCommands for settings. #[derive(BotCommands, Clone)] @@ -584,6 +584,10 @@ pub(crate) fn settings_schema() -> UpdateHandler { .branch(filter_settings_state().endpoint(handle_invalid_setting_value)); dptree::entry() + .branch( + dptree::filter(|cfg: ConfigParameters| cfg.settings.disable_user_settings) + .chain(admin_filter()), + ) .branch(settings_command_handler()) .branch(message_handler) .branch(callback_handler) @@ -596,6 +600,7 @@ mod tests { Img2ImgApi, Img2ImgApiError, Img2ImgParams, Response, Txt2ImgApi, Txt2ImgApiError, Txt2ImgParams, }; + #[cfg(any())] use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest}; use teloxide::types::{UpdateKind, User}; @@ -786,6 +791,7 @@ mod tests { } } + #[cfg(any())] #[tokio::test] async fn test_map_settings_default() { assert!(matches!( @@ -823,6 +829,7 @@ mod tests { )); } + #[cfg(any())] #[tokio::test] async fn test_map_settings_ready() { let txt2img = Txt2ImgParams { diff --git a/crates/stable-diffusion-bot/src/bot/mod.rs b/crates/stable-diffusion-bot/src/bot/mod.rs index fb99fd7..c5b6199 100644 --- a/crates/stable-diffusion-bot/src/bot/mod.rs +++ b/crates/stable-diffusion-bot/src/bot/mod.rs @@ -216,12 +216,34 @@ impl StableDiffusionBot { } } +#[derive(Clone, Debug, Default)] +pub struct SettingsParameters { + pub disable_user_settings: bool, +} + +#[derive(Clone, Debug, Default)] +pub struct UiParameters { + pub hide_rerun_button: bool, + pub hide_reuse_button: bool, + pub hide_settings_button: bool, + pub hide_all_buttons: bool, +} + +#[derive(Clone, Debug, Default)] +pub struct MessageParameters { + pub hide_generation_info: bool, +} + #[derive(Clone, Debug)] pub(crate) struct ConfigParameters { allowed_users: HashSet, txt2img_api: Box, img2img_api: Box, allow_all_users: bool, + administrator_users: HashSet, + settings: SettingsParameters, + ui: UiParameters, + messages: MessageParameters, } impl ConfigParameters { @@ -229,6 +251,11 @@ impl ConfigParameters { pub fn chat_is_allowed(&self, chat_id: &ChatId) -> bool { self.allow_all_users || self.allowed_users.contains(chat_id) } + + /// Checks whether a user is an admin by the config. + pub fn user_is_admin(&self, chat_id: &ChatId) -> bool { + self.administrator_users.contains(chat_id) + } } #[derive(Serialize, Deserialize, Default, Debug)] @@ -256,6 +283,10 @@ pub struct StableDiffusionBotBuilder { comfyui_img2img_prompt_file: Option, comfyui_txt2img_prompt_file: Option, allow_all_users: bool, + administrator_users: Vec, + settings: SettingsParameters, + ui: UiParameters, + messages: MessageParameters, } impl StableDiffusionBotBuilder { @@ -278,6 +309,10 @@ impl StableDiffusionBotBuilder { api_type, comfyui_txt2img_prompt_file: None, comfyui_img2img_prompt_file: None, + administrator_users: Vec::new(), + settings: SettingsParameters::default(), + ui: UiParameters::default(), + messages: MessageParameters::default(), } } @@ -296,7 +331,7 @@ impl StableDiffusionBotBuilder { /// # let sd_api_url = "http://localhost:7860".to_string(); /// # let allow_all_users = false; /// # tokio_test::block_on(async { - /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url, allow_all_users); + /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url); /// /// let bot = builder.db_path(Some("database.sqlite".to_string())).build().await.unwrap(); /// # }); @@ -384,6 +419,118 @@ impl StableDiffusionBotBuilder { self } + /// Builder function that sets whether all users are allowed to use the bot. + pub fn allow_all_users(mut self, allow_all_users: bool) -> Self { + self.allow_all_users = allow_all_users; + self + } + + /// Builder function that sets the settings configuration for the bot. + /// + /// # Arguments + /// + /// * `settings` - A `SettingsParameters` struct representing the settings configuration for the bot. + /// + /// # Examples + /// + /// ``` + /// # use stable_diffusion_bot::StableDiffusionBotBuilder; + /// # use stable_diffusion_bot::SettingsParameters; + /// # + /// # let api_key = "api_key".to_string(); + /// # let allowed_users = vec![1, 2, 3]; + /// # let sd_api_url = "http://localhost:7860".to_string(); + /// # let allow_all_users = false; + /// # tokio_test::block_on(async { + /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url); + /// + /// let bot = builder.configure_settings(SettingsParameters::default()).build().await.unwrap(); + /// # }); + /// ``` + pub fn configure_settings(mut self, settings: SettingsParameters) -> Self { + self.settings = settings; + self + } + + /// Builder function that sets the UI configuration for the bot. + /// + /// # Arguments + /// + /// * `ui` - A `UiParameters` struct representing the UI configuration for the bot. + /// + /// # Examples + /// + /// ``` + /// # use stable_diffusion_bot::StableDiffusionBotBuilder; + /// # use stable_diffusion_bot::UiParameters; + /// # + /// # let api_key = "api_key".to_string(); + /// # let allowed_users = vec![1, 2, 3]; + /// # let sd_api_url = "http://localhost:7860".to_string(); + /// # let allow_all_users = false; + /// # tokio_test::block_on(async { + /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url); + /// + /// let bot = builder.configure_ui(UiParameters::default()).build().await.unwrap(); + /// # }); + /// ``` + pub fn configure_ui(mut self, ui: UiParameters) -> Self { + self.ui = ui; + self + } + + /// Builder function that sets the messages configuration for the bot. + /// + /// # Arguments + /// + /// * `messages` - A `MessageParameters` struct representing the messages configuration for the bot. + /// + /// # Examples + /// + /// ``` + /// # use stable_diffusion_bot::StableDiffusionBotBuilder; + /// # use stable_diffusion_bot::MessageParameters; + /// # + /// # let api_key = "api_key".to_string(); + /// # let allowed_users = vec![1, 2, 3]; + /// # let sd_api_url = "http://localhost:7860".to_string(); + /// # let allow_all_users = false; + /// # tokio_test::block_on(async { + /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url); + /// + /// let bot = builder.configure_messages(MessageParameters::default()).build().await.unwrap(); + /// # }); + /// ``` + pub fn configure_messages(mut self, messages: MessageParameters) -> Self { + self.messages = messages; + self + } + + /// Builder function that sets the administrator users for the bot. + /// + /// # Arguments + /// + /// * `administrator_users` - A `Vec` representing the administrator users for the bot. + /// + /// # Examples + /// + /// ``` + /// # use stable_diffusion_bot::StableDiffusionBotBuilder; + /// # + /// # let api_key = "api_key".to_string(); + /// # let allowed_users = vec![1, 2, 3]; + /// # let sd_api_url = "http://localhost:7860".to_string(); + /// # let allow_all_users = false; + /// # tokio_test::block_on(async { + /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url); + /// + /// let bot = builder.administrator_users(vec![1, 2, 3]).build().await.unwrap(); + /// # }); + pub fn administrator_users(mut self, administrator_users: Vec) -> Self { + self.administrator_users = administrator_users; + self + } + /// Consumes the builder and builds a `StableDiffusionBot` instance. /// /// # Examples @@ -414,6 +561,7 @@ impl StableDiffusionBotBuilder { let bot = Bot::new(self.api_key.clone()); let allowed_users = self.allowed_users.into_iter().map(ChatId).collect(); + let administrator_users = self.administrator_users.into_iter().map(ChatId).collect(); let client = reqwest::Client::new(); @@ -509,6 +657,10 @@ impl StableDiffusionBotBuilder { txt2img_api, img2img_api, allow_all_users: self.allow_all_users, + administrator_users, + settings: self.settings, + ui: self.ui, + messages: self.messages, }; Ok(StableDiffusionBot { @@ -573,7 +725,7 @@ mod tests { bot.config.allowed_users, allowed_users.into_iter().map(ChatId).collect() ); - assert_eq!(bot.config.allow_all_users, allow_all_users); + assert!(!bot.config.allow_all_users); assert_eq!( bot.config .txt2img_api @@ -632,7 +784,7 @@ mod tests { bot.config.allowed_users, allowed_users.into_iter().map(ChatId).collect() ); - assert_eq!(bot.config.allow_all_users, allow_all_users); + assert!(!bot.config.allow_all_users); assert_eq!( bot.config .txt2img_api @@ -652,62 +804,4 @@ mod tests { default_img2img(img2img_settings) ); } - - #[tokio::test] - async fn test_stable_diffusion_bot_no_user_defaults() { - let api_key = "api_key".to_string(); - let sd_api_url = "http://localhost:7860".to_string(); - let allowed_users = vec![1, 2, 3]; - let allow_all_users = false; - let api_type = ApiType::StableDiffusionWebUi; - - let builder = StableDiffusionBotBuilder::new( - api_key.clone(), - allowed_users.clone(), - sd_api_url.clone(), - api_type, - allow_all_users, - ); - - let bot = builder - .txt2img_defaults(Txt2ImgRequest { - width: Some(1024), - height: Some(768), - ..Default::default() - }) - .img2img_defaults(Img2ImgRequest { - width: Some(1024), - height: Some(768), - ..Default::default() - }) - .clear_txt2img_defaults() - .clear_img2img_defaults() - .build() - .await - .unwrap(); - - assert_eq!( - bot.config.allowed_users, - allowed_users.into_iter().map(ChatId).collect() - ); - assert_eq!(bot.config.allow_all_users, allow_all_users); - assert_eq!( - bot.config - .txt2img_api - .as_any() - .downcast_ref::() - .unwrap() - .txt2img_defaults, - default_txt2img(Txt2ImgRequest::default()) - ); - assert_eq!( - bot.config - .img2img_api - .as_any() - .downcast_ref::() - .unwrap() - .img2img_defaults, - default_img2img(Img2ImgRequest::default()) - ); - } } diff --git a/crates/stable-diffusion-bot/src/main.rs b/crates/stable-diffusion-bot/src/main.rs index 1eebd0d..e15c9f8 100644 --- a/crates/stable-diffusion-bot/src/main.rs +++ b/crates/stable-diffusion-bot/src/main.rs @@ -6,8 +6,11 @@ use figment::{ }; use serde::{Deserialize, Serialize}; use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest}; -use stable_diffusion_bot::{ApiType, ComfyUIConfig, StableDiffusionBotBuilder}; -use tracing::metadata::LevelFilter; +use stable_diffusion_bot::{ + ApiType, ComfyUIConfig, MessageParameters, SettingsParameters, StableDiffusionBotBuilder, + UiParameters, +}; +use tracing::{info, metadata::LevelFilter}; use tracing_subscriber::{prelude::*, EnvFilter}; use std::path::PathBuf; @@ -41,6 +44,66 @@ struct Config { img2img: Option, allow_all_users: Option, comfyui: Option, + administrator_users: Option>, + settings: Option, + ui: Option, + messages: Option, + start_message: Option, +} + +#[derive(Serialize, Deserialize, Default, Debug)] +struct Settings { + disable_user_settings: Option, +} + +#[derive(Serialize, Deserialize, Default, Debug)] +struct Ui { + hide_rerun_button: Option, + hide_reuse_button: Option, + hide_settings_button: Option, + hide_all_buttons: Option, +} + +#[derive(Serialize, Deserialize, Default, Debug)] +struct Messages { + hide_generation_info: Option, + generation_info: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "snake_case")] +enum GenerationInfo { + Prompt, + AllPrompts, + NegativePrompt, + AllNegativePrompts, + Seed, + AllSeeds, + Subseed, + AllSubseeds, + SubseedStrength, + Width, + Height, + SamplerName, + CfgScale, + Steps, + BatchSize, + RestoreFaces, + FaceRestorationModel, + SdModelName, + SdModelHash, + SdVaeName, + SdVaeHash, + SeedResizeFromW, + SeedResizeFromH, + DenoisingStrength, + ExtraGenerationParams, + IndexOfFirstImage, + Infotexts, + Styles, + JobTimestamp, + ClipSkip, + IsUsingInpaintingConditioning, } #[tokio::main] @@ -83,6 +146,47 @@ async fn main() -> anyhow::Result<()> { .extract() .context("Invalid configuration")?; + info!(?config); + + let settings = SettingsParameters { + disable_user_settings: config + .settings + .and_then(|s| s.disable_user_settings) + .unwrap_or_default(), + }; + + let mut ui = UiParameters { + hide_rerun_button: config + .ui + .as_ref() + .and_then(|s| s.hide_rerun_button) + .unwrap_or_default(), + hide_reuse_button: config + .ui + .as_ref() + .and_then(|s| s.hide_reuse_button) + .unwrap_or_default(), + hide_settings_button: config + .ui + .as_ref() + .and_then(|s| s.hide_settings_button) + .unwrap_or_default(), + hide_all_buttons: config + .ui + .and_then(|s| s.hide_all_buttons) + .unwrap_or_default(), + }; + + ui.hide_all_buttons |= ui.hide_rerun_button && ui.hide_reuse_button && ui.hide_settings_button; + + let messages = MessageParameters { + hide_generation_info: config + .messages + .as_ref() + .and_then(|s| s.hide_generation_info) + .unwrap_or_default(), + }; + StableDiffusionBotBuilder::new( config.api_key, config.allowed_users, @@ -94,6 +198,10 @@ async fn main() -> anyhow::Result<()> { .txt2img_defaults(config.txt2img.unwrap_or_default()) .img2img_defaults(config.img2img.unwrap_or_default()) .comfyui_config(config.comfyui.unwrap_or_default()) + .administrator_users(config.administrator_users.unwrap_or_default()) + .configure_settings(settings) + .configure_ui(ui) + .configure_messages(messages) .build() .await? .run() From 375a3560800bb95f801cf6789672bd92040c21f0 Mon Sep 17 00:00:00 2001 From: Rich Date: Mon, 27 Nov 2023 20:24:20 -0800 Subject: [PATCH 3/4] make a big huge mess implementing these settings. --- .../src/bot/handlers/mod.rs | 38 +++++++++---------- .../src/bot/handlers/settings.rs | 13 ++++--- crates/stable-diffusion-bot/src/bot/mod.rs | 28 +++++++++++++- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/crates/stable-diffusion-bot/src/bot/handlers/mod.rs b/crates/stable-diffusion-bot/src/bot/handlers/mod.rs index 4676cb6..17eecde 100644 --- a/crates/stable-diffusion-bot/src/bot/handlers/mod.rs +++ b/crates/stable-diffusion-bot/src/bot/handlers/mod.rs @@ -19,12 +19,12 @@ pub(crate) use settings::*; #[derive(BotCommands, Clone)] #[command(rename_rule = "lowercase", description = "Simple commands")] pub(crate) enum UnauthenticatedCommands { - #[command(description = "show help message.")] - Help, #[command(description = "start the bot.")] Start, #[command(description = "change settings.")] Settings, + #[command(description = "show help message.")] + Help, } pub(crate) async fn unauthenticated_commands_handler( @@ -35,17 +35,25 @@ pub(crate) async fn unauthenticated_commands_handler( cmd: UnauthenticatedCommands, dialogue: DiffusionDialogue, ) -> anyhow::Result<()> { + let chat = msg.chat.id; + let user = msg.from().unwrap().id; let text = match cmd { UnauthenticatedCommands::Help => { - if cfg.chat_is_allowed(&msg.chat.id) - || cfg.chat_is_allowed(&msg.from().unwrap().id.into()) - { - format!( - "{}\n\n{}\n\n{}", - UnauthenticatedCommands::descriptions(), - SettingsCommands::descriptions(), - GenCommands::descriptions() - ) + if cfg.chat_is_allowed(&chat) || cfg.chat_is_allowed(&user.into()) { + let unauth = vec![UnauthenticatedCommands::descriptions()]; + let settings = + if !cfg.settings.disable_user_settings || cfg.user_is_admin(&user.into()) { + vec![SettingsCommands::descriptions()] + } else { + vec![] + }; + let gen = vec![GenCommands::descriptions()]; + [unauth, settings, gen] + .iter() + .flatten() + .map(ToString::to_string) + .collect::>() + .join("\n\n") } else if msg.chat.is_group() || msg.chat.is_supergroup() { UnauthenticatedCommands::descriptions() .username_from_me(&me) @@ -92,14 +100,6 @@ pub(crate) fn filter_map_settings() -> UpdateHandler { }) } -pub(crate) fn admin_filter() -> UpdateHandler { - dptree::filter(|cfg: ConfigParameters, upd: Update| { - upd.user() - .map(|user| cfg.user_is_admin(&user.id.into())) - .unwrap_or_default() - }) -} - pub(crate) fn auth_filter() -> UpdateHandler { dptree::filter(|cfg: ConfigParameters, upd: Update| { upd.chat() diff --git a/crates/stable-diffusion-bot/src/bot/handlers/settings.rs b/crates/stable-diffusion-bot/src/bot/handlers/settings.rs index e73bfa2..15e7320 100644 --- a/crates/stable-diffusion-bot/src/bot/handlers/settings.rs +++ b/crates/stable-diffusion-bot/src/bot/handlers/settings.rs @@ -13,7 +13,7 @@ use tracing::{error, warn}; use crate::{bot::ConfigParameters, BotState}; -use super::{admin_filter, filter_map_bot_state, filter_map_settings, DiffusionDialogue, State}; +use super::{filter_map_bot_state, filter_map_settings, DiffusionDialogue, State}; /// BotCommands for settings. #[derive(BotCommands, Clone)] @@ -584,10 +584,13 @@ pub(crate) fn settings_schema() -> UpdateHandler { .branch(filter_settings_state().endpoint(handle_invalid_setting_value)); dptree::entry() - .branch( - dptree::filter(|cfg: ConfigParameters| cfg.settings.disable_user_settings) - .chain(admin_filter()), - ) + .chain(dptree::filter(|cfg: ConfigParameters, upd: Update| { + !cfg.settings.disable_user_settings + || upd + .user() + .map(|user| cfg.user_is_admin(&user.id.into())) + .unwrap_or_default() + })) .branch(settings_command_handler()) .branch(message_handler) .branch(callback_handler) diff --git a/crates/stable-diffusion-bot/src/bot/mod.rs b/crates/stable-diffusion-bot/src/bot/mod.rs index c5b6199..dd4875d 100644 --- a/crates/stable-diffusion-bot/src/bot/mod.rs +++ b/crates/stable-diffusion-bot/src/bot/mod.rs @@ -107,6 +107,7 @@ impl StableDiffusionBot { /// Creates an UpdateHandler for the bot fn schema() -> UpdateHandler { Self::enter::, _>() + .chain(dptree::filter_async(Self::set_my_commands)) .branch(unauth_command_handler()) .branch(authenticated_command_handler()) } @@ -184,6 +185,32 @@ impl StableDiffusionBot { ) } + async fn set_my_commands(bot: Bot, cfg: ConfigParameters, upd: Update) -> bool { + let (chat, user) = match (upd.chat(), upd.user()) { + (Some(c), Some(u)) => (c, u), + _ => return true, + }; + let mut commands = UnauthenticatedCommands::bot_commands(); + if !cfg.settings.disable_user_settings || cfg.user_is_admin(&user.id.into()) { + commands.extend(SettingsCommands::bot_commands()); + } + commands.extend(GenCommands::bot_commands()); + let scope = if chat.id == user.id.into() { + teloxide::types::BotCommandScope::Chat { + chat_id: teloxide::types::Recipient::Id(chat.id), + } + } else { + teloxide::types::BotCommandScope::ChatMember { + chat_id: teloxide::types::Recipient::Id(chat.id), + user_id: user.id, + } + }; + if let Err(e) = bot.set_my_commands(commands).scope(scope).await { + error!("Failed to set commands: {e:?}"); + } + true + } + /// Runs the StableDiffusionBot pub async fn run(self) -> anyhow::Result<()> { let StableDiffusionBot { @@ -193,7 +220,6 @@ impl StableDiffusionBot { } = self; let mut commands = UnauthenticatedCommands::bot_commands(); - commands.extend(SettingsCommands::bot_commands()); commands.extend(GenCommands::bot_commands()); bot.set_my_commands(commands) .scope(teloxide::types::BotCommandScope::Default) From 06af0dc6e954e61989668ae227156b3f734c61ca Mon Sep 17 00:00:00 2001 From: Rich Date: Thu, 4 Jan 2024 16:45:42 -0800 Subject: [PATCH 4/4] clean things up a bit --- .../src/bot/handlers/image.rs | 100 ++++++++++++------ crates/stable-diffusion-bot/src/bot/mod.rs | 58 ++++++++++ 2 files changed, 124 insertions(+), 34 deletions(-) diff --git a/crates/stable-diffusion-bot/src/bot/handlers/image.rs b/crates/stable-diffusion-bot/src/bot/handlers/image.rs index 24ebac2..621b2a8 100644 --- a/crates/stable-diffusion-bot/src/bot/handlers/image.rs +++ b/crates/stable-diffusion-bot/src/bot/handlers/image.rs @@ -62,11 +62,32 @@ impl Photo { } } +struct ReplyConfig { + disable_user_settings: bool, + hide_all_buttons: bool, + hide_settings_button: bool, + hide_reuse_button: bool, + hide_rerun_button: bool, +} + +impl From for ReplyConfig { + fn from(config: ConfigParameters) -> Self { + Self { + disable_user_settings: config.settings.disable_user_settings, + hide_all_buttons: config.ui.hide_all_buttons, + hide_settings_button: config.ui.hide_settings_button, + hide_reuse_button: config.ui.hide_reuse_button, + hide_rerun_button: config.ui.hide_rerun_button, + } + } +} + struct Reply { caption: String, images: Photo, source: MessageId, seed: i64, + config: ReplyConfig, } impl Reply { @@ -75,6 +96,7 @@ impl Reply { images: Vec>, seed: i64, source: MessageId, + config: ReplyConfig, ) -> anyhow::Result { let images = Photo::album(images)?; Ok(Self { @@ -82,16 +104,11 @@ impl Reply { images, source, seed, + config, }) } - pub async fn send( - self, - bot: &Bot, - chat_id: ChatId, - cfg: ConfigParameters, - user_id: Option, - ) -> anyhow::Result<()> { + pub async fn send(self, bot: &Bot, chat_id: ChatId) -> anyhow::Result<()> { match self.images { Photo::Single(image) => { let mut message = bot @@ -99,12 +116,8 @@ impl Reply { .parse_mode(teloxide::types::ParseMode::MarkdownV2) .caption(self.caption) .reply_to_message_id(self.source); - if !cfg.ui.hide_all_buttons { - message = message.reply_markup(keyboard( - self.seed, - cfg, - user_id.map(Into::into).unwrap_or(chat_id), - )) + if !self.config.hide_all_buttons { + message = message.reply_markup(keyboard(self.seed, self.config.into())) } message.await?; } @@ -120,16 +133,12 @@ impl Reply { bot.send_media_group(chat_id, input_media) .reply_to_message_id(self.source) .await?; - if !cfg.ui.hide_all_buttons { + if !self.config.hide_all_buttons { bot.send_message( chat_id, "What would you like to do? Select below, or enter a new prompt.", ) - .reply_markup(keyboard( - self.seed, - cfg, - user_id.map(Into::into).unwrap_or(chat_id), - )) + .reply_markup(keyboard(self.seed, self.config.into())) .reply_to_message_id(self.source) .await?; } @@ -271,9 +280,9 @@ async fn handle_image( .context("Failed to build caption from response")? }; - Reply::new(caption.0, resp.images, seed, msg.id) + Reply::new(caption.0, resp.images, seed, msg.id, cfg.into()) .context("Failed to create response!")? - .send(&bot, msg.chat.id, cfg, msg.from().map(|f| f.id)) + .send(&bot, msg.chat.id) .await?; dialogue @@ -333,9 +342,9 @@ async fn handle_prompt( .context("Failed to build caption from response")? }; - Reply::new(caption.0, resp.images, seed, msg.id) + Reply::new(caption.0, resp.images, seed, msg.id, cfg.into()) .context("Failed to create response!")? - .send(&bot, msg.chat.id, cfg, msg.from().map(|f| f.id)) + .send(&bot, msg.chat.id) .await?; dialogue @@ -350,15 +359,42 @@ async fn handle_prompt( Ok(()) } -fn keyboard(seed: i64, cfg: ConfigParameters, user: ChatId) -> InlineKeyboardMarkup { - let settings_button = if (!cfg.settings.disable_user_settings || cfg.user_is_admin(&user)) - && !cfg.ui.hide_settings_button - { +struct KeyboardConfig { + disable_user_settings: bool, + hide_settings_button: bool, + hide_reuse_button: bool, + hide_rerun_button: bool, +} + +impl From for KeyboardConfig { + fn from(config: ReplyConfig) -> Self { + Self { + disable_user_settings: config.disable_user_settings, + hide_settings_button: config.hide_settings_button, + hide_reuse_button: config.hide_reuse_button, + hide_rerun_button: config.hide_rerun_button, + } + } +} + +impl From for KeyboardConfig { + fn from(config: ConfigParameters) -> Self { + Self { + disable_user_settings: config.settings.disable_user_settings, + hide_settings_button: config.ui.hide_settings_button, + hide_reuse_button: config.ui.hide_reuse_button, + hide_rerun_button: config.ui.hide_rerun_button, + } + } +} + +fn keyboard(seed: i64, cfg: KeyboardConfig) -> InlineKeyboardMarkup { + let settings_button = if (!cfg.disable_user_settings) && !cfg.hide_settings_button { vec![InlineKeyboardButton::callback("⚙️ Settings", "settings")] } else { vec![] }; - let seed_button = if cfg.ui.hide_reuse_button { + let seed_button = if cfg.hide_reuse_button { vec![] } else if seed == -1 { vec![InlineKeyboardButton::callback("🎲 Seed", "reuse/-1")] @@ -368,7 +404,7 @@ fn keyboard(seed: i64, cfg: ConfigParameters, user: ChatId) -> InlineKeyboardMar format!("reuse/{seed}"), )] }; - let rerun_button = if cfg.ui.hide_rerun_button { + let rerun_button = if cfg.hide_rerun_button { vec![] } else { vec![InlineKeyboardButton::callback("🔄 Rerun", "rerun")] @@ -551,11 +587,7 @@ async fn handle_reuse( warn!("Failed to answer set seed callback query: {}", e) } bot.edit_message_reply_markup(chat_id, id) - .reply_markup(keyboard( - -1, - cfg, - message.from().map(|f| f.id.into()).unwrap_or(chat_id), - )) + .reply_markup(keyboard(-1, cfg.into())) .send() .await?; } diff --git a/crates/stable-diffusion-bot/src/bot/mod.rs b/crates/stable-diffusion-bot/src/bot/mod.rs index dd4875d..933da73 100644 --- a/crates/stable-diffusion-bot/src/bot/mod.rs +++ b/crates/stable-diffusion-bot/src/bot/mod.rs @@ -830,4 +830,62 @@ mod tests { default_img2img(img2img_settings) ); } + + #[tokio::test] + async fn test_stable_diffusion_bot_no_user_defaults() { + let api_key = "api_key".to_string(); + let sd_api_url = "http://localhost:7860".to_string(); + let allowed_users = vec![1, 2, 3]; + let allow_all_users = false; + let api_type = ApiType::StableDiffusionWebUi; + + let builder = StableDiffusionBotBuilder::new( + api_key.clone(), + allowed_users.clone(), + sd_api_url.clone(), + api_type, + allow_all_users, + ); + + let bot = builder + .txt2img_defaults(Txt2ImgRequest { + width: Some(1024), + height: Some(768), + ..Default::default() + }) + .img2img_defaults(Img2ImgRequest { + width: Some(1024), + height: Some(768), + ..Default::default() + }) + .clear_txt2img_defaults() + .clear_img2img_defaults() + .build() + .await + .unwrap(); + + assert_eq!( + bot.config.allowed_users, + allowed_users.into_iter().map(ChatId).collect() + ); + assert_eq!(bot.config.allow_all_users, allow_all_users); + assert_eq!( + bot.config + .txt2img_api + .as_any() + .downcast_ref::() + .unwrap() + .txt2img_defaults, + default_txt2img(Txt2ImgRequest::default()) + ); + assert_eq!( + bot.config + .img2img_api + .as_any() + .downcast_ref::() + .unwrap() + .img2img_defaults, + default_img2img(Img2ImgRequest::default()) + ); + } }