Skip to content

Commit 520fd58

Browse files
committed
Encoder Cache
1 parent 47fdba2 commit 520fd58

File tree

8 files changed

+111
-20
lines changed

8 files changed

+111
-20
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using TensorStack.Common.Tensor;
3+
4+
namespace TensorStack.StableDiffusion.Common
5+
{
6+
public record EncoderCache
7+
{
8+
public ImageTensor InputImage { get; init; }
9+
public Tensor<float> CacheResult { get; init; }
10+
11+
public bool IsValid(ImageTensor input)
12+
{
13+
if (input is null || InputImage is null)
14+
return false;
15+
16+
if (!InputImage.Span.SequenceEqual(input.Span))
17+
return false;
18+
19+
return true;
20+
}
21+
}
22+
}

TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
268268
/// <param name="options">The options.</param>
269269
/// <param name="image">The latents.</param>
270270
/// <param name="cancellationToken">The cancellation token.</param>
271-
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
271+
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
272272
{
273273
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
274-
var inputTensor = image.ResizeImage(options.Width, options.Height);
274+
var cacheResult = GetEncoderCache(options);
275+
if (cacheResult is not null)
276+
{
277+
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
278+
return cacheResult;
279+
}
280+
281+
var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
275282
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
276283
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
277284
await AutoEncoder.EncoderUnloadAsync();
278285

279286
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
280-
return encoderResult;
287+
return SetEncoderCache(options, encoderResult);
281288
}
282289

283290

@@ -396,7 +403,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
396403
if (options.HasInputImage)
397404
{
398405
var timestep = scheduler.GetStartTimestep();
399-
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
406+
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
400407
var noiseTensor = scheduler.CreateRandomSample(encoderResult.Dimensions);
401408
return PackLatents(scheduler.ScaleNoise(timestep, encoderResult, noiseTensor));
402409
}

TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
191191
/// <param name="options">The options.</param>
192192
/// <param name="image">The latents.</param>
193193
/// <param name="cancellationToken">The cancellation token.</param>
194-
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
194+
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
195195
{
196196
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
197-
var inputTensor = image.ResizeImage(options.Width, options.Height);
197+
var cacheResult = GetEncoderCache(options);
198+
if (cacheResult is not null)
199+
{
200+
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
201+
return cacheResult;
202+
}
203+
204+
var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
198205
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
199206
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
200207
await AutoEncoder.EncoderUnloadAsync();
201208

202209
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
203-
return encoderResult;
210+
return SetEncoderCache(options, encoderResult);
204211
}
205212

206213

@@ -274,7 +281,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
274281
if (options.HasInputImage)
275282
{
276283
var timestep = scheduler.GetStartTimestep();
277-
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
284+
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
278285
return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor);
279286
}
280287
return noiseTensor;

