Tensorflow 2.0 Transfer Learning

What is Transfer Learning

Four Steps:

  • Load data
  • Build model
  • Train and Test
  • Transfer Learning

Load Data

  • Images and labels
    • X = [1.png, 2.png, 3.png]
    • Y = [4, 9, 1]
  • data = tf.data.Dataset.from_tensor_slices((x,y))
  • data.shuffle().map(func).batch()

Preprocessing

  • Read and resize
    • 224*224 for ResNet
  • Data Augmentation
    • Rotate/Flip
    • Crop
  • Normalize
    • Mean, std
def preprocess(x, y):
    # x: 图片的路径,y:图片的数字编码
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)  # RGBA
    x = tf.image.resize(x, [244, 244])

    # x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    x = tf.image.random_crop(x, [224, 224, 3])

    # x: [0,255]=> -1~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)

    return x, y

Build Model

  • Inherit from Model
  • Define forward graph
  • Add optimizer

Train and Test

  • Train, validation, test
  • Early stopping

详见Code

Reference

  1. TensorFlow-2.x-Tutorials
  2. Transfer learning

Note: Cover Picture