Initial commit

This commit is contained in:
Samuel Berthe 2022-12-22 00:04:00 +01:00
commit 8e1ed1952f
No known key found for this signature in database
GPG Key ID: 64863511FFBD0E3C
13 changed files with 51279 additions and 0 deletions

44
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,44 @@
name: Lint
on:
push:
tags:
branches:
pull_request:
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v2
with:
go-version: 1.18
stable: false
- uses: actions/checkout@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: latest
# Optional: working directory, useful for monorepos
working-directory: ./
# Optional: golangci-lint command line arguments.
args: --timeout 60s --max-same-issues 50
# Optional: show only new issues if it's a pull request. The default value is `false`.
# only-new-issues: true
# Optional: if set to true then the action will use pre-installed Go.
# skip-go-installation: true
# Optional: if set to true then the action don't cache or restore ~/go/pkg.
# skip-pkg-cache: true
# Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
# skip-build-cache: true
# optionally use a specific version of Go rather than the latest one
go_version: '1.18'

37
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Tests
on:
push:
tags:
branches:
pull_request:
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.18
stable: false
- name: Build
run: make build
- name: Test
run: make test
- name: Test
run: make coverage
- name: Codecov
uses: codecov/codecov-action@v2
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./cover.out
flags: unittests
verbose: true

38
.gitignore vendored Normal file
View File

@ -0,0 +1,38 @@
# Created by https://www.toptal.com/developers/gitignore/api/go
# Edit at https://www.toptal.com/developers/gitignore?templates=go
### Go ###
# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
# Go workspace file
go.work
### Go Patch ###
/vendor/
/Godeps/
# End of https://www.toptal.com/developers/gitignore/api/go
cover.out
cover.html
.vscode
.idea/

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 Samuel Berthe
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

43
Makefile Normal file
View File

@ -0,0 +1,43 @@
build:
go build -v ./...
test:
go test -v ./...
watch-test:
reflex -t 50ms -s -- sh -c 'gotest -v ./...'
bench:
go test -benchmem -count 3 -bench ./...
watch-bench:
reflex -t 50ms -s -- sh -c 'go test -benchmem -count 3 -bench ./...'
coverage:
go test -v -coverprofile=cover.out -covermode=atomic .
go tool cover -html=cover.out -o cover.html
tools:
go install github.com/cespare/reflex@latest
go install github.com/rakyll/gotest@latest
go install github.com/psampaz/go-mod-outdated@latest
go install github.com/jondot/goweight@latest
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
go get -t -u golang.org/x/tools/cmd/cover
go get -t -u github.com/sonatype-nexus-community/nancy@latest
go mod tidy
lint:
golangci-lint run --timeout 60s --max-same-issues 50 ./...
lint-fix:
golangci-lint run --timeout 60s --max-same-issues 50 --fix ./...
audit:
go mod tidy
go list -json -m all | nancy sleuth
outdated:
go mod tidy
go list -u -m -json all | go-mod-outdated -update -direct
weight:
goweight

47
README.md Normal file
View File