TensorStack.StableDiffusion/Pipelines/PipelineBase.cs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace TensorStack.StableDiffusion.Pipelines
1616
public abstract class PipelineBase : IDisposable
1717
{
1818
private PromptCache _promptCache;
19+
private EncoderCache _encoderCache;
1920
private GenerateOptions _defaultOptions;
2021
private IReadOnlyList<SchedulerType> _schedulers;
2122

@@ -180,12 +181,45 @@ protected PromptResult SetPromptCache(IPipelineOptions options, PromptResult pro
180181
}
181182

182183

184+
/// <summary>
185+
/// Gets the encoder cache.
186+
/// </summary>
187+
/// <param name="options">The options.</param>
188+
protected Tensor<float> GetEncoderCache(IPipelineOptions options)
189+
{
190+
if (_encoderCache is null)
191+
return default;
192+
193+
if (!_encoderCache.IsValid(options.InputImage))
194+
return default;
195+
196+
return _encoderCache.CacheResult;
197+
}
198+
199+
200+
/// <summary>
201+
/// Sets the encoder cache.
202+
/// </summary>
203+
/// <param name="options">The options.</param>
204+
/// <param name="encoded">The encoded.</param>
205+
protected Tensor<float> SetEncoderCache(IPipelineOptions options, Tensor<float> encoded)
206+
{
207+
_encoderCache = new EncoderCache
208+
{
209+
InputImage = options.InputImage,
210+
CacheResult = encoded
211+
};
212+
return encoded;
213+
}
214+
215+
183216
/// <summary>
184217
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
185218
/// </summary>
186219
public void Dispose()
187220
{
188-
_promptCache = default;
221+
_promptCache = null;
222+
_encoderCache = null;
189223
Dispose(disposing: true);
190224
GC.SuppressFinalize(this);
191225
}

TensorStack.StableDiffusion/Pipelines/StableCascade/StableCascadeBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ protected async Task<Tensor<float>> RunPriorAsync(GenerateOptions options, Tenso
221221
// Create latent sample
222222
var latents = await CreatePriorLatentInputAsync(options, scheduler, cancellationToken);
223223

224-
var image = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
224+
var image = await EncodeLatentsAsync(options, cancellationToken);
225225

226226
// Get Model metadata
227227
var metadata = await PriorUnet.LoadAsync(cancellationToken: cancellationToken);
@@ -391,7 +391,7 @@ private Task<Tensor<float>> CreateDecoderLatentsAsync(GenerateOptions options, I
391391
/// <param name="options">The options.</param>
392392
/// <param name="image">The latents.</param>
393393
/// <param name="cancellationToken">The cancellation token.</param>
394-
private Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
394+
private Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
395395
{
396396
return Task.FromResult(new Tensor<float>([1, 1, ImageEncoder.HiddenSize]));
397397
}

TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionBase.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,20 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
221221
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
222222
{
223223
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
224+
var cacheResult = GetEncoderCache(options);
225+
if (cacheResult is not null)
226+
{
227+
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
228+
return cacheResult;
229+
}
230+
224231
var inputTensor = image.ResizeImage(options.Width, options.Height);
225232
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
226233
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
227234
await AutoEncoder.EncoderUnloadAsync();
228235

229236
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
230-
return encoderResult;
237+
return SetEncoderCache(options, encoderResult);
231238
}
232239

233240

TensorStack.StableDiffusion/Pipelines/StableDiffusion3/StableDiffusion3Base.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,16 +342,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
342342
/// <param name="options">The options.</param>
343343
/// <param name="image">The latents.</param>
344344
/// <param name="cancellationToken">The cancellation token.</param>
345-
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
345+
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
346346
{
347347
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
348-
var inputTensor = image.ResizeImage(options.Width, options.Height);
348+
var cacheResult = GetEncoderCache(options);
349+
if (cacheResult is not null)
350+
{
351+
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
352+
return cacheResult;
353+
}
354+
355+
var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
349356
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
350357
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
351358
await AutoEncoder.EncoderUnloadAsync();
352359

353360
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
354-
return encoderResult;
361+
return SetEncoderCache(options, encoderResult);
355362
}
356363

357364

@@ -525,7 +532,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
525532
if (options.HasInputImage)
526533
{
527534
var timestep = scheduler.GetStartTimestep();
528-
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
535+
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
529536
return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor);
530537
}
531538
return noiseTensor;

TensorStack.StableDiffusion/Pipelines/StableDiffusionXL/StableDiffusionXLBase.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,16 +271,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
271271
/// <param name="options">The options.</param>
272272
/// <param name="image">The latents.</param>
273273
/// <param name="cancellationToken">The cancellation token.</param>
274-
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
274+
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
275275
{
276276
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
277-
var inputTensor = image.ResizeImage(options.Width, options.Height);
277+
var cacheResult = GetEncoderCache(options);
278+
if (cacheResult is not null)
279+
{
280+
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
281+
return cacheResult;
282+
}
283+
284+
var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
278285
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
279286
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
280287
await AutoEncoder.EncoderUnloadAsync();
281288

282289
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
283-
return encoderResult;
290+
return SetEncoderCache(options, encoderResult);
284291
}
285292

286293

@@ -462,7 +469,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
462469
if (options.HasInputImage)
463470
{
464471
var timestep = scheduler.GetStartTimestep();
465-
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
472+
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
466473
return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor);
467474
}
468475
return noiseTensor.Multiply(scheduler.StartSigma);

0 commit comments

Comments
 (0)