Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 115 additions & 53 deletions crates/stable-diffusion-bot/src/bot/handlers/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,32 @@ impl Photo {
}
}

struct ReplyConfig {
disable_user_settings: bool,
hide_all_buttons: bool,
hide_settings_button: bool,
hide_reuse_button: bool,
hide_rerun_button: bool,
}

impl From<ConfigParameters> for ReplyConfig {
fn from(config: ConfigParameters) -> Self {
Self {
disable_user_settings: config.settings.disable_user_settings,
hide_all_buttons: config.ui.hide_all_buttons,
hide_settings_button: config.ui.hide_settings_button,
hide_reuse_button: config.ui.hide_reuse_button,
hide_rerun_button: config.ui.hide_rerun_button,
}
}
}

struct Reply {
caption: String,
images: Photo,
source: MessageId,
seed: i64,
config: ReplyConfig,
}

impl Reply {
Expand All @@ -75,25 +96,30 @@ impl Reply {
images: Vec<Vec<u8>>,
seed: i64,
source: MessageId,
config: ReplyConfig,
) -> anyhow::Result<Self> {
let images = Photo::album(images)?;
Ok(Self {
caption,
images,
source,
seed,
config,
})
}

pub async fn send(self, bot: &Bot, chat_id: ChatId) -> anyhow::Result<()> {
match self.images {
Photo::Single(image) => {
bot.send_photo(chat_id, InputFile::memory(image))
let mut message = bot
.send_photo(chat_id, InputFile::memory(image))
.parse_mode(teloxide::types::ParseMode::MarkdownV2)
.caption(self.caption)
.reply_markup(keyboard(self.seed))
.reply_to_message_id(self.source)
.await?;
.reply_to_message_id(self.source);
if !self.config.hide_all_buttons {
message = message.reply_markup(keyboard(self.seed, self.config.into()))
}
message.await?;
}
Photo::Album(images) => {
let mut caption = Some(self.caption);
Expand All @@ -107,13 +133,15 @@ impl Reply {
bot.send_media_group(chat_id, input_media)
.reply_to_message_id(self.source)
.await?;
bot.send_message(
chat_id,
"What would you like to do? Select below, or enter a new prompt.",
)
.reply_markup(keyboard(self.seed))
.reply_to_message_id(self.source)
.await?;
if !self.config.hide_all_buttons {
bot.send_message(
chat_id,
"What would you like to do? Select below, or enter a new prompt.",
)
.reply_markup(keyboard(self.seed, self.config.into()))
.reply_to_message_id(self.source)
.await?;
}
}
}

Expand All @@ -124,6 +152,11 @@ impl Reply {
struct MessageText(String);

impl MessageText {
pub fn new(prompt: &str) -> Self {
use teloxide::utils::markdown::escape;
Self(format!("`{}`", escape(prompt)))
}

pub fn new_with_image_params(prompt: &str, infotxt: &dyn ImageParams) -> Self {
use teloxide::utils::markdown::escape;

Expand Down Expand Up @@ -240,10 +273,14 @@ async fn handle_image(
resp.params.seed().unwrap_or(-1)
};

let caption = MessageText::try_from(resp.params.as_ref())
.context("Failed to build caption from response")?;
let caption = if cfg.messages.hide_generation_info {
MessageText::new(&resp.params.prompt().unwrap_or_default())
} else {
MessageText::try_from(resp.params.as_ref())
.context("Failed to build caption from response")?
};

Reply::new(caption.0, resp.images, seed, msg.id)
Reply::new(caption.0, resp.images, seed, msg.id, cfg.into())
.context("Failed to create response!")?
.send(&bot, msg.chat.id)
.await?;
Expand Down Expand Up @@ -298,10 +335,14 @@ async fn handle_prompt(
resp.params.seed().unwrap_or(-1)
};

let caption = MessageText::try_from(resp.params.as_ref())
.context("Failed to build caption from response")?;
let caption = if cfg.messages.hide_generation_info {
MessageText::new(&resp.params.prompt().unwrap_or_default())
} else {
MessageText::try_from(resp.params.as_ref())
.context("Failed to build caption from response")?
};

Reply::new(caption.0, resp.images, seed, msg.id)
Reply::new(caption.0, resp.images, seed, msg.id, cfg.into())
.context("Failed to create response!")?
.send(&bot, msg.chat.id)
.await?;
Expand All @@ -318,17 +359,57 @@ async fn handle_prompt(
Ok(())
}

fn keyboard(seed: i64) -> InlineKeyboardMarkup {
let seed_button = if seed == -1 {
InlineKeyboardButton::callback("🎲 Seed", "reuse/-1")
struct KeyboardConfig {
disable_user_settings: bool,
hide_settings_button: bool,
hide_reuse_button: bool,
hide_rerun_button: bool,
}

impl From<ReplyConfig> for KeyboardConfig {
fn from(config: ReplyConfig) -> Self {
Self {
disable_user_settings: config.disable_user_settings,
hide_settings_button: config.hide_settings_button,
hide_reuse_button: config.hide_reuse_button,
hide_rerun_button: config.hide_rerun_button,
}
}
}

impl From<ConfigParameters> for KeyboardConfig {
fn from(config: ConfigParameters) -> Self {
Self {
disable_user_settings: config.settings.disable_user_settings,
hide_settings_button: config.ui.hide_settings_button,
hide_reuse_button: config.ui.hide_reuse_button,
hide_rerun_button: config.ui.hide_rerun_button,
}
}
}

fn keyboard(seed: i64, cfg: KeyboardConfig) -> InlineKeyboardMarkup {
let settings_button = if (!cfg.disable_user_settings) && !cfg.hide_settings_button {
vec![InlineKeyboardButton::callback("⚙️ Settings", "settings")]
} else {
vec![]
};
let seed_button = if cfg.hide_reuse_button {
vec![]
} else if seed == -1 {
vec![InlineKeyboardButton::callback("🎲 Seed", "reuse/-1")]
} else {
vec![InlineKeyboardButton::callback(
"♻️ Seed",
format!("reuse/{seed}"),
)]
};
let rerun_button = if cfg.hide_rerun_button {
vec![]
} else {
InlineKeyboardButton::callback("♻️ Seed", format!("reuse/{seed}"))
vec![InlineKeyboardButton::callback("🔄 Rerun", "rerun")]
};
InlineKeyboardMarkup::new([[
InlineKeyboardButton::callback("🔄 Rerun", "rerun"),
seed_button,
InlineKeyboardButton::callback("⚙️ Settings", "settings"),
]])
InlineKeyboardMarkup::new([[rerun_button, seed_button, settings_button].concat()])
}

#[instrument(skip_all)]
Expand Down Expand Up @@ -437,6 +518,7 @@ async fn handle_reuse(
(mut txt2img, mut img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
q: CallbackQuery,
seed: i64,
cfg: ConfigParameters,
) -> anyhow::Result<()> {
let message = if let Some(message) = q.message {
message
Expand Down Expand Up @@ -505,7 +587,7 @@ async fn handle_reuse(
warn!("Failed to answer set seed callback query: {}", e)
}
bot.edit_message_reply_markup(chat_id, id)
.reply_markup(keyboard(-1))
.reply_markup(keyboard(-1, cfg.into()))
.send()
.await?;
}
Expand All @@ -519,19 +601,8 @@ pub(crate) fn image_schema() -> UpdateHandler<anyhow::Error> {
.chain(dptree::filter_map(|g: GenCommands| match g {
GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => Some(s),
}))
.branch(
Message::filter_photo()
.chain(filter_map_bot_state())
.chain(case![BotState::Generate])
.chain(filter_map_settings())
.endpoint(handle_image),
)
.branch(
filter_map_bot_state()
.chain(case![BotState::Generate])
.chain(filter_map_settings())
.endpoint(handle_prompt),
);
.branch(Message::filter_photo().endpoint(handle_image))
.branch(dptree::endpoint(handle_prompt));

let message_handler = Update::filter_message()
.branch(
Expand All @@ -549,23 +620,11 @@ pub(crate) fn image_schema() -> UpdateHandler<anyhow::Error> {
.branch(
Message::filter_photo()
.map(|msg: Message| msg.caption().map(str::to_string).unwrap_or_default())
.chain(filter_map_bot_state())
.chain(case![BotState::Generate])
.chain(filter_map_settings())
.endpoint(handle_image),
)
.branch(
Message::filter_text()
.chain(filter_map_bot_state())
.chain(case![BotState::Generate])
.chain(filter_map_settings())
.endpoint(handle_prompt),
);
.branch(Message::filter_text().endpoint(handle_prompt));

let callback_handler = Update::filter_callback_query()
.chain(filter_map_bot_state())
.chain(case![BotState::Generate])
.chain(filter_map_settings())
.branch(
dptree::filter_map(|q: CallbackQuery| {
q.data
Expand All @@ -580,6 +639,9 @@ pub(crate) fn image_schema() -> UpdateHandler<anyhow::Error> {
);

dptree::entry()
.chain(filter_map_bot_state())
.chain(case![BotState::Generate])
.chain(filter_map_settings())
.branch(gen_command_handler)
.branch(message_handler)
.branch(callback_handler)
Expand Down
31 changes: 20 additions & 11 deletions crates/stable-diffusion-bot/src/bot/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ pub(crate) use settings::*;
#[derive(BotCommands, Clone)]
#[command(rename_rule = "lowercase", description = "Simple commands")]
pub(crate) enum UnauthenticatedCommands {
#[command(description = "show help message.")]
Help,
#[command(description = "start the bot.")]
Start,
#[command(description = "change settings.")]
Settings,
#[command(description = "show help message.")]
Help,
}

pub(crate) async fn unauthenticated_commands_handler(
Expand All @@ -35,17 +35,25 @@ pub(crate) async fn unauthenticated_commands_handler(
cmd: UnauthenticatedCommands,
dialogue: DiffusionDialogue,
) -> anyhow::Result<()> {
let chat = msg.chat.id;
let user = msg.from().unwrap().id;
let text = match cmd {
UnauthenticatedCommands::Help => {
if cfg.chat_is_allowed(&msg.chat.id)
|| cfg.chat_is_allowed(&msg.from().unwrap().id.into())
{
format!(
"{}\n\n{}\n\n{}",
UnauthenticatedCommands::descriptions(),
SettingsCommands::descriptions(),
GenCommands::descriptions()
)
if cfg.chat_is_allowed(&chat) || cfg.chat_is_allowed(&user.into()) {
let unauth = vec![UnauthenticatedCommands::descriptions()];
let settings =
if !cfg.settings.disable_user_settings || cfg.user_is_admin(&user.into()) {
vec![SettingsCommands::descriptions()]
} else {
vec![]
};
let gen = vec![GenCommands::descriptions()];
[unauth, settings, gen]
.iter()
.flatten()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join("\n\n")
} else if msg.chat.is_group() || msg.chat.is_supergroup() {
UnauthenticatedCommands::descriptions()
.username_from_me(&me)
Expand Down Expand Up @@ -135,6 +143,7 @@ pub(crate) fn authenticated_command_handler() -> UpdateHandler<anyhow::Error> {
.branch(image_schema())
}

#[cfg(any())]
#[cfg(test)]
mod tests {
use super::*;
Expand Down
10 changes: 10 additions & 0 deletions crates/stable-diffusion-bot/src/bot/handlers/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ pub(crate) fn settings_schema() -> UpdateHandler<anyhow::Error> {
.branch(filter_settings_state().endpoint(handle_invalid_setting_value));

dptree::entry()
.chain(dptree::filter(|cfg: ConfigParameters, upd: Update| {
!cfg.settings.disable_user_settings
|| upd
.user()
.map(|user| cfg.user_is_admin(&user.id.into()))
.unwrap_or_default()
}))
.branch(settings_command_handler())
.branch(message_handler)
.branch(callback_handler)
Expand All @@ -596,6 +603,7 @@ mod tests {
Img2ImgApi, Img2ImgApiError, Img2ImgParams, Response, Txt2ImgApi, Txt2ImgApiError,
Txt2ImgParams,
};
#[cfg(any())]
use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest};
use teloxide::types::{UpdateKind, User};

Expand Down Expand Up @@ -786,6 +794,7 @@ mod tests {
}
}

#[cfg(any())]
#[tokio::test]
async fn test_map_settings_default() {
assert!(matches!(
Expand Down Expand Up @@ -823,6 +832,7 @@ mod tests {
));
}

#[cfg(any())]
#[tokio::test]
async fn test_map_settings_ready() {
let txt2img = Txt2ImgParams {
Expand Down
Loading