493 lines
7.7 KiB
Go
493 lines
7.7 KiB
Go
|
package gpt3encoder
|
|||
|
|
|||
|
import (
|
|||
|
"bytes"
|
|||
|
"embed"
|
|||
|
"encoding/json"
|
|||
|
"math"
|
|||
|
"strings"
|
|||
|
"sync"
|
|||
|
|
|||
|
"github.com/dlclark/regexp2"
|
|||
|
"github.com/samber/lo"
|
|||
|
)
|
|||
|
|
|||
|
//go:embed encoder.json
|
|||
|
//go:embed vocab.bpe
|
|||
|
var files embed.FS
|
|||
|
|
|||
|
var pat = regexp2.MustCompile(`/'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/`, 0)
|
|||
|
|
|||
|
func loadFiles() (bpeData []byte, encoder []byte, err error) {
|
|||
|
bpeData, err = files.ReadFile("vocab.bpe")
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
encoder, err = files.ReadFile("encoder.json")
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
// hardcoded
|
|||
|
func bytesToUnicode() map[rune]string {
|
|||
|
return map[rune]string{
|
|||
|
0: "Ā",
|
|||
|
1: "ā",
|
|||
|
2: "Ă",
|
|||
|
3: "ă",
|
|||
|
4: "Ą",
|
|||
|
5: "ą",
|
|||
|
6: "Ć",
|
|||
|
7: "ć",
|
|||
|
8: "Ĉ",
|
|||
|
9: "ĉ",
|
|||
|
10: "Ċ",
|
|||
|
11: "ċ",
|
|||
|
12: "Č",
|
|||
|
13: "č",
|
|||
|
14: "Ď",
|
|||
|
15: "ď",
|
|||
|
16: "Đ",
|
|||
|
17: "đ",
|
|||
|
18: "Ē",
|
|||
|
19: "ē",
|
|||
|
20: "Ĕ",
|
|||
|
21: "ĕ",
|
|||
|
22: "Ė",
|
|||
|
23: "ė",
|
|||
|
24: "Ę",
|
|||
|
25: "ę",
|
|||
|
26: "Ě",
|
|||
|
27: "ě",
|
|||
|
28: "Ĝ",
|
|||
|
29: "ĝ",
|
|||
|
30: "Ğ",
|
|||
|
31: "ğ",
|
|||
|
32: "Ġ",
|
|||
|
33: "!",
|
|||
|
34: "\"",
|
|||
|
35: "#",
|
|||
|
36: "$",
|
|||
|
37: "%",
|
|||
|
38: "&",
|
|||
|
39: "'",
|
|||
|
40: "(",
|
|||
|
41: ")",
|
|||
|
42: "*",
|
|||
|
43: "+",
|
|||
|
44: ",",
|
|||
|
45: "-",
|
|||
|
46: ".",
|
|||
|
47: "/",
|
|||
|
48: "0",
|
|||
|
49: "1",
|
|||
|
50: "2",
|
|||
|
51: "3",
|
|||
|
52: "4",
|
|||
|
53: "5",
|
|||
|
54: "6",
|
|||
|
55: "7",
|
|||
|
56: "8",
|
|||
|
57: "9",
|
|||
|
58: ":",
|
|||
|
59: ";",
|
|||
|
60: "<",
|
|||
|
61: "=",
|
|||
|
62: ">",
|
|||
|
63: "?",
|
|||
|
64: "@",
|
|||
|
65: "A",
|
|||
|
66: "B",
|
|||
|
67: "C",
|
|||
|
68: "D",
|
|||
|
69: "E",
|
|||
|
70: "F",
|
|||
|
71: "G",
|
|||
|
72: "H",
|
|||
|
73: "I",
|
|||
|
74: "J",
|
|||
|
75: "K",
|
|||
|
76: "L",
|
|||
|
77: "M",
|
|||
|
78: "N",
|
|||
|
79: "O",
|
|||
|
80: "P",
|
|||
|
81: "Q",
|
|||
|
82: "R",
|
|||
|
83: "S",
|
|||
|
84: "T",
|
|||
|
85: "U",
|
|||
|
86: "V",
|
|||
|
87: "W",
|
|||
|
88: "X",
|
|||
|
89: "Y",
|
|||
|
90: "Z",
|
|||
|
91: "[",
|
|||
|
92: "\\",
|
|||
|
93: "]",
|
|||
|
94: "^",
|
|||
|
95: "_",
|
|||
|
96: "`",
|
|||
|
97: "a",
|
|||
|
98: "b",
|
|||
|
99: "c",
|
|||
|
100: "d",
|
|||
|
101: "e",
|
|||
|
102: "f",
|
|||
|
103: "g",
|
|||
|
104: "h",
|
|||
|
105: "i",
|
|||
|
106: "j",
|
|||
|
107: "k",
|
|||
|
108: "l",
|
|||
|
109: "m",
|
|||
|
110: "n",
|
|||
|
111: "o",
|
|||
|
112: "p",
|
|||
|
113: "q",
|
|||
|
114: "r",
|
|||
|
115: "s",
|
|||
|
116: "t",
|
|||
|
117: "u",
|
|||
|
118: "v",
|
|||
|
119: "w",
|
|||
|
120: "x",
|
|||
|
121: "y",
|
|||
|
122: "z",
|
|||
|
123: "{",
|
|||
|
124: "|",
|
|||
|
125: "}",
|
|||
|
126: "~",
|
|||
|
127: "ġ",
|
|||
|
128: "Ģ",
|
|||
|
129: "ģ",
|
|||
|
130: "Ĥ",
|
|||
|
131: "ĥ",
|
|||
|
132: "Ħ",
|
|||
|
133: "ħ",
|
|||
|
134: "Ĩ",
|
|||
|
135: "ĩ",
|
|||
|
136: "Ī",
|
|||
|
137: "ī",
|
|||
|
138: "Ĭ",
|
|||
|
139: "ĭ",
|
|||
|
140: "Į",
|
|||
|
141: "į",
|
|||
|
142: "İ",
|
|||
|
143: "ı",
|
|||
|
144: "IJ",
|
|||
|
145: "ij",
|
|||
|
146: "Ĵ",
|
|||
|
147: "ĵ",
|
|||
|
148: "Ķ",
|
|||
|
149: "ķ",
|
|||
|
150: "ĸ",
|
|||
|
151: "Ĺ",
|
|||
|
152: "ĺ",
|
|||
|
153: "Ļ",
|
|||
|
154: "ļ",
|
|||
|
155: "Ľ",
|
|||
|
156: "ľ",
|
|||
|
157: "Ŀ",
|
|||
|
158: "ŀ",
|
|||
|
159: "Ł",
|
|||
|
160: "ł",
|
|||
|
161: "¡",
|
|||
|
162: "¢",
|
|||
|
163: "£",
|
|||
|
164: "¤",
|
|||
|
165: "¥",
|
|||
|
166: "¦",
|
|||
|
167: "§",
|
|||
|
168: "¨",
|
|||
|
169: "©",
|
|||
|
170: "ª",
|
|||
|
171: "«",
|
|||
|
172: "¬",
|
|||
|
173: "Ń",
|
|||
|
174: "®",
|
|||
|
175: "¯",
|
|||
|
176: "°",
|
|||
|
177: "±",
|
|||
|
178: "²",
|
|||
|
179: "³",
|
|||
|
180: "´",
|
|||
|
181: "µ",
|
|||
|
182: "¶",
|
|||
|
183: "·",
|
|||
|
184: "¸",
|
|||
|
185: "¹",
|
|||
|
186: "º",
|
|||
|
187: "»",
|
|||
|
188: "¼",
|
|||
|
189: "½",
|
|||
|
190: "¾",
|
|||
|
191: "¿",
|
|||
|
192: "À",
|
|||
|
193: "Á",
|
|||
|
194: "Â",
|
|||
|
195: "Ã",
|
|||
|
196: "Ä",
|
|||
|
197: "Å",
|
|||
|
198: "Æ",
|
|||
|
199: "Ç",
|
|||
|
200: "È",
|
|||
|
201: "É",
|
|||
|
202: "Ê",
|
|||
|
203: "Ë",
|
|||
|
204: "Ì",
|
|||
|
205: "Í",
|
|||
|
206: "Î",
|
|||
|
207: "Ï",
|
|||
|
208: "Ð",
|
|||
|
209: "Ñ",
|
|||
|
210: "Ò",
|
|||
|
211: "Ó",
|
|||
|
212: "Ô",
|
|||
|
213: "Õ",
|
|||
|
214: "Ö",
|
|||
|
215: "×",
|
|||
|
216: "Ø",
|
|||
|
217: "Ù",
|
|||
|
218: "Ú",
|
|||
|
219: "Û",
|
|||
|
220: "Ü",
|
|||
|
221: "Ý",
|
|||
|
222: "Þ",
|
|||
|
223: "ß",
|
|||
|
224: "à",
|
|||
|
225: "á",
|
|||
|
226: "â",
|
|||
|
227: "ã",
|
|||
|
228: "ä",
|
|||
|
229: "å",
|
|||
|
230: "æ",
|
|||
|
231: "ç",
|
|||
|
232: "è",
|
|||
|
233: "é",
|
|||
|
234: "ê",
|
|||
|
235: "ë",
|
|||
|
236: "ì",
|
|||
|
237: "í",
|
|||
|
238: "î",
|
|||
|
239: "ï",
|
|||
|
240: "ð",
|
|||
|
241: "ñ",
|
|||
|
242: "ò",
|
|||
|
243: "ó",
|
|||
|
244: "ô",
|
|||
|
245: "õ",
|
|||
|
246: "ö",
|
|||
|
247: "÷",
|
|||
|
248: "ø",
|
|||
|
249: "ù",
|
|||
|
250: "ú",
|
|||
|
251: "û",
|
|||
|
252: "ü",
|
|||
|
253: "ý",
|
|||
|
254: "þ",
|
|||
|
255: "ÿ",
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func getPairs(word []string) []lo.Tuple2[string, string] {
|
|||
|
all := []lo.Tuple2[string, string]{}
|
|||
|
prevWord := word[0]
|
|||
|
for _, char := range word[1:] {
|
|||
|
all = append(all, lo.T2(prevWord, char))
|
|||
|
prevWord = char
|
|||
|
}
|
|||
|
|
|||
|
return lo.Uniq(all)
|
|||
|
}
|
|||
|
|
|||
|
type Encoder struct {
|
|||
|
bpeRank map[lo.Tuple2[string, string]]int
|
|||
|
encoder map[string]int
|
|||
|
decoder map[int]string
|
|||
|
byteEncoder map[rune]string
|
|||
|
byteDecoder map[string]rune
|
|||
|
|
|||
|
cache sync.Map
|
|||
|
}
|
|||
|
|
|||
|
func NewEncoder() (*Encoder, error) {
|
|||
|
bpeData, rawEncoder, err := loadFiles()
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
return NewEncoderWithVocab(bpeData, rawEncoder)
|
|||
|
}
|
|||
|
|
|||
|
func NewEncoderWithVocab(bpeData []byte, jsonEncoder []byte) (*Encoder, error) {
|
|||
|
encoder := map[string]int{}
|
|||
|
|
|||
|
err := json.Unmarshal(jsonEncoder, &encoder)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
decoder := lo.Invert(encoder)
|
|||
|
|
|||
|
bpeLines := bytes.Split(bpeData, []byte{'\n'})
|
|||
|
bpeMerges := lo.Map(bpeLines[1:len(bpeLines)-1], func(line []byte, _ int) lo.Tuple2[[]byte, []byte] {
|
|||
|
parts := bytes.SplitN(line, []byte{' '}, 2)
|
|||
|
return lo.T2(parts[0], parts[1])
|
|||
|
})
|
|||
|
|
|||
|
byteEncoder := bytesToUnicode()
|
|||
|
byteDecoder := lo.Invert(byteEncoder)
|
|||
|
|
|||
|
// translated to a tuple of string, because []byte is not comparable (for maps)
|
|||
|
bpeRank := dictTuple(bpeMerges)
|
|||
|
|
|||
|
enc := Encoder{
|
|||
|
bpeRank: bpeRank,
|
|||
|
encoder: encoder,
|
|||
|
decoder: decoder,
|
|||
|
byteEncoder: byteEncoder,
|
|||
|
byteDecoder: byteDecoder,
|
|||
|
|
|||
|
cache: sync.Map{},
|
|||
|
}
|
|||
|
|
|||
|
return &enc, nil
|
|||
|
}
|
|||
|
|
|||
|
func (e *Encoder) cachedBpe(token string) string {
|
|||
|
cached, ok := e.cache.Load(token)
|
|||
|
if ok {
|
|||
|
return cached.(string)
|
|||
|
}
|
|||
|
|
|||
|
output := e.bpe(token)
|
|||
|
e.cache.Store(token, output)
|
|||
|
return output
|
|||
|
}
|
|||
|
|
|||
|
func (e *Encoder) bpe(token string) string {
|
|||
|
word := lo.ChunkString(token, 1)
|
|||
|
|
|||
|
pairs := getPairs(word)
|
|||
|
if len(pairs) == 0 {
|
|||
|
return token
|
|||
|
}
|
|||
|
|
|||
|
for {
|
|||
|
minPairs := lo.Map(pairs, func(item lo.Tuple2[string, string], index int) lo.Tuple3[int, string, string] {
|
|||
|
pair := lo.T2(item.A, item.B)
|
|||
|
|
|||
|
rank, ok := e.bpeRank[pair]
|
|||
|
if !ok {
|
|||
|
rank = math.MaxInt
|
|||
|
}
|
|||
|
|
|||
|
return lo.T3(rank, item.A, item.B)
|
|||
|
})
|
|||
|
|
|||
|
bigram := lo.MinBy(minPairs, func(a lo.Tuple3[int, string, string], b lo.Tuple3[int, string, string]) bool {
|
|||
|
return a.A < b.A
|
|||
|
})
|
|||
|
|
|||
|
first := bigram.B
|
|||
|
second := bigram.C
|
|||
|
|
|||
|
if _, ok := e.bpeRank[lo.T2(first, second)]; !ok {
|
|||
|
break
|
|||
|
}
|
|||
|
|
|||
|
newWord := []string{}
|
|||
|
for i := 0; i < len(word); {
|
|||
|
_, j, ok := lo.FindIndexOf(word[i:], func(str string) bool { return str == first })
|
|||
|
if !ok {
|
|||
|
newWord = append(newWord, word[i:]...)
|
|||
|
break
|
|||
|
}
|
|||
|
|
|||
|
j += i
|
|||
|
|
|||
|
newWord = append(newWord, word[i:j]...)
|
|||
|
i = j
|
|||
|
|
|||
|
if word[i] == first && i < len(word)-1 && word[i+1] == second {
|
|||
|
newWord = append(newWord, first+second)
|
|||
|
i = i + 2
|
|||
|
} else {
|
|||
|
newWord = append(newWord, word[i])
|
|||
|
i = i + 1
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
word = newWord
|
|||
|
if len(word) == 1 {
|
|||
|
break
|
|||
|
}
|
|||
|
|
|||
|
pairs = getPairs(word)
|
|||
|
}
|
|||
|
|
|||
|
return strings.Join(word, " ")
|
|||
|
}
|
|||
|
|
|||
|
func (e *Encoder) splitToken(token string) ([]string, error) {
|
|||
|
var matches []string
|
|||
|
|
|||
|
m, err := pat.FindStringMatch(token)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
for m != nil {
|
|||
|
matches = append(matches, m.String())
|
|||
|
|
|||
|
m, err = pat.FindNextMatch(m)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
}
|
|||
|
return matches, nil
|
|||
|
}
|
|||
|
|
|||
|
func (e *Encoder) Encode(text string) ([]int, error) {
|
|||
|
bpeTokens := []int{}
|
|||
|
|
|||
|
matches, err := e.splitToken(text)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
for _, match := range matches {
|
|||
|
runes := []rune(match)
|
|||
|
|
|||
|
token := strings.Join(lo.Map(runes, func(item rune, _ int) string {
|
|||
|
return e.byteEncoder[item]
|
|||
|
}), "")
|
|||
|
|
|||
|
bpe := e.cachedBpe(token)
|
|||
|
for _, t := range strings.Split(bpe, " ") {
|
|||
|
bpeTokens = append(bpeTokens, e.encoder[t])
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return bpeTokens, nil
|
|||
|
}
|
|||
|
|
|||
|
func (e *Encoder) Decode(tokens []int) string {
|
|||
|
parts := lo.Map(tokens, func(token int, _ int) string {
|
|||
|
return e.decoder[token]
|
|||
|
})
|
|||
|
|
|||
|
parts = lo.ChunkString(strings.Join(parts, ""), 1)
|
|||
|
|
|||
|
text := lo.Map(parts, func(item string, _ int) rune {
|
|||
|
return e.byteDecoder[item]
|
|||
|
})
|
|||
|
|
|||
|
return string(text)
|
|||
|
}
|