Skip to content

Commit

Permalink
Implement Index<NodeIndex> for DAGCircuit (Qiskit#13683)
Browse files Browse the repository at this point in the history
This removes the syntax noise of the `dag.dag()` calls when indexing by
`NodeIndex`.  As it happens, this is _almost_ all of the reason we even
use the underlying graph object in `accelerate`. The only exceptions are
some needless defensive programming in
`RemoveDiagonalGatesBeforeMeasure` (which really is the same thing
underneath anyway), and in the graphviz utilities, which is legitimate.
  • Loading branch information
jakelishman authored Jan 17, 2025
1 parent 3150351 commit 524fb47
Show file tree
Hide file tree
Showing 15 changed files with 47 additions and 41 deletions.
4 changes: 2 additions & 2 deletions crates/accelerate/src/barrier_before_final_measurement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn barrier_before_final_measurements(
}
dag.bfs_successors(node)
.all(|(_, child_successors)| {
child_successors.iter().all(|suc| match dag.dag()[*suc] {
child_successors.iter().all(|suc| match dag[*suc] {
NodeType::Operation(ref suc_inst) => is_exactly_final(suc_inst),
_ => true,
})
Expand All @@ -57,7 +57,7 @@ pub fn barrier_before_final_measurements(
let final_packed_ops: Vec<PackedInstruction> = ordered_node_indices
.into_iter()
.map(|node| {
let NodeType::Operation(ref inst) = dag.dag()[node] else {
let NodeType::Operation(ref inst) = dag[node] else {
unreachable!()
};
let res = inst.clone();
Expand Down
6 changes: 3 additions & 3 deletions crates/accelerate/src/basis/basis_translator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ fn apply_translation(
let mut is_updated = false;
let mut out_dag = dag.copy_empty_like(py, "alike")?;
for node in dag.topological_op_nodes()? {
let node_obj = dag.dag()[node].unwrap_operation();
let node_obj = dag[node].unwrap_operation();
let node_qarg = dag.get_qargs(node_obj.qubits);
let node_carg = dag.get_cargs(node_obj.clbits);
let qubit_set: HashSet<Qubit> = HashSet::from_iter(node_qarg.iter().copied());
Expand Down Expand Up @@ -606,7 +606,7 @@ fn replace_node(
}
if node.params_view().is_empty() {
for inner_index in target_dag.topological_op_nodes()? {
let inner_node = &target_dag.dag()[inner_index].unwrap_operation();
let inner_node = &target_dag[inner_index].unwrap_operation();
let old_qargs = dag.get_qargs(node.qubits);
let old_cargs = dag.get_cargs(node.clbits);
let new_qubits: Vec<Qubit> = target_dag
Expand Down Expand Up @@ -667,7 +667,7 @@ fn replace_node(
.zip(node.params_view())
.into_py_dict_bound(py);
for inner_index in target_dag.topological_op_nodes()? {
let inner_node = &target_dag.dag()[inner_index].unwrap_operation();
let inner_node = &target_dag[inner_index].unwrap_operation();
let old_qargs = dag.get_qargs(node.qubits);
let old_cargs = dag.get_cargs(node.clbits);
let new_qubits: Vec<Qubit> = target_dag
Expand Down
2 changes: 1 addition & 1 deletion crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub(crate) fn analyze_commutations_inner(
// if the node is an input/output node, they do not commute, so we only
// continue if the nodes are operation nodes
if let (NodeType::Operation(packed_inst0), NodeType::Operation(packed_inst1)) =
(&dag.dag()[current_gate_idx], &dag.dag()[*prev_gate_idx])
(&dag[current_gate_idx], &dag[*prev_gate_idx])
{
let op1 = packed_inst0.op.view();
let op2 = packed_inst1.op.view();
Expand Down
6 changes: 3 additions & 3 deletions crates/accelerate/src/commutation_cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ pub(crate) fn cancel_commutations(
if let Some(wire_commutation_set) = commutation_set.get(&Wire::Qubit(wire)) {
for (com_set_idx, com_set) in wire_commutation_set.iter().enumerate() {
if let Some(&nd) = com_set.first() {
if !matches!(dag.dag()[nd], NodeType::Operation(_)) {
if !matches!(dag[nd], NodeType::Operation(_)) {
continue;
}
} else {
continue;
}
for node in com_set.iter() {
let instr = match &dag.dag()[*node] {
let instr = match &dag[*node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set."),
};
Expand Down Expand Up @@ -198,7 +198,7 @@ pub(crate) fn cancel_commutations(
let mut total_angle: f64 = 0.0;
let mut total_phase: f64 = 0.0;
for current_node in cancel_set {
let node_op = match &dag.dag()[*current_node] {
let node_op = match &dag[*current_node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};
Expand Down
10 changes: 5 additions & 5 deletions crates/accelerate/src/consolidate_blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pub(crate) fn consolidate_blocks(
block_qargs.clear();
if block.len() == 1 {
let inst_node = block[0];
let inst = dag.dag()[inst_node].unwrap_operation();
let inst = dag[inst_node].unwrap_operation();
if !is_supported(
target,
basis_gates.as_ref(),
Expand All @@ -123,7 +123,7 @@ pub(crate) fn consolidate_blocks(
let mut basis_count: usize = 0;
let mut outside_basis = false;
for node in &block {
let inst = dag.dag()[*node].unwrap_operation();
let inst = dag[*node].unwrap_operation();
block_qargs.extend(dag.get_qargs(inst.qubits));
all_block_gates.insert(*node);
if inst.op.name() == basis_gate_name {
Expand Down Expand Up @@ -151,7 +151,7 @@ pub(crate) fn consolidate_blocks(
block_qargs.len() as u32,
0,
block.iter().map(|node| {
let inst = dag.dag()[*node].unwrap_operation();
let inst = dag[*node].unwrap_operation();

Ok((
inst.op.clone(),
Expand Down Expand Up @@ -242,7 +242,7 @@ pub(crate) fn consolidate_blocks(
continue;
}
let first_inst_node = run[0];
let first_inst = dag.dag()[first_inst_node].unwrap_operation();
let first_inst = dag[first_inst_node].unwrap_operation();
let first_qubits = dag.get_qargs(first_inst.qubits);

if run.len() == 1
Expand Down Expand Up @@ -272,7 +272,7 @@ pub(crate) fn consolidate_blocks(
if all_block_gates.contains(node) {
already_in_block = true;
}
let gate = dag.dag()[*node].unwrap_operation();
let gate = dag[*node].unwrap_operation();
let operator = match get_matrix_from_inst(py, gate) {
Ok(mat) => mat,
Err(_) => {
Expand Down
2 changes: 1 addition & 1 deletion crates/accelerate/src/convert_2q_block_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub fn blocks_to_matrix(
let mut one_qubit_components_modified = false;
let mut output_matrix: Option<Array2<Complex64>> = None;
for node in op_list {
let inst = dag.dag()[*node].unwrap_operation();
let inst = dag[*node].unwrap_operation();
let op_matrix = get_matrix_from_inst(py, inst)?;
match dag
.get_qargs(inst.qubits)
Expand Down
2 changes: 1 addition & 1 deletion crates/accelerate/src/elide_permutations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn run(py: Python, dag: &mut DAGCircuit) -> PyResult<Option<(DAGCircuit, Vec<usi
// note that DAGCircuit::copy_empty_like clones the interners
let mut new_dag = dag.copy_empty_like(py, "alike")?;
for node_index in dag.topological_op_nodes()? {
if let NodeType::Operation(inst) = &dag.dag()[node_index] {
if let NodeType::Operation(inst) = &dag[node_index] {
match (inst.op.name(), inst.condition()) {
("swap", None) => {
let qargs = dag.get_qargs(inst.qubits);
Expand Down
6 changes: 3 additions & 3 deletions crates/accelerate/src/euler_one_qubit_decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
Some(_) => 1.,
None => raw_run.len() as f64,
};
let qubit: PhysicalQubit = if let NodeType::Operation(inst) = &dag.dag()[raw_run[0]] {
let qubit: PhysicalQubit = if let NodeType::Operation(inst) = &dag[raw_run[0]] {
PhysicalQubit::new(dag.get_qargs(inst.qubits)[0].0)
} else {
unreachable!("nodes in runs will always be op nodes")
Expand Down Expand Up @@ -1175,7 +1175,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
let operator = raw_run
.iter()
.map(|node_index| {
let node = &dag.dag()[*node_index];
let node = &dag[*node_index];
if let NodeType::Operation(inst) = node {
if let Some(target) = target {
error *= compute_error_term_from_target(inst.op.name(), target, qubit);
Expand Down Expand Up @@ -1218,7 +1218,7 @@ pub(crate) fn optimize_1q_gates_decomposition(
let mut outside_basis = false;
if let Some(basis) = basis_gates {
for node in &raw_run {
if let NodeType::Operation(inst) = &dag.dag()[*node] {
if let NodeType::Operation(inst) = &dag[*node] {
if !basis.contains(inst.op.name()) {
outside_basis = true;
break;
Expand Down
2 changes: 1 addition & 1 deletion crates/accelerate/src/gate_direction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ where
}

for (node, op_blocks) in ops_to_replace {
let packed_inst = dag.dag()[node].unwrap_operation();
let packed_inst = dag[node].unwrap_operation();
let OperationRef::Instruction(py_inst) = packed_inst.op.view() else {
panic!("PyInstruction is expected");
};
Expand Down
19 changes: 9 additions & 10 deletions crates/accelerate/src/inverse_cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fn run_on_self_inverse(
let mut chunk: Vec<NodeIndex> = Vec::new();
let max_index = gate_cancel_run.len() - 1;
for (i, cancel_gate) in gate_cancel_run.iter().enumerate() {
let node = &dag.dag()[*cancel_gate];
let node = &dag[*cancel_gate];
if let NodeType::Operation(inst) = node {
if gate_eq(py, inst, &gate)? {
chunk.push(*cancel_gate);
Expand All @@ -78,13 +78,12 @@ fn run_on_self_inverse(
if i == max_index {
partitions.push(std::mem::take(&mut chunk));
} else {
let next_qargs = if let NodeType::Operation(next_inst) =
&dag.dag()[gate_cancel_run[i + 1]]
{
next_inst.qubits
} else {
panic!("Not an op node")
};
let next_qargs =
if let NodeType::Operation(next_inst) = &dag[gate_cancel_run[i + 1]] {
next_inst.qubits
} else {
panic!("Not an op node")
};
if inst.qubits != next_qargs {
partitions.push(std::mem::take(&mut chunk));
}
Expand Down Expand Up @@ -132,8 +131,8 @@ fn run_on_inverse_pairs(
for nodes in runs {
let mut i = 0;
while i < nodes.len() - 1 {
if let NodeType::Operation(inst) = &dag.dag()[nodes[i]] {
if let NodeType::Operation(next_inst) = &dag.dag()[nodes[i + 1]] {
if let NodeType::Operation(inst) = &dag[nodes[i]] {
if let NodeType::Operation(next_inst) = &dag[nodes[i + 1]] {
if inst.qubits == next_inst.qubits
&& ((gate_eq(py, inst, &gate_0)? && gate_eq(py, next_inst, &gate_1)?)
|| (gate_eq(py, inst, &gate_1)?
Expand Down
5 changes: 2 additions & 3 deletions crates/accelerate/src/remove_diagonal_gates_before_measure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
.next()
.expect("index is an operation node, so it must have a predecessor.");

match &dag.dag()[predecessor] {
match &dag[predecessor] {
NodeType::Operation(pred_inst) => match pred_inst.standard_gate() {
Some(gate) => {
if DIAGONAL_1Q_GATES.contains(&gate) {
Expand All @@ -64,8 +64,7 @@ fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
let successors = dag.quantum_successors(predecessor);
let remove_s = successors
.map(|s| {
let node_s = &dag.dag()[s];
if let NodeType::Operation(inst_s) = node_s {
if let NodeType::Operation(inst_s) = &dag[s] {
inst_s.op.name() == "measure"
} else {
false
Expand Down
2 changes: 1 addition & 1 deletion crates/accelerate/src/split_2q_unitaries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub fn split_2q_unitaries(
let nodes: Vec<NodeIndex> = dag.op_node_indices(false).collect();

for node in nodes {
if let NodeType::Operation(inst) = &dag.dag()[node] {
if let NodeType::Operation(inst) = &dag[node] {
let qubits = dag.get_qargs(inst.qubits).to_vec();
// We only attempt to split UnitaryGate objects, but this could be extended in future
// -- however we need to ensure that we can compile the resulting single-qubit unitaries
Expand Down
10 changes: 5 additions & 5 deletions crates/accelerate/src/unitary_synthesis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn apply_synth_dag(
synth_dag: &DAGCircuit,
) -> PyResult<()> {
for out_node in synth_dag.topological_op_nodes()? {
let mut out_packed_instr = synth_dag.dag()[out_node].unwrap_operation().clone();
let mut out_packed_instr = synth_dag[out_node].unwrap_operation().clone();
let synth_qargs = synth_dag.get_qargs(out_packed_instr.qubits);
let mapped_qargs: Vec<Qubit> = synth_qargs
.iter()
Expand Down Expand Up @@ -237,7 +237,7 @@ fn py_run_main_loop(

// Iterate over dag nodes and determine unitary synthesis approach
for node in dag.topological_op_nodes()? {
let mut packed_instr = dag.dag()[node].unwrap_operation().clone();
let mut packed_instr = dag[node].unwrap_operation().clone();

if packed_instr.op.control_flow() {
let OperationRef::Instruction(py_instr) = packed_instr.op.view() else {
Expand Down Expand Up @@ -486,7 +486,7 @@ fn run_2q_unitary_synthesis(
.topological_op_nodes()
.expect("Unexpected error in dag.topological_op_nodes()")
.map(|node| {
let NodeType::Operation(inst) = &synth_dag.dag()[node] else {
let NodeType::Operation(inst) = &synth_dag[node] else {
unreachable!("DAG node must be an instruction")
};
let inst_qubits = synth_dag
Expand Down Expand Up @@ -1002,7 +1002,7 @@ fn synth_su4_dag(
Some(preferred_dir) => {
let mut synth_direction: Option<Vec<u32>> = None;
for node in synth_dag.topological_op_nodes()? {
let inst = &synth_dag.dag()[node].unwrap_operation();
let inst = &synth_dag[node].unwrap_operation();
if inst.op.num_qubits() == 2 {
let qargs = synth_dag.get_qargs(inst.qubits);
synth_direction = Some(vec![qargs[0].0, qargs[1].0]);
Expand Down Expand Up @@ -1066,7 +1066,7 @@ fn reversed_synth_su4_dag(
let mut target_dag = synth_dag.copy_empty_like(py, "alike")?;
let flip_bits: [Qubit; 2] = [Qubit(1), Qubit(0)];
for node in synth_dag.topological_op_nodes()? {
let mut inst = synth_dag.dag()[node].unwrap_operation().clone();
let mut inst = synth_dag[node].unwrap_operation().clone();
let qubits: Vec<Qubit> = synth_dag
.qargs_interner()
.get(inst.qubits)
Expand Down
4 changes: 2 additions & 2 deletions crates/circuit/src/converters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
#[cfg(feature = "cache_pygates")]
use std::sync::OnceLock;

use ::pyo3::prelude::*;
use hashbrown::HashMap;
use pyo3::prelude::*;
use pyo3::{
intern,
types::{PyDict, PyList},
Expand Down Expand Up @@ -106,7 +106,7 @@ pub fn dag_to_circuit(
dag.qargs_interner().clone(),
dag.cargs_interner().clone(),
dag.topological_op_nodes()?.map(|node_index| {
let NodeType::Operation(ref instr) = dag.dag()[node_index] else {
let NodeType::Operation(ref instr) = dag[node_index] else {
unreachable!(
"The received node from topological_op_nodes() is not an Operation node."
)
Expand Down
8 changes: 8 additions & 0 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6967,6 +6967,14 @@ impl DAGCircuit {
}
}

impl ::std::ops::Index<NodeIndex> for DAGCircuit {
type Output = NodeType;

fn index(&self, index: NodeIndex) -> &Self::Output {
self.dag.index(index)
}
}

/// Add to global phase. Global phase can only be Float or ParameterExpression so this
/// does not handle the full possibility of parameter values.
pub(crate) fn add_global_phase(py: Python, phase: &Param, other: &Param) -> PyResult<Param> {
Expand Down

0 comments on commit 524fb47

Please sign in to comment.