RRT-355M โมเดลภาษาที่ลอง 'ตัด softmax' ออกจากชั้น attention เพื่อให้อ่านบริบทยาว ๆ ได้โดยกิน VRAM น้อยลง
RRT-355M คือโมเดลภาษาทดลองขนาด GPT-2 ที่ลองถอด softmax ออกจากชั้น attention แล้วใช้การเปิด-ปิดเส้นเชื่อมระหว่างคำแทน จุดที่น่าจับตาคือมันออกแบบมาเพื่ออ่านข้อความยาว ๆ โดยกินแรมการ์ดจอน้อยลง

มีโมเดลภาษาตัวหนึ่งชื่อ RRT-355M ที่ลองแตะชิ้นส่วนซึ่งโมเดลภาษาเกือบทุกตัวใช้เหมือนกัน นั่นคือ softmax ในชั้น attention มันถอดชิ้นส่วนนี้ออก แล้วเปลี่ยนวิธีคิดใหม่ทั้งหมด RRT-355M ไม่ใช่โมเดลใหญ่ระดับ GPT หรือ Claude แต่เป็นงานทดลองขนาดเท่า GPT-2 Medium (ราว 354 ล้านพารามิเตอร์) ที่นักวิจัยอิสระคนหนึ่งเทรนขึ้นมาเอง เพื่อพิสูจน์แนวคิดเดียวว่า ถ้าตัดกลไกตัวนี้ทิ้ง โมเดลจะอ่านข้อความยาว ๆ ได้โดยกินแรมการ์ดจอ (VRAM) น้อยลงหรือเปล่า
เรื่องนี้น่าสนใจไม่ใช่เพราะคะแนนสอบของมันสูง (มันไม่สูง) แต่เพราะมันแตะจุดที่หลายคนคิดว่าแตะไม่ได้ ก่อนจะเข้าใจว่ามันแปลกตรงไหน ต้องรู้ก่อนว่า softmax ในโมเดลภาษาทำงานยังไง และทำไมมันถึงเป็นต้นทางที่ทำให้ข้อความยิ่งยาวยิ่งกินแรม
attention คือการลากเส้นจากทุกคำไปทุกคำ
หัวใจของโมเดลภาษาสมัยนี้คือกลไกที่ชื่อ attention หน้าที่ของมันคือดูว่าในประโยคหนึ่ง คำไหนควร "สนใจ" คำไหน เช่น เวลาโมเดลอ่านคำว่า "มัน" โมเดลต้องรู้ว่า "มัน" หมายถึงคำไหนที่มาก่อนหน้า attention จึงเป็นกลไกที่ช่วยลากเส้นความสัมพันธ์ระหว่างคำเหล่านี้
ปัญหาอยู่ที่ว่าโดยปกติ attention จะลากเส้นจาก ทุกคำไปหาทุกคำ ถ้าข้อความมี 10 คำ ก็มีเส้นความสัมพันธ์ราว 100 เส้น แต่ถ้าข้อความยาวขึ้นเป็น 1,000 คำ จำนวนเส้นจะพุ่งขึ้นเป็นหลักล้าน เพราะมันโตแบบกำลังสองตามความยาว ยิ่งข้อความยาว แรมการ์ดจอที่ต้องใช้เก็บเส้นพวกนี้ก็ยิ่งบานปลาย นี่คือเหตุผลที่การให้โมเดลอ่านเอกสารยาว ๆ ทั้งเล่มถึงแพงและกินทรัพยากรมหาศาล
softmax คือกฎที่บังคับให้น้ำหนักทุกเส้นรวมกันได้ 100%
แล้ว softmax เข้ามาเกี่ยวตรงไหน หลังจาก attention ให้คะแนนความสัมพันธ์ของแต่ละคู่คำแล้ว softmax คือขั้นตอนที่เอาคะแนนดิบเหล่านั้นมาเกลี่ยให้กลายเป็นสัดส่วน โดยมีกฎเหล็กข้อหนึ่งคือ ผลรวมของน้ำหนักทุกเส้นในแถวเดียวกันต้องเท่ากับ 100% เสมอ
ผลพวงของกฎข้อนี้คือทุกเส้นต้องได้น้ำหนักมากกว่าศูนย์เสมอ แม้แต่คู่คำที่แทบไม่เกี่ยวกันเลยก็ยังต้องได้ส่วนแบ่งเล็ก ๆ ติดไปด้วย พูดง่าย ๆ คือ softmax ไม่ยอมให้ตัดเส้นไหนทิ้งเป็นศูนย์สนิท เส้นที่ไม่จำเป็นจึงยังกินที่ในหน่วยความจำอยู่ดี
เปลี่ยนจากเกลี่ยน้ำหนักเป็นเปิด-ปิดเส้น

