Skip to content

Commit f625d55

Browse files
committed
check error support pgx
1 parent 5ccc9a7 commit f625d55

File tree

2 files changed

+73
-26
lines changed

2 files changed

+73
-26
lines changed

error.go

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package pgsql
22

33
import (
44
"errors"
5+
"reflect"
56
"regexp"
67

78
"github.com/lib/pq"
@@ -41,14 +42,13 @@ func IsErrorClass(err error, class string) bool {
4142
// IsUniqueViolation checks is error an unique_violation with given constraint,
4243
// constraint can be empty to ignore constraint name checks
4344
func IsUniqueViolation(err error, constraint ...string) bool {
44-
var pqErr *pq.Error
45-
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
46-
if len(constraint) == 0 {
47-
return true
48-
}
49-
return contains(constraint, extractConstraint(pqErr))
45+
if !IsErrorCode(err, "23505") { // for drivers that implement sqlState
46+
return false
5047
}
51-
return false
48+
if len(constraint) == 0 {
49+
return true
50+
}
51+
return contains(constraint, extractConstraint(err))
5252
}
5353

5454
// IsInvalidTextRepresentation checks is error an invalid_text_representation
@@ -61,16 +61,15 @@ func IsCharacterNotInRepertoire(err error) bool {
6161
return IsErrorCode(err, "22021")
6262
}
6363

64-
// IsForeignKeyViolation checks is error an foreign_key_violation
64+
// IsForeignKeyViolation checks is error a foreign_key_violation
6565
func IsForeignKeyViolation(err error, constraint ...string) bool {
66-
var pqErr *pq.Error
67-
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
68-
if len(constraint) == 0 {
69-
return true
70-
}
71-
return contains(constraint, extractConstraint(pqErr))
66+
if !IsErrorCode(err, "23503") { // for drivers that implement sqlState
67+
return false
7268
}
73-
return false
69+
if len(constraint) == 0 {
70+
return true
71+
}
72+
return contains(constraint, extractConstraint(err))
7473
}
7574

7675
// IsQueryCanceled checks is error an query_canceled error
@@ -85,19 +84,39 @@ func IsSerializationFailure(err error) bool {
8584
return IsErrorCode(err, "40001")
8685
}
8786

88-
func extractConstraint(err *pq.Error) string {
89-
if err.Constraint != "" {
90-
return err.Constraint
91-
}
92-
if err.Message == "" {
93-
return ""
94-
}
95-
if s := extractCRDBKey(err.Message); s != "" {
96-
return s
87+
func extractConstraint(err error) string {
88+
{ // pq
89+
var pqErr *pq.Error
90+
if errors.As(err, &pqErr) {
91+
if pqErr.Constraint != "" {
92+
return pqErr.Constraint
93+
}
94+
if pqErr.Message == "" {
95+
return ""
96+
}
97+
if s := extractCRDBKey(pqErr.Message); s != "" {
98+
return s
99+
}
100+
if s := extractLastQuote(pqErr.Message); s != "" {
101+
return s
102+
}
103+
return ""
104+
}
97105
}
98-
if s := extractLastQuote(err.Message); s != "" {
99-
return s
106+
107+
{ // pgx
108+
v := reflect.ValueOf(err)
109+
if v.Kind() == reflect.Ptr {
110+
v = v.Elem()
111+
}
112+
if v.Kind() != reflect.Struct {
113+
return ""
114+
}
115+
if f := v.FieldByName("ConstraintName"); f.IsValid() {
116+
return f.String()
117+
}
100118
}
119+
101120
return ""
102121
}
103122

error_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@ import (
1111
"github.com/acoshift/pgsql"
1212
)
1313

14+
type pgxError struct {
15+
Code string
16+
ConstraintName string
17+
}
18+
19+
func (e *pgxError) Error() string {
20+
return "pgxError"
21+
}
22+
23+
func (e *pgxError) SQLState() string {
24+
return e.Code
25+
}
26+
1427
func TestIsUniqueViolation(t *testing.T) {
1528
t.Parallel()
1629

@@ -45,6 +58,21 @@ func TestIsUniqueViolation(t *testing.T) {
4558
Table: "users",
4659
Constraint: "users_email_key",
4760
}))
61+
62+
assert.True(t, pgsql.IsUniqueViolation(&pgxError{
63+
Code: "23505",
64+
ConstraintName: "users_email_key",
65+
}))
66+
67+
assert.True(t, pgsql.IsUniqueViolation(&pgxError{
68+
Code: "23505",
69+
ConstraintName: "users_email_key",
70+
}, "users_email_key"))
71+
72+
assert.False(t, pgsql.IsUniqueViolation(&pgxError{
73+
Code: "23505",
74+
ConstraintName: "users_email_key",
75+
}, "pkey"))
4876
}
4977

5078
func TestIsForeignKeyViolation(t *testing.T) {

0 commit comments

Comments
 (0)