博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow中的lstm的state
阅读量:6404 次
发布时间:2019-06-23

本文共 1742 字,大约阅读时间需要 5 分钟。

   

考虑 state_is_tuple

   

Output, new_state = cell(input, state)

   

state其实是两个 一个 c state,一个m(对应下图的hidden 或者h) 其中m(hidden)其实也就是输出

   

   

   

   

new_state = (LSTMStateTuple(c, m) if self._state_is_tuple

else array_ops.concat(1, [c, m]))

return m, new_state

   

   

def basic_rnn_seq2seq(

encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None):

with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):

_, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype)

return rnn_decoder(decoder_inputs, enc_state, cell)

   

   

def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,

scope=None):

with variable_scope.variable_scope(scope or "rnn_decoder"):

state = initial_state

outputs = []

prev = None

for i, inp in enumerate(decoder_inputs):

if loop_function is not None and prev is not None:

with variable_scope.variable_scope("loop_function", reuse=True):

inp = loop_function(prev, i)

if i > 0:

variable_scope.get_variable_scope().reuse_variables()

output, state = cell(inp, state)

outputs.append(output)

if loop_function is not None:

prev = output

return outputs, state

   

   

这里decoder用了encoder的最后一个state 作为输入

   

然后输出结果是decoder过程最后的state 加上所有ouput的集合(也就是hidden的集合)

注意ouputs[-1]其实数值和state里面的m是一致的

当然有可能后面outputs dynamic rnn 会补0

   

encode_feature, state = melt.rnn.encode(

cell,

inputs,

seq_length,

encode_method=0,

output_method=3)

   

encode_feature.eval()

array([[[ 4.27834410e-03, 1.45841937e-03, 1.25767402e-02,

5.00775501e-03],
[ 6.24437723e-03, 2.60074623e-03, 2.32168660e-02,
9.47457738e-03],
[ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,
-5.71310846e-03],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]]], dtype=float32)

   

   

state[1].eval()

array([[ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,

-5.71310846e-03]], dtype=float32)

   

   

   

转载地址:http://lqjea.baihongyu.com/

你可能感兴趣的文章
《手机测试Robotium实战教程》—第2章2.4节ADT插件的安装
查看>>
《架构真经:互联网技术架构的设计》分而治之
查看>>
发展型机器人:由人类婴儿启发的机器人. 2.2 机器人学简介
查看>>
干货!2017苹果开发者大会发布了啥,看这篇就够了
查看>>
数博会第一辩:机器智能是人也不具备的智能
查看>>
《术以载道——软件过程改进实践指南》目录—导读
查看>>
SSH 使用密钥登录并禁止口令登录实践
查看>>
《统计会犯错——如何避免数据分析中的统计陷阱》—第1章功效曲线
查看>>
机器人系统设计与制作:Python语言实现1.1 什么是机器人
查看>>
《JavaScript面向对象精要》——1.2 原始类型
查看>>
《Visual C++ 开发从入门到精通》——2.2 分析C++的程序结构
查看>>
《实现模式(修订版)》—第3章3.1节价值观
查看>>
SQL数据库的终结?
查看>>
C 和 C++ 文件操作详解
查看>>
【猜代码赢大奖】又是一年四月一,代码整人别客气
查看>>
C++实践参考——静态成员应用
查看>>
PetaData · 架构体系 · PetaData第二代低成本存储体系
查看>>
Redis内存分析方法
查看>>
CPU、内存、IO虚拟化关键技术及其优化探索
查看>>
SQL优化之六脉神剑
查看>>