x = x.reshape([-1, 28, 28, 1]) x = nn.Conv(features=16, kernel_size=(5, 5)(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape([x.shape[0], -1]) # Flatten x = nn.Dense(features=1024)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) if get_logits: return x x = nn.softmax(x) return x 条件分岐を記述可能