mirror of
https://github.com/blackboxprogramming/quantum-math-lab.git
synced 2026-03-17 08:57:24 -05:00
Fix lint, most_likely() tie-breaking, add 8 tests (36 total, 99% coverage)
Co-authored-by: blackboxprogramming <118287761+blackboxprogramming@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,7 @@ probability distribution for a subset of qubits.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
|
from typing import Dict, Mapping, Optional, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -53,7 +53,8 @@ class MeasurementResult:
|
|||||||
returning the lexicographically smallest string.
|
returning the lexicographically smallest string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return max(self.counts.items(), key=lambda item: (item[1], item[0]))[0]
|
best_count = max(self.counts.values())
|
||||||
|
return min(k for k, v in self.counts.items() if v == best_count)
|
||||||
|
|
||||||
def total_shots(self) -> int:
|
def total_shots(self) -> int:
|
||||||
"""Return the total number of measurement shots."""
|
"""Return the total number of measurement shots."""
|
||||||
|
|||||||
@@ -133,10 +133,10 @@ class TestCustomUnitary:
|
|||||||
assert probs["1"] == pytest.approx(1.0, abs=1e-8)
|
assert probs["1"] == pytest.approx(1.0, abs=1e-8)
|
||||||
|
|
||||||
def test_identity_gate(self):
|
def test_identity_gate(self):
|
||||||
I = np.eye(2, dtype=np.complex128)
|
identity = np.eye(2, dtype=np.complex128)
|
||||||
circuit = QuantumCircuit(1)
|
circuit = QuantumCircuit(1)
|
||||||
circuit.hadamard(0)
|
circuit.hadamard(0)
|
||||||
circuit.apply_custom(I, [0])
|
circuit.apply_custom(identity, [0])
|
||||||
probs = circuit.probabilities()
|
probs = circuit.probabilities()
|
||||||
assert probs["0"] == pytest.approx(0.5, abs=1e-8)
|
assert probs["0"] == pytest.approx(0.5, abs=1e-8)
|
||||||
|
|
||||||
@@ -247,3 +247,86 @@ class TestProbabilityNormalization:
|
|||||||
circuit.hadamard(1)
|
circuit.hadamard(1)
|
||||||
probs = circuit.probabilities(qubits=[0])
|
probs = circuit.probabilities(qubits=[0])
|
||||||
assert sum(probs.values()) == pytest.approx(1.0)
|
assert sum(probs.values()) == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Additional coverage ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestEmptyQubits:
|
||||||
|
def test_probabilities_empty_list_returns_full_distribution(self):
|
||||||
|
"""probabilities(qubits=[]) takes the early-return path and matches full dist."""
|
||||||
|
circuit = QuantumCircuit(2)
|
||||||
|
circuit.hadamard(0)
|
||||||
|
probs_empty = circuit.probabilities(qubits=[])
|
||||||
|
probs_full = circuit.probabilities()
|
||||||
|
for key in probs_full:
|
||||||
|
assert probs_empty[key] == pytest.approx(probs_full[key], abs=1e-8)
|
||||||
|
|
||||||
|
def test_measure_empty_qubits_collapses_to_single_basis_state(self):
|
||||||
|
"""measure(qubits=[]) collapses the full state; _collapse_state empty path."""
|
||||||
|
rng = np.random.default_rng(seed=0)
|
||||||
|
circuit = QuantumCircuit(2)
|
||||||
|
circuit.hadamard(0)
|
||||||
|
circuit.cnot(0, 1)
|
||||||
|
result = circuit.measure(qubits=[], shots=1, rng=rng)
|
||||||
|
assert result.total_shots() == 1
|
||||||
|
probs = circuit.probabilities()
|
||||||
|
nonzero = [v for v in probs.values() if v > 1e-9]
|
||||||
|
assert len(nonzero) == 1
|
||||||
|
assert nonzero[0] == pytest.approx(1.0, abs=1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNonSquareUnitary:
|
||||||
|
def test_non_square_matrix_raises(self):
|
||||||
|
"""A non-square matrix must raise ValueError (covers the ndim/shape check)."""
|
||||||
|
circuit = QuantumCircuit(1)
|
||||||
|
with pytest.raises(ValueError, match="square"):
|
||||||
|
circuit.apply_custom(np.zeros((2, 3), dtype=complex), [0])
|
||||||
|
|
||||||
|
|
||||||
|
class TestMostLikelyTieBreaking:
|
||||||
|
def test_tie_broken_by_lexicographically_smallest(self):
|
||||||
|
"""Ties in counts must resolve to the lexicographically smallest string."""
|
||||||
|
result = MeasurementResult(counts={"01": 5, "10": 5})
|
||||||
|
assert result.most_likely() == "01"
|
||||||
|
|
||||||
|
def test_tie_three_way(self):
|
||||||
|
result = MeasurementResult(counts={"11": 3, "00": 3, "01": 3})
|
||||||
|
assert result.most_likely() == "00"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdditionalCircuits:
|
||||||
|
def test_pauli_y_gate(self):
|
||||||
|
"""Pauli-Y: Y|0⟩ = i|1⟩, so probability of measuring |1⟩ is 1."""
|
||||||
|
Y = np.array([[0, -1j], [1j, 0]], dtype=np.complex128)
|
||||||
|
circuit = QuantumCircuit(1)
|
||||||
|
circuit.apply_custom(Y, [0])
|
||||||
|
probs = circuit.probabilities()
|
||||||
|
assert probs["0"] == pytest.approx(0.0, abs=1e-8)
|
||||||
|
assert probs["1"] == pytest.approx(1.0, abs=1e-8)
|
||||||
|
|
||||||
|
def test_five_qubit_ghz_state(self):
|
||||||
|
"""5-qubit GHZ state: equal superposition of |00000⟩ and |11111⟩."""
|
||||||
|
circuit = QuantumCircuit(5)
|
||||||
|
circuit.hadamard(0)
|
||||||
|
for i in range(1, 5):
|
||||||
|
circuit.cnot(0, i)
|
||||||
|
probs = circuit.probabilities()
|
||||||
|
assert probs["00000"] == pytest.approx(0.5, abs=1e-8)
|
||||||
|
assert probs["11111"] == pytest.approx(0.5, abs=1e-8)
|
||||||
|
for key, val in probs.items():
|
||||||
|
if key not in ("00000", "11111"):
|
||||||
|
assert val == pytest.approx(0.0, abs=1e-8)
|
||||||
|
|
||||||
|
def test_bell_state_qubit1_collapses_qubit0(self):
|
||||||
|
"""Measuring qubit 0 of a Bell state collapses qubit 1 to the same value."""
|
||||||
|
circuit = QuantumCircuit(2)
|
||||||
|
circuit.hadamard(0)
|
||||||
|
circuit.cnot(0, 1)
|
||||||
|
rng = np.random.default_rng(seed=999)
|
||||||
|
result = circuit.measure(qubits=[0], shots=1, rng=rng)
|
||||||
|
outcome = result.most_likely()
|
||||||
|
probs = circuit.probabilities()
|
||||||
|
if outcome == "0":
|
||||||
|
assert probs["00"] == pytest.approx(1.0, abs=1e-8)
|
||||||
|
else:
|
||||||
|
assert probs["11"] == pytest.approx(1.0, abs=1e-8)
|
||||||
|
|||||||
Reference in New Issue
Block a user