@ -0,0 +1,47 @@
# go-gpt-3-encoder
Go BPE tokenizer (Encoder+Decoder) for GPT2 and GPT3.
## About
GPT2 and GPT3 use byte pair encoding to turn text into a series of integers to feed into the model. This is a Go implementation of OpenAI's original Python encoder/decoder which can be found [here](https://github.com/openai/gpt-2/blob/master/src/encoder.py).
This code was inspired by [Javascript implementation](https://github.com/latitudegames/GPT-3-Encoder) and partially generated by OpenAI himself!
## Install
```bash
go get github.com/samber/go-gpt-3-encoder
```
## Usage
Compatible with Node >= 12
```go
import "github.com/samber/go-gpt-3-encoder"
encoder, err := NewEncoder()
if err != nil {
log.Fatal(err)
}
str := "This is an example sentence to try encoding out on!"
encoded, err := encoder.Encode(str)
if err != nil {
log.Fatal(err)
}
fmt.Println("We can look at each token and what it represents:")
for _, token := encoded {
fmt.Printf("%s -- %s\n", token, encoder.Decode([]string{token}))
}
decoded := encoder.Decode(encoded)
fmt.Printf("We can decode it back into: %s\n", decoded)
```
## Contribute
Some corner cases are not covered by this library. See `@TODO` in tests.

492
encoder.go Normal file
View File

@ -0,0 +1,492 @@
package gpt3encoder
import (
"bytes"
"embed"
"encoding/json"
"math"
"strings"
"sync"
"github.com/dlclark/regexp2"
"github.com/samber/lo"
)
//go:embed encoder.json
//go:embed vocab.bpe
var files embed.FS
var pat = regexp2.MustCompile(`/'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/`, 0)
func loadFiles() (bpeData []byte, encoder []byte, err error) {
bpeData, err = files.ReadFile("vocab.bpe")
if err != nil {
return
}
encoder, err = files.ReadFile("encoder.json")
if err != nil {
return
}
return
}
// hardcoded
func bytesToUnicode() map[rune]string {
return map[rune]string{
0: "Ā",
1: "ā",
2: "Ă",
3: "ă",
4: "Ą",
5: "ą",
6: "Ć",
7: "ć",
8: "Ĉ",
9: "ĉ",
10: "Ċ",
11: "ċ",
12: "Č",
13: "č",
14: "Ď",
15: "ď",
16: "Đ",
17: "đ",
18: "Ē",
19: "ē",
20: "Ĕ",
21: "ĕ",
22: "Ė",
23: "ė",
24: "Ę",
25: "ę",
26: "Ě",
27: "ě",
28: "Ĝ",
29: "ĝ",
30: "Ğ",
31: "ğ",
32: "Ġ",
33: "!",
34: "\"",
35: "#",
36: "$",
37: "%",
38: "&",
39: "'",
40: "(",
41: ")",
42: "*",
43: "+",
44: ",",
45: "-",
46: ".",
47: "/",
48: "0",
49: "1",
50: "2",
51: "3",
52: "4",
53: "5",
54: "6",
55: "7",
56: "8",
57: "9",
58: ":",
59: ";",
60: "<",
61: "=",
62: ">",
63: "?",
64: "@",
65: "A",
66: "B",
67: "C",
68: "D",
69: "E",
70: "F",
71: "G",
72: "H",
73: "I",
74: "J",
75: "K",
76: "L",
77: "M",
78: "N",
79: "O",
80: "P",
81: "Q",
82: "R",
83: "S",
84: "T",
85: "U",
86: "V",
87: "W",
88: "X",
89: "Y",
90: "Z",
91: "[",
92: "\\",
93: "]",
94: "^",
95: "_",
96: "`",
97: "a",
98: "b",
99: "c",
100: "d",
101: "e",
102: "f",
103: "g",
104: "h",
105: "i",
106: "j",
107: "k",
108: "l",
109: "m",
110: "n",
111: "o",
112: "p",
113: "q",
114: "r",
115: "s",
116: "t",
117: "u",
118: "v",
119: "w",
120: "x",
121: "y",
122: "z",
123: "{",
124: "|",
125: "}",
126: "~",
127: "ġ",
128: "Ģ",
129: "ģ",
130: "Ĥ",
131: "ĥ",
132: "Ħ",
133: "ħ",
134: "Ĩ",
135: "ĩ",
136: "Ī",
137: "ī",
138: "Ĭ",
139: "ĭ",
140: "Į",
141: "į",
142: "İ",
143: "ı",
144: "IJ",
145: "ij",
146: "Ĵ",
147: "ĵ",
148: "Ķ",
149: "ķ",
150: "ĸ",
151: "Ĺ",
152: "ĺ",
153: "Ļ",
154: "ļ",
155: "Ľ",
156: "ľ",
157: "Ŀ",
158: "ŀ",
159: "Ł",
160: "ł",
161: "¡",
162: "¢",
163: "£",
164: "¤",
165: "¥",
166: "¦",
167: "§",
168: "¨",
169: "©",
170: "ª",
171: "«",
172: "¬",
173: "Ń",
174: "®",
175: "¯",
176: "°",
177: "±",
178: "²",
179: "³",
180: "´",
181: "µ",
182: "¶",
183: "·",
184: "¸",
185: "¹",
186: "º",
187: "»",
188: "¼",
189: "½",
190: "¾",
191: "¿",
192: "À",
193: "Á",
194: "Â",
195: "Ã",
196: "Ä",
197: "Å",
198: "Æ",
199: "Ç",
200: "È",
201: "É",
202: "Ê",
203: "Ë",
204: "Ì",
205: "Í",
206: "Î",
207: "Ï",
208: "Ð",
209: "Ñ",
210: "Ò",
211: "Ó",
212: "Ô",
213: "Õ",
214: "Ö",
215: "×",
216: "Ø",
217: "Ù",
218: "Ú",
219: "Û",
220: "Ü",
221: "Ý",
222: "Þ",
223: "ß",
224: "à",
225: "á",
226: "â",
227: "ã",
228: "ä",
229: "å",
230: "æ",
231: "ç",
232: "è",
233: "é",
234: "ê",
235: "ë",
236: "ì",
237: "í",
238: "î",
239: "ï",
240: "ð",
241: "ñ",
242: "ò",
243: "ó",
244: "ô",
245: "õ",
246: "ö",
247: "÷",
248: "ø",
249: "ù",
250: "ú",
251: "û",
252: "ü",
253: "ý",
254: "þ",
255: "ÿ",
}
}
func getPairs(word []string) []lo.Tuple2[string, string] {
all := []lo.Tuple2[string, string]{}
prevWord := word[0]
for _, char := range word[1:] {
all = append(all, lo.T2(prevWord, char))
prevWord = char
}
return lo.Uniq(all)
}
type Encoder struct {
bpeRank map[lo.Tuple2[string, string]]int
encoder map[string]int
decoder map[int]string
byteEncoder map[rune]string
byteDecoder map[string]rune
cache sync.Map
}
func NewEncoder() (*Encoder, error) {
bpeData, rawEncoder, err := loadFiles()
if err != nil {
return nil, err
}
return NewEncoderWithVocab(bpeData, rawEncoder)
}
func NewEncoderWithVocab(bpeData []byte, jsonEncoder []byte) (*Encoder, error) {
encoder := map[string]int{}
err := json.Unmarshal(jsonEncoder, &encoder)
if err != nil {
return nil, err
}
decoder := lo.Invert(encoder)
bpeLines := bytes.Split(bpeData, []byte{'\n'})
bpeMerges := lo.Map(bpeLines[1:len(bpeLines)-1], func(line []byte, _ int) lo.Tuple2[[]byte, []byte] {
parts := bytes.SplitN(line, []byte{' '}, 2)
return lo.T2(parts[0], parts[1])
})
byteEncoder := bytesToUnicode()
byteDecoder := lo.Invert(byteEncoder)
// translated to a tuple of string, because []byte is not comparable (for maps)
bpeRank := dictTuple(bpeMerges)
enc := Encoder{
bpeRank: bpeRank,
encoder: encoder,
decoder: decoder,
byteEncoder: byteEncoder,
byteDecoder: byteDecoder,
cache: sync.Map{},
}
return &enc, nil
}
func (e *Encoder) cachedBpe(token string) string {
cached, ok := e.cache.Load(token)
if ok {
return cached.(string)
}
output := e.bpe(token)
e.cache.Store(token, output)
return output
}
func (e *Encoder) bpe(token string) string {
word := lo.ChunkString(token, 1)
pairs := getPairs(word)
if len(pairs) == 0 {
return token
}
for {
minPairs := lo.Map(pairs, func(item lo.Tuple2[string, string], index int) lo.Tuple3[int, string, string] {
pair := lo.T2(item.A, item.B)
rank, ok := e.bpeRank[pair]
if !ok {
rank = math.MaxInt
}
return lo.T3(rank, item.A, item.B)
})
bigram := lo.MinBy(minPairs, func(a lo.Tuple3[int, string, string], b lo.Tuple3[int, string, string]) bool {
return a.A < b.A
})
first := bigram.B
second := bigram.C
if _, ok := e.bpeRank[lo.T2(first, second)]; !ok {
break
}
newWord := []string{}
for i := 0; i < len(word); {
_, j, ok := lo.FindIndexOf(word[i:], func(str string) bool { return str == first })
if !ok {
newWord = append(newWord, word[i:]...)
break
}
j += i
newWord = append(newWord, word[i:j]...)
i = j
if word[i] == first && i < len(word)-1 && word[i+1] == second {
newWord = append(newWord, first+second)
i = i + 2
} else {
newWord = append(newWord, word[i])
i = i + 1
}
}
word = newWord
if len(word) == 1 {
break
}
pairs = getPairs(word)
}
return strings.Join(word, " ")
}
func (e *Encoder) splitToken(token string) ([]string, error) {
var matches []string
m, err := pat.FindStringMatch(token)
if err != nil {
return nil, err
}
for m != nil {
matches = append(matches, m.String())
m, err = pat.FindNextMatch(m)
if err != nil {
return nil, err
}
}
return matches, nil
}
func (e *Encoder) Encode(text string) ([]int, error) {
bpeTokens := []int{}
matches, err := e.splitToken(text)
if err != nil {
return nil, err
}
for _, match := range matches {
runes := []rune(match)
token := strings.Join(lo.Map(runes, func(item rune, _ int) string {
return e.byteEncoder[item]
}), "")
bpe := e.cachedBpe(token)
for _, t := range strings.Split(bpe, " ") {
bpeTokens = append(bpeTokens, e.encoder[t])
}
}
return bpeTokens, nil
}
func (e *Encoder) Decode(tokens []int) string {
parts := lo.Map(tokens, func(token int, _ int) string {
return e.decoder[token]
})
parts = lo.ChunkString(strings.Join(parts, ""), 1)
text := lo.Map(parts, func(item string, _ int) rune {
return e.byteDecoder[item]
})
return string(text)
}

1
encoder.json Normal file

File diff suppressed because one or more lines are too long

503
encoder_test.go Normal file
View File

@ -0,0 +1,503 @@
package gpt3encoder
import (
"testing"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
)
func TestBytesToUnicode(t *testing.T) {
is := assert.New(t)
// Most useful test E.V.E.R ^^
want := map[rune]string{
0: "Ā",
1: "ā",
2: "Ă",
3: "ă",
4: "Ą",
5: "ą",
6: "Ć",
7: "ć",
8: "Ĉ",
9: "ĉ",
10: "Ċ",
11: "ċ",
12: "Č",
13: "č",
14: "Ď",
15: "ď",
16: "Đ",
17: "đ",
18: "Ē",
19: "ē",
20: "Ĕ",
21: "ĕ",
22: "Ė",
23: "ė",
24: "Ę",
25: "ę",
26: "Ě",
27: "ě",
28: "Ĝ",
29: "ĝ",
30: "Ğ",
31: "ğ",
32: "Ġ",
33: "!",
34: "\"",
35: "#",
36: "$",
37: "%",
38: "&",
39: "'",
40: "(",
41: ")",
42: "*",
43: "+",
44: ",",
45: "-",
46: ".",
47: "/",
48: "0",
49: "1",
50: "2",
51: "3",
52: "4",
53: "5",
54: "6",
55: "7",
56: "8",
57: "9",
58: ":",
59: ";",
60: "<",
61: "=",
62: ">",
63: "?",
64: "@",
65: "A",
66: "B",
67: "C",
68: "D",
69: "E",
70: "F",
71: "G",
72: "H",
73: "I",
74: "J",
75: "K",
76: "L",
77: "M",
78: "N",
79: "O",
80: "P",
81: "Q",
82: "R",
83: "S",
84: "T",
85: "U",
86: "V",
87: "W",
88: "X",
89: "Y",
90: "Z",
91: "[",
92: "\\",
93: "]",
94: "^",
95: "_",
96: "`",
97: "a",
98: "b",
99: "c",
100: "d",
101: "e",
102: "f",
103: "g",
104: "h",
105: "i",
106: "j",
107: "k",
108: "l",
109: "m",
110: "n",
111: "o",
112: "p",
113: "q",
114: "r",
115: "s",
116: "t",
117: "u",
118: "v",
119: "w",
120: "x",
121: "y",
122: "z",
123: "{",
124: "|",
125: "}",
126: "~",
127: "ġ",
128: "Ģ",
129: "ģ",
130: "Ĥ",
131: "ĥ",
132: "Ħ",
133: "ħ",
134: "Ĩ",
135: "ĩ",
136: "Ī",
137: "ī",
138: "Ĭ",
139: "ĭ",
140: "Į",
141: "į",
142: "İ",
143: "ı",
144: "IJ",
145: "ij",
146: "Ĵ",
147: "ĵ",
148: "Ķ",
149: "ķ",
150: "ĸ",
151: "Ĺ",
152: "ĺ",
153: "Ļ",
154: "ļ",
155: "Ľ",
156: "ľ",
157: "Ŀ",
158: "ŀ",
159: "Ł",
160: "ł",
161: "¡",
162: "¢",
163: "£",
164: "¤",
165: "¥",
166: "¦",
167: "§",
168: "¨",
169: "©",
170: "ª",
171: "«",
172: "¬",
173: "Ń",
174: "®",
175: "¯",
176: "°",
177: "±",
178: "²",
179: "³",
180: "´",
181: "µ",
182: "¶",
183: "·",
184: "¸",
185: "¹",
186: "º",
187: "»",
188: "¼",
189: "½",
190: "¾",
191: "¿",
192: "À",
193: "Á",
194: "Â",
195: "Ã",
196: "Ä",
197: "Å",
198: "Æ",
199: "Ç",
200: "È",
201: "É",
202: "Ê",
203: "Ë",
204: "Ì",
205: "Í",
206: "Î",
207: "Ï",
208: "Ð",
209: "Ñ",
210: "Ò",
211: "Ó",
212: "Ô",
213: "Õ",
214: "Ö",
215: "×",
216: "Ø",
217: "Ù",
218: "Ú",
219: "Û",
220: "Ü",
221: "Ý",
222: "Þ",
223: "ß",
224: "à",
225: "á",
226: "â",
227: "ã",
228: "ä",
229: "å",
230: "æ",
231: "ç",
232: "è",
233: "é",
234: "ê",
235: "ë",
236: "ì",
237: "í",
238: "î",
239: "ï",
240: "ð",
241: "ñ",
242: "ò",
243: "ó",
244: "ô",
245: "õ",
246: "ö",
247: "÷",
248: "ø",
249: "ù",
250: "ú",
251: "û",
252: "ü",
253: "ý",
254: "þ",
255: "ÿ",
}
got := bytesToUnicode()
is.EqualValues(want, got)
}
func TestNewEncoder_bpeRank(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
is.EqualValues(49830, encoder.bpeRank[lo.T2("c", "rim")])
is.EqualValues(49880, encoder.bpeRank[lo.T2("Ġdispens", "ary")])
is.EqualValues(49905, encoder.bpeRank[lo.T2("ĠAm", "p")])
}
func TestNewEncoder_encoder(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
is.EqualValues(50225, encoder.encoder["Ġreclaimed"])
is.EqualValues(50145, encoder.encoder["headers"])
is.EqualValues(50256, encoder.encoder["<|endoftext|>"])
}
func TestNewEncoder_decoder(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
is.EqualValues("Ġreclaimed", encoder.decoder[50225])
is.EqualValues("headers", encoder.decoder[50145])
is.EqualValues("<|endoftext|>", encoder.decoder[50256])
}
func TestNewEncoder_splitToken(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
want := []string{
"hello",
" 👋",
" world",
" 🌍",
" This",
" is",
" a",
" long",
" string",
" to",
" test",
" whether",
" or",
" not",
" the",
" emoji",
" issue",
" was",
" fixed",
"!",
}
got, err := encoder.splitToken("hello 👋 world 🌍 This is a long string to test whether or not the emoji issue was fixed!")
is.EqualValues(want, got)
is.Nil(err)
}
func TestGetPairs(t *testing.T) {
is := assert.New(t)
words := [][]string{
{"h", "e", "l", "l", "o"},
{"he", "l", "l", "o"},
{"hel", "l", "o"},
{"hell", "o"},
{"hello"},
{"Ġ", "ð", "Ł", "ij", "ĭ"},
{"Ġ", "w", "o", "r", "l", "d"},
{"Ġ", "ð", "Ł", "Į", "į"},
{"Ġ", "T", "h", "i", "s"},
{"Ġ", "i", "s"},
{"Ġ", "a"},
{"Ġ", "l", "o", "n", "g"},
{"Ġ", "s", "t", "r", "i", "n", "g"},
{"Ġ", "t", "o"},
{"Ġ", "t", "e", "s", "t"},
{"Ġ", "w", "h", "e", "t", "h", "e", "r"},
{"Ġ", "o", "r"},
{"Ġ", "n", "o", "t"},
{"Ġ", "t", "h", "e"},
{"Ġ", "e", "m", "o", "j", "i"},
{"Ġ", "i", "s", "s", "u", "e"},
{"Ġ", "w", "a", "s"},
{"Ġ", "f", "i", "x", "e", "d"},
{"!"},
}
wants := [][]lo.Tuple2[string, string]{
{lo.T2("h", "e"), lo.T2("e", "l"), lo.T2("l", "l"), lo.T2("l", "o")},
{lo.T2("he", "l"), lo.T2("l", "l"), lo.T2("l", "o")},
{lo.T2("hel", "l"), lo.T2("l", "o")},
{lo.T2("hell", "o")},
{},
{lo.T2("Ġ", "ð"), lo.T2("ð", "Ł"), lo.T2("Ł", "ij"), lo.T2("ij", "ĭ")},
{lo.T2("Ġ", "w"), lo.T2("w", "o"), lo.T2("o", "r"), lo.T2("r", "l"), lo.T2("l", "d")},
{lo.T2("Ġ", "ð"), lo.T2("ð", "Ł"), lo.T2("Ł", "Į"), lo.T2("Į", "į")},
{lo.T2("Ġ", "T"), lo.T2("T", "h"), lo.T2("h", "i"), lo.T2("i", "s")},
{lo.T2("Ġ", "i"), lo.T2("i", "s")},
{lo.T2("Ġ", "a")},
{lo.T2("Ġ", "l"), lo.T2("l", "o"), lo.T2("o", "n"), lo.T2("n", "g")},
{lo.T2("Ġ", "s"), lo.T2("s", "t"), lo.T2("t", "r"), lo.T2("r", "i"), lo.T2("i", "n"), lo.T2("n", "g")},
{lo.T2("Ġ", "t"), lo.T2("t", "o")},
{lo.T2("Ġ", "t"), lo.T2("t", "e"), lo.T2("e", "s"), lo.T2("s", "t")},
{lo.T2("Ġ", "w"), lo.T2("w", "h"), lo.T2("h", "e"), lo.T2("e", "t"), lo.T2("t", "h"), lo.T2("e", "r")},
{lo.T2("Ġ", "o"), lo.T2("o", "r")},
{lo.T2("Ġ", "n"), lo.T2("n", "o"), lo.T2("o", "t")},
{lo.T2("Ġ", "t"), lo.T2("t", "h"), lo.T2("h", "e")},
{lo.T2("Ġ", "e"), lo.T2("e", "m"), lo.T2("m", "o"), lo.T2("o", "j"), lo.T2("j", "i")},
{lo.T2("Ġ", "i"), lo.T2("i", "s"), lo.T2("s", "s"), lo.T2("s", "u"), lo.T2("u", "e")},
{lo.T2("Ġ", "w"), lo.T2("w", "a"), lo.T2("a", "s")},
{lo.T2("Ġ", "f"), lo.T2("f", "i"), lo.T2("i", "x"), lo.T2("x", "e"), lo.T2("e", "d")},
{},
}
for i := range words {
want := wants[i]
got := getPairs(words[i])
is.EqualValues(want, got, i)
}
got := getPairs([]string{"hello"})
is.EqualValues([]lo.Tuple2[string, string]{}, got)
}
func TestNewEncoder_bpe(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
cases := []lo.Tuple2[string, string]{
lo.T2("hello", "hello"),
lo.T2("ĠðŁijĭ", "ĠðŁij ĭ"),
lo.T2("Ġworld", "Ġworld"),
lo.T2("ĠðŁĮį", "ĠðŁ Į į"),
lo.T2("ĠThis", "ĠThis"),
lo.T2("Ġis", "Ġis"),
lo.T2("Ġa", "Ġa"),
lo.T2("Ġlong", "Ġlong"),
lo.T2("Ġstring", "Ġstring"),
lo.T2("Ġto", "Ġto"),
lo.T2("Ġtest", "Ġtest"),
lo.T2("Ġwhether", "Ġwhether"),
lo.T2("Ġor", "Ġor"),
lo.T2("Ġnot", "Ġnot"),
lo.T2("Ġthe", "Ġthe"),
lo.T2("Ġemoji", "Ġemoji"),
lo.T2("Ġissue", "Ġissue"),
lo.T2("Ġwas", "Ġwas"),
lo.T2("Ġfixed", "Ġfixed"),
}
for _, c := range cases {
got := encoder.bpe(c.A)
want := c.B
is.EqualValues(want, got)
}
}
func TestNewEncoder_encode(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
want := []int{31373, 995, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0}
got, err := encoder.Encode("hello world This is a long string to test whether or not the emoji issue was fixed!")
is.EqualValues(want, got)
is.Nil(err)
// @TODO
// want = []int{31373, 50169, 233, 995, 12520, 234, 235, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0}
// got, err = encoder.Encode("hello 👋 world 🌍 This is a long string to test whether or not the emoji issue was fixed!")
// is.EqualValues(want, got)
// is.Nil(err)
}
func TestNewEncoder_decode(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
want := "hello world This is a long string to test whether or not the emoji issue was fixed!"
got := encoder.Decode([]int{31373, 995, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0})
is.EqualValues(want, got)
// @TODO
// want = "hello 👋 world 🌍 This is a long string to test whether or not the emoji issue was fixed!"
// got = encoder.Decode([]int{31373, 50169, 233, 995, 12520, 234, 235, 770, 318, 257, 890, 4731, 284, 1332, 1771, 393, 407, 262, 44805, 2071, 373, 5969, 0})
// is.EqualValues(want, got)
}
func TestNewEncoder_e2e(t *testing.T) {
is := assert.New(t)
encoder, err := NewEncoder()
is.Nil(err)
cases := []lo.Tuple2[string, []int]{
// lo.T2("", []int{}), // @TODO
lo.T2(" ", []int{220}),
lo.T2("\t", []int{197}),
lo.T2("This is some text", []int{1212, 318, 617, 2420}),
lo.T2("indivisible", []int{521, 452, 12843}),
// lo.T2("hello 👋 world 🌍", []int{31373, 50169, 233, 995, 12520, 234, 235}), // @TODO
}
for _, c := range cases {
encoded, err := encoder.Encode(c.A)
is.Nil(err)
is.EqualValues(c.B, encoded, c.A)
result := encoder.Decode(encoded)
is.EqualValues(c.A, result, c.A)
}
}

16
go.mod Normal file
View File

@ -0,0 +1,16 @@
module gpt-3-encoder
go 1.18
require (
github.com/dlclark/regexp2 v1.7.0
github.com/samber/lo v1.37.0
github.com/stretchr/testify v1.8.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

23
go.sum Normal file
View File

@ -0,0 +1,23 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo=
github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw=
github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

13
utils.go Normal file
View File

@ -0,0 +1,13 @@
package gpt3encoder
import (
"github.com/samber/lo"
)
func dictTuple(tuples []lo.Tuple2[[]byte, []byte]) map[lo.Tuple2[string, string]]int {
i := -1
return lo.SliceToMap(tuples, func(item lo.Tuple2[[]byte, []byte]) (lo.Tuple2[string, string], int) {
i++
return lo.T2(string(item.A), string(item.B)), i
})
}

50001
vocab.bpe Normal file

File diff suppressed because it is too large Load Diff