Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions arrow/cdata/cdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ func (imp *cimporter) doImportChildren() error {
st := imp.dt.(*arrow.StructType)
for i, c := range children {
imp.children[i].dt = st.Field(i).Type
imp.children[i].importChild(imp, c)
if err := imp.children[i].importChild(imp, c); err != nil {
return err
}
}
case arrow.RUN_END_ENCODED: // import run-ends and values
st := imp.dt.(*arrow.RunEndEncodedType)
Expand All @@ -428,13 +430,17 @@ func (imp *cimporter) doImportChildren() error {
dt := imp.dt.(*arrow.DenseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
if err := imp.children[i].importChild(imp, c); err != nil {
return err
}
}
case arrow.SPARSE_UNION:
dt := imp.dt.(*arrow.SparseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
if err := imp.children[i].importChild(imp, c); err != nil {
return err
}
}
}

Expand All @@ -455,6 +461,10 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error {
imp.alloc = &importAllocator{arr: imp.arr}
}

if err := imp.doImport(); err != nil {
return err
}

// we tie the releasing of the array to when the buffers are
// cleaned up, so if there are no buffers that we've imported
// such as for a null array or a nested array with no bitmap
Expand All @@ -468,26 +478,19 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error {
}
}()

return imp.doImport()
return nil
}

// import is called recursively as needed for importing an array and its children
// in order to generate array.Data objects
func (imp *cimporter) doImport() error {
// move the array from the src object passed in to the one referenced by
// this importer. That way we can set up a finalizer on the created
// arrow.ArrayData object so we clean up our Array's memory when garbage collected.
defer func(arr *CArrowArray) {
// this should only occur in the case of an error happening
// during import, at which point we need to clean up the
// ArrowArray struct we allocated.
if imp.data == nil {
C.free(unsafe.Pointer(arr))
}
}(imp.arr)

// import any children
if err := imp.doImportChildren(); err != nil {
for _, c := range imp.children {
if c.data != nil {
c.data.Release()
}
}
return err
}

Expand Down
126 changes: 124 additions & 2 deletions arrow/cdata/cdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ func createTestDenseUnion() arrow.Array {

func createTestUnionArr(mode arrow.UnionMode) arrow.Array {
fields := []arrow.Field{
arrow.Field{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
arrow.Field{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true},
{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true},
}
typeCodes := []arrow.UnionTypeCode{5, 10}
bld := array.NewBuilder(memory.DefaultAllocator, arrow.UnionOf(mode, fields, typeCodes)).(array.UnionBuilder)
Expand Down Expand Up @@ -785,6 +785,124 @@ func TestRecordBatch(t *testing.T) {
assert.True(t, array.RecordEqual(rb, rec))
}

func TestImportStructWithInvalidSchema(t *testing.T) {
mem := mallocator.NewMallocator()
defer mem.AssertSize(t, 0)

arr := createTestStructArr()
defer arr.Release()

carr := createCArr(arr, mem)
defer freeTestMallocatorArr(carr, mem)

sc := testStruct([]string{"+s", "c", "l"}, []string{"", "a", "b"}, []int64{0, flagIsNullable, flagIsNullable})
defer freeMallocedSchemas(sc)

top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0]
_, err := ImportCRecordBatch(carr, top)
assert.Error(t, err)
}

func TestImportDenseUnionWithInvalidSchema(t *testing.T) {
mem := mallocator.NewMallocator()
defer mem.AssertSize(t, 0)

unionArr := createTestDenseUnion()
defer unionArr.Release()

structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf(
arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false},
))
defer structBld.Release()

unionBld := structBld.FieldBuilder(0).(*array.DenseUnionBuilder)
structBld.Append(true)
du := unionArr.(*array.DenseUnion)
for i := 0; i < du.Len(); i++ {
unionBld.Append(du.TypeCode(i))
if du.TypeCode(i) == 5 {
unionBld.Child(0).(*array.Int32Builder).Append(du.Field(0).(*array.Int32).Value(int(du.ValueOffset(i))))
} else {
unionBld.Child(1).(*array.Uint8Builder).Append(du.Field(1).(*array.Uint8).Value(int(du.ValueOffset(i))))
}
}

structArr := structBld.NewArray()
defer structArr.Release()

carr := createCArr(structArr, mem)
defer freeTestMallocatorArr(carr, mem)

unionSc := testUnion([]string{"+ud:5,10", "i", "u"}, []string{"", "u0", "u1"}, []int64{0, flagIsNullable, flagIsNullable})

structSc := testStruct([]string{"+s", "+ud:5,10"}, []string{"", "union_field"}, []int64{0, 0})
defer freeMallocedSchemas(structSc)

structTop := (*[1]*CArrowSchema)(unsafe.Pointer(structSc))[0]
unionTop := (*[1]*CArrowSchema)(unsafe.Pointer(unionSc))[0]

children := unsafe.Slice(structTop.children, 1)
oldChild := children[0]
children[0] = unionTop

_, err := ImportCRecordBatch(carr, structTop)

children[0] = oldChild

assert.Error(t, err)
}

func TestImportSPARSEUnionWithInvalidSchema(t *testing.T) {
mem := mallocator.NewMallocator()
defer mem.AssertSize(t, 0)

unionArr := createTestSparseUnion()
defer unionArr.Release()

structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf(
arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false},
))
defer structBld.Release()

