मैं वर्तमान में TensorFlow के साथ CIFAR10 डेटासेट के साथ काम कर रहा हूँ। विभिन्न कारणों से मुझे एक पूर्वनिर्धारित नियम द्वारा लेबल बदलने की आवश्यकता है, उदाहरण के लिए। प्रत्येक उदाहरण, जिसमें 4 का लेबल है, को 3 में बदला जाना चाहिए या प्रत्येक जिसमें 1 है, को 6 में बदला जाना चाहिए।

मैंने निम्नलिखित विधि की कोशिश की है:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')

def relabel_map(l):
    return {0: 0, 1: 6, 2: 1, 3: 2, 4: 3, 5: 4, 6: 9, 7: 5, 8: 7, 9: 8}[l]

ds_train = ds_train.map(lambda example: (example['image'], tf.py_function(relabel_map, [example['label']], [tf.int64])))

for ex in ds_train.take(1):
    plt.imshow(np.array(ex[0], dtype=np.uint8))
    plt.show()
    print(ex[1])

जब मैं इसे चलाने का प्रयास करता हूं, तो मुझे for ex in ds_train.take(1): के साथ निम्न त्रुटि मिलती है:

लेखन त्रुटि: टेंसर हैश करने योग्य नहीं है। इसके बजाय, कुंजी के रूप में tensor.ref() का उपयोग करें।

मेरा अजगर संस्करण 3.8.12 है और TensorFlow संस्करण 2.7.0 है।

पीएस: हो सकता है कि मैं इस परिवर्तन को एक-गर्म में परिवर्तित करके और इसे मैट्रिक्स के साथ बदलकर कर सकता हूं, लेकिन यह कोड में बहुत कम सरल दिखाई देगा।

0
leevii 27 नवम्बर 2021, 02:15

1 उत्तर

सबसे बढ़िया उत्तर

मैं आपके मामले के लिए tf.lookup.StaticHashTable का उपयोग करने का सुझाव दूंगा:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')

table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=tf.int64),
        values=tf.constant([0, 6, 1, 2, 3, 4, 9, 5, 7, 8],  dtype=tf.int64),
    ),
    default_value= tf.constant(0,  dtype=tf.int64)
)

def relabel_map(example):
    example['label'] = table.lookup(example['label'])
    return example

ds_train = ds_train.map(relabel_map)

for ex in ds_train.take(1):
    plt.imshow(np.array(ex['image'], dtype=np.uint8))
    plt.show()
    print(ex['label'])

enter image description here

tf.Tensor(5, shape=(), dtype=int64)
1
AloneTogether 27 नवम्बर 2021, 10:39
धन्यवाद, इसने काम किया
 – 
leevii
27 नवम्बर 2021, 13:22