2012年12月17日

Threano scan function

http://deeplearning.net/software/theano/tutorial/loop.html 使用scan的好處在連結裡面已交代清楚,但是用法仍需再解釋清楚。首先從第一個範例開始,下面兩個程式都是計算A的k次方之值:
result = 1
for i in xrange(k):
    result = result * A

import theano
import theano.tensor as T
theano.config.warn.subtensor_merge_bug = False

k = T.iscalar("k")
A = T.vector("A")

def inner_fct(prior_result, A):
    return prior_result * A

# Symbolic description of the result
result, updates = theano.scan(fn=inner_fct,
                            outputs_info=T.ones_like(A),
                            non_sequences=A, n_steps=k)

'''
Scan has provided us with A ** 1 through A ** k.  
Keep only the last value. Scan notices this and 
does not waste memory saving them.
'''
final_result = result[-1]

power = theano.function(inputs=[A, k], outputs=final_result,
                      updates=updates)

print power(range(10),2)
#[  0.   1.   4.   9.  16.  25.  36.  49.  64.  81.]

  • scan當中的fn為所要執行的函數,也可使用lambda的方式來定義。
  • 第二個param outputs_info設定為大小與 A相同的矩陣,且矩陣內之值全部為1。
  • non_sequences為在scan當中不會變動之值,在此A在整個loop當中均不會變化。
  • steps為所要執行次數。