349 lines
7.5 KiB
Go
349 lines
7.5 KiB
Go
|
package pq
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"database/sql/driver"
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"sync"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
errCopyInClosed = errors.New("pq: copyin statement has already been closed")
|
||
|
errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
|
||
|
errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
|
||
|
errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
|
||
|
errCopyInProgress = errors.New("pq: COPY in progress")
|
||
|
)
|
||
|
|
||
|
// CopyIn creates a COPY FROM statement which can be prepared with
|
||
|
// Tx.Prepare(). The target table should be visible in search_path.
|
||
|
func CopyIn(table string, columns ...string) string {
|
||
|
buffer := bytes.NewBufferString("COPY ")
|
||
|
BufferQuoteIdentifier(table, buffer)
|
||
|
buffer.WriteString(" (")
|
||
|
makeStmt(buffer, columns...)
|
||
|
return buffer.String()
|
||
|
}
|
||
|
|
||
|
// MakeStmt makes the stmt string for CopyIn and CopyInSchema.
|
||
|
func makeStmt(buffer *bytes.Buffer, columns ...string) {
|
||
|
//s := bytes.NewBufferString()
|
||
|
for i, col := range columns {
|
||
|
if i != 0 {
|
||
|
buffer.WriteString(", ")
|
||
|
}
|
||
|
BufferQuoteIdentifier(col, buffer)
|
||
|
}
|
||
|
buffer.WriteString(") FROM STDIN")
|
||
|
}
|
||
|
|
||
|
// CopyInSchema creates a COPY FROM statement which can be prepared with
|
||
|
// Tx.Prepare().
|
||
|
func CopyInSchema(schema, table string, columns ...string) string {
|
||
|
buffer := bytes.NewBufferString("COPY ")
|
||
|
BufferQuoteIdentifier(schema, buffer)
|
||
|
buffer.WriteRune('.')
|
||
|
BufferQuoteIdentifier(table, buffer)
|
||
|
buffer.WriteString(" (")
|
||
|
makeStmt(buffer, columns...)
|
||
|
return buffer.String()
|
||
|
}
|
||
|
|
||
|
type copyin struct {
|
||
|
cn *conn
|
||
|
buffer []byte
|
||
|
rowData chan []byte
|
||
|
done chan bool
|
||
|
|
||
|
closed bool
|
||
|
|
||
|
mu struct {
|
||
|
sync.Mutex
|
||
|
err error
|
||
|
driver.Result
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const ciBufferSize = 64 * 1024
|
||
|
|
||
|
// flush buffer before the buffer is filled up and needs reallocation
|
||
|
const ciBufferFlushSize = 63 * 1024
|
||
|
|
||
|
func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
|
||
|
if !cn.isInTransaction() {
|
||
|
return nil, errCopyNotSupportedOutsideTxn
|
||
|
}
|
||
|
|
||
|
ci := ©in{
|
||
|
cn: cn,
|
||
|
buffer: make([]byte, 0, ciBufferSize),
|
||
|
rowData: make(chan []byte),
|
||
|
done: make(chan bool, 1),
|
||
|
}
|
||
|
// add CopyData identifier + 4 bytes for message length
|
||
|
ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
|
||
|
|
||
|
b := cn.writeBuf('Q')
|
||
|
b.string(q)
|
||
|
cn.send(b)
|
||
|
|
||
|
awaitCopyInResponse:
|
||
|
for {
|
||
|
t, r := cn.recv1()
|
||
|
switch t {
|
||
|
case 'G':
|
||
|
if r.byte() != 0 {
|
||
|
err = errBinaryCopyNotSupported
|
||
|
break awaitCopyInResponse
|
||
|
}
|
||
|
go ci.resploop()
|
||
|
return ci, nil
|
||
|
case 'H':
|
||
|
err = errCopyToNotSupported
|
||
|
break awaitCopyInResponse
|
||
|
case 'E':
|
||
|
err = parseError(r)
|
||
|
case 'Z':
|
||
|
if err == nil {
|
||
|
ci.setBad(driver.ErrBadConn)
|
||
|
errorf("unexpected ReadyForQuery in response to COPY")
|
||
|
}
|
||
|
cn.processReadyForQuery(r)
|
||
|
return nil, err
|
||
|
default:
|
||
|
ci.setBad(driver.ErrBadConn)
|
||
|
errorf("unknown response for copy query: %q", t)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// something went wrong, abort COPY before we return
|
||
|
b = cn.writeBuf('f')
|
||
|
b.string(err.Error())
|
||
|
cn.send(b)
|
||
|
|
||
|
for {
|
||
|
t, r := cn.recv1()
|
||
|
switch t {
|
||
|
case 'c', 'C', 'E':
|
||
|
case 'Z':
|
||
|
// correctly aborted, we're done
|
||
|
cn.processReadyForQuery(r)
|
||
|
return nil, err
|
||
|
default:
|
||
|
ci.setBad(driver.ErrBadConn)
|
||
|
errorf("unknown response for CopyFail: %q", t)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) flush(buf []byte) {
|
||
|
// set message length (without message identifier)
|
||
|
binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
|
||
|
|
||
|
_, err := ci.cn.c.Write(buf)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) resploop() {
|
||
|
for {
|
||
|
var r readBuf
|
||
|
t, err := ci.cn.recvMessage(&r)
|
||
|
if err != nil {
|
||
|
ci.setBad(driver.ErrBadConn)
|
||
|
ci.setError(err)
|
||
|
ci.done <- true
|
||
|
return
|
||
|
}
|
||
|
switch t {
|
||
|
case 'C':
|
||
|
// complete
|
||
|
res, _ := ci.cn.parseComplete(r.string())
|
||
|
ci.setResult(res)
|
||
|
case 'N':
|
||
|
if n := ci.cn.noticeHandler; n != nil {
|
||
|
n(parseError(&r))
|
||
|
}
|
||
|
case 'Z':
|
||
|
ci.cn.processReadyForQuery(&r)
|
||
|
ci.done <- true
|
||
|
return
|
||
|
case 'E':
|
||
|
err := parseError(&r)
|
||
|
ci.setError(err)
|
||
|
default:
|
||
|
ci.setBad(driver.ErrBadConn)
|
||
|
ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
|
||
|
ci.done <- true
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) setBad(err error) {
|
||
|
ci.cn.err.set(err)
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) getBad() error {
|
||
|
return ci.cn.err.get()
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) err() error {
|
||
|
ci.mu.Lock()
|
||
|
err := ci.mu.err
|
||
|
ci.mu.Unlock()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// setError() sets ci.err if one has not been set already. Caller must not be
|
||
|
// holding ci.Mutex.
|
||
|
func (ci *copyin) setError(err error) {
|
||
|
ci.mu.Lock()
|
||
|
if ci.mu.err == nil {
|
||
|
ci.mu.err = err
|
||
|
}
|
||
|
ci.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) setResult(result driver.Result) {
|
||
|
ci.mu.Lock()
|
||
|
ci.mu.Result = result
|
||
|
ci.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) getResult() driver.Result {
|
||
|
ci.mu.Lock()
|
||
|
result := ci.mu.Result
|
||
|
ci.mu.Unlock()
|
||
|
if result == nil {
|
||
|
return driver.RowsAffected(0)
|
||
|
}
|
||
|
return result
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) NumInput() int {
|
||
|
return -1
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
|
||
|
return nil, ErrNotSupported
|
||
|
}
|
||
|
|
||
|
// Exec inserts values into the COPY stream. The insert is asynchronous
|
||
|
// and Exec can return errors from previous Exec calls to the same
|
||
|
// COPY stmt.
|
||
|
//
|
||
|
// You need to call Exec(nil) to sync the COPY stream and to get any
|
||
|
// errors from pending data, since Stmt.Close() doesn't return errors
|
||
|
// to the user.
|
||
|
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
|
||
|
if ci.closed {
|
||
|
return nil, errCopyInClosed
|
||
|
}
|
||
|
|
||
|
if err := ci.getBad(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer ci.cn.errRecover(&err)
|
||
|
|
||
|
if err := ci.err(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if len(v) == 0 {
|
||
|
if err := ci.Close(); err != nil {
|
||
|
return driver.RowsAffected(0), err
|
||
|
}
|
||
|
|
||
|
return ci.getResult(), nil
|
||
|
}
|
||
|
|
||
|
numValues := len(v)
|
||
|
for i, value := range v {
|
||
|
ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
|
||
|
if i < numValues-1 {
|
||
|
ci.buffer = append(ci.buffer, '\t')
|
||
|
}
|
||
|
}
|
||
|
|
||
|
ci.buffer = append(ci.buffer, '\n')
|
||
|
|
||
|
if len(ci.buffer) > ciBufferFlushSize {
|
||
|
ci.flush(ci.buffer)
|
||
|
// reset buffer, keep bytes for message identifier and length
|
||
|
ci.buffer = ci.buffer[:5]
|
||
|
}
|
||
|
|
||
|
return driver.RowsAffected(0), nil
|
||
|
}
|
||
|
|
||
|
// CopyData inserts a raw string into the COPY stream. The insert is
|
||
|
// asynchronous and CopyData can return errors from previous CopyData calls to
|
||
|
// the same COPY stmt.
|
||
|
//
|
||
|
// You need to call Exec(nil) to sync the COPY stream and to get any
|
||
|
// errors from pending data, since Stmt.Close() doesn't return errors
|
||
|
// to the user.
|
||
|
func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) {
|
||
|
if ci.closed {
|
||
|
return nil, errCopyInClosed
|
||
|
}
|
||
|
|
||
|
if finish := ci.cn.watchCancel(ctx); finish != nil {
|
||
|
defer finish()
|
||
|
}
|
||
|
|
||
|
if err := ci.getBad(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer ci.cn.errRecover(&err)
|
||
|
|
||
|
if err := ci.err(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
ci.buffer = append(ci.buffer, []byte(line)...)
|
||
|
ci.buffer = append(ci.buffer, '\n')
|
||
|
|
||
|
if len(ci.buffer) > ciBufferFlushSize {
|
||
|
ci.flush(ci.buffer)
|
||
|
// reset buffer, keep bytes for message identifier and length
|
||
|
ci.buffer = ci.buffer[:5]
|
||
|
}
|
||
|
|
||
|
return driver.RowsAffected(0), nil
|
||
|
}
|
||
|
|
||
|
func (ci *copyin) Close() (err error) {
|
||
|
if ci.closed { // Don't do anything, we're already closed
|
||
|
return nil
|
||
|
}
|
||
|
ci.closed = true
|
||
|
|
||
|
if err := ci.getBad(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer ci.cn.errRecover(&err)
|
||
|
|
||
|
if len(ci.buffer) > 0 {
|
||
|
ci.flush(ci.buffer)
|
||
|
}
|
||
|
// Avoid touching the scratch buffer as resploop could be using it.
|
||
|
err = ci.cn.sendSimpleMessage('c')
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
<-ci.done
|
||
|
ci.cn.inCopy = false
|
||
|
|
||
|
if err := ci.err(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|