你当前正在访问 Microsoft Azure Global Edition 技术文档网站。 如果需要访问由世纪互联运营的 Microsoft Azure 中国技术文档网站,请访问 https://docs.azure.cn。
Clone 函数
将模型的一部分网络复制到 BrainScript 函数中。
BS.Network.CloneFunction (inputNodes, outputNodes,
parameters="learnable" /*|"constant"|"shared"*/)
parameters
inputNodes
是 1 个或多个输入的数组。 它列出源网络的节点,这些节点是要提取的函数的输入。 调用生成的 BrainScript 函数时,克隆函数的参数将替换为这些节点。outputNodes
是单个输出节点或多个输出节点的记录。 这些表示原始网络中哪个节点是克隆函数的输出。 生成的 BrainScript 函数将返回这些函数。parameters
确定克隆节内可学习参数的处理方式。 可以理解以下值:"learnable"
:克隆函数中的每个可学习参数都将获取其自己的副本,此后会像任何其他参数一样通过训练进行更新。 这是默认值。"constant"
:可学习参数被复制,但随后被冻结。 克隆的函数在后续训练期间不会获取任何更新,例如,如果要在较小的自定义集的后续训练中使用在大型标准训练集上训练的功能提取器。"shared"
:原始可学习参数将继续以共享方式使用。 在后续训练期间,它们将从其原始用途和克隆用途更新。 如果多次调用由CloneFunction()
其返回的 BrainScript 函数,则所有克隆都将共享参数。
返回值
使用任意数量的输入参数inputNodes
的 BrainScript 函数,如果为标量,则返回标量;如果outputNodes
outputNodes
为记录,则返回具有匹配名称的记录。
说明
CloneFunction()
是用于 编辑 和创建模型的函数。 它将模型的一部分网络复制到 BrainScript 函数中,以便可以重复使用该网络的这一部分。 结果是一个 BrainScript 函数,可以像在常规 BrainScript 函数中定义网络此部分一样。
原始网络可以是单独的网络。 这样,就可以导入) 已针对不同数据训练的外部网络的一部分 (部分。 CloneFunction()
允许冻结克隆的模型参数。 这样,外部网络就可以用作固定特征提取器,或在适应设置中充当正则器。
原始网络也可以是当前定义的一部分,克隆可与原始网络共享其参数。 这允许通过网络对不同数据进行操作的多个相同路径,例如,用于对称比较两个输入的相似性的设置,其中功能提取层共享 (并共同学习两个输入的) 。 但是,如果原始网络部分包含循环循环,则当前不起作用。
要复制的节由其输入和输出节点定义。 Imagine要克隆的子节周围绘制线条的网络图。 然后,通过传递跨线的所有连接以输入标记区域作为参数,然后指定该行所表示的 inputNodes
此部分,并指定 outputNodes
所有连接。 CloneFunction()
将本部分提取到一个 BrainScript 函数中,其参数数等于数目 inputNodes
,输出为单个节点或节点字典。
也可以将可学习参数表示为 inputNodes
。 在这种情况下,可以将新参数替换为创建的 BrainScript 函数 CloneFunction()
的参数。 如果要复制函数,但从头开始学习参数,请执行此操作。 在这种情况下,还可以更改维度。
示例用例:
- 适应 (KL) :起始模型的冻结只读副本用作 KL-正则器
- 适应 (FDLR) :在网络固定时训练注入的输入转换
- 图像:ImageNet 网络的下层充当另一个图像任务的不可变功能提取器
- DSSM:将同一网络子部分应用于两个输入
节点名称的问题.
[
]
若要引用包含或[
或或的网络]
中的.
节点,请将这些字符替换为_
。
例如,如果 network
包含调用 result.z
的节点, network.result.z
将失败;而是说 network.result_z
。
实现说明
CloneFunction()
实际上不会在后台创建 BrainScript 代码。 而是创建一个类似于 BrainScript 函数的 C++ 对象。 CloneFunction()
本身也不克隆原始网络。 它只保留引用。 调用返回的 CloneFunction()
函数时,会发生实际的克隆。
示例
基本用法:
# create a BS function by copying a piece of an existing network loaded from disk
network = BS.Network.Load ("some.dnn")
net = BS.Network.CloneFunction (network.features, network.logP)
# apply the copy to a new input
out = net (myFeatures)
# This will create a copy of the subsection from network.features to network.logP
# where all links to network.features get replaced by links to myFeatures.
包含多个输入和输出节点的示例:
# This specific example passes two input nodes --> the resulting BS function will have 2 inputs;
# and it passes a record of output nodes --> the BS function will return a record with the same member names
network = BS.Network.Load ("some.dnn")
net = BS.Network.CloneFunction ((network.features:network.labels), [ ce = network.ce ; errs = network.errs ])
# 'net' is now a BrainScript function with this signature:
# CloneFunction (input1, input2) --> [ ce = ... ; errs = ... ]
# now create a network from the BS function
myFeatures = Input (13)
myLabels = Input (42)
out = net (myFeatures, myLabels) # e.g. myFeatures substitutes the original 'features' node
criterionNodes = (out.ce) # and the return value is a record with members 'ce' and 'errs'
evaluationNodes = (out.errs)
具体示例:调整网络,同时将原始网络用作正则器 (KLD) :
# load network
network = BS.Network.Load ("some.dnn")
# create a trainable clone and a read-only reference clone
adaptNet = BS.Network.CloneFunction (network.features, [ z = network.z ], parameters="learnable")
refNet = BS.Network.CloneFunction (network.features, [ z = network.z ], parameters="constant")
# create the main network
features = Input (42)
labels = Input (9000)
z = adaptNet (features).z
zRef = refNet (features).z
# training criterion
# In KL adaptation, labels are a linear interpolation of the one-hot targets
# and the posteriors produced by the reference network.
refWeight = 0.9
kldLabels = labels * (1-refWeight) + Softmax (zRef) * refWeight # interpolate with ref output
ce = CrossEntropyWithSoftmax (kldLabels, z) # the CE criterion is taken against these interpolated soft labels
errs = ErrorPrediction (labels, z) # errors are of course still counted against the actual labels
criterionNodes = (ce)
evaluationNodes = (errs)