อธิบาย Transformer : Attention is all you need
บทความนี้จะมาอธิบาย paper Attention Is All You Need ครับ
ก่อนจะเข้าที่ต้ว paper โดยตรงจะต้องมาอธิบายคร่าวๆก่อนว่าทำไมถึงได้มี paper นี้ออกมาครับ การแปลภาษาจากภาษานึง ไปอีกภาษานึงเป็นงานที่สามารถใช้ ML เข้ามาแก้ปัญหาได้
โดยก่อนหน้านี้ส่วนใหญ่จะใช้ RNN ในรูปเเบบ Encoder-Decoder architecture ครับ โดยจะใช้ RNN ในฝั่ง Encoder เปลี่ยนคำศัพท์ไปเรื่อยๆจนในทีสุดได้ 1 vector ที่จะเป็นตัว representation ของทั้งประโยคที่เราจะแปล จากนั้นก็ส่ง vector นี้ให้ Decoder แปลเป็นภาษาที่เราต้องการครับ หากอยากให้ model ได้ผลดีขึ้นก็สามารถทำได้หลากหลายท่า ไม่ว่าจะใช้ Deep RNN, LSTM, GRU, word embedding, attention ก็ได้ครับ
model เก่าก็มีอยู่เเล้ว ทำไมต้องออก model ใหม่มาด้วย?
ปัญหาหลักๆจาก model เก่ามีดังนี้ครับ
- RNN ไม่ว่าจะ vanilla rnn,lstm,gru นั้นเป็นเเบบ weight-tied หรือก็คือทุกคำใช้ weight เดียวกันหมด ทำให้ model ไม่ได้parameter หลากหลายพอที่จะเข้าใจภาษาได้ทั้งหมด
- Encoder-Decoder architecture นั้นไม่สามารถแปลภาษาเเบบ parallel ได้ เช่นหากผมอยากจะแปลคำว่า it ที่เป็นคำที่ 22 ในประโยค ผมก็ต้องให้ RNN อ่านตั้งเเต่คำเเรกยันคำที่ 21 ก่อนถึงจะแปลได้ ซึ่งช้ามากครับ ไม่ parallel ; นอกจากนั้นยังอาจจะทำให้ context ของคำว่า it หายไปด้วย เพราะ rnn ได้ encode มาเเล้วหลายคำเเล้วจนอาจจะลืมได้
- Encoder-Decoder architecture นั้นจะเปลี่ยนทั้งประโยคเป็นเพียง 1 vector ตรงนี้อาจจะไม่มีปัญหาสำหรับประโยคสั้นๆ เเต่ถ้าประโยคที่ยาวๆเเล้ว มันยากมากๆที่เราจะบีบความหมายของทั้งประโยคให้อยู่ใน vector เดียว
จากปัญหาเหล่านี้จึงมี idea ที่จะเอา RNN ใน Encoder-Decoder architecture ออกเเล้วใช้เเค่ attention mechanism เเทนครับ
เเล้ว attention mechanism คืออะไร?
attention ถูกนำเสนอครั้งเเรกใน paper นี้ครับ ดังนั้นผมจึงขออธิบาย paper นี้อย่างรวดเร็วซะก่อน ใน paper ยังคงเป็นงานแปลภาษาเเละใช้RNNครับ โดยจะนำคำในภาษาที่อยากแปลมาเข้า model ที่เป็น BRNN-LSTM ทำให้ได้ h-left ซึ่งเป็น vector จาก encoder ของคำเวลาอ่านจากซ้ายไปขวา เเละได้ h-right ซึ่งเป็นvector จาก encoder ของคำเวลาอ่านจากขวามาซ้ายครับ
จากนั้นนำทั้งสอง vector มาต่อกัน แล้วใช้ RNN ฝั่งแปลภาษาสร้างคำในอีกภาษาทีละคำครับ สิ่งที่เพิ่มเข้ามาคือเวลาจะแปล ตัวแปลภาษาจะไม่ได้ดูเเค่ vector เดียว เเต่จะดูมันทุก vector เเล้วจะใช้น้ำหนักในเเต่ละ vector ไม่เท่ากันครับ ตัวน้ำหนักนี้เองที่เรียกว่า attention score ดังนั้นจะไม่ต้องมากังวลว่าจะใช้บีบทั้งประโยคเข้าไปที่ 1 vector เเบบ encoder-decoder architecture ยังไง ใครที่สนใจสามารถอ่าน paper หรือไม่ก็ดู vdo นี้ก็ได้ครับ
Deeplearning.ai: https://www.youtube.com/watch?v=SysgYptB198
2110594 NLP(Chula) : https://www.youtube.com/watch?v=5SH9bpQ33Xk)
หลักจาก paper นี้ออกมาก็มีการพัฒนาตัว attention ไปมากมายครับเเละขอสรุปที่ต้องใช้เพื่ออธิบาย tranformer เป็นเเค่ 4 รูปเเบบนี้ครับ
Additive attention : เป็นการคิด attention score โดยนำคำที่แปลไปก่อนหน้า เเละ encoderของคำ มาต่อกัน ->คูณweight -> ผ่าน tanh เพื่อให้ได้ attention score สำหรับคำนั้นๆออกมาครับ
Multiplicative attention : เหมือน additive attention เลยครับ เเต่ว่าไม่ได้ผ่าน activation function ซึ่งจะทำให้คำนวณไวกว่าเเต่ก็ผลการทดลองก็ไม่ได้ดีไปกว่าadditive ครับ (ต้องเลือกว่าจะคำนวณเร็ว หรือ เเม่นยำ)
Self-attention : อันนี้ถูกเสนอใน paperนี้ครับ; เนื่องจากจะเห็นว่าที่ผ่านมา attention ถูกใช้กับ model แปลภาษาหมดเลย เเล้วถ้าไม่ได้แปลภาษาละ จะเอาไปใช้ยังไง? คำถามนี้จึงเป็นจุดกำเนิดของ self-attention ที่ผู้เขียนได้ใช้ attention เพื่อ classification task จากประโยค เช่น comment นี้คนเขียนน่าจะให้รีวิวสักกี่ดาวเป็นต้น โดยผมจะอธิบายสั้นๆไว้เหมือนเดิม หากใครต้องการอ่านเพิ่มเติมก็กดที่ link ได้เลยครับ; ในงาน classification นั้นผู้เขียนจะใช้ BRNN-LSTM ในการ encode เหมือนกันครับ จากนั้นก็นำ encoded-vector จากฝั่งซ้ายเเละขวามาต่อกันจนได้ vector H ออกมา จากนั้นนำ H นี้ไปคูณ weight เเล้วผ่าน tanh เเล้วจึงนำไปคูณ weight อีกรอบเเละผ่าน softmax เเละจะได้ attention score ของเเต่ละคำมาครับ นำ score นี้ไปคูณกับ H อีกรอบจะได้ embedding ของทั้งประโยคออกมาเเละเข้า Feed forward network เพื่อเดาคะเเนนเป็นอันจบครับ จะเห็นว่าไม่ต้องมี decoder ก็ใช้ attention ได้
key-value attention: ที่ผ่านมา attention score ก็คิดคล้ายๆกันหมดคือเอา hidden-state มาคูณ weight เเล้วเข้า softmax เพื่อให้ได้ attention score ออกมา เเต่ว่าการทำแบบนี้ก็เกิดปัญหาครับ เพราะว่า decoder ที่สร้าง attention score จะต้องทำหลายหน้าที่มาก ทั้งเดาคำต่อไป, สร้าง attention score เเละก็จำ context จากคำก่อนหน้า
จึงมี paper ที่อยากเเยกหน้าที่เหล่านี้ออกมาครับ (ผมจะอธิบายสั้นๆเหมือนเดิม อ่านเพิ่มเติมได้ที่นี่เลยครับ);จะมีการสร้าง key ขึ้นมาโดยนำ hidden state ของชั้นก่อนหน้า L คำมาคูณด้วย weightเเละบวก hidden ของคำปัจจุบันคูณด้วย weight คนละตัว เเล้วนำค่าที่บวกเเล้วเข้า tanh ครับ เมื่อผ่าน tanh มาเเล้วจะได้ vector M ออกมา เราจะเอา M ไปคูณ weight เเล้วเอาเข้า softmax อีกครั้งเพื่อสร้าง attention score ครับ; มาต่อที่ value ที่จะหา context-representation ใน window L ครับ โดยจะนำ attention score มาคูณด้วย value ของเเต่ละคำจนได้ vector r ออกมาครับ (value คือ hidden state ของคำก่อนหน้าครับ); ในขั้นสุดท้ายก็จะนำ vector r กับ value ของคำในปัจจุบันมาคิดเเบบ additive attention ครับ
หลังจากรู้จัก attention ไปเเล้ว เราจะไปกันต่อกับ transformer เลยครับ
โดยก่อนที่จะอ่านผมมี vdo เบื้องต้นไว้ให้ ซึ่งอธิบายไว้ได้ดีเลยทีเดียวครับ
stanford: https://www.youtube.com/watch?v=5vcj8kSwBCY
CodeEmporium : https://www.youtube.com/watch?v=TQQlZhbC5ps
Yannic Kilcher : https://www.youtube.com/watch?v=iDulhoQ2pro
สำหรับตัว transformer ให้ลองจินตนาการถึง encoder-decoder architecture ที่เอา rnn ออกไปหมดเลยครับ ก่อนอื่นมาลองรู้จักหน้าตาของมันกันก่อน
สมมติว่าผมอยากแปลจากภาษา eng ไปเป็น french นี่คือสิ่งที่ tranformer จะทำในเเต่ละ step ครับ
เริ่มจาก Encoder กันก่อน
- เราจะมี word embedding ของภาษา eng ในเเต่ละคำ ซึ่งเป็น pre-train มาครับ
- ผมจะนำ embedding นี้ไปบวกด้วย positional encoding ที่เป็นกราฟ sin,cos ที่ความถี่จะแปรผันไปตามมิติของ word embedding ในเเต่ละตำเเหน่งครับ เหตุผลที่ต้องบวกเข้าไปเพราะว่าเราเอา RNN ออกเเล้ว ทำให้ model ไม่สามารถรู้ได้ว่าคำไหนมาก่อนมาหลัง โดยสมการของกราฟนี้คือ
อาจจะมองภาพไม่ออกเเต่ลองคิดว่าถ้า word embedding เป็น vector, positional encoding ก็เป็นเเค่ vector ที่มีตัวเลขเเสดงความถี่ในเเต่ละตำเเหน่งมาบวกครับ ลองดูตัวอย่างด้านล่างดู
ตัวอย่างค่า positional embedding สำหรับ 20 คำเเรก โดยมิติของ word embedding คือ 512
ที่ใช้กราฟ sin-cos เพราะว่ากราฟพวกนี้มันสามารถสร้างได้เรื่อยๆ ไม่ว่าจำนวนคำจะมีกี่คำครับ ถ้าผมใช้เเค่ one-hot vector ที่มี 1 เเทนเเต่ละตำเเหน่งก็ยังสามารถใช้งานได้ เเต่ถ้างานที่เราจะนำไปใช้มีความยาวของประโยคมากกว่าที่ผม train model มาก็จะใช้ไม่ได้เลย ดังนั้นการใช้กราฟ sin-cos เป็นการการันตีได้ว่าไม่ว่าความยาวของประโยคที่จะแปลจะยาวเท่าไหร่ model เราก็ยังมี positional encoding ให้เสมอครับ
- หลังจากได้ embedding ที่มี positional encoding เเล้วเราจะเอา vector อันใหม่นี้ไปเข้า self-attention head โดย attention-head จะนำ vector z ไปคูณกับ W-q,W-v,W-k เพื่อสร้าง query,value,key ของเเต่ละคำออกมาครับ เเนวคิดของ query,value,key คือ
query จะเป็นตัวเเทนของ decoder-hidden-state
value เป็นตัวเเทนของ encoder-hidden-state ของคำอื่นๆ
key เป็นตัวเเทนของ encoder-hidden-state ของคำที่จะแปล
ผมจะเอา query ของคำที่จะแปลไปคูณกับ key ของคำอื่นๆทุกๆคำ เพื่อสร้าง score ออกมาครับ จากนั้นนำคำนั้นไปหารด้วย 1/sqrt(dk) ;โดย dk คือมิติของkey; จากนั้นนำทุกคะเเนนเข้า softmax จะได้ attention score ออกมาครับ
เราจะนำ attention score นี้ไปคูณกับ value ของเเต่ละคำ ถ้าคำไหน attention score เยอะ valueก็จะมีค่ามาก ถ้ามี attention score น้อย valueก็จะมีค่าน้อย เราจะนำทุก value ที่ผ่านการคูณของเเต่ละคำมาบวกกัน จนได้ vector z ที่เป็นตัวเเทนของคำนั้นๆ โดยคิดถึงอิทธิพลของคำอื่นๆในประโยคไว้เเล้วนั่นเองครับ
ในการทำงานจริง เราจะไม่ทำทีละ vector เเบบนี้นะครับ เพราะว่ามันช้า เราจะอัดมันเป็น matrix ของคำ เเล้วคูณ matrix นี้กับ weight ต่างๆ เพื่อให้ได้ attention score ของเเต่ละคำในคราวเดียวกันไปเลยครับ นอกจากนี้ transformer จะไม่ได้มี self-attention เเค่หัวเดียว เเต่ว่ามีถึง 8 หัวครับ หลักการของเเต่ละหัวจะเหมือนกันครับ คือนำ vector z ไปคูณกับ W-q,W-v,W-k เเต่ว่า weight ในเเต่ละหัวจะไม่เท่ากัน ทำให้เราได้ vector z ออกมา 8 ตัวที่ไม่เหมือนกัน…..
เเล้วจะเอา 8 ตัวนี้ไปใช้ยังไง?
ก่อนที่จะออกจาก multi head self-attention ตัว model จะเอา z ทั้ง 8 ตัวมาต่อกันยาวๆเเล้วคูณด้วย weight W-o เพื่อให้เหลือเเค่ z ตัวเดียว เเล้วจึงนำ z นั้นไปใช้ครับ
ที่ต้องมีหลายๆหัว เพราะว่าปกติเเล้ว self-attention มันก็จะ focus เเถวเเต่คำตัวเองนั่นเเหละครับ(ยกเว้นว่าจะมีการเพิ่ม penalty score เข้าไปที่ loss function — อ่านเพิ่มได้ที่ paper self-attention ครับ) ดังนั้นการมีหลายๆหัวก็ช่วย vary ค่า attention ให้ไปในคำอื่นๆบ้าง เเละนำค่าเหล่านั้นมารวมกันตอนท้าย
หลังจากได้มหากาฬ z ออกมาเเล้วจะนำ z นี้ไปเข้า residual network เเละ layer normalize ครับ ตัว residual network มีมาเพื่อไม่ให้ positional encoding หายไปครับหลังจากผ่านการ transform มาเเล้วครับ
จะเห็นว่า attention score ในส่วนที่มี residual จะยังคงเส้นเฉียงๆไว้ได้(แปลว่า focus ที่เเถวๆคำตัวเอง) เเต่ถ้าไม่มี residual จะมี attention score ที่มั่วมากๆ ส่วนการเพิ่ม timing signal ก็พอช่วยได้บ้าง เเต่ไม่เท่า residual ครับ
หลังจากนั้นก็จะเข้า feed forward network (ffn) ปกติ 1 รอบครับ ใน paper ไม่ได้อธิบายไว้ว่าเอาเข้าทำไม เเต่ผมเดาเล่นๆว่าเพื่อ non-linear เฉยๆ
พอออกจาก ffn ก็ผ่าน residual กับ layer normalize อีกรอบครับ จากนั้นเราจะได้ vector z ที่ผ่านร้อนผ่านหนาวมาบ้างเเล้ว เราก็จะให้มันไปผ่านร้อนผ่านหนาวอีก 5 รอบครับ คือเอามันเข้า encoder ใหม่อีก 5 รอบ สุทธิรวมเเล้วคือเข้า encoder ทั้งหมด 6 รอบครับ
หลังจากผ่านไป 6 รอบเราจะได้ vector z ที่พร้อมไปใช้งานต่อในขา Decoder เเล้วครับ
มาฝั่ง Decoder กันบ้าง
ฝั่ง decoder จะเหมือน encoder เลยครับ เเต่ว่าจะมี masked-self-attention เพิ่มเข้ามา ในส่วนของ masked-self-attention จะทำการ masked ไว้ด้วย โดยการ masked คือ หากผมอยากหา attention score ของคำที่สาม model จะเห็นเเค่คำก่อนหน้า หรือก็คือคำที่ 1,2 เท่านั้นครับ เพื่อไม่ให้ไปอิงกับคำทางขวาที่เป็นคำในอนาคต หลังจากได้ attention score ของเเต่ละคำมาเเล้วจะเอามาเข้า self-attention อีกรอบครับ เเต่คราวนี้จะใช้ key,query เป็น vector z ที่ได้มาจาก encoder ครับ
จากน้้นก็เหมือนเดิมเลย คือเข้า residual+layer normalization เเล้วจึงเข้า ffw เเล้ว residual+layer normalization อีกรอบ วนเเบบนี้ทั้งหมด 6 รอบ โดย key,query จะใช้จาก encoderเสมอครับ
หลังจากผ่าน decoder ทั้งหมดมาจะเข้า linear 1 รอบเเละเข้า softmax เพื่อหาคำแปลครับ
ก็จบเพียงเท่านี้สำหรับ transformer ที่ได้นำ rnn ออกเเล้วใช้เเค่ attention mechanism เพื่อแปลภาษาเเบบ parallel ได้ครับ
อ้างอิงเพิ่มเติม :