From 52c49fa54d890e4fc390825e5936539b76e16657 Mon Sep 17 00:00:00 2001 From: Alejo Gradau Date: Wed, 17 Dec 2025 23:37:42 +0000 Subject: [PATCH 1/3] Add Kochen-Specker inequality implementation --- recirq/contextuality/kochen_specker.py | 889 +++++++++++++++++++++++++ 1 file changed, 889 insertions(+) create mode 100644 recirq/contextuality/kochen_specker.py diff --git a/recirq/contextuality/kochen_specker.py b/recirq/contextuality/kochen_specker.py new file mode 100644 index 00000000..a2d0a409 --- /dev/null +++ b/recirq/contextuality/kochen_specker.py @@ -0,0 +1,889 @@ +# Copyright 2025 Google +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import warnings +import collections +import random +from collections.abc import Iterator, Mapping, Sequence +from typing import Literal + +import attrs +import cirq +import matplotlib.pyplot as plt +import numpy as np + +type Axis = Literal["row", "col"] +type QndStrategy = Literal["map_all_paulis_to_ancilla", "map_2q_paulis_to_ancilla"] +type CircuitStyle = Literal["flat", "jt", "jt_pauli", "jt_pauli_pivot"] + + +_CIRCUIT_DD_SUPPORTED = hasattr(cirq, "add_dynamical_decoupling") + +if _CIRCUIT_DD_SUPPORTED: + _ADD_DD_FUNC = cirq.add_dynamical_decoupling +else: + _ADD_DD_FUNC = None + + +@attrs.frozen +class MagicSquare: + """A Mermin-Peres magic square.""" + + paulis: Sequence[Sequence[str]] + + def build_circuit_for_context( + self, + context: tuple[Axis, int], + trios: Sequence[tuple[cirq.Qid, cirq.Qid, cirq.Qid]], + qnd_strategy: QndStrategy = "map_all_paulis_to_ancilla", + virtual_z_half_turns: float = 0.0, + add_ancilla_to_measure_key: bool = False, + ) -> cirq.FrozenCircuit: + """Circuit to measure the idx'th col or row of this magic square. + + Args: + context: Tuple of (Axis, idx). Axis corresponds to either a row or a column, while idx + determines which row or col to measure from the Mermin-Peres square. + trios: Sequence of three-element tuples of the form (q0, q1, ancilla). + qnd_strategy: Whether to use ancillae for measuring single-qubit paulis. + virtual_z_half_turns: This applied only for Paulis mapped to an ancilla. If not zero, + amount of half turns around the Z axis to apply on the ancilla during the mapping of + the paulis from the data qubits (q0, q1). This virtualZ is applied before the last + Hadamard during the mapping of the Pauli to the ancilla. + add_ancilla_to_measure_key: Whether to include the ancilla qubit name in the measure + key or not. + """ + axis, idx = context + if axis == "row": + paulis = [self.paulis[idx][c] for c in range(3)] + else: + paulis = [self.paulis[r][idx] for r in range(3)] + return cirq.FrozenCircuit( + cirq.FrozenCircuit.zip( + *[ + generate_qnd_pauli_circuit( + q0, + q1, + anc, + pauli, + measure_key=( + f"{pauli}_{axis}_{anc}" + if add_ancilla_to_measure_key + else f"{pauli}_{axis}" + ), + qnd_strategy=qnd_strategy, + virtual_z_half_turns=virtual_z_half_turns, + ) + for q0, q1, anc in trios + ] + ) + for pauli in paulis + ) + + def kochen_specker_expt( + self, + q0: cirq.Qid, + q1: cirq.Qid, + anc: cirq.Qid, + num_contexts: int = 36, + seed: int = 0, + qnd_strategy: QndStrategy = "map_all_paulis_to_ancilla", + virtual_z_half_turns: float = 0.0, + ) -> MagicSquareExperiment: + """Builds a Kochen-Specker experiment. + + Args: + num_contexts: Number of rows+cols measured out of the 6 + possible row/col. + seed: Seed for random number generator. Specifying this will give + reproducible results for the random axes and indices. + qnd_strategy: Whether to map all paulis to an ancilla, or only 2q paulis. + virtual_z_half_turns: This applied only for Paulis mapped to an ancilla. If not zero, + amount of half turns around the Z axis to apply on the ancilla during the mapping of + the paulis from the data qubits (q0, q1). This virtualZ is applied before the last + Hadamard during the mapping of the Pauli to the ancilla. + """ + + axes: list[Axis] = ["col", "row"] + indices = [0, 1, 2] + + random.seed(seed) + random_axes = random.choices(axes, k=num_contexts) + random_idx = random.choices(indices, k=num_contexts) + + context_to_circuit: dict[tuple[str, int], cirq.FrozenCircuit] = { + (axis, idx): self.build_circuit_for_context( + trios=[(q0, q1, anc)], + context=(axis, idx), + qnd_strategy=qnd_strategy, + virtual_z_half_turns=virtual_z_half_turns, + ) + for axis in axes + for idx in indices + } + + def pauli_key(row: int, col: int, axis: Axis) -> PauliKey: + return PauliKey(f"{self.paulis[row][col]}_{axis}") + + contexts: dict[tuple[Axis, int], Sequence[PauliKey]] = { + **{ + ("row", r): [pauli_key(r, c, "row") for c in range(3)] for r in range(3) + }, + **{ + ("col", c): [pauli_key(r, c, "col") for r in range(3)] for c in range(3) + }, + } + + ks_circuits = [ + context_to_circuit[axis, idx] for axis, idx in zip(random_axes, random_idx) + ] + circuit = transform_circuit_for_jump_table( + ks_circuits=ks_circuits, + ) + return MagicSquareExperiment(self, circuit, contexts) + + +EASY_SQUARE = MagicSquare( + [["1Z", "Z1", "ZZ"], ["X1", "1X", "XX"], ["-XZ", "-ZX", "YY"]] +) +HARD_SQUARE = MagicSquare([["ZX", "YY", "XZ"], ["XY", "ZZ", "YX"], ["YZ", "XX", "ZY"]]) + + +def transform_circuit_for_jump_table( + ks_circuits: list[cirq.FrozenCircuit], +) -> cirq.FrozenCircuit: + """Transform the circuit. + + Args: + ks_circuits: Sequence of circuits to measure a sequence of contexts. + """ + return cirq.FrozenCircuit(ks_circuits) + + +def _copy_circuit(circuit: cirq.FrozenCircuit) -> cirq.FrozenCircuit: + return cirq.FrozenCircuit.from_moments(*circuit).with_tags(*circuit.tags) + + +@attrs.frozen +class MagicSquareExperiment: + square: MagicSquare + circuit: cirq.FrozenCircuit + contexts: dict[tuple[Axis, int], Sequence[PauliKey]] + + def run(self, sampler: cirq.Sampler, repetitions: int = 1000) -> MagicSquareResult: + result = sampler.run(self.circuit, repetitions=repetitions) + return MagicSquareResult(self, result) + + +@attrs.frozen +class PauliKey: + key: str + slice: list[int] | slice = slice(None) + + +@attrs.frozen(slots=False) +class MagicSquareResult: + expt: MagicSquareExperiment + result: cirq.Result + + def estimate_chi(self) -> tuple[float, float]: + """Estimate chi out of the 6 parity averages. + + This method takes the cirq.Result output by the Kochen-Specker + experiment and returns an estimate of chi. We assume that the length of the + circuit ran was long enough to contain at least one instance of each of the + six 3-observable parities, e.g. (1X X1 XX). + """ + + chi_mean, chi_error = estimate_chi(self.squeezed_records(), self.expt.contexts) + + return chi_mean, chi_error + + def squeezed_records(self) -> Mapping[str, np.ndarray]: + """Squeeze the ResultDict.records.values() ndarrays to 2D. + + Since there is only one qubit per measurement key in our experiments, the third index is + a dummy index. This helps us process data more easily. + """ + squeezed_records = { + pauli_key: np.squeeze(records_values) + for pauli_key, records_values in self.result.records.items() + } + return squeezed_records + + def estimate_parities(self) -> dict[tuple[Axis, int], tuple[float, float]]: + """Estimate average parity and error of each of the 6 contexts. + + Useful to get more nuanced information on each row/col. We first evaluate the parity for + each time every context has appeared in the circuit. We then average the parities for + each context and estimate the error of this average. For references on the error estimation, + check the error of a binomial distribution. + + Returns: + context_averages: A dictionary mapping a context (axis, idx) to a tuple + (mean_parity, parity_error) + """ + context_averages = estimate_parities( + self.squeezed_records(), self.expt.contexts + ) + return context_averages + + def estimate_coincidences( + self, + ) -> tuple[ + dict[str, list[float]], dict[str, list[float]], dict[str, list[tuple[str, str]]] + ]: + """Each PauliString, get coincidence probability for compatible/incompatible measurements. + + A compatible pair of measurements of a given PauliString P_i (i = 0, 1, ..., 8), say ZZ, + is a pair of outcomes of ZZ where all other PauliStrings P_j (j != i) measured in between + these two outcomes are compatible with ZZ. Essentially, this happens when each and every + P_j commutes with the P_i in question. An incompatible pair of measurements of a + PauliString P_i happens when in between a pair of outcomes of P_i, there is at least one + P_k that does not commute with P_i, scrambling its measurement outcome. + + For example, say the order of the contexts measured (i.e. measurement dna) of the given + circuit is, in the EASY_SQUARE: + + dna: + ('1Z_row', 'Z1_row', 'ZZ_row') - First context + ('ZZ_col', 'XX_col', 'YY_col') - Second context + ('X1_row', '1X_row', 'XX_row') - Third context + ('ZZ_col', 'XX_col', 'YY_col') - Fourth context + ('ZZ_col', 'XX_col', 'YY_col') - Fifth context + ('X1_row', '1X_row', 'XX_row') - Sixth context + ('ZZ_col', 'XX_col', 'YY_col') - Seventh context + + ... + + When considering the PauliString ZZ, a compatible pair of ZZ measurements happens between + (first, second) contexts and between (fourth, fifth) contexts, since ZZ commutes with + every element of each tuple of contexts. Nonetheless, pair of ZZ measurements that + happens between (second, fourth) contexts and (fifth, seventh) is an incompatible pair of + measurements of ZZ. For example, since the third context does not contain a ZZ + measurement and no element in that context commutes with ZZ, measuring that third + context will scramble the measurement outcome of ZZ once we measure this PauliString in the + fourth context. Same happens when measuring the sixth context. + + In theory, for compatible measurements of a given PauliString, we expect the measurement + outcomes to have a probability of coincidence P=1, while for incompatible measurements + we expect the probability of coincidence P=0.5 (since there are measurements in between + that do not commute with such Pauli, so its outcome gets scrambled). + + Args: + circuit: a cirq.FrozenCircuit. Need to pass a circuit without using the jump table + (flatten=True). + + Returns: + - pauli_to_compatible_coincidences: dict + - pauli_to_incompatible_coincidences: dict + - pauli_to_compatible_contexts: dict + + Explanations: + - pauli_to_compatible_coincidences and pauli_to_incompatible_coincidences: + Two dictionaries. They map each PauliString in the MagicSquare to a list + of 0's and 1's for compatible and incompatible pairs of measurements, respectively. + 1: indicates coincidence between the pair of measurements. + 0: indicates no coincidence between the pair of measurements. + + Given the dna of PauliStrings presented above, ZZ was measured 5 times in the + circuit. The pairs contexts that contain ZZ are (1st, 2nd), (2nd, 4th), (4th, 5th), + and (5th, 7th). Say that the coincidences of each pair gives the array + + coincidences_array = [1, 0, 1, 1]. + + We know that, in this circuit dna, the odd elements correspond to compatible + contexts [(1st, 2nd), (4th, 5th)], while the even elements correspond to + incompatible contexts [(2nd, 4th), (5th, 7th)]. Then the values for the ZZ key in + each dictionary would be: + + pauli_to_compatible_coincidences[ZZ] = [1, 1] + pauli_to_incompatible_coincidences[ZZ] = [0, 1] + + - pauli_to_compatible_contexts: + This dictionary is useful for debugging purposes. + Each PauliString can show up in a circuit as part of a context that + corresponds to either a row or a column. For the compatible pairs of + measurements in a Kochen Specker circuit (and only the compatible pairs), + this dictionary is a mapping between each PauliString in the MagicSquare to a + list of tuples containing whether each tuples of previous context, where each + tuple contains the axis context corresponds to a row or column context. For the + example we are analysing, this would give: + + pauli_to_compatible_contexts[ZZ] = [('row', 'col'), ('col', 'col')] + """ + circuit = self.expt.circuit + repetitions = self.result.repetitions + dna = get_paulis_dna_for_circuit(circuit) + paulis_set = set(dna) # length = 18 + paulis = [pauli for row in self.expt.square.paulis for pauli in row] + + # Map each pauli_context, e.g. 1Z_row, to a list of N iterators. N = repetitions and + # each nested list has length equal to the number of times such 1Z_row shows up in the + # circuit. + pauli_context_outcomes = { + pauli_context: [ + iter(self.squeezed_records()[pauli_context][repetition].flatten()) + for repetition in range(repetitions) + ] + for pauli_context in paulis_set + } + + # Build a dictionary pauli_outcomes that maps pauliStrings (e.g. ZZ) to a 2D list of + # lists of shape(repetitions, num_measurements). The elements of the nested lists have the + # outcomes (0's and 1's) of such pauli in the order they were measured and recorded in + # the circuit. + + # Create this map once before the pauli_outcomes loop + pauli_to_dna_indices = collections.defaultdict(list) + for i, op in enumerate(dna): + pauli = op[:-4] # e.g., 'ZZ_row' -> 'ZZ' + pauli_to_dna_indices[pauli].append(i) + + pauli_outcomes: dict[str, list[list[int]]] = { + pauli: [ + [ + next(pauli_context_outcomes[dna[i]][rep]) + for i in pauli_to_dna_indices[pauli] + ] + for rep in range(repetitions) + ] + for pauli in paulis + } + + pauli_average_coincidences = {} + for pauli, outcome_lists in pauli_outcomes.items(): + # Convert to a 2D NumPy array of shape (repetitions, num_measurements) + outcomes_array = np.array(outcome_lists, dtype=np.int8) + + # Vectorized comparison: compare each column with the next one + # This creates a boolean array of shape (repetitions, num_measurements - 1) + coincidences_array = outcomes_array[:, :-1] == outcomes_array[:, 1:] + + # Calculate the mean directly along the repetitions axis (axis=0) + pauli_average_coincidences[pauli] = np.mean(coincidences_array, axis=0) + + # pauli_compatible_mask maps paulis to a list of True's and False's. + # True in the indices when such pauli was also measured in the prev dna triplet (compatible) + # False otherwise (incompatible). + + # We expect the True's indices to correlate with the indices in pauli_coincidences with + # probability == 1 + # We expect the False's indices to correlate with the indices in pauli_coincidences with + # probability == 0.5 + pauli_compatible_mask: dict[str, list[bool]] = {} + + # Maps paulis to a list of tuples of contexts, e.g. ("row", "row") or ("col", + # "row). There are 4 different combinations of tuples of contexts. Will be useful to debug + # coincidence errors within compatible measurements. + pauli_to_compatible_contexts: dict[str, list[tuple[str, str]]] = {} + + prev_dna_triplet: list[str] = ["", "", ""] + prev_context_axis: str = "" + contexts_are_compatible = False + for ii in range(int(len(dna) / 3)): + current_dna_triplet = dna[(ii * 3) : (ii * 3 + 3)] + current_context_axis = dna[(ii * 3)][-3:] # e.g. "row" or "col" + for dna_elem in current_dna_triplet: + pauli = dna_elem[:-4] + # when it first appears, just create the lists + if pauli not in pauli_compatible_mask.keys(): + pauli_compatible_mask[pauli] = [] + pauli_to_compatible_contexts[pauli] = [] + else: # if list already exist, compare with previous_triplet + contexts_are_compatible = any( + pauli in prev_pauli for prev_pauli in prev_dna_triplet + ) + pauli_compatible_mask[pauli].append(contexts_are_compatible) + if contexts_are_compatible: + pauli_to_compatible_contexts[pauli].append( + ( + prev_context_axis, + current_context_axis, + ) + ) + prev_dna_triplet = dna[(ii * 3) : (ii * 3 + 3)] + prev_context_axis = dna[(ii * 3)][-3:] + + # Build compatible coincidences and incompatible coincidences dictionaries + pauli_compatible = { + pauli: pauli_average_coincidences[pauli][pauli_compatible_mask[pauli]] + for pauli in paulis + } + + pauli_incompatible = { + pauli: pauli_average_coincidences[pauli][ + [not elem for elem in pauli_compatible_mask[pauli]] + ] + for pauli in paulis + } + + return pauli_compatible, pauli_incompatible, pauli_to_compatible_contexts + + def get_pauli_coincidences_per_context( + self, + ) -> dict[str, dict[tuple[str, str], list[float]]]: + """Nested mapping, Paulis to compatible measurement context pairs to coincidence probs. + + See the docstrings of the third dictionary returned by self.estimate_coincidences for + more info on the meaning of "compatible measurement context pairs". An example of an + entry of the dictionary that self.get_pauli_coincidences_per_context returns is: + + {ZZ: {(row, row): 0.98, (row, col): 0.97, (col, row): 0.93, (col, col): 0.95}. + + Useful for debugging purposes. + """ + pauli_compatible_per_context: dict[str, dict[tuple[str, str], list[float]]] = {} + pauli_compatible, _, pauli_to_compatible_contexts = self.estimate_coincidences() + + for pauli in pauli_compatible: + for probability, context_tuple in zip( + pauli_compatible[pauli], pauli_to_compatible_contexts[pauli] + ): + pauli_compatible_per_context.setdefault(pauli, {}).setdefault( + context_tuple, [] + ).append(probability) + + return pauli_compatible_per_context + + def plot_coincidences_timeline(self, include_legend: bool = True) -> None: + """Plot coincidences coincidences over time. + + The coincidences over time show the evolution of the coincidence probability over time, + with time ticks given by the position of each Pauli along the circuit. + """ + pauli_compatible, pauli_incompatible, _ = self.estimate_coincidences() + paulis = [pauli for row in self.expt.square.paulis for pauli in row] + + # Plot coincidence probabilities over time + f, axs = plt.subplots(1, 2) + + for pauli in paulis: + axs[0].plot(pauli_compatible[pauli], label=pauli, marker="o") + axs[1].plot(pauli_incompatible[pauli], label=pauli, marker="o") + + for ax in axs: + ax.set_ylim(0, 1.1) + if include_legend: + ax.legend() + ax.set_xlabel("Pauli iteration within circuit") + ax.set_ylabel("Probability of coincidence") + + axs[0].set_title("Compatible contexts") + axs[1].set_title("Incompatible contexts") + f.tight_layout() + + def plot_coincidences_histograms_per_pauli(self) -> None: + """Plots histograms for each Pauli measurement in pauli_compatible dict.""" + pauli_compatible, _, _ = self.estimate_coincidences() + + keys = list(pauli_compatible.keys()) + num_keys = len(keys) + + # Calculate number of rows and columns for a roughly square grid + num_cols = 3 + num_rows = int(np.ceil(num_keys / num_cols)) + + fig, axes = plt.subplots( + num_rows, num_cols, figsize=(num_cols * 5, num_rows * 4) + ) + axes = axes.flatten() # Flatten the 2D array of axes for easy iteration + + for i, key in enumerate(keys): + ax = axes[i] + data = pauli_compatible[key] + + ax.hist(data, bins=20, edgecolor="black", alpha=0.7) + ax.set_title(f"Pauli: {key}") + ax.set_xlabel("P_agree") + ax.set_ylabel("Counts") + ax.grid(axis="y", alpha=0.75) + + # Hide any unused subplots + for j in range(num_keys, len(axes)): + fig.delaxes(axes[j]) + + plt.tight_layout() + + def plot_coincidences_histogram_combined_paulis( + self, ax: plt.Axes | None = None + ) -> None: + """Plots a single histogram combining all Pauli measurements from pauli_compatible.""" + pauli_compatible, _, _ = self.estimate_coincidences() + all_data = np.concatenate(list(pauli_compatible.values())) + if ax is None: + f, ax = plt.subplots(figsize=(10, 6)) + ax.hist(all_data, bins=20, edgecolor="black", alpha=0.7) + ax.set_title("Combined Pauli Coincidence Rates") + ax.set_xlabel("P_agree") + ax.set_ylabel("Counts") + ax.grid(axis="y", alpha=0.75) + + +def get_paulis_dna_for_circuit(circuit: cirq.FrozenCircuit) -> list[str]: + """Extract the string of all Paulis measured in the circuit in the order they appear. + + Returns: + paulis_dna: list of strings (Paulis). + For example: + dna = [ + "1Z_row", "Z1_row", "ZZ_row", + "X1_row", "1X_row", "XX_row", + ... + "ZZ_col", "XX_col", "YY_col" + ] + """ + + def meas_keys(circuit: cirq.FrozenCircuit) -> Iterator[str]: + """Yields all non-circuit operations from a possibly nested circuit. + + Recurses into subcircuits to yield their non-circuit operations. + """ + for moment in circuit.moments: + if len(moment) > 1 and any( + isinstance(op.untagged, cirq.CircuitOperation) for op in moment + ): + raise ValueError( + f"CircuitOperation must be the only operation in a moment: {moment}" + ) + for op in moment: + if isinstance(circuit_op := op.untagged, cirq.CircuitOperation): + if (reps := circuit_op.repetitions) == 1: + yield from meas_keys(circuit_op.circuit) + else: + yield from list(meas_keys(circuit_op.circuit)) * reps + elif cirq.is_measurement(op): + yield cirq.measurement_key_name(op) + + return list(meas_keys(circuit)) + + +def get_rot_gate_from_pauli(pauli: cirq.Gate | str) -> tuple[cirq.Gate, cirq.Gate]: + """Get the change of basis rotation and its inverse to measure a given Pauli. If the Pauli + doesn't require a basis rotation we add a WaitLike gate with the duration of cirq.rx to + guarantee that each context has the same length. + + Args: + pauli: Which Pauli we want to measure + + Returns: + The pair of gates to perform the rotation and its inverse. + + Raises: + ValueError: If Pauli is not X, Y, or Z. + """ + if pauli == "X" or pauli == cirq.X: + return cirq.ry(-np.pi / 2), cirq.ry(np.pi / 2) + elif pauli == "-X": + return cirq.ry(np.pi / 2), cirq.ry(-np.pi / 2) + elif pauli == "Y" or pauli == cirq.Y: + return cirq.rx(np.pi / 2), cirq.rx(-np.pi / 2) + elif pauli == "-Y": + return cirq.rx(-np.pi / 2), cirq.rx(np.pi / 2) + elif pauli == "Z" or pauli == cirq.Z: + return cirq.I, cirq.I + elif pauli == "-Z": + return cirq.X, cirq.X + elif pauli == "1": + return cirq.I, cirq.I + else: + raise ValueError(f"pauli must be X, Y, or Z: {pauli}") + + +def _generate_qnd_pauli_circuit( + q0: cirq.Qid, + q1: cirq.Qid, + anc: cirq.Qid, + paulis: tuple[str, str], + measure_key: str | None, + virtual_z_half_turns: float = 0.0, +) -> cirq.FrozenCircuit: + """Make a circuit to perform QND measurement of the specified Pauli. + + If len(pauli)==1, we do not map the qubit to an ancilla, and just perform a QND measurement on + the qubit itself. + + Args: + q0: Data qubit 0 + q1: Data qubit 1 + anc: Ancilla qubit + paulis: The Paulis to measure on each qubit. + measure_key: A key to associate to the measurement. + virtual_z_half_turns: This applied only for Paulis mapped to an ancilla. If not zero, + amount of half turns around the Z axis to apply on the ancilla during the mapping of + the paulis from the data qubits (q0, q1). This virtualZ is applied twice, once before + each Hadamard during the mapping of the Pauli to the ancilla. + + Returns: + A circuit to measure the indicated Pauli operator + """ + if "1" in paulis: + if paulis[0] == "1": + pauli, qubit, other = paulis[1], q1, q0 + else: + pauli, qubit, other = paulis[0], q0, q1 + pre, post = get_rot_gate_from_pauli(pauli) + circuit = cirq.FrozenCircuit.from_moments( + [pre(qubit), cirq.I(other)], + cirq.M(qubit, key=measure_key), + post(qubit), + ) + if _CIRCUIT_DD_SUPPORTED: + circuit = _ADD_DD_FUNC(circuit) + else: + warnings.warn( + "Skipping Dynamical Decoupling: This feature requires Cirq 1.4 or later. " + "Please upgrade your Cirq version to enable noise mitigation.", + UserWarning, + ) + return circuit + return _generate_qnd_pauli_circuit_with_ancillas( + q0, q1, anc, paulis, measure_key, virtual_z_half_turns=virtual_z_half_turns + ) + + +def _generate_qnd_pauli_circuit_with_ancillas( + q0: cirq.Qid, + q1: cirq.Qid, + anc: cirq.Qid, + paulis: tuple[str, str], + measure_key: str | None, + virtual_z_half_turns: float = 0.0, +) -> cirq.FrozenCircuit: + """Make a circuit to perform a QND measurement of the specified Pauli. + + This method is different from generate_qnd_pauli_circuit since we always map the observables to + an ancilla, regardless if the observable is a 1-body or a 2-body Pauli. + + Args: + q0: Data qubit 0 + q1: Data qubit 1 + anc: Ancilla qubit + paulis: The Paulis to measure on each qubit. + measure_key: A key to associate to the measurement. + virtual_z_half_turns: This applied only for Paulis mapped to an ancilla. If not zero, + amount of half turns around the Z axis to apply on the ancilla during the mapping of + the paulis from the data qubits (q0, q1). This virtualZ is applied twice, once before + each Hadamard during the mapping of the Pauli to the ancilla. + """ + # Get pre- and post- rotations for the pauli to measure. + pre0, post0 = (rot(q0) for rot in get_rot_gate_from_pauli(paulis[0])) + pre1, post1 = (rot(q1) for rot in get_rot_gate_from_pauli(paulis[1])) + return cirq.FrozenCircuit.from_moments( + [pre0, pre1, cirq.H(anc)], + *(cirq.CZ(q, anc) for q, pauli in zip((q0, q1), paulis) if pauli != "1"), + cirq.Z(anc) ** virtual_z_half_turns, + [post0, post1, cirq.H(anc)], + *measure_reset_circuit(q0, q1, anc, measure_key), + ) + + +def measure_reset_circuit( + q0: cirq.Qid, q1: cirq.Qid, anc: cirq.Qid, measure_key: str | None +) -> cirq.FrozenCircuit: + m = cirq.M(anc, key=measure_key) + circuit = cirq.FrozenCircuit.from_moments( + [ + m, + cirq.I(q0), + cirq.I(q1), + ], + cirq.R(anc), + ) + if _CIRCUIT_DD_SUPPORTED: + circuit = _ADD_DD_FUNC(circuit) + else: + warnings.warn( + "Skipping Dynamical Decoupling: This feature requires Cirq 1.4 or later. " + "Please upgrade your Cirq version to enable noise mitigation.", + UserWarning, + ) + return circuit + + +def generate_qnd_pauli_circuit( + q0: cirq.Qid, + q1: cirq.Qid, + anc: cirq.Qid, + pauli: str, + measure_key: str | None, + qnd_strategy: QndStrategy = "map_all_paulis_to_ancilla", + virtual_z_half_turns: float = 0.0, +) -> cirq.FrozenCircuit: + """Generate a circuit to measure the one-qubit or two-qubit pauli (PauliString). + + Args: + q0: First qubit. + q1: Second qubit. + anc: Measure/ancilla qubit. + pauli: PauliString, can be one-qubit or two-qubit Pauli. + measure_key: Key for the cirq.M gate. + qnd_strategy: Whether to map all PauliStrings to the ancilla or only map two-qubit Paulis. + virtual_z_half_turns: This applied only for Paulis mapped to an ancilla. If not zero, + amount of half turns around the Z axis to apply on the ancilla during the mapping of + the paulis from the data qubits (q0, q1). This virtualZ is applied twice, once before + each Hadamard during the mapping of the Pauli to the ancilla. + """ + # Split into signed single-qubit paulis. Include the sign on an X or Y term if possible. + unsigned = pauli.removeprefix("-") + p0, p1 = unsigned[0], unsigned[1] + if pauli.startswith("-"): + if p0 in ("X", "Y"): + p0 = f"-{p0}" + elif p1 in ("X", "Y"): + p1 = f"-{p1}" + elif p0 == "Z": + p0 = "-Z" + elif p1 == "Z": + p1 = "-Z" + else: + raise ValueError(f"invalid pauli string: {pauli}") + if qnd_strategy == "map_2q_paulis_to_ancilla": + return _generate_qnd_pauli_circuit( + q0, + q1, + anc, + (p0, p1), + measure_key, + virtual_z_half_turns=virtual_z_half_turns, + ) + elif qnd_strategy == "map_all_paulis_to_ancilla": + return _generate_qnd_pauli_circuit_with_ancillas( + q0, + q1, + anc, + (p0, p1), + measure_key, + virtual_z_half_turns=virtual_z_half_turns, + ) + else: + raise ValueError( + 'qnd_strategy can only be "map_all_paulis_to_ancilla" or "map_2q_paulis_to_ancilla"' + f" got {qnd_strategy=}." + ) + + +def estimate_parities( + records: Mapping[str, np.ndarray], + contexts: Mapping[tuple[Axis, int], Sequence[PauliKey]], +) -> dict[tuple[Axis, int], tuple[float, float]]: + """Estimate average parity and error of each of the 6 contexts. + + Useful to get more nuanced information on each row/col. We first evaluate the parity for + each time every context has appeared in the circuit. We then average the parities for + each context and estimate the error of this average. For references on the error estimation, + check the error of a binomial distribution. + + Args: + records: The value for each key is a 2-D array of booleans, with the first index running + over circuit repetitions, the second index running over instances of the measurement + key in the circuit. + contexts: # TODO + + Returns: + context_averages: A dictionary mapping a context (axis, idx) to a tuple + (mean_parity, parity_error) + """ + context_parities = {} + for context, pauli_keys in contexts.items(): + # Sum all outcome arrays for this context first. Summing the arrays first and doing + # one exponentiation is much faster than doing multiple exponentiations and then a + # product. + outcome_sum = np.sum( + [ + records[pauli_key.key][:, pauli_key.slice].astype(int) + for pauli_key in pauli_keys + ], + axis=0, + ) + # Now calculate parity in one vectorized operation + context_parities[context] = (-1) ** outcome_sum + + context_averages = {} + for context, parities in context_parities.items(): + mean_parity = np.mean(parities) + + # Arithmetically identical to np.mean([(parity + 1) / 2 and much faster + p_mean = abs((mean_parity + 1) / 2) + + parity_error = 2 * np.sqrt(p_mean * (1 - p_mean) / parities.size) + + context_averages[context] = (float(mean_parity), float(parity_error)) + + return context_averages + + +def estimate_chi( + records: Mapping[str, np.ndarray], + contexts: Mapping[tuple[Axis, int], Sequence[PauliKey]], +) -> tuple[float, float]: + """Estimate chi out of the 6 parity averages. + + This method takes the output of the Kochen-Specker experiment and returns an estimate of chi. + We assume that the length of the circuit ran was long enough to contain at least one instance of + each of the six 3-observable parities, e.g. (1X X1 XX). + """ + context_averages = estimate_parities(records, contexts) + + chi_mean = sum( + avg * (1 if axis == "row" else -1) + for (axis, _idx), (avg, _unc) in context_averages.items() + ) + chi_error = float(np.sqrt(sum(unc**2 for _avg, unc in context_averages.values()))) + + return chi_mean, chi_error + + +# Utils +def get_parities_from_single_context_experiments( + context_to_kochen_specker_result: dict[tuple[Axis, int], MagicSquareResult], +) -> dict[tuple[Axis, int], tuple[float, float]]: + """Get the parities (with uncertainties) for each context. + + Assumes that we ran 6 single-context Kochen Specker experiments. Returns a mapping from + context to their parities (with uncertainties). + + Args: + context_to_kochen_specker_result: maps context to a MagicSquareResult instance. Object is + usually returned by "run_kochen_specker_per_context". + """ + axes: list[Axis] = ["col", "row"] + indices = [0, 1, 2] + all_contexts: list[tuple[Axis, int]] = [ + (axis, idx) for axis in axes for idx in indices + ] + context_to_parities = {} + context_to_parities.update( + { + context: context_to_kochen_specker_result[context].estimate_parities()[ + context + ] + for context in all_contexts + } + ) + return context_to_parities + + +def get_chi_from_single_context_experiment( + context_to_parities: dict[tuple[Axis, int], tuple[float, float]], +) -> tuple[float, float]: + """Estimate chi (with uncertainty) from a context_to_parities mapping.""" + chi_mean = sum( + avg * (1 if axis == "row" else -1) + for (axis, _idx), (avg, _unc) in context_to_parities.items() + ) + chi_error = np.sqrt(sum(unc**2 for _avg, unc in context_to_parities.values())) + + return chi_mean, chi_error From e4a1a61ccc325b6ac724249357a751f09791e1da Mon Sep 17 00:00:00 2001 From: Alejo Gradau Date: Wed, 17 Dec 2025 23:46:00 +0000 Subject: [PATCH 2/3] fix Literal types --- recirq/contextuality/kochen_specker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recirq/contextuality/kochen_specker.py b/recirq/contextuality/kochen_specker.py index a2d0a409..e93601b0 100644 --- a/recirq/contextuality/kochen_specker.py +++ b/recirq/contextuality/kochen_specker.py @@ -24,9 +24,9 @@ import matplotlib.pyplot as plt import numpy as np -type Axis = Literal["row", "col"] -type QndStrategy = Literal["map_all_paulis_to_ancilla", "map_2q_paulis_to_ancilla"] -type CircuitStyle = Literal["flat", "jt", "jt_pauli", "jt_pauli_pivot"] +Axis = Literal["row", "col"] +QndStrategy = Literal["map_all_paulis_to_ancilla", "map_2q_paulis_to_ancilla"] +CircuitStyle = Literal["flat", "jt", "jt_pauli", "jt_pauli_pivot"] _CIRCUIT_DD_SUPPORTED = hasattr(cirq, "add_dynamical_decoupling") From 43220c91ed6ae57ce610b288d24a43dfb27e864b Mon Sep 17 00:00:00 2001 From: Alejo Gradau Date: Mon, 22 Dec 2025 16:56:02 +0000 Subject: [PATCH 3/3] Add docstrings --- recirq/contextuality/kochen_specker.py | 66 ++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/recirq/contextuality/kochen_specker.py b/recirq/contextuality/kochen_specker.py index e93601b0..a7270c5f 100644 --- a/recirq/contextuality/kochen_specker.py +++ b/recirq/contextuality/kochen_specker.py @@ -39,7 +39,18 @@ @attrs.frozen class MagicSquare: - """A Mermin-Peres magic square.""" + """Represents a MagicSquare. + + A MagicSquare is a 3x3 grid of two-qubit Pauli operators used to demonstrate quantum + contextuality, specifically the Kochen-Specker theorem. + + The operators in each row and each column mutually commute. + The product of operators in each row is +I. + The product of operators in each column is -I. + + Attributes: + paulis: A 3x3 sequence of strings representing the Pauli operators in the square. + """ paulis: Sequence[Sequence[str]] @@ -102,9 +113,14 @@ def kochen_specker_expt( qnd_strategy: QndStrategy = "map_all_paulis_to_ancilla", virtual_z_half_turns: float = 0.0, ) -> MagicSquareExperiment: - """Builds a Kochen-Specker experiment. + """Builds a Kochen-Specker experiment, which measures state-independent contextuality. + + See https://arxiv.org/pdf/2512.02284 for more details. Args: + q0: First data qubit. + q1: Second data qubit. + anc: Ancilla qubit. num_contexts: Number of rows+cols measured out of the 6 possible row/col. seed: Seed for random number generator. Specifying this will give @@ -114,6 +130,9 @@ def kochen_specker_expt( amount of half turns around the Z axis to apply on the ancilla during the mapping of the paulis from the data qubits (q0, q1). This virtualZ is applied before the last Hadamard during the mapping of the Pauli to the ancilla. + + Returns: + A `MagicSquareExperiment` object containing the constructed circuit and context info. """ axes: list[Axis] = ["col", "row"] @@ -178,23 +197,64 @@ def _copy_circuit(circuit: cirq.FrozenCircuit) -> cirq.FrozenCircuit: @attrs.frozen class MagicSquareExperiment: + """An experiment to test for contextuality using a magic square. + + Attributes: + square: The `MagicSquare` instance used for the experiment. + circuit: The `cirq.FrozenCircuit` that will be run. This circuit + is composed of measurements for a sequence of randomly chosen + contexts (rows or columns of the magic square). + contexts: A dictionary mapping each possible context (row or column) to + the sequence of `PauliKey`s that identify the measurements for that + context. + """ + square: MagicSquare circuit: cirq.FrozenCircuit contexts: dict[tuple[Axis, int], Sequence[PauliKey]] def run(self, sampler: cirq.Sampler, repetitions: int = 1000) -> MagicSquareResult: + """Runs the experiment circuit on a sampler. + + Args: + sampler: The `cirq.Sampler` to run the experiment on. + repetitions: The number of times to repeat the circuit. + + Returns: + A `MagicSquareResult` object containing the results of the experiment. + """ result = sampler.run(self.circuit, repetitions=repetitions) return MagicSquareResult(self, result) @attrs.frozen class PauliKey: + """A key to identify a specific Pauli measurement in a circuit. + + Attributes: + key: The measurement key string for the Pauli measurement. + slice: A slice or list of indices to select specific measurement instances + if the key appears multiple times in a circuit for a given context. + Defaults to `slice(None)` to select all instances. + """ + key: str slice: list[int] | slice = slice(None) @attrs.frozen(slots=False) class MagicSquareResult: + """The results of a `MagicSquareExperiment`. + + This class holds the `cirq.Result` from running the experiment and provides + methods to analyze the results, such as estimating the Kochen-Specker value (chi) + and other statistical properties. + + Attributes: + expt: The `MagicSquareExperiment` that was run. + result: The `cirq.Result` object from the sampler. + """ + expt: MagicSquareExperiment result: cirq.Result @@ -790,7 +850,7 @@ def estimate_parities( records: The value for each key is a 2-D array of booleans, with the first index running over circuit repetitions, the second index running over instances of the measurement key in the circuit. - contexts: # TODO + contexts: Mapping contexts to their corresponding pauli keys. Returns: context_averages: A dictionary mapping a context (axis, idx) to a tuple