493 lines
7.7 KiB
Go
493 lines
7.7 KiB
Go
package gpt3encoder
|
||
|
||
import (
|
||
"bytes"
|
||
"embed"
|
||
"encoding/json"
|
||
"math"
|
||
"strings"
|
||
"sync"
|
||
|
||
"github.com/dlclark/regexp2"
|
||
"github.com/samber/lo"
|
||
)
|
||
|
||
//go:embed encoder.json
|
||
//go:embed vocab.bpe
|
||
var files embed.FS
|
||
|
||
var pat = regexp2.MustCompile(`/'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/`, 0)
|
||
|
||
func loadFiles() (bpeData []byte, encoder []byte, err error) {
|
||
bpeData, err = files.ReadFile("vocab.bpe")
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
encoder, err = files.ReadFile("encoder.json")
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// hardcoded
|
||
func bytesToUnicode() map[rune]string {
|
||
return map[rune]string{
|
||
0: "Ā",
|
||
1: "ā",
|
||
2: "Ă",
|
||
3: "ă",
|
||
4: "Ą",
|
||
5: "ą",
|
||
6: "Ć",
|
||
7: "ć",
|
||
8: "Ĉ",
|
||
9: "ĉ",
|
||
10: "Ċ",
|
||
11: "ċ",
|
||
12: "Č",
|
||
13: "č",
|
||
14: "Ď",
|
||
15: "ď",
|
||
16: "Đ",
|
||
17: "đ",
|
||
18: "Ē",
|
||
19: "ē",
|
||
20: "Ĕ",
|
||
21: "ĕ",
|
||
22: "Ė",
|
||
23: "ė",
|
||
24: "Ę",
|
||
25: "ę",
|
||
26: "Ě",
|
||
27: "ě",
|
||
28: "Ĝ",
|
||
29: "ĝ",
|
||
30: "Ğ",
|
||
31: "ğ",
|
||
32: "Ġ",
|
||
33: "!",
|
||
34: "\"",
|
||
35: "#",
|
||
36: "$",
|
||
37: "%",
|
||
38: "&",
|
||
39: "'",
|
||
40: "(",
|
||
41: ")",
|
||
42: "*",
|
||
43: "+",
|
||
44: ",",
|
||
45: "-",
|
||
46: ".",
|
||
47: "/",
|
||
48: "0",
|
||
49: "1",
|
||
50: "2",
|
||
51: "3",
|
||
52: "4",
|
||
53: "5",
|
||
54: "6",
|
||
55: "7",
|
||
56: "8",
|
||
57: "9",
|
||
58: ":",
|
||
59: ";",
|
||
60: "<",
|
||
61: "=",
|
||
62: ">",
|
||
63: "?",
|
||
64: "@",
|
||
65: "A",
|
||
66: "B",
|
||
67: "C",
|
||
68: "D",
|
||
69: "E",
|
||
70: "F",
|
||
71: "G",
|
||
72: "H",
|
||
73: "I",
|
||
74: "J",
|
||
75: "K",
|
||
76: "L",
|
||
77: "M",
|
||
78: "N",
|
||
79: "O",
|
||
80: "P",
|
||
81: "Q",
|
||
82: "R",
|
||
83: "S",
|
||
84: "T",
|
||
85: "U",
|
||
86: "V",
|
||
87: "W",
|
||
88: "X",
|
||
89: "Y",
|
||
90: "Z",
|
||
91: "[",
|
||
92: "\\",
|
||
93: "]",
|
||
94: "^",
|
||
95: "_",
|
||
96: "`",
|
||
97: "a",
|
||
98: "b",
|
||
99: "c",
|
||
100: "d",
|
||
101: "e",
|
||
102: "f",
|
||
103: "g",
|
||
104: "h",
|
||
105: "i",
|
||
106: "j",
|
||
107: "k",
|
||
108: "l",
|
||
109: "m",
|
||
110: "n",
|
||
111: "o",
|
||
112: "p",
|
||
113: "q",
|
||
114: "r",
|
||
115: "s",
|
||
116: "t",
|
||
117: "u",
|
||
118: "v",
|
||
119: "w",
|
||
120: "x",
|
||
121: "y",
|
||
122: "z",
|
||
123: "{",
|
||
124: "|",
|
||
125: "}",
|
||
126: "~",
|
||
127: "ġ",
|
||
128: "Ģ",
|
||
129: "ģ",
|
||
130: "Ĥ",
|
||
131: "ĥ",
|
||
132: "Ħ",
|
||
133: "ħ",
|
||
134: "Ĩ",
|
||
135: "ĩ",
|
||
136: "Ī",
|
||
137: "ī",
|
||
138: "Ĭ",
|
||
139: "ĭ",
|
||
140: "Į",
|
||
141: "į",
|
||
142: "İ",
|
||
143: "ı",
|
||
144: "IJ",
|
||
145: "ij",
|
||
146: "Ĵ",
|
||
147: "ĵ",
|
||
148: "Ķ",
|
||
149: "ķ",
|
||
150: "ĸ",
|
||
151: "Ĺ",
|
||
152: "ĺ",
|
||
153: "Ļ",
|
||
154: "ļ",
|
||
155: "Ľ",
|
||
156: "ľ",
|
||
157: "Ŀ",
|
||
158: "ŀ",
|
||
159: "Ł",
|
||
160: "ł",
|
||
161: "¡",
|
||
162: "¢",
|
||
163: "£",
|
||
164: "¤",
|
||
165: "¥",
|
||
166: "¦",
|
||
167: "§",
|
||
168: "¨",
|
||
169: "©",
|
||
170: "ª",
|
||
171: "«",
|
||
172: "¬",
|
||
173: "Ń",
|
||
174: "®",
|
||
175: "¯",
|
||
176: "°",
|
||
177: "±",
|
||
178: "²",
|
||
179: "³",
|
||
180: "´",
|
||
181: "µ",
|
||
182: "¶",
|
||
183: "·",
|
||
184: "¸",
|
||
185: "¹",
|
||
186: "º",
|
||
187: "»",
|
||
188: "¼",
|
||
189: "½",
|
||
190: "¾",
|
||
191: "¿",
|
||
192: "À",
|
||
193: "Á",
|
||
194: "Â",
|
||
195: "Ã",
|
||
196: "Ä",
|
||
197: "Å",
|
||
198: "Æ",
|
||
199: "Ç",
|
||
200: "È",
|
||
201: "É",
|
||
202: "Ê",
|
||
203: "Ë",
|
||
204: "Ì",
|
||
205: "Í",
|
||
206: "Î",
|
||
207: "Ï",
|
||
208: "Ð",
|
||
209: "Ñ",
|
||
210: "Ò",
|
||
211: "Ó",
|
||
212: "Ô",
|
||
213: "Õ",
|
||
214: "Ö",
|
||
215: "×",
|
||
216: "Ø",
|
||
217: "Ù",
|
||
218: "Ú",
|
||
219: "Û",
|
||
220: "Ü",
|
||
221: "Ý",
|
||
222: "Þ",
|
||
223: "ß",
|
||
224: "à",
|
||
225: "á",
|
||
226: "â",
|
||
227: "ã",
|
||
228: "ä",
|
||
229: "å",
|
||
230: "æ",
|
||
231: "ç",
|
||
232: "è",
|
||
233: "é",
|
||
234: "ê",
|
||
235: "ë",
|
||
236: "ì",
|
||
237: "í",
|
||
238: "î",
|
||
239: "ï",
|
||
240: "ð",
|
||
241: "ñ",
|
||
242: "ò",
|
||
243: "ó",
|
||
244: "ô",
|
||
245: "õ",
|
||
246: "ö",
|
||
247: "÷",
|
||
248: "ø",
|
||
249: "ù",
|
||
250: "ú",
|
||
251: "û",
|
||
252: "ü",
|
||
253: "ý",
|
||
254: "þ",
|
||
255: "ÿ",
|
||
}
|
||
}
|
||
|
||
func getPairs(word []string) []lo.Tuple2[string, string] {
|
||
all := []lo.Tuple2[string, string]{}
|
||
prevWord := word[0]
|
||
for _, char := range word[1:] {
|
||
all = append(all, lo.T2(prevWord, char))
|
||
prevWord = char
|
||
}
|
||
|
||
return lo.Uniq(all)
|
||
}
|
||
|
||
type Encoder struct {
|
||
bpeRank map[lo.Tuple2[string, string]]int
|
||
encoder map[string]int
|
||
decoder map[int]string
|
||
byteEncoder map[rune]string
|
||
byteDecoder map[string]rune
|
||
|
||
cache sync.Map
|
||
}
|
||
|
||
func NewEncoder() (*Encoder, error) {
|
||
bpeData, rawEncoder, err := loadFiles()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return NewEncoderWithVocab(bpeData, rawEncoder)
|
||
}
|
||
|
||
func NewEncoderWithVocab(bpeData []byte, jsonEncoder []byte) (*Encoder, error) {
|
||
encoder := map[string]int{}
|
||
|
||
err := json.Unmarshal(jsonEncoder, &encoder)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
decoder := lo.Invert(encoder)
|
||
|
||
bpeLines := bytes.Split(bpeData, []byte{'\n'})
|
||
bpeMerges := lo.Map(bpeLines[1:len(bpeLines)-1], func(line []byte, _ int) lo.Tuple2[[]byte, []byte] {
|
||
parts := bytes.SplitN(line, []byte{' '}, 2)
|
||
return lo.T2(parts[0], parts[1])
|
||
})
|
||
|
||
byteEncoder := bytesToUnicode()
|
||
byteDecoder := lo.Invert(byteEncoder)
|
||
|
||
// translated to a tuple of string, because []byte is not comparable (for maps)
|
||
bpeRank := dictTuple(bpeMerges)
|
||
|
||
enc := Encoder{
|
||
bpeRank: bpeRank,
|
||
encoder: encoder,
|
||
decoder: decoder,
|
||
byteEncoder: byteEncoder,
|
||
byteDecoder: byteDecoder,
|
||
|
||
cache: sync.Map{},
|
||
}
|
||
|
||
return &enc, nil
|
||
}
|
||
|
||
func (e *Encoder) cachedBpe(token string) string {
|
||
cached, ok := e.cache.Load(token)
|
||
if ok {
|
||
return cached.(string)
|
||
}
|
||
|
||
output := e.bpe(token)
|
||
e.cache.Store(token, output)
|
||
return output
|
||
}
|
||
|
||
func (e *Encoder) bpe(token string) string {
|
||
word := lo.ChunkString(token, 1)
|
||
|
||
pairs := getPairs(word)
|
||
if len(pairs) == 0 {
|
||
return token
|
||
}
|
||
|
||
for {
|
||
minPairs := lo.Map(pairs, func(item lo.Tuple2[string, string], index int) lo.Tuple3[int, string, string] {
|
||
pair := lo.T2(item.A, item.B)
|
||
|
||
rank, ok := e.bpeRank[pair]
|
||
if !ok {
|
||
rank = math.MaxInt
|
||
}
|
||
|
||
return lo.T3(rank, item.A, item.B)
|
||
})
|
||
|
||
bigram := lo.MinBy(minPairs, func(a lo.Tuple3[int, string, string], b lo.Tuple3[int, string, string]) bool {
|
||
return a.A < b.A
|
||
})
|
||
|
||
first := bigram.B
|
||
second := bigram.C
|
||
|
||
if _, ok := e.bpeRank[lo.T2(first, second)]; !ok {
|
||
break
|
||
}
|
||
|
||
newWord := []string{}
|
||
for i := 0; i < len(word); {
|
||
_, j, ok := lo.FindIndexOf(word[i:], func(str string) bool { return str == first })
|
||
if !ok {
|
||
newWord = append(newWord, word[i:]...)
|
||
break
|
||
}
|
||
|
||
j += i
|
||
|
||
newWord = append(newWord, word[i:j]...)
|
||
i = j
|
||
|
||
if word[i] == first && i < len(word)-1 && word[i+1] == second {
|
||
newWord = append(newWord, first+second)
|
||
i = i + 2
|
||
} else {
|
||
newWord = append(newWord, word[i])
|
||
i = i + 1
|
||
}
|
||
}
|
||
|
||
word = newWord
|
||
if len(word) == 1 {
|
||
break
|
||
}
|
||
|
||
pairs = getPairs(word)
|
||
}
|
||
|
||
return strings.Join(word, " ")
|
||
}
|
||
|
||
func (e *Encoder) splitToken(token string) ([]string, error) {
|
||
var matches []string
|
||
|
||
m, err := pat.FindStringMatch(token)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for m != nil {
|
||
matches = append(matches, m.String())
|
||
|
||
m, err = pat.FindNextMatch(m)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
return matches, nil
|
||
}
|
||
|
||
func (e *Encoder) Encode(text string) ([]int, error) {
|
||
bpeTokens := []int{}
|
||
|
||
matches, err := e.splitToken(text)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, match := range matches {
|
||
runes := []rune(match)
|
||
|
||
token := strings.Join(lo.Map(runes, func(item rune, _ int) string {
|
||
return e.byteEncoder[item]
|
||
}), "")
|
||
|
||
bpe := e.cachedBpe(token)
|
||
for _, t := range strings.Split(bpe, " ") {
|
||
bpeTokens = append(bpeTokens, e.encoder[t])
|
||
}
|
||
}
|
||
|
||
return bpeTokens, nil
|
||
}
|
||
|
||
func (e *Encoder) Decode(tokens []int) string {
|
||
parts := lo.Map(tokens, func(token int, _ int) string {
|
||
return e.decoder[token]
|
||
})
|
||
|
||
parts = lo.ChunkString(strings.Join(parts, ""), 1)
|
||
|
||
text := lo.Map(parts, func(item string, _ int) rune {
|
||
return e.byteDecoder[item]
|
||
})
|
||
|
||
return string(text)
|
||
}
|