From 3abdbaf60ac4e71bd62945577be0be01e2cb9797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Thu, 5 Dec 2024 10:29:06 +0100 Subject: [PATCH 1/3] Added custom color schemes to figures API --- topicwizard/figures/documents.py | 20 ++++++++++++++++---- topicwizard/figures/groups.py | 11 +++++++++-- topicwizard/figures/topics.py | 5 ++++- topicwizard/figures/words.py | 10 ++++++++-- topicwizard/plots/topics.py | 5 +++-- 5 files changed, 40 insertions(+), 11 deletions(-) diff --git a/topicwizard/figures/documents.py b/topicwizard/figures/documents.py index c678ae8..9cda2c6 100644 --- a/topicwizard/figures/documents.py +++ b/topicwizard/figures/documents.py @@ -1,4 +1,5 @@ """External API for creating self-contained figures for documents.""" + from typing import List, Optional, Union import numpy as np @@ -56,7 +57,10 @@ def document_map( def document_topic_distribution( - topic_data: TopicData, documents: Union[List[str], str], top_n: int = 8 + topic_data: TopicData, + documents: Union[List[str], str], + top_n: int = 8, + color_scheme: str = "Portland", ) -> go.Figure: """Displays topic distribution on a bar plot for a document or a set of documents. @@ -69,6 +73,8 @@ def document_topic_distribution( Documents to display topic distribution for. top_n: int, default 8 Number of topics to display at most. + color_scheme: str, default 'Portland' + Name of the Plotly color scheme to use for the plot. """ transform = topic_data["transform"] if transform is None: @@ -80,7 +86,7 @@ def document_topic_distribution( topic_importances = prepare.document_topic_importances(transform(documents)) topic_importances = topic_importances.groupby(["topic_id"]).sum().reset_index() n_topics = topic_data["document_topic_matrix"].shape[-1] - twilight = colors.get_colorscale("Portland") + twilight = colors.get_colorscale(color_scheme) topic_colors = colors.sample_colorscale(twilight, np.arange(n_topics) / n_topics) topic_colors = np.array(topic_colors) return plots.document_topic_barplot( @@ -89,7 +95,11 @@ def document_topic_distribution( def document_topic_timeline( - topic_data: TopicData, document: str, window_size: int = 10, step_size: int = 1 + topic_data: TopicData, + document: str, + window_size: int = 10, + step_size: int = 1, + color_scheme: str = "Portland", ) -> go.Figure: """Projects documents into 2d space and displays them on a scatter plot. @@ -103,6 +113,8 @@ def document_topic_timeline( The windows over which topic inference should be run. step_size: int, default 1 Size of the steps for the rolling window. + color_scheme: str, default 'Portland' + Name of Plotly color scheme to use for the plot. """ timeline = prepare.calculate_timeline( doc_id=0, @@ -113,7 +125,7 @@ def document_topic_timeline( ) topic_names = topic_data["topic_names"] n_topics = len(topic_names) - twilight = colors.get_colorscale("Portland") + twilight = colors.get_colorscale(color_scheme) topic_colors = colors.sample_colorscale(twilight, np.arange(n_topics) / n_topics) topic_colors = np.array(topic_colors) return plots.document_timeline(timeline, topic_names, topic_colors) diff --git a/topicwizard/figures/groups.py b/topicwizard/figures/groups.py index 9f1b3a6..aa54cb8 100644 --- a/topicwizard/figures/groups.py +++ b/topicwizard/figures/groups.py @@ -1,4 +1,5 @@ """External API for creating self-contained figures for groups.""" + from typing import List import numpy as np @@ -68,7 +69,11 @@ def group_map(topic_data: TopicData, group_labels: List[str]) -> go.Figure: def group_topic_barcharts( - topic_data: TopicData, group_labels: List[str], top_n: int = 5, n_columns: int = 4 + topic_data: TopicData, + group_labels: List[str], + top_n: int = 5, + n_columns: int = 4, + color_scheme: str = "Portland", ): """Displays the most important topics for each group. @@ -82,6 +87,8 @@ def group_topic_barcharts( Maximum number of topics to display for each group. n_columns: int, default 4 Indicates how many columns the faceted plot should have. + color_scheme: str, default 'Portland' + Name of the plotly color scheme to use for the figure. """ # Factorizing group labels group_id_labels, group_names = pd.factorize(group_labels) @@ -105,7 +112,7 @@ def group_topic_barcharts( horizontal_spacing=0.01, ) n_topics = len(topic_data["topic_names"]) - color_scheme = colors.get_colorscale("Portland") + color_scheme = colors.get_colorscale(color_scheme) topic_colors = colors.sample_colorscale( color_scheme, np.arange(n_topics) / n_topics, low=0.25, high=1.0 ) diff --git a/topicwizard/figures/topics.py b/topicwizard/figures/topics.py index 95570bf..719e86d 100644 --- a/topicwizard/figures/topics.py +++ b/topicwizard/figures/topics.py @@ -110,6 +110,7 @@ def topic_wordclouds( topic_data: TopicData, top_n: int = 30, n_columns: int = 4, + color_scheme: str = "copper", ) -> go.Figure: """Plots most relevant words as word clouds for every topic. @@ -121,6 +122,8 @@ def topic_wordclouds( Specifies the number of words to show for each topic. n_columns: int, default 4 Number of columns in the subplot grid. + color_scheme: str, default 'copper' + Matplotlib color scheme to use for the wordcloud. """ n_topics = topic_data["topic_term_matrix"].shape[0] ( @@ -147,7 +150,7 @@ def topic_wordclouds( components=topic_term_importances, vocab=topic_data["vocab"], ) - subfig = plots.wordcloud(top_words) + subfig = plots.wordcloud(top_words, color_scheme=color_scheme) row, column = (topic_id // n_columns) + 1, (topic_id % n_columns) + 1 fig.add_trace(subfig.data[0], row=row, col=column) fig.update_layout( diff --git a/topicwizard/figures/words.py b/topicwizard/figures/words.py index f7acb36..5ebffac 100644 --- a/topicwizard/figures/words.py +++ b/topicwizard/figures/words.py @@ -18,6 +18,7 @@ def word_map( topic_data: TopicData, z_threshold: float = 2.0, topic_axes: Optional[Tuple[Union[str, int], Union[str, int]]] = None, + color_scheme: str = "tempo", ) -> go.Figure: """Plots words on a scatter plot based on UMAP projections of their importances in topics into 2D space or by two topic axes. @@ -38,6 +39,8 @@ def word_map( The topic axes along which the words should be displayed. If not specified, the axes on the graph are going to be UMAP projections' dimensions. + color_scheme: str, default 'tempo' + Name of the Plotly color scheme to use for the plot. """ topic_names = topic_data["topic_names"] if topic_axes is None: @@ -58,7 +61,7 @@ def word_map( freq_z = zscore(word_frequencies) dominant_topic = prepare.dominant_topic(topic_data["topic_term_matrix"]) dominant_topic = np.array(topic_data["topic_names"])[dominant_topic] - tempo = colors.get_colorscale("tempo") + tempo = colors.get_colorscale(color_scheme) n_topics = len(topic_data["topic_names"]) topic_colors = colors.sample_colorscale(tempo, np.arange(n_topics) / n_topics) topic_colors = np.array(topic_colors) @@ -99,6 +102,7 @@ def word_association_barchart( words: Union[List[str], str], n_association: int = 0, top_n: int = 20, + color_scheme: str = "Rainbow", ): """Plots bar chart of most important topics for the given words and their closest associations in topic space. @@ -114,6 +118,8 @@ def word_association_barchart( None get displayed by default. top_n: int = 20 Top N topics to display. + color_scheme: str, default 'Rainbow' + Name of the Plotly color scheme to use for the plot. """ if isinstance(words, str): words = [words] @@ -128,7 +134,7 @@ def word_association_barchart( word_ids, topic_data["topic_term_matrix"], n_association ) n_topics = topic_data["topic_term_matrix"].shape[0] - tempo = colors.get_colorscale("Rainbow") + tempo = colors.get_colorscale(color_scheme) topic_colors = colors.sample_colorscale(tempo, np.arange(n_topics) / n_topics) topic_colors = np.array(topic_colors) top_topics = prepare.top_topics( diff --git a/topicwizard/plots/topics.py b/topicwizard/plots/topics.py index b88c101..14cd391 100644 --- a/topicwizard/plots/topics.py +++ b/topicwizard/plots/topics.py @@ -1,4 +1,5 @@ """Module containing plotting utilities for topics.""" + from typing import List import numpy as np @@ -139,7 +140,7 @@ def topic_plot(top_words: pd.DataFrame): return fig -def wordcloud(top_words: pd.DataFrame) -> go.Figure: +def wordcloud(top_words: pd.DataFrame, color_scheme: str = "copper") -> go.Figure: """Plots most relevant words for current topic as a worcloud.""" top_dict = { word: importance @@ -151,7 +152,7 @@ def wordcloud(top_words: pd.DataFrame) -> go.Figure: width=800, height=1060, background_color="white", - colormap="copper", + colormap=color_scheme, scale=4, ).generate_from_frequencies(top_dict) image = cloud.to_image() From 7fc1771a11897aa377cf43f920a55c629c4923a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Thu, 5 Dec 2024 10:35:21 +0100 Subject: [PATCH 2/3] Relaxed joblib version' --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 19187e1..3e4c412 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dash = "^2.7.1" dash-extensions = "^1.0.4" dash-mantine-components = "~0.12.1" dash-iconify = "~0.1.2" -joblib = "~1.2.0" +joblib = "^1.2.0" scikit-learn = "^1.2.0" scipy = ">=1.8.0" umap-learn = ">=0.5.3" From 0e18be20780b77e15ecc5a2a4fa533d54ab2323f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Thu, 5 Dec 2024 10:35:40 +0100 Subject: [PATCH 3/3] Version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3e4c412..2644115 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "topic-wizard" -version = "1.1.1" +version = "1.1.2" description = "Pretty and opinionated topic model visualization in Python." authors = ["Márton Kardos "] license = "MIT"