unionBld := structBld.FieldBuilder(0).(*array.SparseUnionBuilder)
structBld.Append(true)
su := unionArr.(*array.SparseUnion)
for i := 0; i < su.Len(); i++ {
unionBld.Append(su.TypeCode(i))
if su.TypeCode(i) == 5 {
unionBld.Child(0).(*array.Int32Builder).Append(su.Field(0).(*array.Int32).Value(i))
unionBld.Child(1).(*array.Uint8Builder).AppendNull()
} else {
unionBld.Child(0).(*array.Int32Builder).AppendNull()
unionBld.Child(1).(*array.Uint8Builder).Append(su.Field(1).(*array.Uint8).Value(i))
}
}

structArr := structBld.NewArray()
defer structArr.Release()

carr := createCArr(structArr, mem)
defer freeTestMallocatorArr(carr, mem)

unionSc := testUnion([]string{"+us:5,10", "i", "u"}, []string{"", "u0", "u1"}, []int64{0, flagIsNullable, flagIsNullable})

structSc := testStruct([]string{"+s", "+us:5,10"}, []string{"", "union_field"}, []int64{0, 0})
defer freeMallocedSchemas(structSc)

structTop := (*[1]*CArrowSchema)(unsafe.Pointer(structSc))[0]
unionTop := (*[1]*CArrowSchema)(unsafe.Pointer(unionSc))[0]

children := unsafe.Slice(structTop.children, 1)
oldChild := children[0]
children[0] = unionTop

_, err := ImportCRecordBatch(carr, structTop)

children[0] = oldChild

assert.Error(t, err)
}

func TestRecordReaderStream(t *testing.T) {
stream := arrayStreamTest()
defer releaseStreamTest(stream)
Expand Down Expand Up @@ -1006,17 +1124,21 @@ func (r *failingReader) Schema() *arrow.Schema {
}
return arrdata.Records["primitives"][0].Schema()
}

func (r *failingReader) Next() bool {
r.opCount -= 1
return r.opCount > 0
}

func (r *failingReader) RecordBatch() arrow.RecordBatch {
arrdata.Records["primitives"][0].Retain()
return arrdata.Records["primitives"][0]
}

func (r *failingReader) Record() arrow.Record {
return r.RecordBatch()
}

func (r *failingReader) Err() error {
if r.opCount == 0 {
return fmt.Errorf("Expected error message")
Expand Down
Loading