-
Notifications
You must be signed in to change notification settings - Fork 24
Description
I'm not sure if it's appropriate to create issues like this, feel free to close it without warning. Otherwise, I'd request this to stay open for some time in case somebody is interested.
Why?: The two existing backends for pixelfly use either huggingface blocksparse or triton. However, these are not always available, such as when training on TPUs or using custom parameters (e.g. triton offers only a couple block sizes)
What?: Below you can find a (limited) re-implementation of pixelfly in pure pytorch. Instead of block-sparse kernels, this implementation takes advantage of the fact that butterfly layout has equal number of nonzero blocks in each row. We can take advantage of this using a two-stage procedure:
- compute all blocks using regular (dense) matmul with
[in_features, (block_size * blocks_per_input)]weights - aggregate blocks according to butterfly layout using
F.embedding_bag(..., mode='sum')
Here's the implementation: https://gist.github.com/justheuristic/9e4fb81381451a4bc8cbfee0a5100eba
It's heavily inspired by the original code and re_uses parts of blocksparse_linear.py
It's a single file, requires only pytorch and einops and is compatible with TPUs. The speed-ups are comparable (see example_and_tests), plus it supports custom block sizes, tf32, autocast, etc. You can also easily re-write this in tensorflow using tfa.EmbeddingBag
Feel free to use for whatever :)