数据集 API“flat_map”方法对与“map”方法一起使用的相同代码产生错误

2023-12-01

我正在尝试创建一个管道来使用 TensorFlow Dataset API 和 Pandas 读取多个 CSV 文件。然而,使用flat_map方法正在产生错误。但是,如果我使用map方法我能够构建代码并在会话中运行它。这是我正在使用的代码。我已经打开了#17415TensorFlow Github 存储库中的问题。但显然,这不是一个错误,他们要求我在这里发帖。

folder_name = './data/power_data/'
file_names = os.listdir(folder_name)
def _get_data_for_dataset(file_name,rows=100):#
    print(file_name.decode())

    df_input=pd.read_csv(os.path.join(folder_name, file_name.decode()),
                         usecols =['Wind_MWh','Actual_Load_MWh'],nrows = rows)
    X_data = df_input.as_matrix()
    X_data.astype('float32', copy=False)

    return X_data
dataset = tf.data.Dataset.from_tensor_slices(file_names)
dataset = dataset.flat_map(lambda file_name: tf.py_func(_get_data_for_dataset, 
[file_name], tf.float64))
dataset= dataset.batch(2)
fiter = dataset.make_one_shot_iterator()
get_batch = iter.get_next()

我收到以下错误:map_func must return a Dataset object。当我使用时管道工作没有错误map但它没有给出我想要的输出。例如,如果 Pandas 从每个 CSV 文件中读取 N 行,我希望管道连接 B 文件中的数据并为我提供一个形状为 (N*B, 2) 的数组。相反,它给我 (B, N,2),其中 B 是批量大小。map正在添加另一个轴而不是连接现有轴。根据我在文档中的理解flat_map应该给出平坦的输出。在文档中,两者map and flat_map返回类型数据集。那么我的代码如何与 map 一起使用而不是与 flat_map 一起使用?

如果您能向我指出将 Dataset API 与 Pandas 模块一起使用的代码,那就太好了。


As mikkola 在评论中指出, the Dataset.map() and Dataset.flat_map()期望具有不同签名的函数:Dataset.map()接受一个函数,将输入数据集的单个元素映射到单个新元素,而Dataset.flat_map()接受一个函数,将输入数据集的单个元素映射到Dataset的元素。

如果您希望返回数组的每一行_get_data_for_dataset()到 成为一个单独的元素,你应该使用Dataset.flat_map()并将输出转换为tf.py_func() to a Dataset, using Dataset.from_tensor_slices():

folder_name = './data/power_data/'
file_names = os.listdir(folder_name)

def _get_data_for_dataset(file_name, rows=100):
    df_input=pd.read_csv(os.path.join(folder_name, file_name.decode()),
                         usecols=['Wind_MWh', 'Actual_Load_MWh'], nrows=rows)
    X_data = df_input.as_matrix()
    return X_data.astype('float32', copy=False)

dataset = tf.data.Dataset.from_tensor_slices(file_names)

# Use `Dataset.from_tensor_slices()` to make a `Dataset` from the output of 
# the `tf.py_func()` op.
dataset = dataset.flat_map(lambda file_name: tf.data.Dataset.from_tensor_slices(
    tf.py_func(_get_data_for_dataset, [file_name], tf.float32)))

dataset = dataset.batch(2)

iter = dataset.make_one_shot_iterator()
get_batch = iter.get_next()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

数据集 API“flat_map”方法对与“map”方法一起使用的相同代码产生错误 的相关文章

随机推荐