Skip to content

Commit 94027ac

Browse files
committed
Prompt cache
1 parent 794ad03 commit 94027ac

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using TensorStack.StableDiffusion.Pipelines;
3+
4+
namespace TensorStack.StableDiffusion.Common
5+
{
6+
public record PromptCache
7+
{
8+
public string Conditional { get; init; }
9+
public string Unconditional { get; init; }
10+
public PromptResult CacheResult { get; init; }
11+
12+
public bool IsValid(IPipelineOptions options)
13+
{
14+
return string.Equals(Conditional, options.Prompt, StringComparison.OrdinalIgnoreCase)
15+
&& string.Equals(Unconditional, options.NegativePrompt, StringComparison.OrdinalIgnoreCase);
16+
}
17+
}
18+
}

TensorStack.StableDiffusion/Pipelines/PipelineBase.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace TensorStack.StableDiffusion.Pipelines
1515
{
1616
public abstract class PipelineBase : IDisposable
1717
{
18+
private PromptCache _promptCache;
1819
private GenerateOptions _defaultOptions;
1920
private IReadOnlyList<SchedulerType> _schedulers;
2021

@@ -149,11 +150,42 @@ protected Tensor<float> ApplyGuidance(Tensor<float> conditional, Tensor<float> u
149150
}
150151

151152

153+
/// <summary>
154+
/// Gets the prompt cache.
155+
/// </summary>
156+
/// <param name="options">The options.</param>
157+
protected PromptResult GetPromptCache(IPipelineOptions options)
158+
{
159+
if (_promptCache is null || !_promptCache.IsValid(options))
160+
return default;
161+
162+
return _promptCache.CacheResult;
163+
}
164+
165+
166+
/// <summary>
167+
/// Sets the prompt cache.
168+
/// </summary>
169+
/// <param name="options">The options.</param>
170+
/// <param name="promptResult">The prompt result to cache.</param>
171+
protected PromptResult SetPromptCache(IPipelineOptions options, PromptResult promptResult)
172+
{
173+
_promptCache = new PromptCache
174+
{
175+
CacheResult = promptResult,
176+
Conditional = options.Prompt,
177+
Unconditional = options.NegativePrompt,
178+
};
179+
return promptResult;
180+
}
181+
182+
152183
/// <summary>
153184
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
154185
/// </summary>
155186
public void Dispose()
156187
{
188+
_promptCache = default;
157189
Dispose(disposing: true);
158190
GC.SuppressFinalize(this);
159191
}

0 commit comments

Comments
 (0)