python - 从unittest.TestCase切换到tf.test.TestCase后的幻像测试

标签 python unit-testing tensorflow python-unittest

以下代码:

class BoxListOpsTest(unittest.TestCase):                                                                                                                                                                                                                              
    """Tests for common bounding box operations."""                                                                                                                                                                                                                   

    def test_area(self):                                                                                                                                                                                                                                              
        corners = tf.constant([[0.0, 0.0, 10.0, 20.0], [1.0, 2.0, 3.0, 4.0]])                                                                                                                                                                                         
        exp_output = [200.0, 4.0]                                                                                                                                                                                                                                     
        boxes = box_list.BoxList(corners)                                                                                                                                                                                                                             
        areas = box_list_ops.area(boxes)                                                                                                                                                                                                                              

        with tf.Session() as sess:                                                                                                                                                                                                                                    
            areas_output = sess.run(areas)                                                                                                                                                                                                                            
            np.testing.assert_allclose(areas_output, exp_output)                                                                                                                                                                                                      


if __name__ == '__main__':                                                                                                                                                                                                                                            
    unittest.main()

被解释为具有单个测试的测试用例:

.
----------------------------------------------------------------------
Ran 1 test in 0.471s

OK

但是,切换到tf.test.TestCase:

class BoxListOpsTest(tf.test.TestCase):                                                                                                                                                                                                                               
    """Tests for common bounding box operations."""                                                                                                                                                                                                                   

    def test_area(self):                                                                                                                                                                                                                                              
        corners = tf.constant([[0.0, 0.0, 10.0, 20.0], [1.0, 2.0, 3.0, 4.0]])                                                                                                                                                                                         
        exp_output = [200.0, 4.0]                                                                                                                                                                                                                                     
        boxes = box_list.BoxList(corners)                                                                                                                                                                                                                             
        areas = box_list_ops.area(boxes)                                                                                                                                                                                                                              
        # with self.session() as sess:                                                                                                                                                                                                                                
        with tf.Session() as sess:                                                                                                                                                                                                                                    
            areas_output = sess.run(areas)                                                                                                                                                                                                                            
            np.testing.assert_allclose(areas_output, exp_output)                                                                                                                                                                                                      


if __name__ == '__main__':                                                                                                                                                                                                                                            
    tf.test.main()

引入了一些第二个测试,已被跳过:

.s
----------------------------------------------------------------------
Ran 2 tests in 0.524s

OK (skipped=1)

第二个测试的起源是什么?我应该担心它吗?

我使用的是 TensorFlow 1.13。

最佳答案

这是 tf.test.TestCase.test_session方法。由于命名不吉利,unittest 认为 test_session 方法是一个测试并将其添加到测试套件中。为了防止将 test_session 作为测试运行,Tensorflow 必须在内部跳过它,因此会导致“跳过”测试:

def test_session(self,
                 graph=None,
                 config=None,
                 use_gpu=False,
                 force_gpu=False):
    if self.id().endswith(".test_session"):
        self.skipTest("Not a test.")

通过使用 --verbose 标志运行测试来验证跳过的测试是 test_session。您应该看到与此类似的输出:

...
test_session (BoxListOpsTest)
Use cached_session instead. (deprecated) ... skipped 'Not a test.'

尽管 test_session 自 1.11 起已弃用,并应替换为 cached_session ( related commit ),但截至目前,尚未计划在 2.0 中删除它。为了摆脱它,您可以对收集的测试应用自定义过滤器。

单元测试

您可以定义自定义 load_tests功能:

test_cases = (BoxListOpsTest, )

def load_tests(loader, tests, pattern):
    suite = unittest.TestSuite()
    for test_class in test_cases:
        tests = loader.loadTestsFromTestCase(test_class)
        filtered_tests = [t for t in tests if not t.id().endswith('.test_session')]
        suite.addTests(filtered_tests)
    return suite

pytest

添加自定义pytest_collection_modifyitems Hook 您的 conftest.py:

def pytest_collection_modifyitems(session, config, items):
    items[:] = [item for item in items if item.name != 'test_session']

关于python - 从unittest.TestCase切换到tf.test.TestCase后的幻像测试,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55417214/

相关文章:

c# - 带有响应 header 的单元测试 webapi Controller

tensorflow - 如何使用Tflearn构建词嵌入模型?

python - 使用word2vec预训练向量,如何生成句子的id作为tensorflow中tf.nn.embedding_lookup函数的输入?

python - Anaconda Python Conda pipbuild 因 WindowsError 找不到文件而失败

python - scipy odeint 结果取决于输入时间数组

unit-testing - Corda 流单元测试中各种 verifySignatures 函数之间的区别

android - 对有延迟的 Rxjava 可观察对象进行单元测试

python - 打印表中的所有行

c# - 图形布局和重新排列

python-3.x - 为什么导入 "from tensorflow.train import Feature"不起作用