From abf3bba2d561c9f475409ab79f66b2cc88a19d2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 10 Jun 2019 17:42:31 +0800 Subject: [PATCH] Clarify DenseTranspose --- 17_autoencoders_and_gans.ipynb | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/17_autoencoders_and_gans.ipynb b/17_autoencoders_and_gans.ipynb index c18e320..16b2151 100644 --- a/17_autoencoders_and_gans.ipynb +++ b/17_autoencoders_and_gans.ipynb @@ -380,8 +380,6 @@ "metadata": {}, "outputs": [], "source": [ - "K = keras.backend\n", - "\n", "class DenseTranspose(keras.layers.Layer):\n", " def __init__(self, dense, activation=None, **kwargs):\n", " self.dense = dense\n", @@ -393,8 +391,8 @@ " initializer=\"zeros\")\n", " super().build(batch_input_shape)\n", " def call(self, inputs):\n", - " z = inputs @ K.transpose(self.dense.weights[0]) + self.biases\n", - " return self.activation(z)" + " z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)\n", + " return self.activation(z + self.biases)" ] }, { @@ -403,20 +401,27 @@ "metadata": {}, "outputs": [], "source": [ + "keras.backend.clear_session()\n", "tf.random.set_seed(42)\n", "np.random.seed(42)\n", "\n", + "dense_1 = keras.layers.Dense(100, activation=\"selu\")\n", + "dense_2 = keras.layers.Dense(30, activation=\"selu\")\n", + "\n", "tied_encoder = keras.models.Sequential([\n", " keras.layers.Flatten(input_shape=[28, 28]),\n", - " keras.layers.Dense(100, activation=\"selu\"),\n", - " keras.layers.Dense(30, activation=\"selu\"),\n", + " dense_1,\n", + " dense_2\n", "])\n", + "\n", "tied_decoder = keras.models.Sequential([\n", - " DenseTranspose(tied_encoder.layers[2], activation=\"selu\"),\n", - " DenseTranspose(tied_encoder.layers[1], activation=\"sigmoid\"),\n", + " DenseTranspose(dense_2, activation=\"selu\"),\n", + " DenseTranspose(dense_1, activation=\"sigmoid\"),\n", " keras.layers.Reshape([28, 28])\n", "])\n", + "\n", "tied_ae = keras.models.Sequential([tied_encoder, tied_decoder])\n", + "\n", "tied_ae.compile(loss=\"binary_crossentropy\",\n", " optimizer=keras.optimizers.SGD(lr=1.5), metrics=[rounded_accuracy])\n", "history = tied_ae.fit(X_train, X_train, epochs=10,\n", @@ -473,6 +478,7 @@ "tf.random.set_seed(42)\n", "np.random.seed(42)\n", "\n", + "K = keras.backend\n", "X_train_flat = K.batch_flatten(X_train) # equivalent to .reshape(-1, 28 * 28)\n", "X_valid_flat = K.batch_flatten(X_valid)\n", "enc1, dec1, X_train_enc1, X_valid_enc1 = train_autoencoder(\n",