tensorflow에서 일부 변수 / 범위를 "고정": stop_gradient 대 변수 전달을 최소화
교대 훈련 미니 배치 중에 그래프의 하나 또는 다른 부분을 '고정'해야하는 Adversarial NN 을 구현하려고합니다 . 즉 G와 D라는 두 개의 하위 네트워크가 있습니다.
G( Z ) -> Xz
D( X ) -> Y
의 경우 손실 기능 G
에 따라 달라집니다 D[G(Z)], D[X]
.
먼저 모든 G 매개 변수가 고정 된 D의 매개 변수를 훈련 한 다음 D의 매개 변수가 고정 된 G의 매개 변수를 훈련해야합니다. 첫 번째 경우 손실 함수는 두 번째 경우 부정적인 손실 함수가되며 업데이트는 첫 번째 또는 두 번째 서브 네트워크의 매개 변수에 적용되어야합니다.
나는 tensorflow가 tf.stop_gradient
기능을 가지고 있음을 보았다 . D (다운 스트림) 서브 네트워크를 훈련하기 위해이 함수를 사용하여 기울기 흐름을 차단할 수 있습니다.
Z -> [ G ] -> tf.stop_gradient(Xz) -> [ D ] -> Y
에 tf.stop_gradient
인라인 예제없이 매우 간결하게 주석을 달았 seq2seq.py
지만 (예제 가 너무 길고 읽기 쉽지 않음) 그래프 생성 중에 호출해야하는 것처럼 보입니다. 교대 배치에서 그래디언트 흐름을 차단 / 차단 해제하려면 그래프 모델을 다시 만들고 다시 초기화해야 함을 의미합니까?
또한 G (업스트림) 네트워크를 통해 흐르는 그래디언트를을 통해 차단할 수없는 것 같습니다 .tf.stop_gradient
대안으로 변수 목록을 옵티 마이저 호출에으로 전달할 수 있다는 것을 알았습니다 opt_op = opt.minimize(cost, <list of variables>)
. 각 서브 네트워크의 범위에서 모든 변수를 얻을 수 있다면 쉬운 솔루션이 될 것입니다. tf.scope에 대해 얻을 수 있습니까 <list of variables>
?
질문에서 언급했듯이이를 달성하는 가장 쉬운 방법은 별도의 호출을 사용하여 두 개의 최적화 작업을 만드는 것 opt.minimize(cost, ...)
입니다. 기본적으로 최적화 프로그램은의 모든 변수를 사용합니다 tf.trainable_variables()
. 특정 범위로 변수를 필터링하려는 경우 다음과 같이 선택적 scope
인수를 사용할 수 있습니다 tf.get_collection()
.
optimizer = tf.train.AdagradOptimzer(0.01)
first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
"scope/prefix/for/first/vars")
first_train_op = optimizer.minimize(cost, var_list=first_train_vars)
second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
"scope/prefix/for/second/vars")
second_train_op = optimizer.minimize(cost, var_list=second_train_vars)
@mrry의 대답은 완전히 옳고 아마도 내가 제안하려는 것보다 더 일반적입니다. 그러나 그것을 달성하는 더 간단한 방법은 파이썬 참조를 직접 전달하는 것입니다 var_list
.
W = tf.Variable(...)
C = tf.Variable(...)
Y_est = tf.matmul(W,C)
loss = tf.reduce_sum((data-Y_est)**2)
optimizer = tf.train.AdamOptimizer(0.001)
# You can pass the python object directly
train_W = optimizer.minimize(loss, var_list=[W])
train_C = optimizer.minimize(loss, var_list=[C])
여기에 자체 포함 된 예가 있습니다. https://gist.github.com/ahwillia/8cedc710352eb919b684d8848bc2df3a
고려할 수있는 또 다른 옵션은 변수에 trainable = False를 설정할 수 있다는 것입니다. 즉, 교육으로 수정되지 않습니다.
tf.Variable(my_weights, trainable=False)
내 접근 방식에 단점이 있는지는 모르겠지만이 구성을 사용하여이 문제를 해결했습니다.
do_gradient = <Tensor that evaluates to 0 or 1>
no_gradient = 1 - do_gradient
wrapped_op = do_gradient * original + no_gradient * tf.stop_gradient(original)
따라서이면 do_gradient = 1
값과 그래디언트가 잘 흐르지 만 인 경우 do_gradient = 0
값은 stop_gradient op를 통해서만 흐르고 그래디언트가 역류하는 것을 중지합니다.
내 시나리오에서 do_gradient를 random_shuffle 텐서의 인덱스에 연결하면 네트워크의 다른 부분을 무작위로 훈련시킬 수 있습니다.
'programing' 카테고리의 다른 글
SQL Server에서 현재 날짜에 시간을 추가하는 방법은 무엇입니까? (0) | 2021.01.18 |
---|---|
대상 주체 이름이 올바르지 않습니다. (0) | 2021.01.18 |
PHP는 null이 0이라고 간주합니다. (0) | 2021.01.18 |
동일한 모델에서 여러 연결이있는 Rails Polymorphic Association (0) | 2021.01.18 |
두 개의 IEnumerable을 동시에 반복하는 방법은 무엇입니까? (0) | 2021.01.18 |