|
9 | 9 | "database/sql" |
10 | 10 | "database/sql/driver" |
11 | 11 | "encoding/binary" |
| 12 | + "encoding/json" |
12 | 13 | "errors" |
13 | 14 | "fmt" |
14 | 15 | "io" |
@@ -1143,6 +1144,10 @@ func isDriverSetting(key string) bool { |
1143 | 1144 | return true |
1144 | 1145 | case "password": |
1145 | 1146 | return true |
| 1147 | + case "oauth_token": |
| 1148 | + return true |
| 1149 | + case "oauth_token_file": |
| 1150 | + return true |
1146 | 1151 | case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": |
1147 | 1152 | return true |
1148 | 1153 | case "fallback_application_name": |
@@ -1290,59 +1295,135 @@ func (cn *conn) auth(r *readBuf, o values) { |
1290 | 1295 | // from the server.. |
1291 | 1296 |
|
1292 | 1297 | case 10: |
1293 | | - sc := scram.NewClient(sha256.New, o["user"], o["password"]) |
1294 | | - sc.Step(nil) |
1295 | | - if sc.Err() != nil { |
1296 | | - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1298 | + switch saslMethod := r.string(); saslMethod { |
| 1299 | + case "SCRAM-SHA-256": |
| 1300 | + cn.saslScram(o) |
| 1301 | + case "OAUTHBEARER": |
| 1302 | + cn.saslOAuth(o) |
1297 | 1303 | } |
1298 | | - scOut := sc.Out() |
1299 | 1304 |
|
1300 | | - w := cn.writeBuf('p') |
1301 | | - w.string("SCRAM-SHA-256") |
1302 | | - w.int32(len(scOut)) |
1303 | | - w.bytes(scOut) |
1304 | | - cn.send(w) |
| 1305 | + default: |
| 1306 | + errorf("unknown authentication response: %d", code) |
| 1307 | + } |
| 1308 | +} |
1305 | 1309 |
|
1306 | | - t, r := cn.recv() |
1307 | | - if t != 'R' { |
1308 | | - errorf("unexpected password response: %q", t) |
1309 | | - } |
| 1310 | +func (cn *conn) saslScram(o values) { |
| 1311 | + sc := scram.NewClient(sha256.New, o["user"], o["password"]) |
| 1312 | + sc.Step(nil) |
| 1313 | + if sc.Err() != nil { |
| 1314 | + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1315 | + } |
| 1316 | + scOut := sc.Out() |
1310 | 1317 |
|
1311 | | - if r.int32() != 11 { |
1312 | | - errorf("unexpected authentication response: %q", t) |
1313 | | - } |
| 1318 | + w := cn.writeBuf('p') |
| 1319 | + w.string("SCRAM-SHA-256") |
| 1320 | + w.int32(len(scOut)) |
| 1321 | + w.bytes(scOut) |
| 1322 | + cn.send(w) |
1314 | 1323 |
|
1315 | | - nextStep := r.next(len(*r)) |
1316 | | - sc.Step(nextStep) |
1317 | | - if sc.Err() != nil { |
1318 | | - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
1319 | | - } |
| 1324 | + t, r := cn.recv() |
| 1325 | + if t != 'R' { |
| 1326 | + errorf("unexpected password response: %q", t) |
| 1327 | + } |
1320 | 1328 |
|
1321 | | - scOut = sc.Out() |
1322 | | - w = cn.writeBuf('p') |
1323 | | - w.bytes(scOut) |
1324 | | - cn.send(w) |
| 1329 | + if r.int32() != 11 { |
| 1330 | + errorf("unexpected authentication response: %q", t) |
| 1331 | + } |
1325 | 1332 |
|
1326 | | - t, r = cn.recv() |
1327 | | - if t != 'R' { |
1328 | | - errorf("unexpected password response: %q", t) |
1329 | | - } |
| 1333 | + nextStep := r.next(len(*r)) |
| 1334 | + sc.Step(nextStep) |
| 1335 | + if sc.Err() != nil { |
| 1336 | + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1337 | + } |
1330 | 1338 |
|
1331 | | - if r.int32() != 12 { |
1332 | | - errorf("unexpected authentication response: %q", t) |
| 1339 | + scOut = sc.Out() |
| 1340 | + w = cn.writeBuf('p') |
| 1341 | + w.bytes(scOut) |
| 1342 | + cn.send(w) |
| 1343 | + |
| 1344 | + t, r = cn.recv() |
| 1345 | + if t != 'R' { |
| 1346 | + errorf("unexpected password response: %q", t) |
| 1347 | + } |
| 1348 | + |
| 1349 | + if r.int32() != 12 { |
| 1350 | + errorf("unexpected authentication response: %q", t) |
| 1351 | + } |
| 1352 | + |
| 1353 | + nextStep = r.next(len(*r)) |
| 1354 | + sc.Step(nextStep) |
| 1355 | + if sc.Err() != nil { |
| 1356 | + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1357 | + } |
| 1358 | +} |
| 1359 | + |
| 1360 | +func (cn *conn) saslOAuth(o values) { |
| 1361 | + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1 |
| 1362 | + w := cn.writeBuf('p') |
| 1363 | + w.string("OAUTHBEARER") |
| 1364 | + |
| 1365 | + token, err := getOAuthToken(o) |
| 1366 | + if err != nil { |
| 1367 | + errorf("failed to obtain oauth token: %s", err) |
| 1368 | + } |
| 1369 | + initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01") |
| 1370 | + w.int32(len(initialResponse)) |
| 1371 | + w.bytes(initialResponse) |
| 1372 | + cn.send(w) |
| 1373 | + |
| 1374 | + t, r := cn.recv() |
| 1375 | + if t != 'R' { |
| 1376 | + errorf("unexpected oauth response: %q", t) |
| 1377 | + } |
| 1378 | + |
| 1379 | + if code := r.int32(); code != 0 { |
| 1380 | + // usually on an authentication error we should get a |
| 1381 | + // AuthenticationSASLContinue message |
| 1382 | + if code != 11 { |
| 1383 | + errorf("unexpected oauth response: %q %d", t, code) |
1333 | 1384 | } |
1334 | 1385 |
|
1335 | | - nextStep = r.next(len(*r)) |
1336 | | - sc.Step(nextStep) |
1337 | | - if sc.Err() != nil { |
1338 | | - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1386 | + // the AuthenticationSASLContinue does have an error payload |
| 1387 | + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2 |
| 1388 | + errResponse := struct { |
| 1389 | + Status string `json:"status"` |
| 1390 | + Scope string `json:"scope"` |
| 1391 | + OpenIDConfiguration string `json:"openid-configuration"` |
| 1392 | + }{} |
| 1393 | + err := json.Unmarshal(*r, &errResponse) |
| 1394 | + if err != nil { |
| 1395 | + errorf("invalid oauth error response") |
1339 | 1396 | } |
1340 | 1397 |
|
1341 | | - default: |
1342 | | - errorf("unknown authentication response: %d", code) |
| 1398 | + errorf("oauth authentication failed '%s'", errResponse.Status) |
| 1399 | + |
| 1400 | + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3 |
| 1401 | + // we deliberately don't complete the error messaging sequence as described |
| 1402 | + // in 3.2.3 as we're going to close the connection either way |
| 1403 | + // w = cn.writeBuf('p') |
| 1404 | + // w.int32(1) |
| 1405 | + // w.bytes([]byte{0x01}) |
| 1406 | + // cn.send(w) |
1343 | 1407 | } |
1344 | 1408 | } |
1345 | 1409 |
|
| 1410 | +func getOAuthToken(o values) (string, error) { |
| 1411 | + if token, ok := o["oauth_token"]; ok { |
| 1412 | + return token, nil |
| 1413 | + } |
| 1414 | + |
| 1415 | + if tokenFile, ok := o["oauth_token_file"]; ok { |
| 1416 | + rawToken, err := os.ReadFile(tokenFile) |
| 1417 | + if err != nil { |
| 1418 | + return "", err |
| 1419 | + } |
| 1420 | + rawToken = bytes.TrimSuffix(rawToken, []byte("\n")) |
| 1421 | + return string(rawToken), nil |
| 1422 | + } |
| 1423 | + |
| 1424 | + return "", fmt.Errorf("no oauth token configured") |
| 1425 | +} |
| 1426 | + |
1346 | 1427 | type format int |
1347 | 1428 |
|
1348 | 1429 | const formatText format = 0 |
|
0 commit comments