调试 TensorFlow 模型

正如我们在本系列文章中所学到的,TensorFlow程序用于构建和训练可用于各种任务预测的模型。在训练模型时,您可以构建计算图,然后运行计算图以进行训练,并评估计算图以进行预测。重复执行这些任务,直到您对模型的质量感到满意为止,然后将计算图与学习的参数一起保存。在生产中,计算图是从文件构建或恢复的,且使用参数填充。

构建深度学习模型是一项复杂的技术,TensorFlow API及其生态系统同样复杂。当我们在TensorFlow中构建和训练模型时,有时我们会得到不同类型的错误,或者模型不能按预期工作。例如,您经常看到自己陷入以下一种或多种情况:

  • 损失或度量输出了NaN(非数字)
  • 即使经过多次迭代,损失或其他指标也没有改善

在这种情况下,我们需要调试使用TensorFlow API编写的代码。

要修复代码以使其正常工作,可以使用调试器或平台提供的其他方法和工具,例如Python中的Python调试器(pdb)和Linux OS中的GNU调试器(gdb)。当出现问题时,TensorFlow API还提供一些额外的支持来修复代码。

在本文中,我们将学习TensorFlow中可用的其他工具和技术,以帮助调试:

  • 使用tf.Session.run()获取张量值
  • 使用tf.Print()打印张量值
  • 用tf.Assert()断言条件是否成立
  • 使用TensorFlow调试器(tfdbg)进行调试

使用tf.Session.run()获取张量值

您可以使用tf.Session.run()获取要打印的张量值。这些值作为NumPy数组返回,可以使用Python语句打印或记录。这是最易于理解和最简便易行的方法,最大的缺点是计算图执行所有依赖路径,即从获取的张量开始,如果这些路径包括训练操作,则它只会前进一步或一个周期。

因此,大多数情况下你不会调用tf.Session.run()来获取计算图中间的张量,但是你会执行整个计算图并获取所有张量,有些张量需要你调试,同时有些张量则不需要。

函数tf.Session.partial_run()也适用于您可能想要执行计算图的一部分的情况,但它是一个高度实验性的API,尚未准备好用于生产。

使用tf.Print()打印张量值

以调试为目的打印值的另一种做法是使用tf.Print()。当执行包含tf.Print()节点的路径时,可以在tf.Print()中包装张量以在标准错误控制台中打印其值。tf.Print()函数具有以下签名:

该函数的参数如下:

  • input_ 是一个从函数返回的张量,没有任何操作。
  • data 是要打印的张量列表。
  • message 是一个字符串,它作为打印输出的前缀打印出来。
  • first_n 表示打印输出的步骤数;如果此值为负,则只要执行路径,就始终打印该值。
  • summarize 表示从张量中打印的元素的数量;默认情况下,仅打印三个元素。

您可以按照Jupyter笔记本ch-18_TensorFlow_Debugging中的代码进行操作。

让我们修改之前创建的MNIST MLP模型来添加print语句:

当我们运行代码时,我们在Jupyter的控制台中获得以下内容:

使用tf.Print()的唯一缺点是该函数提供有限的格式化功能。

用tf.Assert()断言条件是否成立

调试TensorFlow模型的另一种方法是插入条件断言。tf.Assert()函数接受一个条件,如果条件为假,则它打印给定张量的列表并抛出tf.errors.InvalidArgumentError。

  1. tf.Assert()函数具有以下签名:
  1. 断言操作不像tf.Print()函数那样落入计算图的路径中。为了确保执行tf.Assert()操作,我们需要将它添加到依赖项中。例如,让我们定义一个断言来检查所有输入是否为正:
  1. 在定义模型时将assert_op添加到依赖项,如下所示:
  1. 为了测试这段代码,我们在5个周期之后引入了一个杂质,如下:
  1. 代码正常运行了五个周期后抛出错误:

除了可以采用任何有效条件表达式的tf.Assert()函数之外,TensorFlow还提供以下断言操作,以检查特定条件并使用简单语法:

  • assert_equal
  • assert_greater
  • assert_greater_equal
  • assert_integer
  • assert_less
  • assert_less_equal
  • assert_negative
  • assert_none_equal
  • assert_non_negative
  • assert_non_positive
  • assert_positive
  • assert_proper_iterable
  • assert_rank
  • assert_rank_at_least
  • assert_rank_in
  • assert_same_float_dtype
  • assert_scalar
  • assert_type
  • assert_variables_initialized

作为示例,前面提到的示例断言操作也可以写成如下形式:

使用TensorFlow调试器(tfdbg)进行调试

TensorFlow调试器(tfdbg)与其他流行的调试器(如pdb和gdb)在高级别上的工作方式相同。要使用调试器,该过程通常如下:

  1. 在代码中的断点处设置要中断的位置并检查变量;
  2. 在调试模式下运行代码;
  3. 当代码在断点处中断时,检查该断点然后继续下一步;

一些调试器还允许您在代码执行时以交互方式观察变量,而不仅仅是在断点处:

  1. 为了使用tfdbg,首先导入所需的模块并将会话包装在调试器包装器中:
  1. 接下来,将过滤器附加到会话对象。附加过滤器与在其他调试器中设置断点相同。例如,以下代码附加了一个tfdbg.has_inf_or_nan过滤器,如果任何中间张量具有nan或inf值,则该过滤器会中断:
  1. 现在,当代码执行tfs.run()时,调试器将在控制台中启动调试器接口,您可以在其中运行各种调试器命令来监视张量值。
  2. 我们提供了在ch-18_mnist_tfdbg.py文件中试用tfdbg的代码。当我们用python3执行代码文件时,我们看到了tfdbg控制台:
  1. 在tfdbg>提示符处输入命令run -f has_inf_or_nan。代码在第一个周期后中断,因为我们使用np.inf值填充数据:
  1. 现在,您可以使用tfdbg控制台或可单击的界面来检查各种张量的值。例如,我们查看其中一个渐变的值:

您可以在以下链接中找到有关使用tfdbg控制台和检查变量的更多信息:

https://www.tensorflow.org/guide/debugger

https://tensorflow.google.cn/guide/debugger

总结

在本文中,我们学习了如何在TensorFlow中调试用于构建和训练模型的代码。我们了解到我们可以使用tf.Session.run()将张量作为NumPy数组获取。我们还可以通过在计算图中添加tf.Print()操作来打印张量的值。我们还学习了如何在使用tf.Assert()和其他tf.assert_ *操作执行期间某些条件无法保持时引发错误。我们通过对TensorFlow调试器(tfdbg)的介绍作为文章结尾,该调试器用于设置断点和观察张量值,就像我们在Python调试器(pdb)或GNU调试器(gdb)中调试代码一样。

本文将我们的旅程带入一个新的里程碑。我们不希望旅程在此结束,但我们相信旅程刚刚开始,您将进一步扩展和应用本系列文章中获得的知识和技能。

我们非常期待听到您的经验,反馈和建议。

发表评论

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据