#!/usr/bin/env python
# coding: utf-8

# In[1]:


# pylint: disable=line-too-long
# pylint: disable=wrong-import-order
# pylint: disable=wrong-import-position
# pylint: disable=pointless-string-statement
# pylint: disable=pointless-statement
# pylint: disable=invalid-name
# pylint: disable=duplicate-code
# pylint: disable=too-many-lines
# pylint: disable=no-self-use
# pylint: disable=too-few-public-methods


# In[2]:


import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from directory_tree import display_tree
from qcodes import Instrument, ManualParameter, Parameter, validators
from scipy.optimize import minimize_scalar

import quantify_core.data.handling as dh
from quantify_core.analysis import base_analysis as ba
from quantify_core.analysis import cosine_analysis as ca
from quantify_core.measurement import Gettable, MeasurementControl
from quantify_core.utilities.dataset_examples import mk_2d_dataset_v1
from quantify_core.utilities.examples_support import (
    default_datadir,
    mk_cosine_instrument,
)
from quantify_core.utilities.inspect_utils import display_source_code

dh.set_datadir(default_datadir())
meas_ctrl = MeasurementControl("meas_ctrl")


# In[3]:


mw_source1 = Instrument("mw_source1")
# NB: for brevity only, this not the proper way of adding parameters to QCoDeS instruments
mw_source1.freq = ManualParameter(
    name="freq",
    label="Frequency",
    unit="Hz",
    vals=validators.Numbers(),
    initial_value=1.0,
)

pulsar_QRM = Instrument("pulsar_QRM")
# NB: for brevity only, this not the proper way of adding parameters to QCoDeS instruments
pulsar_QRM.signal = Parameter(
    name="sig_a", label="Signal", unit="V", get_cmd=lambda: mw_source1.freq() * 1e-8
)


# In[4]:


meas_ctrl.settables(
    mw_source1.freq
)  # We want to set the frequency of a microwave source
meas_ctrl.setpoints(np.arange(5e9, 5.2e9, 100e3))  # Scan around 5.1 GHz
meas_ctrl.gettables(pulsar_QRM.signal)  # acquire the signal from the pulsar QRM
dset = meas_ctrl.run(name="Frequency sweep")  # run the experiment


# In[5]:


t = ManualParameter("time", label="Time", unit="s")


class WaveGettable:
    """An examples of a gettable."""

    def __init__(self):
        self.unit = "V"
        self.label = "Amplitude"
        self.name = "sine"

    def get(self):
        """Return the gettable value."""
        return np.sin(t() / np.pi)

    def prepare(self) -> None:
        """Optional methods to prepare can be left undefined."""
        print("Preparing the WaveGettable for acquisition.")

    def finish(self) -> None:
        """Optional methods to finish can be left undefined."""
        print("Finishing WaveGettable to wrap up the experiment.")


# verify compliance with the Gettable format
wave_gettable = WaveGettable()
Gettable(wave_gettable)


# In[6]:


t = ManualParameter(
    "time",
    label="Time",
    unit="s",
    vals=validators.Numbers(),  # accepts a single number, e.g. a float or integer
)


class DualWave1D:
    """Example of a "dual" gettable."""

    def __init__(self):
        self.unit = ["V", "V"]
        self.label = ["Sine Amplitude", "Cosine Amplitude"]
        self.name = ["sin", "cos"]

    def get(self):
        """Return the value of the gettable."""
        return np.array([np.sin(t() / np.pi), np.cos(t() / np.pi)])

    # N.B. the optional prepare and finish methods are omitted in this Gettable.


# verify compliance with the Gettable format
wave_gettable = DualWave1D()
Gettable(wave_gettable)


# In[7]:


time = ManualParameter(
    name="time",
    label="Time",
    unit="s",
    vals=validators.Arrays(),  # accepts an array of values
)
signal = Parameter(
    name="sig_a", label="Signal", unit="V", get_cmd=lambda: np.cos(time())
)

time.batched = True
time.batch_size = 5
signal.batched = True
signal.batch_size = 10

meas_ctrl.settables(time)
meas_ctrl.gettables(signal)
meas_ctrl.setpoints(np.linspace(0, 7, 23))
dset = meas_ctrl.run("my experiment")
dset_grid = dh.to_gridded_dataset(dset)

dset_grid.y0.plot()


# In[8]:


