# Copyright (C) 2023 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.
"""
Module for plotting historgram of measurement counts against quantum states.
"""
from typing import Callable, Dict, List, Optional, Union
import matplotlib.pyplot as plt
from matplotlib import colormaps
# pylint: disable=too-many-arguments,unnecessary-lambda
def _counts_to_decimal(counts: dict) -> dict:
"""
Converts a dictionary of counts to decimal form.
Args:
counts_dict (dict): A dictionary where keys are strings representing states and
values are integers representing counts.
Returns:
dict: A dictionary with the same keys as the input dictionary, but with values
converted to their respective decimal proportions.
Raises:
ValueError: If the total count is zero.
TypeError: If the input dictionary contains non-integer values
Example:
>>> counts_to_decimal({"00": 10, "01": 15, "10": 20, "11": 5})
{"00": 0.2, "01": 0.3, "10": 0.4, "11": 0.1}
"""
try:
total_count = sum(counts.values())
except TypeError as err:
raise TypeError("Counts values must be integers.") from err
if total_count == 0:
raise ValueError("Total count cannot be zero.")
decimal_dict = {key: value / total_count for key, value in counts.items()}
return decimal_dict
def _plot_data(
counts: Union[List[Dict], dict],
legend: Optional[Union[List[str], str]] = None,
colors: Optional[Union[List[str], str]] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
show_plot: Optional[bool] = True,
save_path: Optional[str] = None,
transform_fn: Optional[Callable[[dict], dict]] = None,
label_format_fn: Optional[Callable[[float], str]] = lambda x: str(x),
):
if not isinstance(counts, list):
counts = [counts]
if isinstance(colors, str):
colors = [colors]
if isinstance(legend, str):
legend = [legend]
if transform_fn:
counts = [transform_fn(counts_dict) for counts_dict in counts]
all_states = sorted(set(state for counts_dict in counts for state in counts_dict.keys()))
num_dicts = len(counts)
bar_width = 0.8 / num_dicts
x_positions = range(len(all_states))
if colors is None:
cmap = colormaps.get_cmap("tab10")
colors = [cmap(i / 10) for i in range(num_dicts)]
if len(colors) != len(counts):
raise ValueError("Number of colors must match number of datasets")
if isinstance(legend, list) and len(legend) != len(counts):
raise ValueError("Number of legend labels must match number of datasets")
for i, counts_dict in enumerate(counts):
counts_iter = [counts_dict.get(state, 0) for state in all_states]
default_label = f"Job {i}" if num_dicts > 1 else None
label = legend[i] if legend and i < len(legend) else default_label
bars = plt.bar(
[x + (i * bar_width) for x in x_positions],
counts_iter,
width=bar_width,
color=colors[i % num_dicts],
label=label,
align="center",
)
y_min, y_max = plt.gca().get_ylim()
y_range = y_max - y_min
offset_percentage = 0.02 # Adjust this value to get the desired offset
y_offset = y_range * offset_percentage
for hbar, count in zip(bars, counts_iter):
plt.text(
hbar.get_x() + hbar.get_width() / 2,
hbar.get_height() + +y_offset,
label_format_fn(count),
ha="center",
va="bottom",
color="black",
fontsize=8,
)
plt.xticks(x_positions, all_states, rotation=45)
y_ticks = plt.gca().get_yticks()
plt.yticks(y_ticks)
plt.grid(axis="y", linestyle="--", linewidth=0.7, alpha=0.7)
if y_label:
plt.ylabel(y_label)
if x_label:
plt.xlabel(x_label)
if title:
plt.title(title)
if legend or num_dicts > 1:
plt.legend()
if save_path:
plt.savefig(save_path)
if show_plot:
plt.show()
[docs]
def plot_distribution(
counts: Union[List[Dict], Dict],
legend: Optional[Union[List[str], str]] = None,
colors: Optional[Union[List[str], str]] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
show_plot: Optional[bool] = True,
save_path: Optional[str] = None,
):
"""
Plots a histogram probability distribution of quantum states.
Args:
counts (Union[List[Dict], Dict]): Dictionary or a list of dictionaries containing the
quantum states as keys and their respective counts as
values.
legend (Optional[Union[List[str], str]]): List of strings or a single string representing
the labels of the datasets. Defaults to None,
where it generates default labels.
colors (Optional[Union[List[str], str]]): List of strings or a single string representing
the colors for each dataset. Defaults to None,
where it generates a color sequence.
title (Optional[str]): String representing the title of the plot. Defaults to None.
x_label (Optional[str]): String representing the label for the x-axis. Defaults to None.
y_label (Optional[str]): String representing the label for the y-axis. Defaults to None.
show_plot (Optional[bool]): Boolean representing whether to show the plot. Defaults to True.
save_path (Optional[str]): String representing the path to save the plot. Defaults to None.
Returns:
None: This function does not return a value; it displays a plot.
Raises:
ValueError: Raises an error if input arguments do not match the expected types or formats.
Example:
.. code-block:: python
counts_dict1 = {'00': 50, '01': 30, '10': 10, '11': 10}
counts_dict2 = {'00': 20, '01': 40, '10': 30, '11': 10}
plot_distribution(
counts=[counts_dict1, counts_dict2],
legend=['First Execution', 'Second Execution'],
colors=['crimson', 'midnightblue'],
title="Quantum State Probability Distribution",
x_label="Quantum States",
y_label="Probabilities"
)
"""
y_label = "Quasi-probabilities" if not y_label else y_label
_plot_data(
counts,
legend,
colors,
title,
x_label,
y_label,
show_plot,
save_path,
transform_fn=_counts_to_decimal,
label_format_fn=lambda x: "{:.3f}".format(x), # pylint: disable=consider-using-f-string
)
[docs]
def plot_histogram(
counts: Union[List[Dict], Dict],
legend: Optional[Union[List[str], str]] = None,
colors: Optional[Union[List[str], str]] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
show_plot: Optional[bool] = True,
save_path: Optional[str] = None,
):
"""
Plots a histogram of measurement counts against quantum states.
Args:
counts (Union[List[Dict], Dict]): Dictionary or a list of dictionaries containing the
quantum states as keys and their respective counts as
values.
legend (Optional[Union[List[str], str]]): List of strings or a single string representing
the labels of the datasets. Defaults to None,
where it generates default labels.
colors (Optional[Union[List[str], str]]): List of strings or a single string representing
the colors for each dataset. Defaults to None,
where it generates a color sequence.
title (Optional[str]): String representing the title of the plot. Defaults to None.
x_label (Optional[str]): String representing the label for the x-axis. Defaults to None.
y_label (Optional[str]): String representing the label for the y-axis. Defaults to None.
show_plot (Optional[bool]): Boolean representing whether to show the plot. Defaults to True.
save_path (Optional[str]): String representing the path to save the plot. Defaults to None.
Returns:
None: This function does not return a value; it displays a plot.
Raises:
ValueError: Raises an error if input arguments do not match the expected types or formats.
Example:
.. code-block:: python
counts_dict1 = {'00': 50, '01': 30, '10': 10, '11': 10}
counts_dict2 = {'00': 20, '01': 40, '10': 30, '11': 10}
plot_histogram(
counts=[counts_dict1, counts_dict2],
legend=['First Execution', 'Second Execution'],
colors=['crimson', 'midnightblue'],
title="Quantum State Measurement Counts",
x_label="Quantum States",
y_label="Counts"
)
"""
y_label = "Counts" if not y_label else y_label
_plot_data(
counts,
legend,
colors,
title,
x_label,
y_label,
show_plot,
save_path,
label_format_fn=lambda x: str(x),
)