go-gpt-3-encoder/encoder.go

493 lines
7.7 KiB
Go
Raw Normal View History

2022-12-21 23:04:00 +00:00
package gpt3encoder
import (
"bytes"
"embed"
"encoding/json"
"math"
"strings"
"sync"
"github.com/dlclark/regexp2"
"github.com/samber/lo"
)
//go:embed encoder.json
//go:embed vocab.bpe
var files embed.FS
var pat = regexp2.MustCompile(`/'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/`, 0)
func loadFiles() (bpeData []byte, encoder []byte, err error) {
bpeData, err = files.ReadFile("vocab.bpe")
if err != nil {
return
}
encoder, err = files.ReadFile("encoder.json")
if err != nil {
return
}
return
}
// hardcoded
func bytesToUnicode() map[rune]string {
return map[rune]string{
0: "Ā",
1: "ā",
2: "Ă",
3: "ă",
4: "Ą",
5: "ą",
6: "Ć",
7: "ć",
8: "Ĉ",
9: "ĉ",
10: "Ċ",
11: "ċ",
12: "Č",
13: "č",
14: "Ď",
15: "ď",
16: "Đ",
17: "đ",
18: "Ē",
19: "ē",
20: "Ĕ",
21: "ĕ",
22: "Ė",
23: "ė",
24: "Ę",
25: "ę",
26: "Ě",
27: "ě",
28: "Ĝ",
29: "ĝ",
30: "Ğ",
31: "ğ",
32: "Ġ",
33: "!",
34: "\"",
35: "#",
36: "$",
37: "%",
38: "&",
39: "'",
40: "(",
41: ")",
42: "*",
43: "+",
44: ",",
45: "-",
46: ".",
47: "/",
48: "0",
49: "1",
50: "2",
51: "3",
52: "4",
53: "5",
54: "6",
55: "7",
56: "8",
57: "9",
58: ":",
59: ";",
60: "<",
61: "=",
62: ">",
63: "?",
64: "@",
65: "A",
66: "B",
67: "C",
68: "D",
69: "E",
70: "F",
71: "G",
72: "H",
73: "I",
74: "J",
75: "K",
76: "L",
77: "M",
78: "N",
79: "O",
80: "P",
81: "Q",
82: "R",
83: "S",
84: "T",
85: "U",
86: "V",
87: "W",
88: "X",
89: "Y",
90: "Z",
91: "[",
92: "\\",
93: "]",
94: "^",
95: "_",
96: "`",
97: "a",
98: "b",
99: "c",
100: "d",
101: "e",
102: "f",
103: "g",
104: "h",
105: "i",
106: "j",
107: "k",
108: "l",
109: "m",
110: "n",
111: "o",
112: "p",
113: "q",
114: "r",
115: "s",
116: "t",
117: "u",
118: "v",
119: "w",
120: "x",
121: "y",
122: "z",
123: "{",
124: "|",
125: "}",
126: "~",
127: "ġ",
128: "Ģ",
129: "ģ",
130: "Ĥ",
131: "ĥ",
132: "Ħ",
133: "ħ",
134: "Ĩ",
135: "ĩ",
136: "Ī",
137: "ī",
138: "Ĭ",
139: "ĭ",
140: "Į",
141: "į",
142: "İ",
143: "ı",
144: "IJ",
145: "ij",
146: "Ĵ",
147: "ĵ",
148: "Ķ",
149: "ķ",
150: "ĸ",
151: "Ĺ",
152: "ĺ",
153: "Ļ",
154: "ļ",
155: "Ľ",
156: "ľ",
157: "Ŀ",
158: "ŀ",
159: "Ł",
160: "ł",
161: "¡",
162: "¢",
163: "£",
164: "¤",
165: "¥",
166: "¦",
167: "§",
168: "¨",
169: "©",
170: "ª",
171: "«",
172: "¬",
173: "Ń",
174: "®",
175: "¯",
176: "°",
177: "±",
178: "²",
179: "³",
180: "´",
181: "µ",
182: "¶",
183: "·",
184: "¸",
185: "¹",
186: "º",
187: "»",
188: "¼",
189: "½",
190: "¾",
191: "¿",
192: "À",
193: "Á",
194: "Â",
195: "Ã",
196: "Ä",
197: "Å",
198: "Æ",
199: "Ç",
200: "È",
201: "É",
202: "Ê",
203: "Ë",
204: "Ì",
205: "Í",
206: "Î",
207: "Ï",
208: "Ð",
209: "Ñ",
210: "Ò",
211: "Ó",
212: "Ô",
213: "Õ",
214: "Ö",
215: "×",
216: "Ø",
217: "Ù",
218: "Ú",
219: "Û",
220: "Ü",
221: "Ý",
222: "Þ",
223: "ß",
224: "à",
225: "á",
226: "â",
227: "ã",
228: "ä",
229: "å",
230: "æ",
231: "ç",
232: "è",
233: "é",
234: "ê",
235: "ë",
236: "ì",
237: "í",
238: "î",
239: "ï",
240: "ð",
241: "ñ",
242: "ò",
243: "ó",
244: "ô",
245: "õ",
246: "ö",
247: "÷",
248: "ø",
249: "ù",
250: "ú",
251: "û",
252: "ü",
253: "ý",
254: "þ",
255: "ÿ",
}
}
func getPairs(word []string) []lo.Tuple2[string, string] {
all := []lo.Tuple2[string, string]{}
prevWord := word[0]
for _, char := range word[1:] {
all = append(all, lo.T2(prevWord, char))
prevWord = char
}
return lo.Uniq(all)
}
type Encoder struct {
bpeRank map[lo.Tuple2[string, string]]int
encoder map[string]int
decoder map[int]string
byteEncoder map[rune]string
byteDecoder map[string]rune
cache sync.Map
}
func NewEncoder() (*Encoder, error) {
bpeData, rawEncoder, err := loadFiles()
if err != nil {
return nil, err
}
return NewEncoderWithVocab(bpeData, rawEncoder)
}
func NewEncoderWithVocab(bpeData []byte, jsonEncoder []byte) (*Encoder, error) {
encoder := map[string]int{}
err := json.Unmarshal(jsonEncoder, &encoder)
if err != nil {
return nil, err
}
decoder := lo.Invert(encoder)
bpeLines := bytes.Split(bpeData, []byte{'\n'})
bpeMerges := lo.Map(bpeLines[1:len(bpeLines)-1], func(line []byte, _ int) lo.Tuple2[[]byte, []byte] {
parts := bytes.SplitN(line, []byte{' '}, 2)
return lo.T2(parts[0], parts[1])
})
byteEncoder := bytesToUnicode()
byteDecoder := lo.Invert(byteEncoder)
// translated to a tuple of string, because []byte is not comparable (for maps)
bpeRank := dictTuple(bpeMerges)
enc := Encoder{
bpeRank: bpeRank,
encoder: encoder,
decoder: decoder,
byteEncoder: byteEncoder,
byteDecoder: byteDecoder,
cache: sync.Map{},
}
return &enc, nil
}
func (e *Encoder) cachedBpe(token string) string {
cached, ok := e.cache.Load(token)
if ok {
return cached.(string)
}
output := e.bpe(token)
e.cache.Store(token, output)
return output
}
func (e *Encoder) bpe(token string) string {
word := lo.ChunkString(token, 1)
pairs := getPairs(word)
if len(pairs) == 0 {
return token
}
for {
minPairs := lo.Map(pairs, func(item lo.Tuple2[string, string], index int) lo.Tuple3[int, string, string] {
pair := lo.T2(item.A, item.B)
rank, ok := e.bpeRank[pair]
if !ok {
rank = math.MaxInt
}
return lo.T3(rank, item.A, item.B)
})
bigram := lo.MinBy(minPairs, func(a lo.Tuple3[int, string, string], b lo.Tuple3[int, string, string]) bool {
return a.A < b.A
})
first := bigram.B
second := bigram.C
if _, ok := e.bpeRank[lo.T2(first, second)]; !ok {
break
}
newWord := []string{}
for i := 0; i < len(word); {
_, j, ok := lo.FindIndexOf(word[i:], func(str string) bool { return str == first })
if !ok {
newWord = append(newWord, word[i:]...)
break
}
j += i
newWord = append(newWord, word[i:j]...)
i = j
if word[i] == first && i < len(word)-1 && word[i+1] == second {
newWord = append(newWord, first+second)
i = i + 2
} else {
newWord = append(newWord, word[i])
i = i + 1
}
}
word = newWord
if len(word) == 1 {
break
}
pairs = getPairs(word)
}
return strings.Join(word, " ")
}
func (e *Encoder) splitToken(token string) ([]string, error) {
var matches []string
m, err := pat.FindStringMatch(token)
if err != nil {
return nil, err
}
for m != nil {
matches = append(matches, m.String())
m, err = pat.FindNextMatch(m)
if err != nil {
return nil, err
}
}
return matches, nil
}
func (e *Encoder) Encode(text string) ([]int, error) {
bpeTokens := []int{}
matches, err := e.splitToken(text)
if err != nil {
return nil, err
}
for _, match := range matches {
runes := []rune(match)
token := strings.Join(lo.Map(runes, func(item rune, _ int) string {
return e.byteEncoder[item]
}), "")
bpe := e.cachedBpe(token)
for _, t := range strings.Split(bpe, " ") {
bpeTokens = append(bpeTokens, e.encoder[t])
}
}
return bpeTokens, nil
}
func (e *Encoder) Decode(tokens []int) string {
parts := lo.Map(tokens, func(token int, _ int) string {
return e.decoder[token]
})
parts = lo.ChunkString(strings.Join(parts, ""), 1)
text := lo.Map(parts, func(item string, _ int) rune {
return e.byteDecoder[item]
})
return string(text)
}