Skip to content

Conversation

@isabelfaulds
Copy link
Contributor

closes #2350

This PR is to test a new method for handling shift network lowerings where multiple sources map to the same target, for example with ciphertexts of repeated content. The current method uses minimal distance between target and potential sources to select sources.

This code is still under development as it isn't running in reasonable time for the example MLIR and needs to be tested more if it improves current method. I'm submitting to see if there's any input on approach or implementation

The method used here is:
Shuffle target slots to fill.
Fill a target slot by sorting potential sources by 1) the count of targets the source has already been mapped to, and 2) the number of slots it would have a rotation conflict with ascending.
Any source with the minimal count can be picked at random.
Fill all the targets and track number of rotations needed and random seed used.
Repeat the process with a different random seed, here 50 times, and keep the best mapping.

This was implemented with these changes:
In ImplementShiftNetwork::findShiftScheme(), targetToSources is checked for multiple sources and proceeds if any.
All the potential sources are added as vertices to an all conflicts graph.
The edges needed to come from a strategy where all potential sources are included.
ShiftStrategy::evaluate() runs with all targetToSources instead of targetToSource
Sourceshifts is filled with each potential source and caculated shift to a potential target.
The collisions across rounds were calculated, this is where a bottleneck is for the original 4x4 duplicates 64 times ciphertext
The edges are added to the all conflicts graph.
Tests are run 50 times - shuffle targets and for each choose a minimal source, calculate new conflict graph & number of rotation groups, keep the mapping and conflict graph with lowest number rotation groups

To run use

(time bazel run //tools:heir-opt -- \
--debug-only=implement-shift-network,shift-scheme \
--implement-shift-network \
"$PWD/tests/Dialect/TensorExt/Transforms/implement_shift_network_issue_2350.mlir" \
> implement_shift_network_2350_output.mlir \  
2> implement_shift_network_2350_debug.log) 2> time.log

This is the shorter example I used for tests/Dialect/TensorExt/Transforms/implement_shift_network_issue_2350.mlir

// RUN: heir-opt --implement-shift-network --canonicalize %s | FileCheck %s

#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (2i2 - i3 + slot) mod 4 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= slot <= 63 }">
#layout1 = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : exists (e0, e1, e2: i0 = 0 and ct = 0 and 16e2 = -i1 + slot + 16e0 and 0 <= i1 <= 63 and 0 <= slot <= 63 and 0 <= e1 <= 3 and -3 + i1 - 16e0 <= 4e1 <= i1 - 16e0) }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x2x2x2xf64>, layout = #layout>
module {
  // CHECK: func.func @main
  // CHECK-SAME: (%[[arg0:.*]]: !secret.secret<tensor<1x64xf64>>
  // CHECK: return %[[arg0]]
  func.func @main(%arg0: !secret.secret<tensor<1x64xf64>> {tensor_ext.original_type = #tensor_ext.original_type<originalType = tensor<1x1x4x4xf64>, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 63 }">>}, %arg1: tensor<2x1x3x3xf64>) -> (!secret.secret<tensor<1x64xf64>> {tensor_ext.original_type = #original_type}) {
    %0 = secret.generic(%arg0: !secret.secret<tensor<1x64xf64>>) {
    ^body(%input0: tensor<1x64xf64>):
      %1 = tensor_ext.remap %input0 {permutation = #layout1} : tensor<1x64xf64>
      secret.yield %1 : tensor<1x64xf64>
    } -> !secret.secret<tensor<1x64xf64>>
    return %0 : !secret.secret<tensor<1x64xf64>>
  }
}

@isabelfaulds
Copy link
Contributor Author

I noticed the CI failures for formatting and the mapping const change, I'll fix these and push an updated commit

@isabelfaulds isabelfaulds force-pushed the shift-networks-source-conflicts branch from 61bc180 to 0c4f829 Compare January 9, 2026 04:09
@isabelfaulds isabelfaulds force-pushed the shift-networks-source-conflicts branch from 0c4f829 to 4d19111 Compare January 9, 2026 22:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ImplementShiftNetwork: Reduce the number of conflicts when there are multiple sources that map to a target slot

1 participant