我有一个简单的单元测试,我在其中检查我是否可以使用略有不同的参数来实例化我的 Tensorflow 类。这似乎是 @pytest.mark.parametrize
的一个很好的用例。
但是,我发现如果我的单元测试是 tf.test.TestCase
的方法,则 parametrize
会被忽略。
例如,当我对以下代码运行 pytest
时:
class TestBasicRewardNet(tf.test.TestCase):
@pytest.mark.parametrize("env", ['FrozenLake-v0', 'CartPole-v1',
'CarRacing-v0', 'LunarLander-v2'])
def test_init_no_crash(self, env):
for i in range(3):
x = BasicRewardNet(env)
我收到错误 TypeError: test_init_no_crash() missing 1 required positional argument: 'env'
。
为了解决这个问题,我尝试去掉类包装器,但这让我错过了一些自动 Tensorflow 测试初始化。特别是,现在每个 BasicRewardNet
都构建在同一个 TensorFlow 图中,所以我需要做一些事情,比如添加一个变量范围来避免
冲突。添加这个变量范围似乎很麻烦。
@pytest.mark.parametrize("env", ['FrozenLake-v0', 'CartPole-v1',
'CarRacing-v0', 'LunarLander-v2'])
def test_init_no_crash(env):
for i in range(3):
with tf.variable_scope(env+str(i)):
x = BasicRewardNet(env)
我想知道是否有人知道我可以干净利落地兼顾两全其美的方法?我希望能够使用 parametrize
并同时获得 tf.test.TestCase
的自动 Tensorflow 初始化。
最佳答案
正如 hoefling 的评论中提到的那样,可以使用 tf.test.TestCase.subTest
解决。
class TestBasicRewardNet(tf.test.TestCase):
@staticmethod
def my_sub_test(env):
for i in range(3):
with tf.variable_scope(env+str(i)):
x = BasicRewardNet(env)
def test_init_no_crash(env):
for env in ['FrozenLake-v0', 'CartPole-v1','CarRacing-v0', 'LunarLander-v2']:
with self.subTest(env):
self.my_sub_test(env)
要在使用 pytest
运行时能够使用 subTest
功能,您应该添加 pytest-subtests在要求中,否则你将没有它们!
关于python - 同时使用 pytest 和 tf.test.TestCase 的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53824687/