with tempfile.TemporaryDirectory() as tmpdir:
    old_dir = dh.get_datadir()
    dh.set_datadir(Path(tmpdir) / "quantify-data")
    # we generate a dummy dataset and a few empty dirs for pretty printing
    (Path(dh.get_datadir()) / "20210301").mkdir()
    (Path(dh.get_datadir()) / "20210428").mkdir()

    quantify_dataset = mk_2d_dataset_v1()
    ba.BasicAnalysis(dataset=quantify_dataset).run()
    # to make sure the full path is displayed
    print(display_tree(dh.get_datadir(), string_rep=True), end="")
    dh.set_datadir(old_dir)


# In[9]:


# plot the columns of the dataset
_, axs = plt.subplots(3, 1, sharex=True)
xr.plot.line(quantify_dataset.x0[:54], label="x0", ax=axs[0], marker=".")
xr.plot.line(quantify_dataset.x1[:54], label="x1", ax=axs[1], color="C1", marker=".")
xr.plot.line(quantify_dataset.y0[:54], label="y0", ax=axs[2], color="C2", marker=".")
tuple(ax.legend() for ax in axs)
# return the dataset
quantify_dataset


# In[10]:


gridded_dset = dh.to_gridded_dataset(quantify_dataset)
gridded_dset.y0.plot()
gridded_dset


# In[11]:


display_source_code(mk_cosine_instrument)


# In[12]:


pars = mk_cosine_instrument()
meas_ctrl.settables(pars.t)
meas_ctrl.setpoints(np.linspace(0, 2, 50))
meas_ctrl.gettables(pars.sig)
dataset = meas_ctrl.run("Cosine experiment")
dataset


# In[13]:


a_obj = ca.CosineAnalysis(label="Cosine experiment")
a_obj.run()  # execute the analysis.
a_obj.display_figs_mpl()  # displays the figures created in previous step.


# In[14]:


experiment_container_path = dh.locate_experiment_container(tuid=dataset.tuid)
print(display_tree(experiment_container_path, string_rep=True), end="")


# In[15]:


# for example, the fitted frequency and amplitude are stored
freq = a_obj.quantities_of_interest["frequency"]
amp = a_obj.quantities_of_interest["amplitude"]
print(f"frequency {freq}")
print(f"amplitude {amp}")


# In[16]:


display_source_code(ba.BasicAnalysis)


# In[17]:


time = ManualParameter(
    name="time", label="Time", unit="s", vals=validators.Numbers(), initial_value=1
)
signal = Parameter(
    name="sig_a", label="Signal", unit="V", get_cmd=lambda: np.cos(time())
)

meas_ctrl.settables(time)
meas_ctrl.gettables(signal)
meas_ctrl.setpoints(np.linspace(0, 7, 20))
dset = meas_ctrl.run("my experiment")
dset_grid = dh.to_gridded_dataset(dset)

dset_grid.y0.plot(marker="o")
dset_grid


# In[18]:


time_a = ManualParameter(
    name="time_a", label="Time A", unit="s", vals=validators.Numbers(), initial_value=1
)
time_b = ManualParameter(
    name="time_b", label="Time B", unit="s", vals=validators.Numbers(), initial_value=1
)
signal = Parameter(
    name="sig_a",
    label="Signal A",
    unit="V",
    get_cmd=lambda: np.exp(time_a()) + 0.5 * np.exp(time_b()),
)

meas_ctrl.settables([time_a, time_b])
meas_ctrl.gettables(signal)
meas_ctrl.setpoints_grid([np.linspace(0, 5, 10), np.linspace(5, 0, 12)])
dset = meas_ctrl.run("my experiment")
dset_grid = dh.to_gridded_dataset(dset)

dset_grid.y0.plot(cmap="viridis")
dset_grid


# In[19]:


time = ManualParameter(
    name="time", label="Time", unit="s", vals=validators.Numbers(), initial_value=1
)
signal = Parameter(
    name="sig_a", label="Signal", unit="V", get_cmd=lambda: np.cos(time())
)
meas_ctrl.settables(time)
meas_ctrl.gettables(signal)
dset = meas_ctrl.run_adaptive("1D minimizer", {"adaptive_function": minimize_scalar})

dset_ad = dh.to_gridded_dataset(dset)
# add a grey cosine for reference
x = np.linspace(np.min(dset_ad["x0"]), np.max(dset_ad["x0"]), 101)
y = np.cos(x)
plt.plot(x, y, c="grey", ls="--")
_ = dset_ad.y0.plot(marker="o")


