copy_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. package pq
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "database/sql/driver"
  6. "net"
  7. "strings"
  8. "testing"
  9. )
  10. func TestCopyInStmt(t *testing.T) {
  11. stmt := CopyIn("table name")
  12. if stmt != `COPY "table name" () FROM STDIN` {
  13. t.Fatal(stmt)
  14. }
  15. stmt = CopyIn("table name", "column 1", "column 2")
  16. if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` {
  17. t.Fatal(stmt)
  18. }
  19. stmt = CopyIn(`table " name """`, `co"lumn""`)
  20. if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` {
  21. t.Fatal(stmt)
  22. }
  23. }
  24. func TestCopyInSchemaStmt(t *testing.T) {
  25. stmt := CopyInSchema("schema name", "table name")
  26. if stmt != `COPY "schema name"."table name" () FROM STDIN` {
  27. t.Fatal(stmt)
  28. }
  29. stmt = CopyInSchema("schema name", "table name", "column 1", "column 2")
  30. if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` {
  31. t.Fatal(stmt)
  32. }
  33. stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`)
  34. if stmt != `COPY "schema "" name """"""".`+
  35. `"table "" name """"""" ("co""lumn""""") FROM STDIN` {
  36. t.Fatal(stmt)
  37. }
  38. }
  39. func TestCopyInMultipleValues(t *testing.T) {
  40. db := openTestConn(t)
  41. defer db.Close()
  42. txn, err := db.Begin()
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. defer txn.Rollback()
  47. _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
  48. if err != nil {
  49. t.Fatal(err)
  50. }
  51. stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
  52. if err != nil {
  53. t.Fatal(err)
  54. }
  55. longString := strings.Repeat("#", 500)
  56. for i := 0; i < 500; i++ {
  57. _, err = stmt.Exec(int64(i), longString)
  58. if err != nil {
  59. t.Fatal(err)
  60. }
  61. }
  62. _, err = stmt.Exec()
  63. if err != nil {
  64. t.Fatal(err)
  65. }
  66. err = stmt.Close()
  67. if err != nil {
  68. t.Fatal(err)
  69. }
  70. var num int
  71. err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
  72. if err != nil {
  73. t.Fatal(err)
  74. }
  75. if num != 500 {
  76. t.Fatalf("expected 500 items, not %d", num)
  77. }
  78. }
  79. func TestCopyInRaiseStmtTrigger(t *testing.T) {
  80. db := openTestConn(t)
  81. defer db.Close()
  82. if getServerVersion(t, db) < 90000 {
  83. var exists int
  84. err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists)
  85. if err == sql.ErrNoRows {
  86. t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger")
  87. } else if err != nil {
  88. t.Fatal(err)
  89. }
  90. }
  91. txn, err := db.Begin()
  92. if err != nil {
  93. t.Fatal(err)
  94. }
  95. defer txn.Rollback()
  96. _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
  97. if err != nil {
  98. t.Fatal(err)
  99. }
  100. _, err = txn.Exec(`
  101. CREATE OR REPLACE FUNCTION pg_temp.temptest()
  102. RETURNS trigger AS
  103. $BODY$ begin
  104. raise notice 'Hello world';
  105. return new;
  106. end $BODY$
  107. LANGUAGE plpgsql`)
  108. if err != nil {
  109. t.Fatal(err)
  110. }
  111. _, err = txn.Exec(`
  112. CREATE TRIGGER temptest_trigger
  113. BEFORE INSERT
  114. ON temp
  115. FOR EACH ROW
  116. EXECUTE PROCEDURE pg_temp.temptest()`)
  117. if err != nil {
  118. t.Fatal(err)
  119. }
  120. stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
  121. if err != nil {
  122. t.Fatal(err)
  123. }
  124. longString := strings.Repeat("#", 500)
  125. _, err = stmt.Exec(int64(1), longString)
  126. if err != nil {
  127. t.Fatal(err)
  128. }
  129. _, err = stmt.Exec()
  130. if err != nil {
  131. t.Fatal(err)
  132. }
  133. err = stmt.Close()
  134. if err != nil {
  135. t.Fatal(err)
  136. }
  137. var num int
  138. err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
  139. if err != nil {
  140. t.Fatal(err)
  141. }
  142. if num != 1 {
  143. t.Fatalf("expected 1 items, not %d", num)
  144. }
  145. }
  146. func TestCopyInTypes(t *testing.T) {
  147. db := openTestConn(t)
  148. defer db.Close()
  149. txn, err := db.Begin()
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. defer txn.Rollback()
  154. _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)")
  155. if err != nil {
  156. t.Fatal(err)
  157. }
  158. stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing"))
  159. if err != nil {
  160. t.Fatal(err)
  161. }
  162. _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil)
  163. if err != nil {
  164. t.Fatal(err)
  165. }
  166. _, err = stmt.Exec()
  167. if err != nil {
  168. t.Fatal(err)
  169. }
  170. err = stmt.Close()
  171. if err != nil {
  172. t.Fatal(err)
  173. }
  174. var num int
  175. var text string
  176. var blob []byte
  177. var nothing sql.NullString
  178. err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, &nothing)
  179. if err != nil {
  180. t.Fatal(err)
  181. }
  182. if num != 1234567890 {
  183. t.Fatal("unexpected result", num)
  184. }
  185. if text != "Héllö\n ☃!\r\t\\" {
  186. t.Fatal("unexpected result", text)
  187. }
  188. if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) {
  189. t.Fatal("unexpected result", blob)
  190. }
  191. if nothing.Valid {
  192. t.Fatal("unexpected result", nothing.String)
  193. }
  194. }
  195. func TestCopyInWrongType(t *testing.T) {
  196. db := openTestConn(t)
  197. defer db.Close()
  198. txn, err := db.Begin()
  199. if err != nil {
  200. t.Fatal(err)
  201. }
  202. defer txn.Rollback()
  203. _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
  204. if err != nil {
  205. t.Fatal(err)
  206. }
  207. stmt, err := txn.Prepare(CopyIn("temp", "num"))
  208. if err != nil {
  209. t.Fatal(err)
  210. }
  211. defer stmt.Close()
  212. _, err = stmt.Exec("Héllö\n ☃!\r\t\\")
  213. if err != nil {
  214. t.Fatal(err)
  215. }
  216. _, err = stmt.Exec()
  217. if err == nil {
  218. t.Fatal("expected error")
  219. }
  220. if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" {
  221. t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge)
  222. }
  223. }
  224. func TestCopyOutsideOfTxnError(t *testing.T) {
  225. db := openTestConn(t)
  226. defer db.Close()
  227. _, err := db.Prepare(CopyIn("temp", "num"))
  228. if err == nil {
  229. t.Fatal("COPY outside of transaction did not return an error")
  230. }
  231. if err != errCopyNotSupportedOutsideTxn {
  232. t.Fatalf("expected %s, got %s", err, err.Error())
  233. }
  234. }
  235. func TestCopyInBinaryError(t *testing.T) {
  236. db := openTestConn(t)
  237. defer db.Close()
  238. txn, err := db.Begin()
  239. if err != nil {
  240. t.Fatal(err)
  241. }
  242. defer txn.Rollback()
  243. _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
  244. if err != nil {
  245. t.Fatal(err)
  246. }
  247. _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary")
  248. if err != errBinaryCopyNotSupported {
  249. t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err)
  250. }
  251. // check that the protocol is in a valid state
  252. err = txn.Rollback()
  253. if err != nil {
  254. t.Fatal(err)
  255. }
  256. }
  257. func TestCopyFromError(t *testing.T) {
  258. db := openTestConn(t)
  259. defer db.Close()
  260. txn, err := db.Begin()
  261. if err != nil {
  262. t.Fatal(err)
  263. }
  264. defer txn.Rollback()
  265. _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
  266. if err != nil {
  267. t.Fatal(err)
  268. }
  269. _, err = txn.Prepare("COPY temp (num) TO STDOUT")
  270. if err != errCopyToNotSupported {
  271. t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err)
  272. }
  273. // check that the protocol is in a valid state
  274. err = txn.Rollback()
  275. if err != nil {
  276. t.Fatal(err)
  277. }
  278. }
  279. func TestCopySyntaxError(t *testing.T) {
  280. db := openTestConn(t)
  281. defer db.Close()
  282. txn, err := db.Begin()
  283. if err != nil {
  284. t.Fatal(err)
  285. }
  286. defer txn.Rollback()
  287. _, err = txn.Prepare("COPY ")
  288. if err == nil {
  289. t.Fatal("expected error")
  290. }
  291. if pge := err.(*Error); pge.Code.Name() != "syntax_error" {
  292. t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge)
  293. }
  294. // check that the protocol is in a valid state
  295. err = txn.Rollback()
  296. if err != nil {
  297. t.Fatal(err)
  298. }
  299. }
  300. // Tests for connection errors in copyin.resploop()
  301. func TestCopyRespLoopConnectionError(t *testing.T) {
  302. db := openTestConn(t)
  303. defer db.Close()
  304. txn, err := db.Begin()
  305. if err != nil {
  306. t.Fatal(err)
  307. }
  308. defer txn.Rollback()
  309. var pid int
  310. err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid)
  311. if err != nil {
  312. t.Fatal(err)
  313. }
  314. _, err = txn.Exec("CREATE TEMP TABLE temp (a int)")
  315. if err != nil {
  316. t.Fatal(err)
  317. }
  318. stmt, err := txn.Prepare(CopyIn("temp", "a"))
  319. if err != nil {
  320. t.Fatal(err)
  321. }
  322. defer stmt.Close()
  323. _, err = db.Exec("SELECT pg_terminate_backend($1)", pid)
  324. if err != nil {
  325. t.Fatal(err)
  326. }
  327. if getServerVersion(t, db) < 90500 {
  328. // We have to try and send something over, since postgres before
  329. // version 9.5 won't process SIGTERMs while it's waiting for
  330. // CopyData/CopyEnd messages; see tcop/postgres.c.
  331. _, err = stmt.Exec(1)
  332. if err != nil {
  333. t.Fatal(err)
  334. }
  335. }
  336. _, err = stmt.Exec()
  337. if err == nil {
  338. t.Fatalf("expected error")
  339. }
  340. switch pge := err.(type) {
  341. case *Error:
  342. if pge.Code.Name() != "admin_shutdown" {
  343. t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name())
  344. }
  345. case *net.OpError:
  346. // ignore
  347. default:
  348. if err == driver.ErrBadConn {
  349. // likely an EPIPE
  350. } else {
  351. t.Fatalf("unexpected error, got %+#v", err)
  352. }
  353. }
  354. _ = stmt.Close()
  355. }
  356. func BenchmarkCopyIn(b *testing.B) {
  357. db := openTestConn(b)
  358. defer db.Close()
  359. txn, err := db.Begin()
  360. if err != nil {
  361. b.Fatal(err)
  362. }
  363. defer txn.Rollback()
  364. _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
  365. if err != nil {
  366. b.Fatal(err)
  367. }
  368. stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
  369. if err != nil {
  370. b.Fatal(err)
  371. }
  372. for i := 0; i < b.N; i++ {
  373. _, err = stmt.Exec(int64(i), "hello world!")
  374. if err != nil {
  375. b.Fatal(err)
  376. }
  377. }
  378. _, err = stmt.Exec()
  379. if err != nil {
  380. b.Fatal(err)
  381. }
  382. err = stmt.Close()
  383. if err != nil {
  384. b.Fatal(err)
  385. }
  386. var num int
  387. err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
  388. if err != nil {
  389. b.Fatal(err)
  390. }
  391. if num != b.N {
  392. b.Fatalf("expected %d items, not %d", b.N, num)
  393. }
  394. }