แทนที่จะใช้ softmax เกลี่ยน้ำหนักให้ทุกเส้น ชั้น attention ของ RRT-355M ใช้กลไกที่เรียกว่า deterministic gate ซึ่งทำงานแบบเปิด-ปิดเส้นแทน หลักการคิดเป็นลำดับสั้น ๆ ได้ว่า
- คำนวณคะแนนความเข้ากันได้ของคู่คำตามปกติ
- หักคะแนนตามระยะห่าง คู่คำที่อยู่ห่างกันมากจะโดนหักหนัก เพราะส่วนใหญ่มักไม่ได้เกี่ยวข้องกันจริง
- ถ้าคะแนนสุทธิไม่ผ่านเกณฑ์ ก็ ปิดเส้นนั้นทิ้ง ให้น้ำหนักเป็นศูนย์สนิท
- เส้นที่ผ่านเกณฑ์เท่านั้นที่จะเอาไปคิดต่อ
หัวใจอยู่ที่การหักคะแนนตามระยะห่างนี่เอง เพราะมันทำให้เส้นที่เชื่อมคำไกล ๆ ถูกตัดทิ้งไปเกือบหมด เหลือไว้แต่เส้นที่สำคัญจริง ผลที่ผู้สร้างวัดได้ตอนเทรนคือมีเส้นถูกตัดทิ้งถึง 99.66% หรือพูดอีกแบบคือ จากเส้นทั้งหมด เหลือเส้นที่ยัง "เปิด" อยู่จริงไม่ถึง 1% ช่องว่างตรงนี้แหละคือที่มาของการประหยัด เพราะเมื่อเส้นส่วนใหญ่เป็นศูนย์ ก็ไม่ต้องเสียแรมและกำลังประมวลผลไปกับมัน
เคอร์เนลที่กล้าข้ามช่องว่าง
รู้ว่าเส้นส่วนใหญ่เป็นศูนย์เป็นเรื่องหนึ่ง แต่การทำให้คอมพิวเตอร์ใช้ประโยชน์จากช่องว่างนั้นได้จริงเป็นอีกเรื่อง เพราะโดยปกติ การ์ดจอจะไล่คำนวณทุกช่องไม่ว่าจะมีค่าหรือไม่ ผู้สร้างจึงเขียนโปรแกรมสั่งงานการ์ดจอระดับล่างด้วย Triton (เครื่องมือสำหรับเขียนงานที่รันบน GPU) ให้ตรวจรอบแรกก่อนว่าบล็อกไหนไม่มีเส้นที่เปิดอยู่เลย แล้ว ข้ามบล็อกนั้นไปทั้งบล็อก ไม่ต้องเสียเวลาคำนวณ
ตัวเลขที่ผู้สร้างรายงานคือ ที่ความยาวข้อความ 2,048 token เคอร์เนลข้ามได้ 34% ของบล็อก และเมื่อข้อความยาวขึ้นเป็น 8,192 token ตัวเลขขยับขึ้นเป็น 55% ยิ่งข้อความยาว ช่องว่างยิ่งเยอะ การข้ามก็ยิ่งคุ้ม นี่คือเหตุผลที่แนวคิดนี้ถูกออกแบบมาเพื่องานบริบทยาวโดยเฉพาะ
เก่งน้อยลงเท่าไร แลกกับอะไร
คำถามที่สำคัญที่สุดของวิธีนี้คือ พอตัดเส้นทิ้งเกือบหมด โมเดลจะโง่ลงไหม ผู้สร้างเอา RRT-355M ไปวัดด้วยชุดทดสอบมาตรฐาน 22 งานที่ชื่อ CORE แล้วเทียบกับโมเดลขนาดใกล้กัน
| โมเดล | คะแนน CORE |
|---|---|
| GPT-2 124M (ฐานเทียบล่างสุด) | 0.1211 |
| RRT-355M (งานนี้) | 0.1558 |
| GPT-2 medium (ขนาดเท่ากัน) | 0.1770 |
| Pythia 410M (รุ่นใหม่กว่า) | 0.1895 |
ผลที่ออกมาคือ RRT-355M ทำได้ดีกว่า GPT-2 รุ่นเล็ก แต่ยังตามหลัง GPT-2 medium ที่ขนาดเท่ากันอยู่ราว 0.02 คะแนน ผู้สร้างเองเรียกสิ่งนี้ว่า "การแลกที่วัดได้ ไม่ใช่ความสามารถพังทลาย" ซึ่งตรงไปตรงมาดี บางงานมันทำได้ดีกว่า GPT-2 medium ด้วยซ้ำ เช่นโจทย์ถาม-ตอบเชิงเหตุผล แต่บางงานที่ต้องจำบริบทให้เป๊ะกลับแย่ลงชัดเจน เรื่องนี้พอเข้าใจได้ เพราะวิธีตัดเส้นไกลทิ้งย่อมทำให้โยงหาคำที่อยู่ห่างกันมากได้ยากขึ้น
นี่คือความซื่อสัตย์ของงานชิ้นนี้ มันไม่ได้บอกว่าตัด softmax แล้วได้ของฟรี แต่บอกว่ายอมเสียความแม่นนิดหน่อย เพื่อแลกกับการประหยัดแรมตอนอ่านข้อความยาว
ทำไมงานเล็ก ๆ ชิ้นนี้ถึงน่าจับตา
ต้องพูดให้ชัดก่อนว่า RRT-355M ไม่ใช่ของที่จะเอามาแทน GPT, Claude หรือ Gemini ได้ มันคืองานพิสูจน์แนวคิดของคนคนเดียว เทรนด้วยข้อมูลราว 11,500 ล้าน token บนการ์ด H100 สี่ใบ จบแล้วก็จบเลย ไม่มีแผนทำตัวใหญ่กว่านี้ต่อ คนในชุมชน r/MachineLearning เองก็ตั้งคำถามกับมันเยอะ โดยเฉพาะประเด็นที่ว่าในเมื่อ LM head ยังใช้ softmax อยู่ จะเรียกว่า softmax-free ได้เต็มปากแค่ไหน
แต่คุณค่าของมันไม่ได้อยู่ที่คะแนน มันอยู่ที่การกล้าตั้งคำถามกับชิ้นส่วนที่ทุกคนคิดว่าขยับไม่ได้ และลงมือพิสูจน์จนเห็นตัวเลขจริงว่าโมเดลที่ไม่บังคับให้น้ำหนักรวมกันเป็น 100% ก็ยังเรียนรู้ภาษาได้ ที่สำคัญคือมีของให้ลองจริง ทั้ง weights ที่เปิดให้โหลดและโค้ดที่เปิดทั้งหมด ใครอยากลองส่องว่ากลไกเปิด-ปิดเส้นนี้หน้าตาเป็นยังไง ก็เข้าไปโหลด weights ได้ที่ RRT-Foundation บน HuggingFace (เว็บศูนย์รวมโมเดลและ weights แบบเปิด) หรือดูโค้ดที่ RRT-LLM-FOUNDATION ได้เลย
คำถามที่ค้างไว้ให้คิดต่อคือ ถ้างานทดลองขนาดเล็กจากคนคนเดียวยังพิสูจน์ได้ว่าเส้นความสัมพันธ์ส่วนใหญ่ในโมเดลภาษาเป็นเส้นที่ตัดทิ้งได้ แล้วในโมเดลยักษ์ที่เราใช้กันทุกวัน ยังมีเส้นที่เปล่าประโยชน์ซ่อนอยู่อีกมากแค่ไหน
ที่มา: โพสต์ I released a softmax-free attention model at GPT-2 scale จาก r/MachineLearning



