Skip to content

Commit 57527ca

Browse files
committed
Add support for OAuth
Support for OAuth authentication got introduced in Postgres 18.
1 parent b7ffbd3 commit 57527ca

File tree

1 file changed

+119
-38
lines changed

1 file changed

+119
-38
lines changed

conn.go

Lines changed: 119 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"database/sql"
1010
"database/sql/driver"
1111
"encoding/binary"
12+
"encoding/json"
1213
"errors"
1314
"fmt"
1415
"io"
@@ -1143,6 +1144,10 @@ func isDriverSetting(key string) bool {
11431144
return true
11441145
case "password":
11451146
return true
1147+
case "oauth_token":
1148+
return true
1149+
case "oauth_token_file":
1150+
return true
11461151
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni":
11471152
return true
11481153
case "fallback_application_name":
@@ -1290,59 +1295,135 @@ func (cn *conn) auth(r *readBuf, o values) {
12901295
// from the server..
12911296

12921297
case 10:
1293-
sc := scram.NewClient(sha256.New, o["user"], o["password"])
1294-
sc.Step(nil)
1295-
if sc.Err() != nil {
1296-
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1298+
switch saslMethod := r.string(); saslMethod {
1299+
case "SCRAM-SHA-256":
1300+
cn.saslScram(o)
1301+
case "OAUTHBEARER":
1302+
cn.saslOAuth(o)
12971303
}
1298-
scOut := sc.Out()
12991304

1300-
w := cn.writeBuf('p')
1301-
w.string("SCRAM-SHA-256")
1302-
w.int32(len(scOut))
1303-
w.bytes(scOut)
1304-
cn.send(w)
1305+
default:
1306+
errorf("unknown authentication response: %d", code)
1307+
}
1308+
}
13051309

1306-
t, r := cn.recv()
1307-
if t != 'R' {
1308-
errorf("unexpected password response: %q", t)
1309-
}
1310+
func (cn *conn) saslScram(o values) {
1311+
sc := scram.NewClient(sha256.New, o["user"], o["password"])
1312+
sc.Step(nil)
1313+
if sc.Err() != nil {
1314+
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1315+
}
1316+
scOut := sc.Out()
13101317

1311-
if r.int32() != 11 {
1312-
errorf("unexpected authentication response: %q", t)
1313-
}
1318+
w := cn.writeBuf('p')
1319+
w.string("SCRAM-SHA-256")
1320+
w.int32(len(scOut))
1321+
w.bytes(scOut)
1322+
cn.send(w)
13141323

1315-
nextStep := r.next(len(*r))
1316-
sc.Step(nextStep)
1317-
if sc.Err() != nil {
1318-
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1319-
}
1324+
t, r := cn.recv()
1325+
if t != 'R' {
1326+
errorf("unexpected password response: %q", t)
1327+
}
13201328

1321-
scOut = sc.Out()
1322-
w = cn.writeBuf('p')
1323-
w.bytes(scOut)
1324-
cn.send(w)
1329+
if r.int32() != 11 {
1330+
errorf("unexpected authentication response: %q", t)
1331+
}
13251332

1326-
t, r = cn.recv()
1327-
if t != 'R' {
1328-
errorf("unexpected password response: %q", t)
1329-
}
1333+
nextStep := r.next(len(*r))
1334+
sc.Step(nextStep)
1335+
if sc.Err() != nil {
1336+
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1337+
}
13301338

1331-
if r.int32() != 12 {
1332-
errorf("unexpected authentication response: %q", t)
1339+
scOut = sc.Out()
1340+
w = cn.writeBuf('p')
1341+
w.bytes(scOut)
1342+
cn.send(w)
1343+
1344+
t, r = cn.recv()
1345+
if t != 'R' {
1346+
errorf("unexpected password response: %q", t)
1347+
}
1348+
1349+
if r.int32() != 12 {
1350+
errorf("unexpected authentication response: %q", t)
1351+
}
1352+
1353+
nextStep = r.next(len(*r))
1354+
sc.Step(nextStep)
1355+
if sc.Err() != nil {
1356+
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1357+
}
1358+
}
1359+
1360+
func (cn *conn) saslOAuth(o values) {
1361+
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1
1362+
w := cn.writeBuf('p')
1363+
w.string("OAUTHBEARER")
1364+
1365+
token, err := getOAuthToken(o)
1366+
if err != nil {
1367+
errorf("failed to obtain oauth token: %s", err)
1368+
}
1369+
initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01")
1370+
w.int32(len(initialResponse))
1371+
w.bytes(initialResponse)
1372+
cn.send(w)
1373+
1374+
t, r := cn.recv()
1375+
if t != 'R' {
1376+
errorf("unexpected oauth response: %q", t)
1377+
}
1378+
1379+
if code := r.int32(); code != 0 {
1380+
// usually on an authentication error we should get a
1381+
// AuthenticationSASLContinue message
1382+
if code != 11 {
1383+
errorf("unexpected oauth response: %q %d", t, code)
13331384
}
13341385

1335-
nextStep = r.next(len(*r))
1336-
sc.Step(nextStep)
1337-
if sc.Err() != nil {
1338-
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1386+
// the AuthenticationSASLContinue does have an error payload
1387+
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2
1388+
errResponse := struct {
1389+
Status string `json:"status"`
1390+
Scope string `json:"scope"`
1391+
OpenIDConfiguration string `json:"openid-configuration"`
1392+
}{}
1393+
err := json.Unmarshal(*r, &errResponse)
1394+
if err != nil {
1395+
errorf("invalid oauth error response")
13391396
}
13401397

1341-
default:
1342-
errorf("unknown authentication response: %d", code)
1398+
errorf("oauth authentication failed '%s'", errResponse.Status)
1399+
1400+
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3
1401+
// we deliberately don't complete the error messaging sequence as described
1402+
// in 3.2.3 as we're going to close the connection either way
1403+
// w = cn.writeBuf('p')
1404+
// w.int32(1)
1405+
// w.bytes([]byte{0x01})
1406+
// cn.send(w)
13431407
}
13441408
}
13451409

1410+
func getOAuthToken(o values) (string, error) {
1411+
if token, ok := o["oauth_token"]; ok {
1412+
return token, nil
1413+
}
1414+
1415+
if tokenFile, ok := o["oauth_token_file"]; ok {
1416+
rawToken, err := os.ReadFile(tokenFile)
1417+
if err != nil {
1418+
return "", err
1419+
}
1420+
rawToken = bytes.TrimSuffix(rawToken, []byte("\n"))
1421+
return string(rawToken), nil
1422+
}
1423+
1424+
return "", fmt.Errorf("no oauth token configured")
1425+
}
1426+
13461427
type format int
13471428

13481429
const formatText format = 0

0 commit comments

Comments
 (0)