ValueError:输入“input_2”缺少数据。您传递了一个带有键 [‘y’, ‘x’] 的数据字典。预期以下键:[‘input_2’]

乘风 tensorflow 397

原文标题ValueError: Missing data for input “input_2”. You passed a data dictionary with keys [‘y’, ‘x’]. Expected the following keys: [‘input_2’]

按照前面的代码here我正在评估联邦学习模型,我遇到了几个问题。这是评估代码

central_test = test.create_tf_dataset_from_all_clients()
test_data = central_test.map(reshape_data)

# function that accepts a server state, and uses 
#Keras to evaluate on the test dataset.
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_test)

server_state = federated_algorithm.initialize()
evaluate(server_state)

这是错误信息

ValueError: Missing data for input "input_2". You passed a data dictionary with keys ['y', 'x']. Expected the following keys: ['input_2']

那么这里的问题是什么?方法create_tf_dataset_from_all_clients的使用是否在正确的位置?因为 – 正如教程中所写的那样 – 用于创建集中评估数据集。为什么我们需要使用集中式数据集?

原文链接:https://stackoverflow.com//questions/71506975/valueerror-missing-data-for-input-input-2-you-passed-a-data-dictionary-with

回复

我来回复
  • AloneTogether的头像
    AloneTogether 评论

    test数据集在评估期间具有不同的格式。尝试:

    test_data = test.create_tf_dataset_from_all_clients().map(reshape_data).batch(2)
    test_data = test_data.map(lambda x: (x['x'], x['y']))
    
    def evaluate(server_state):
      keras_model = create_keras_model()
      keras_model.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
      )
      keras_model.set_weights(server_state)
      keras_model.evaluate(test_data)
    
    server_state = federated_algorithm.initialize()
    evaluate(server_state)
    
    2年前 0条评论