go-gpt-3-encoder/encoder.go

493 lines
7.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gpt3encoder
import (
"bytes"
"embed"
"encoding/json"
"math"
"strings"
"sync"
"github.com/dlclark/regexp2"
"github.com/samber/lo"
)
//go:embed encoder.json
//go:embed vocab.bpe
var files embed.FS
var pat = regexp2.MustCompile(`/'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/`, 0)
func loadFiles() (bpeData []byte, encoder []byte, err error) {
bpeData, err = files.ReadFile("vocab.bpe")
if err != nil {
return
}
encoder, err = files.ReadFile("encoder.json")
if err != nil {
return
}
return
}
// hardcoded
func bytesToUnicode() map[rune]string {
return map[rune]string{
0: "Ā",
1: "ā",
2: "Ă",
3: "ă",
4: "Ą",
5: "ą",
6: "Ć",
7: "ć",
8: "Ĉ",
9: "ĉ",
10: "Ċ",
11: "ċ",
12: "Č",
13: "č",
14: "Ď",
15: "ď",
16: "Đ",
17: "đ",
18: "Ē",
19: "ē",
20: "Ĕ",
21: "ĕ",
22: "Ė",
23: "ė",
24: "Ę",
25: "ę",
26: "Ě",
27: "ě",
28: "Ĝ",
29: "ĝ",
30: "Ğ",
31: "ğ",
32: "Ġ",
33: "!",
34: "\"",
35: "#",
36: "$",
37: "%",
38: "&",
39: "'",
40: "(",
41: ")",
42: "*",
43: "+",
44: ",",
45: "-",
46: ".",
47: "/",
48: "0",
49: "1",
50: "2",
51: "3",
52: "4",
53: "5",
54: "6",
55: "7",
56: "8",
57: "9",
58: ":",
59: ";",
60: "<",
61: "=",
62: ">",
63: "?",
64: "@",
65: "A",
66: "B",
67: "C",
68: "D",
69: "E",
70: "F",
71: "G",
72: "H",
73: "I",
74: "J",
75: "K",
76: "L",
77: "M",
78: "N",
79: "O",
80: "P",
81: "Q",
82: "R",
83: "S",
84: "T",
85: "U",
86: "V",
87: "W",
88: "X",
89: "Y",
90: "Z",
91: "[",
92: "\\",
93: "]",
94: "^",
95: "_",
96: "`",
97: "a",
98: "b",
99: "c",
100: "d",
101: "e",
102: "f",
103: "g",
104: "h",
105: "i",
106: "j",
107: "k",
108: "l",
109: "m",
110: "n",
111: "o",
112: "p",
113: "q",
114: "r",
115: "s",
116: "t",
117: "u",
118: "v",
119: "w",
120: "x",
121: "y",
122: "z",
123: "{",
124: "|",
125: "}",
126: "~",
127: "ġ",
128: "Ģ",
129: "ģ",
130: "Ĥ",
131: "ĥ",
132: "Ħ",
133: "ħ",
134: "Ĩ",
135: "ĩ",
136: "Ī",
137: "ī",
138: "Ĭ",
139: "ĭ",
140: "Į",
141: "į",
142: "İ",
143: "ı",
144: "IJ",
145: "ij",
146: "Ĵ",
147: "ĵ",
148: "Ķ",
149: "ķ",
150: "ĸ",
151: "Ĺ",
152: "ĺ",
153: "Ļ",
154: "ļ",
155: "Ľ",
156: "ľ",
157: "Ŀ",
158: "ŀ",
159: "Ł",
160: "ł",
161: "¡",
162: "¢",
163: "£",
164: "¤",
165: "¥",
166: "¦",
167: "§",
168: "¨",
169: "©",
170: "ª",
171: "«",
172: "¬",
173: "Ń",
174: "®",
175: "¯",
176: "°",
177: "±",
178: "²",
179: "³",
180: "´",
181: "µ",
182: "¶",
183: "·",
184: "¸",
185: "¹",
186: "º",
187: "»",
188: "¼",
189: "½",
190: "¾",
191: "¿",
192: "À",
193: "Á",
194: "Â",
195: "Ã",
196: "Ä",
197: "Å",
198: "Æ",
199: "Ç",
200: "È",
201: "É",
202: "Ê",
203: "Ë",
204: "Ì",
205: "Í",
206: "Î",
207: "Ï",
208: "Ð",
209: "Ñ",
210: "Ò",
211: "Ó",
212: "Ô",
213: "Õ",
214: "Ö",
215: "×",
216: "Ø",
217: "Ù",
218: "Ú",
219: "Û",
220: "Ü",
221: "Ý",
222: "Þ",
223: "ß",
224: "à",
225: "á",
226: "â",
227: "ã",
228: "ä",
229: "å",
230: "æ",
231: "ç",
232: "è",
233: "é",
234: "ê",
235: "ë",
236: "ì",
237: "í",
238: "î",
239: "ï",
240: "ð",
241: "ñ",
242: "ò",
243: "ó",
244: "ô",
245: "õ",
246: "ö",
247: "÷",
248: "ø",
249: "ù",
250: "ú",
251: "û",
252: "ü",
253: "ý",
254: "þ",
255: "ÿ",
}
}
func getPairs(word []string) []lo.Tuple2[string, string] {
all := []lo.Tuple2[string, string]{}
prevWord := word[0]
for _, char := range word[1:] {
all = append(all, lo.T2(prevWord, char))
prevWord = char
}
return lo.Uniq(all)
}
type Encoder struct {
bpeRank map[lo.Tuple2[string, string]]int
encoder map[string]int
decoder map[int]string
byteEncoder map[rune]string
byteDecoder map[string]rune
cache sync.Map
}
func NewEncoder() (*Encoder, error) {
bpeData, rawEncoder, err := loadFiles()
if err != nil {
return nil, err
}
return NewEncoderWithVocab(bpeData, rawEncoder)
}
func NewEncoderWithVocab(bpeData []byte, jsonEncoder []byte) (*Encoder, error) {
encoder := map[string]int{}
err := json.Unmarshal(jsonEncoder, &encoder)
if err != nil {
return nil, err
}
decoder := lo.Invert(encoder)
bpeLines := bytes.Split(bpeData, []byte{'\n'})
bpeMerges := lo.Map(bpeLines[1:len(bpeLines)-1], func(line []byte, _ int) lo.Tuple2[[]byte, []byte] {
parts := bytes.SplitN(line, []byte{' '}, 2)
return lo.T2(parts[0], parts[1])
})
byteEncoder := bytesToUnicode()
byteDecoder := lo.Invert(byteEncoder)
// translated to a tuple of string, because []byte is not comparable (for maps)
bpeRank := dictTuple(bpeMerges)
enc := Encoder{
bpeRank: bpeRank,
encoder: encoder,
decoder: decoder,
byteEncoder: byteEncoder,
byteDecoder: byteDecoder,
cache: sync.Map{},
}
return &enc, nil
}
func (e *Encoder) cachedBpe(token string) string {
cached, ok := e.cache.Load(token)
if ok {
return cached.(string)
}
output := e.bpe(token)
e.cache.Store(token, output)
return output
}
func (e *Encoder) bpe(token string) string {
word := lo.ChunkString(token, 1)
pairs := getPairs(word)
if len(pairs) == 0 {
return token
}
for {
minPairs := lo.Map(pairs, func(item lo.Tuple2[string, string], index int) lo.Tuple3[int, string, string] {
pair := lo.T2(item.A, item.B)
rank, ok := e.bpeRank[pair]
if !ok {
rank = math.MaxInt
}
return lo.T3(rank, item.A, item.B)
})
bigram := lo.MinBy(minPairs, func(a lo.Tuple3[int, string, string], b lo.Tuple3[int, string, string]) bool {
return a.A < b.A
})
first := bigram.B
second := bigram.C
if _, ok := e.bpeRank[lo.T2(first, second)]; !ok {
break
}
newWord := []string{}
for i := 0; i < len(word); {
_, j, ok := lo.FindIndexOf(word[i:], func(str string) bool { return str == first })
if !ok {
newWord = append(newWord, word[i:]...)
break
}
j += i
newWord = append(newWord, word[i:j]...)
i = j
if word[i] == first && i < len(word)-1 && word[i+1] == second {
newWord = append(newWord, first+second)
i = i + 2
} else {
newWord = append(newWord, word[i])
i = i + 1
}
}
word = newWord
if len(word) == 1 {
break
}
pairs = getPairs(word)
}
return strings.Join(word, " ")
}
func (e *Encoder) splitToken(token string) ([]string, error) {
var matches []string
m, err := pat.FindStringMatch(token)
if err != nil {
return nil, err
}
for m != nil {
matches = append(matches, m.String())
m, err = pat.FindNextMatch(m)
if err != nil {
return nil, err
}
}
return matches, nil
}
func (e *Encoder) Encode(text string) ([]int, error) {
bpeTokens := []int{}
matches, err := e.splitToken(text)
if err != nil {
return nil, err
}
for _, match := range matches {
runes := []rune(match)
token := strings.Join(lo.Map(runes, func(item rune, _ int) string {
return e.byteEncoder[item]
}), "")
bpe := e.cachedBpe(token)
for _, t := range strings.Split(bpe, " ") {
bpeTokens = append(bpeTokens, e.encoder[t])
}
}
return bpeTokens, nil
}
func (e *Encoder) Decode(tokens []int) string {
parts := lo.Map(tokens, func(token int, _ int) string {
return e.decoder[token]
})
parts = lo.ChunkString(strings.Join(parts, ""), 1)
text := lo.Map(parts, func(item string, _ int) rune {
return e.byteDecoder[item]
})
return string(text)
}