diff --git a/src/qcodes/instrument/channel.py b/src/qcodes/instrument/channel.py index 95baf3d91ca5..f1df530e25a3 100644 --- a/src/qcodes/instrument/channel.py +++ b/src/qcodes/instrument/channel.py @@ -26,6 +26,7 @@ from typing_extensions import Unpack + from .instrument import Instrument from .instrument_base import InstrumentBaseKWArgs @@ -86,8 +87,14 @@ def parent(self) -> _TIB_co: return self._parent @property - def root_instrument(self) -> InstrumentBase: - return self._parent.root_instrument + def root_instrument(self) -> Instrument: + # the root instrument is the top level parent of this module, we need to + # go up the parent hierarchy until we find an object that returns itself as the parent, this should be the root instrument. We also + # this is required to be an Instrument. + # Once 3.13 is the minimum supported version + # consider replacing with a generic parameter with a default + # value of Instrument. + return cast("Instrument", self._parent.root_instrument) @property def name_parts(self) -> list[str]: diff --git a/src/qcodes/instrument/instrument_base.py b/src/qcodes/instrument/instrument_base.py index 52894a2e7284..7e673798546b 100644 --- a/src/qcodes/instrument/instrument_base.py +++ b/src/qcodes/instrument/instrument_base.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence - from typing import NotRequired + from typing import NotRequired, Self from qcodes.instrument.channel import ChannelTuple, InstrumentModule from qcodes.logger.instrument_logger import InstrumentLoggerAdapter @@ -579,7 +579,7 @@ def ancestors(self) -> tuple[InstrumentBase, ...]: return (self,) @property - def root_instrument(self) -> InstrumentBase: + def root_instrument(self) -> Self: """ The topmost parent of this module. diff --git a/src/qcodes/parameters/parameter_base.py b/src/qcodes/parameters/parameter_base.py index 1b76f56fe38a..9439ba54213c 100644 --- a/src/qcodes/parameters/parameter_base.py +++ b/src/qcodes/parameters/parameter_base.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from datetime import datetime from functools import cached_property, wraps -from typing import TYPE_CHECKING, Any, ClassVar, Generic, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, cast, overload import numpy as np from typing_extensions import TypeVar @@ -47,7 +47,7 @@ from types import TracebackType from qcodes.dataset.data_set_protocol import ValuesType - from qcodes.instrument import InstrumentBase + from qcodes.instrument import Instrument, InstrumentBase from qcodes.logger.instrument_logger import InstrumentLoggerAdapter ParameterDataTypeVar = TypeVar("ParameterDataTypeVar", default=Any) # InstrumentTypeVar_co is a covariant type variable representing the instrument @@ -1209,7 +1209,7 @@ def instrument(self) -> InstrumentTypeVar_co: return self._instrument @property - def root_instrument(self) -> InstrumentBase | None: + def root_instrument(self) -> Instrument | None: """ Return the fundamental instrument that this parameter belongs too. E.g if the parameter is bound to a channel this will return the @@ -1217,7 +1217,7 @@ def root_instrument(self) -> InstrumentBase | None: :meth:`instrument` to get the channel. """ if self._instrument is not None: - return self._instrument.root_instrument + return cast("Instrument | None", self._instrument.root_instrument) else: return None diff --git a/src/qcodes/parameters/specialized_parameters.py b/src/qcodes/parameters/specialized_parameters.py index d55868467d04..255e57f92415 100644 --- a/src/qcodes/parameters/specialized_parameters.py +++ b/src/qcodes/parameters/specialized_parameters.py @@ -8,8 +8,9 @@ import warnings from time import perf_counter -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal +from qcodes.parameters.parameter_base import InstrumentTypeVar_co from qcodes.utils import QCoDeSDeprecationWarning from qcodes.validators import Strings, Validator @@ -18,7 +19,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from qcodes.instrument import InstrumentBase + from qcodes.instrument import Instrument class ElapsedTimeParameter(Parameter): @@ -96,7 +97,9 @@ def t0(self) -> float: return self._t0 -class InstrumentRefParameter(Parameter): +class InstrumentRefParameter( + Parameter[str, InstrumentTypeVar_co], Generic[InstrumentTypeVar_co] +): """ An instrument reference parameter. @@ -134,12 +137,12 @@ def __init__( self, name: str, *args: Any, - instrument: InstrumentBase | None = None, + instrument: InstrumentTypeVar_co = None, label: str | None = None, unit: str | None = None, get_cmd: str | Callable[..., Any] | Literal[False] | None = None, set_cmd: str | Callable[..., Any] | Literal[False] | None = None, - initial_value: float | str | None = None, + initial_value: str | None = None, max_val_age: float | None = None, vals: Validator[Any] | None = None, docstring: str | None = None, @@ -229,16 +232,15 @@ def __init__( **kwargs, ) - # TODO(nulinspiratie) check class works now it's subclassed from Parameter - def get_instr(self) -> InstrumentBase: + def get_instr(self) -> Instrument | None: """ Returns the instance of the instrument with the name equal to the value of this parameter. """ + # lazy import to avoid circular import + # since Instrument module depends on parameer module + from qcodes.instrument import Instrument # noqa: PLC0415 + ref_instrument_name = self.get() - # note that _instrument refers to the instrument this parameter belongs - # to, while the ref_instrument_name is the instrument that is the value - # of this parameter. - if self._instrument is None: - raise RuntimeError("InstrumentRefParameter is not bound to an instrument.") - return self._instrument.find_instrument(ref_instrument_name) + + return Instrument.find_instrument(ref_instrument_name) diff --git a/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1517a_smu.py b/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1517a_smu.py index b66a514cceb4..4b728241117a 100644 --- a/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1517a_smu.py +++ b/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1517a_smu.py @@ -433,7 +433,7 @@ def test_iv_sweep_delay(smu: KeysightB1517A) -> None: smu.iv_sweep.step_delay(0.01) smu.iv_sweep.trigger_delay(0.1) smu.iv_sweep.measure_delay(15.4) - + assert isinstance(mainframe, MagicMock) mainframe.write.assert_has_calls( [ call("WT 43.12,0.0,0.0,0.0,0.0"), diff --git a/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1520a_cmu.py b/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1520a_cmu.py index 4a5532f5d2c7..bdbecb2ded33 100644 --- a/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1520a_cmu.py +++ b/tests/drivers/keysight_b1500/b1500_driver_tests/test_b1520a_cmu.py @@ -179,7 +179,7 @@ def test_get_post_sweep_voltage_cond(cmu: KeysightB1520A) -> None: def test_cv_sweep_delay(cmu: KeysightB1520A) -> None: mainframe = cmu.root_instrument - + assert isinstance(mainframe, MagicMock) mainframe.ask.return_value = "WTDCV0.0,0.0,0.0,0.0,0.0" cmu.cv_sweep.hold_time(1.0) @@ -192,6 +192,7 @@ def test_cv_sweep_delay(cmu: KeysightB1520A) -> None: def test_cmu_sweep_steps(cmu: KeysightB1520A) -> None: mainframe = cmu.root_instrument + assert isinstance(mainframe, MagicMock) mainframe.ask.return_value = "WDCV3,1,0.0,0.0,1" cmu.cv_sweep.sweep_start(2.0) cmu.cv_sweep.sweep_end(4.0) @@ -208,6 +209,7 @@ def test_cv_sweep_voltages(cmu: KeysightB1520A) -> None: end = 1.0 steps = 5 return_string = f"WDCV3,1,{start},{end},{steps}" + assert isinstance(mainframe, MagicMock) mainframe.ask.return_value = return_string cmu.cv_sweep.sweep_start(start) @@ -226,6 +228,7 @@ def test_sweep_modes(cmu: KeysightB1520A) -> None: steps = 5 mode = constants.SweepMode.LINEAR_TWO_WAY return_string = f"WDCV3,{mode},{start},{end},{steps}" + assert isinstance(mainframe, MagicMock) mainframe.ask.return_value = return_string cmu.cv_sweep.sweep_start(start) @@ -249,6 +252,7 @@ def test_run_sweep(cmu: KeysightB1520A) -> None: f"0.0000;WDCV3," f"1,{start},{end},{steps};ACT0,1" ) + assert isinstance(mainframe, MagicMock) mainframe.ask.return_value = return_string cmu.setup_fnc_already_run = True cmu.impedance_model(constants.IMP.MeasurementMode.G_X) diff --git a/tests/drivers/keysight_b1500/b1500_driver_tests/test_sampling_measurement.py b/tests/drivers/keysight_b1500/b1500_driver_tests/test_sampling_measurement.py index 973d29f139e7..4758b5d0388f 100644 --- a/tests/drivers/keysight_b1500/b1500_driver_tests/test_sampling_measurement.py +++ b/tests/drivers/keysight_b1500/b1500_driver_tests/test_sampling_measurement.py @@ -83,6 +83,7 @@ def test_sampling_measurement( actual_data = smu.sampling_measurement_trace.get() np.testing.assert_allclose(actual_data, data_to_return, atol=1e-3) + assert isinstance(smu.root_instrument.ask, Mock) smu.root_instrument.ask.assert_called_with("XE") diff --git a/tests/parameter/test_instrument_ref_parameter.py b/tests/parameter/test_instrument_ref_parameter.py index 3379d4981419..f3e2250bc1b6 100644 --- a/tests/parameter/test_instrument_ref_parameter.py +++ b/tests/parameter/test_instrument_ref_parameter.py @@ -9,28 +9,32 @@ from collections.abc import Generator +class DummyHolder(DummyInstrument): + def __init__(self, name: str) -> None: + super().__init__(name) + self.test = self.add_parameter( + "test", + parameter_class=InstrumentRefParameter, + initial_value=None, + ) + + @pytest.fixture(name="instrument_a") -def _make_instrument_a() -> "Generator[DummyInstrument, None, None]": - a = DummyInstrument("dummy_holder") - try: - yield a - finally: - a.close() +def _make_instrument_a() -> "Generator[DummyHolder, None, None]": + + a = DummyHolder("dummy_holder") + yield a + a.close() @pytest.fixture(name="instrument_d") def _make_instrument_d() -> "Generator[DummyInstrument, None, None]": d = DummyInstrument("dummy") - try: - yield d - finally: - d.close() + yield d + d.close() -def test_get_instr( - instrument_a: DummyInstrument, instrument_d: DummyInstrument -) -> None: - instrument_a.add_parameter("test", parameter_class=InstrumentRefParameter) +def test_get_instr(instrument_a: DummyHolder, instrument_d: DummyInstrument) -> None: instrument_a.test.set(instrument_d.name)