# In[20]:


time_a = ManualParameter(
    name="time_a", label="Time A", unit="s", vals=validators.Numbers(), initial_value=1
)
time_b = ManualParameter(
    name="time_b", label="Time B", unit="s", vals=validators.Numbers(), initial_value=1
)

signal = Parameter(
    name="sig_a",
    label="Signal A",
    unit="V",
    get_cmd=lambda: np.exp(time_a()) + 0.5 * np.exp(time_b()),
)


class DualWave2D:
    """A "dual" gettable example that depends on two settables."""

    def __init__(self):
        self.unit = ["V", "V"]
        self.label = ["Sine Amplitude", "Cosine Amplitude"]
        self.name = ["sin", "cos"]

    def get(self):
        """Returns the value of the gettable."""
        return np.array([np.sin(time_a() * np.pi), np.cos(time_b() * np.pi)])


dual_wave = DualWave2D()
meas_ctrl.settables([time_a, time_b])
meas_ctrl.gettables([signal, dual_wave])
meas_ctrl.setpoints_grid([np.linspace(0, 3, 21), np.linspace(4, 0, 20)])
dset = meas_ctrl.run("my experiment")
dset_grid = dh.to_gridded_dataset(dset)

for yi, cmap in zip(("y0", "y1", "y2"), ("viridis", "inferno", "plasma")):
    dset_grid[yi].plot(cmap=cmap)
    plt.show()
dset_grid


# In[21]:


time = ManualParameter(
    name="time",
    label="Time",
    unit="s",
    vals=validators.Arrays(),
    initial_value=np.array([1, 2, 3]),
)
signal = Parameter(
    name="sig_a", label="Signal", unit="V", get_cmd=lambda: np.cos(time())
)

time.batched = True
signal.batched = True

meas_ctrl.settables(time)
meas_ctrl.gettables(signal)
meas_ctrl.setpoints(np.linspace(0, 7, 20))
dset = meas_ctrl.run("my experiment")
dset_grid = dh.to_gridded_dataset(dset)

dset_grid.y0.plot(marker="o")
print(f"\nNOTE: The gettable returns an array:\n\n{signal.get()}")
dset_grid


# In[22]:


time_a = ManualParameter(
    name="time_a", label="Time A", unit="s", vals=validators.Numbers(), initial_value=1
)
time_b = ManualParameter(
    name="time_b",
    label="Time B",
    unit="s",
    vals=validators.Arrays(),
    initial_value=np.array([1, 2, 3]),
)
signal = Parameter(
    name="sig_a",
    label="Signal A",
    unit="V",
    get_cmd=lambda: np.exp(time_a()) + 0.5 * np.exp(time_b()),
)

time_b.batched = True
time_b.batch_size = 12
signal.batched = True

meas_ctrl.settables([time_a, time_b])
meas_ctrl.gettables(signal)
# `setpoints_grid` will take into account the `.batched` attribute
meas_ctrl.setpoints_grid([np.linspace(0, 5, 10), np.linspace(4, 0, time_b.batch_size)])
dset = meas_ctrl.run("my experiment")
dset_grid = dh.to_gridded_dataset(dset)

dset_grid.y0.plot(cmap="viridis")
dset_grid


# In[23]:


time = ManualParameter(
    name="time",
    label="Time",
    unit="s",
    vals=validators.Arrays(),
    initial_value=np.array([1, 2, 3]),
)


class DualWaveBatched:
    """A "dual" batched gettable example."""

    def __init__(self):
        self.unit = ["V", "V"]
        self.label = ["Amplitude W1", "Amplitude W2"]
        self.name = ["sine", "cosine"]
        self.batched = True
        self.batch_size = 100

    def get(self):
        """Returns the value of the gettable."""
        return np.array([np.sin(time() * np.pi), np.cos(time() * np.pi)])


time.batched = True
dual_wave = DualWaveBatched()

meas_ctrl.settables(time)
meas_ctrl.gettables(dual_wave)
meas_ctrl.setpoints(np.linspace(0, 7, 100))
dset = meas_ctrl.run("my experiment")
dset_grid = dh.to_gridded_dataset(dset)

_, ax = plt.subplots()
dset_grid.y0.plot(marker="o", label="y0", ax=ax)
dset_grid.y1.plot(marker="s", label="y1", ax=ax)
_ = ax.legend()

