博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tf keras SimpleRNN源码解析
阅读量:2134 次
发布时间:2019-04-30

本文共 20648 字,大约阅读时间需要 68 分钟。

环境
package version
tensorflow 2.3.0
keras 2.4.3
源码

部分主要源码

class RNN(Layer):  def __init__(self,               cell,               return_sequences=False,               return_state=False,               go_backwards=False,               stateful=False,               unroll=False,               time_major=False,               **kwargs):    if isinstance(cell, (list, tuple)):      cell = StackedRNNCells(cell)    # If True, the output for masked timestep will be zeros, whereas in the    # False case, output from previous timestep is returned for masked timestep.    self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)    if 'input_shape' not in kwargs and (        'input_dim' in kwargs or 'input_length' in kwargs):      input_shape = (kwargs.pop('input_length', None),                     kwargs.pop('input_dim', None))      kwargs['input_shape'] = input_shape    super(RNN, self).__init__(**kwargs)    self.cell = cell    self.return_sequences = return_sequences    self.return_state = return_state    self.go_backwards = go_backwards    self.stateful = stateful    self.unroll = unroll    self.time_major = time_major    self.supports_masking = True    self.input_spec = None    self.state_spec = None    self._states = None    self.constants_spec = None    self._num_constants = 0    if stateful:      if ds_context.has_strategy():        raise ValueError('RNNs with stateful=True not yet supported with '                         'tf.distribute.Strategy.')  @property  def states(self):    if self._states is None:      state = nest.map_structure(lambda _: None, self.cell.state_size)      return state if nest.is_sequence(self.cell.state_size) else [state]    return self._states  @trackable.no_automatic_dependency_tracking  def states(self, states):    self._states = states  def compute_mask(self, inputs, mask):    # Time step masks must be the same for each input.    # This is because the mask for an RNN is of size [batch, time_steps, 1],    # and specifies which time steps should be skipped, and a time step    # must be skipped for all inputs.    # TODO(scottzhu): Should we accept multiple different masks?    mask = nest.flatten(mask)[0]    output_mask = mask if self.return_sequences else None    if self.return_state:      state_mask = [None for _ in self.states]      return [output_mask] + state_mask    else:      return output_mask  def build(self, input_shape):    if isinstance(input_shape, list):      input_shape = input_shape[0]      # The input_shape here could be a nest structure.    # do the tensor_shape to shapes here. The input could be single tensor, or a    # nested structure of tensors.    def get_input_spec(shape):      """Convert input shape to InputSpec."""      if isinstance(shape, tensor_shape.TensorShape):        input_spec_shape = shape.as_list()      else:        input_spec_shape = list(shape)      batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)      if not self.stateful:        input_spec_shape[batch_index] = None      input_spec_shape[time_step_index] = None      return InputSpec(shape=tuple(input_spec_shape))    def get_step_input_shape(shape):      if isinstance(shape, tensor_shape.TensorShape):        shape = tuple(shape.as_list())      # remove the timestep from the input_shape      return shape[1:] if self.time_major else (shape[0],) + shape[2:]    # Check whether the input shape contains any nested shapes. It could be    # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy    # inputs.    try:      input_shape = tensor_shape.as_shape(input_shape)    except (ValueError, TypeError):      # A nested tensor input      pass    if not nest.is_sequence(input_shape):      # This indicates the there is only one input.      if self.input_spec is not None:        self.input_spec[0] = get_input_spec(input_shape)      else:        self.input_spec = [get_input_spec(input_shape)]      step_input_shape = get_step_input_shape(input_shape)    else:      if self.input_spec is not None:        self.input_spec[0] = nest.map_structure(get_input_spec, input_shape)      else:        self.input_spec = generic_utils.to_list(            nest.map_structure(get_input_spec, input_shape))      step_input_shape = nest.map_structure(get_step_input_shape, input_shape)    # allow cell (if layer) to build before we set or validate state_spec.    if isinstance(self.cell, Layer) and not self.cell.built:      with K.name_scope(self.cell.name):        self.cell.build(step_input_shape)        self.cell.built = True    # set or validate state_spec    if _is_multiple_state(self.cell.state_size):      state_size = list(self.cell.state_size)    else:      state_size = [self.cell.state_size]    if self.state_spec is not None:      # initial_state was passed in call, check compatibility      self._validate_state_spec(state_size, self.state_spec)    else:      self.state_spec = [          InputSpec(shape=[None] + tensor_shape.as_shape(dim).as_list())          for dim in state_size      ]    if self.stateful:      self.reset_states()    self.built = True  @staticmethod  def _validate_state_spec(cell_state_sizes, init_state_specs):    """Validate the state spec between the initial_state and the state_size.    Args:      cell_state_sizes: list, the `state_size` attribute from the cell.      init_state_specs: list, the `state_spec` from the initial_state that is        passed in `call()`.    Raises:      ValueError: When initial state spec is not compatible with the state size.    """    validation_error = ValueError(        'An `initial_state` was passed that is not compatible with '        '`cell.state_size`. Received `state_spec`={}; '        'however `cell.state_size` is '        '{}'.format(init_state_specs, cell_state_sizes))    flat_cell_state_sizes = nest.flatten(cell_state_sizes)    flat_state_specs = nest.flatten(init_state_specs)    if len(flat_cell_state_sizes) != len(flat_state_specs):      raise validation_error    for cell_state_spec, cell_state_size in zip(flat_state_specs,                                                flat_cell_state_sizes):      if not tensor_shape.TensorShape(          # Ignore the first axis for init_state which is for batch          cell_state_spec.shape[1:]).is_compatible_with(              tensor_shape.TensorShape(cell_state_size)):        raise validation_error  @doc_controls.do_not_doc_inheritable  def get_initial_state(self, inputs):    get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)    if nest.is_sequence(inputs):      # The input are nested sequences. Use the first element in the seq to get      # batch size and dtype.      inputs = nest.flatten(inputs)[0]    input_shape = array_ops.shape(inputs)    batch_size = input_shape[1] if self.time_major else input_shape[0]    dtype = inputs.dtype    if get_initial_state_fn:      init_state = get_initial_state_fn(          inputs=None, batch_size=batch_size, dtype=dtype)    else:      init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,                                               dtype)    # Keras RNN expect the states in a list, even if it's a single state tensor.    if not nest.is_sequence(init_state):      init_state = [init_state]    # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.    return list(init_state)  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):    inputs, initial_state, constants = _standardize_args(inputs,                                                         initial_state,                                                         constants,                                                         self._num_constants)    if initial_state is None and constants is None:      return super(RNN, self).__call__(inputs, **kwargs)    # If any of `initial_state` or `constants` are specified and are Keras    # tensors, then add them to the inputs and temporarily modify the    # input_spec to include them.    additional_inputs = []    additional_specs = []    if initial_state is not None:      additional_inputs += initial_state      self.state_spec = nest.map_structure(          lambda s: InputSpec(shape=K.int_shape(s)), initial_state)      additional_specs += self.state_spec    if constants is not None:      additional_inputs += constants      self.constants_spec = [          InputSpec(shape=K.int_shape(constant)) for constant in constants      ]      self._num_constants = len(constants)      additional_specs += self.constants_spec    # additional_inputs can be empty if initial_state or constants are provided    # but empty (e.g. the cell is stateless).    flat_additional_inputs = nest.flatten(additional_inputs)    is_keras_tensor = K.is_keras_tensor(        flat_additional_inputs[0]) if flat_additional_inputs else True    for tensor in flat_additional_inputs:      if K.is_keras_tensor(tensor) != is_keras_tensor:        raise ValueError('The initial state or constants of an RNN'                         ' layer cannot be specified with a mix of'                         ' Keras tensors and non-Keras tensors'                         ' (a "Keras tensor" is a tensor that was'                         ' returned by a Keras layer, or by `Input`)')    if is_keras_tensor:      # Compute the full input spec, including state and constants      full_input = [inputs] + additional_inputs      if self.built:        # Keep the input_spec since it has been populated in build() method.        full_input_spec = self.input_spec + additional_specs      else:        # The original input_spec is None since there could be a nested tensor        # input. Update the input_spec to match the inputs.        full_input_spec = generic_utils.to_list(            nest.map_structure(lambda _: None, inputs)) + additional_specs      # Perform the call with temporarily replaced input_spec      self.input_spec = full_input_spec      output = super(RNN, self).__call__(full_input, **kwargs)      # Remove the additional_specs from input spec and keep the rest. It is      # important to keep since the input spec was populated by build(), and      # will be reused in the stateful=True.      self.input_spec = self.input_spec[:-len(additional_specs)]      return output    else:      if initial_state is not None:        kwargs['initial_state'] = initial_state      if constants is not None:        kwargs['constants'] = constants      return super(RNN, self).__call__(inputs, **kwargs)  def call(self,           inputs,           mask=None,           training=None,           initial_state=None,           constants=None):    # The input should be dense, padded with zeros. If a ragged input is fed    # into the layer, it is padded and the row lengths are used for masking.    inputs, row_lengths = K.convert_inputs_if_ragged(inputs)    is_ragged_input = (row_lengths is not None)    self._validate_args_if_ragged(is_ragged_input, mask)    inputs, initial_state, constants = self._process_inputs(        inputs, initial_state, constants)    self._maybe_reset_cell_dropout_mask(self.cell)    if isinstance(self.cell, StackedRNNCells):      for cell in self.cell.cells:        self._maybe_reset_cell_dropout_mask(cell)    if mask is not None:      # Time step masks must be the same for each input.      # TODO(scottzhu): Should we accept multiple different masks?      mask = nest.flatten(mask)[0]    if nest.is_sequence(inputs):      # In the case of nested input, use the first element for shape check.      input_shape = K.int_shape(nest.flatten(inputs)[0])    else:      input_shape = K.int_shape(inputs)    timesteps = input_shape[0] if self.time_major else input_shape[1]    if self.unroll and timesteps is None:      raise ValueError('Cannot unroll a RNN if the '                       'time dimension is undefined. \n'                       '- If using a Sequential model, '                       'specify the time dimension by passing '                       'an `input_shape` or `batch_input_shape` '                       'argument to your first layer. If your '                       'first layer is an Embedding, you can '                       'also use the `input_length` argument.\n'                       '- If using the functional API, specify '                       'the time dimension by passing a `shape` '                       'or `batch_shape` argument to your Input layer.')    kwargs = {
} if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training # TF RNN cells expect single tensor as state instead of list wrapped tensor. is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None # Use the __call__ function for callable objects, eg layers, so that it # will have the proper name scopes for the ops, etc. cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') def step(inputs, states): constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type states = states[0] if len(states) == 1 and is_tf_rnn_cell else states output, new_states = cell_call_fn( inputs, states, constants=constants, **kwargs) if not nest.is_sequence(new_states): new_states = [new_states] return output, new_states else: def step(inputs, states): states = states[0] if len(states) == 1 and is_tf_rnn_cell else states output, new_states = cell_call_fn(inputs, states, **kwargs) if not nest.is_sequence(new_states): new_states = [new_states] return output, new_states last_output, outputs, states = K.rnn( step, inputs, initial_state, constants=constants, go_backwards=self.go_backwards, mask=mask, unroll=self.unroll, input_length=row_lengths if row_lengths is not None else timesteps, time_major=self.time_major, zero_output_for_mask=self.zero_output_for_mask) if self.stateful: updates = [ state_ops.assign(self_state, state) for self_state, state in zip( nest.flatten(self.states), nest.flatten(states)) ] self.add_update(updates) if self.return_sequences: output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths) else: output = last_output if self.return_state: if not isinstance(states, (list, tuple)): states = [states] else: states = list(states) return generic_utils.to_list(output) + states else: return output def _process_inputs(self, inputs, initial_state, constants): # input shape: `(samples, time (padded with zeros), input_dim)` # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. if (isinstance(inputs, collections_abc.Sequence) and not isinstance(inputs, tuple)): # get initial_state from full input spec # as they could be copied to multiple GPU. if not self._num_constants: initial_state = inputs[1:] else: initial_state = inputs[1:-self._num_constants] constants = inputs[-self._num_constants:] if len(initial_state) == 0: initial_state = None inputs = inputs[0] if self.stateful: if initial_state is not None: # When layer is stateful and initial_state is provided, check if the # recorded state is same as the default value (zeros). Use the recorded # state if it is not same as the default. non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s) for s in nest.flatten(self.states)]) # Set strict = True to keep the original structure of the state. initial_state = control_flow_ops.cond(non_zero_count > 0, true_fn=lambda: self.states, false_fn=lambda: initial_state, strict=True) else: initial_state = self.states elif initial_state is None: initial_state = self.get_initial_state(inputs) if len(initial_state) != len(self.states): raise ValueError('Layer has ' + str(len(self.states)) + ' states but was passed ' + str(len(initial_state)) + ' initial states.') return inputs, initial_state, constants def _validate_args_if_ragged(self, is_ragged_input, mask): if not is_ragged_input: return if mask is not None: raise ValueError('The mask that was passed in was ' + str(mask) + ' and cannot be applied to RaggedTensor inputs. Please ' 'make sure that there is no mask passed in by upstream ' 'layers.') if self.unroll: raise ValueError('The input received contains RaggedTensors and does ' 'not support unrolling. Disable unrolling by passing ' '`unroll=False` in the RNN Layer constructor.') def reset_states(self, states=None): """Reset the recorded states for the stateful RNN layer. Can only be used when RNN layer is constructed with `stateful` = `True`. Args: states: Numpy arrays that contains the value for the initial state, which will be feed to cell at the first time step. When the value is None, zero filled numpy array will be created based on the cell state size. Raises: AttributeError: When the RNN layer is not stateful. ValueError: When the batch size of the RNN layer is unknown. ValueError: When the input numpy array is not compatible with the RNN layer state, either size wise or dtype wise. """ if not self.stateful: raise AttributeError('Layer must be stateful.') spec_shape = None if self.input_spec is not None: spec_shape = nest.flatten(self.input_spec[0])[0].shape if spec_shape is None: # It is possible to have spec shape to be None, eg when construct a RNN # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know # it has 3 dim input, but not its full shape spec before build(). batch_size = None else: batch_size = spec_shape[1] if self.time_major else spec_shape[0] if not batch_size: raise ValueError('If a RNN is stateful, it needs to know ' 'its batch size. Specify the batch size ' 'of your input tensors: \n' '- If using a Sequential model, ' 'specify the batch size by passing ' 'a `batch_input_shape` ' 'argument to your first layer.\n' '- If using the functional API, specify ' 'the batch size by passing a ' '`batch_shape` argument to your Input layer.') # initialize state if None if nest.flatten(self.states)[0] is None: def create_state_variable(state): return K.zeros([batch_size] + tensor_shape.as_shape(state).as_list()) self.states = nest.map_structure( create_state_variable, self.cell.state_size) if not nest.is_sequence(self.states): self.states = [self.states] elif states is None: for state, size in zip(nest.flatten(self.states), nest.flatten(self.cell.state_size)): K.set_value(state, np.zeros([batch_size] + tensor_shape.as_shape(size).as_list())) else: flat_states = nest.flatten(self.states) flat_input_states = nest.flatten(states) if len(flat_input_states) != len(flat_states): raise ValueError('Layer ' + self.name + ' expects ' + str(len(flat_states)) + ' states, ' 'but it received ' + str(len(flat_input_states)) + ' state values. Input received: ' + str(states)) set_value_tuples = [] for i, (value, state) in enumerate(zip(flat_input_states, flat_states)): if value.shape != state.shape: raise ValueError( 'State ' + str(i) + ' is incompatible with layer ' + self.name + ': expected shape=' + str( (batch_size, state)) + ', found shape=' + str(value.shape)) set_value_tuples.append((state, value)) K.batch_set_value(set_value_tuples)
流程
build

input_shape

step_input_shape
state_size

转载地址:http://nlugf.baihongyu.com/

你可能感兴趣的文章
[Jmeter]jmeter之脚本录制与回放,优化(windows下的jmeter)
查看>>
Jmeter之正则
查看>>
【JMeter】1.9上考试jmeter测试调试
查看>>
【虫师】【selenium】参数化
查看>>
【Python练习】文件引用用户名密码登录系统
查看>>
学习网站汇总
查看>>
【Python】用Python打开csv和xml文件
查看>>
【Loadrunner】性能测试报告实战
查看>>
【面试】一份自我介绍模板
查看>>
【自动化测试】自动化测试需要了解的的一些事情。
查看>>
【selenium】selenium ide的安装过程
查看>>
【手机自动化测试】monkey测试
查看>>
【英语】软件开发常用英语词汇
查看>>
Fiddler 抓包工具总结
查看>>
【雅思】雅思需要购买和准备的学习资料
查看>>
【雅思】雅思写作作业(1)
查看>>
【雅思】【大作文】【审题作业】关于同不同意的审题作业(重点)
查看>>
【Loadrunner】通过loadrunner录制时候有事件但是白页无法出来登录页怎么办?
查看>>
【English】【托业】【四六级】写译高频词汇
查看>>
【托业】【新东方全真模拟】01~02-----P5~6
查看>>