diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go index 4085ed3d..314cc7a0 100644 --- a/arrow/cdata/cdata.go +++ b/arrow/cdata/cdata.go @@ -407,7 +407,9 @@ func (imp *cimporter) doImportChildren() error { st := imp.dt.(*arrow.StructType) for i, c := range children { imp.children[i].dt = st.Field(i).Type - imp.children[i].importChild(imp, c) + if err := imp.children[i].importChild(imp, c); err != nil { + return err + } } case arrow.RUN_END_ENCODED: // import run-ends and values st := imp.dt.(*arrow.RunEndEncodedType) @@ -428,13 +430,17 @@ func (imp *cimporter) doImportChildren() error { dt := imp.dt.(*arrow.DenseUnionType) for i, c := range children { imp.children[i].dt = dt.Fields()[i].Type - imp.children[i].importChild(imp, c) + if err := imp.children[i].importChild(imp, c); err != nil { + return err + } } case arrow.SPARSE_UNION: dt := imp.dt.(*arrow.SparseUnionType) for i, c := range children { imp.children[i].dt = dt.Fields()[i].Type - imp.children[i].importChild(imp, c) + if err := imp.children[i].importChild(imp, c); err != nil { + return err + } } } @@ -455,6 +461,10 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error { imp.alloc = &importAllocator{arr: imp.arr} } + if err := imp.doImport(); err != nil { + return err + } + // we tie the releasing of the array to when the buffers are // cleaned up, so if there are no buffers that we've imported // such as for a null array or a nested array with no bitmap @@ -468,26 +478,19 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error { } }() - return imp.doImport() + return nil } // import is called recursively as needed for importing an array and its children // in order to generate array.Data objects func (imp *cimporter) doImport() error { - // move the array from the src object passed in to the one referenced by - // this importer. That way we can set up a finalizer on the created - // arrow.ArrayData object so we clean up our Array's memory when garbage collected. - defer func(arr *CArrowArray) { - // this should only occur in the case of an error happening - // during import, at which point we need to clean up the - // ArrowArray struct we allocated. - if imp.data == nil { - C.free(unsafe.Pointer(arr)) - } - }(imp.arr) - // import any children if err := imp.doImportChildren(); err != nil { + for _, c := range imp.children { + if c.data != nil { + c.data.Release() + } + } return err } diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go index 170a5151..4c3d29f0 100644 --- a/arrow/cdata/cdata_test.go +++ b/arrow/cdata/cdata_test.go @@ -669,8 +669,8 @@ func createTestDenseUnion() arrow.Array { func createTestUnionArr(mode arrow.UnionMode) arrow.Array { fields := []arrow.Field{ - arrow.Field{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, - arrow.Field{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, + {Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, } typeCodes := []arrow.UnionTypeCode{5, 10} bld := array.NewBuilder(memory.DefaultAllocator, arrow.UnionOf(mode, fields, typeCodes)).(array.UnionBuilder) @@ -785,6 +785,124 @@ func TestRecordBatch(t *testing.T) { assert.True(t, array.RecordEqual(rb, rec)) } +func TestImportStructWithInvalidSchema(t *testing.T) { + mem := mallocator.NewMallocator() + defer mem.AssertSize(t, 0) + + arr := createTestStructArr() + defer arr.Release() + + carr := createCArr(arr, mem) + defer freeTestMallocatorArr(carr, mem) + + sc := testStruct([]string{"+s", "c", "l"}, []string{"", "a", "b"}, []int64{0, flagIsNullable, flagIsNullable}) + defer freeMallocedSchemas(sc) + + top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0] + _, err := ImportCRecordBatch(carr, top) + assert.Error(t, err) +} + +func TestImportDenseUnionWithInvalidSchema(t *testing.T) { + mem := mallocator.NewMallocator() + defer mem.AssertSize(t, 0) + + unionArr := createTestDenseUnion() + defer unionArr.Release() + + structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf( + arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false}, + )) + defer structBld.Release() + + unionBld := structBld.FieldBuilder(0).(*array.DenseUnionBuilder) + structBld.Append(true) + du := unionArr.(*array.DenseUnion) + for i := 0; i < du.Len(); i++ { + unionBld.Append(du.TypeCode(i)) + if du.TypeCode(i) == 5 { + unionBld.Child(0).(*array.Int32Builder).Append(du.Field(0).(*array.Int32).Value(int(du.ValueOffset(i)))) + } else { + unionBld.Child(1).(*array.Uint8Builder).Append(du.Field(1).(*array.Uint8).Value(int(du.ValueOffset(i)))) + } + } + + structArr := structBld.NewArray() + defer structArr.Release() + + carr := createCArr(structArr, mem) + defer freeTestMallocatorArr(carr, mem) + + unionSc := testUnion([]string{"+ud:5,10", "i", "u"}, []string{"", "u0", "u1"}, []int64{0, flagIsNullable, flagIsNullable}) + + structSc := testStruct([]string{"+s", "+ud:5,10"}, []string{"", "union_field"}, []int64{0, 0}) + defer freeMallocedSchemas(structSc) + + structTop := (*[1]*CArrowSchema)(unsafe.Pointer(structSc))[0] + unionTop := (*[1]*CArrowSchema)(unsafe.Pointer(unionSc))[0] + + children := unsafe.Slice(structTop.children, 1) + oldChild := children[0] + children[0] = unionTop + + _, err := ImportCRecordBatch(carr, structTop) + + children[0] = oldChild + + assert.Error(t, err) +} + +func TestImportSPARSEUnionWithInvalidSchema(t *testing.T) { + mem := mallocator.NewMallocator() + defer mem.AssertSize(t, 0) + + unionArr := createTestSparseUnion() + defer unionArr.Release() + + structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf( + arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false}, + )) + defer structBld.Release() + + unionBld := structBld.FieldBuilder(0).(*array.SparseUnionBuilder) + structBld.Append(true) + su := unionArr.(*array.SparseUnion) + for i := 0; i < su.Len(); i++ { + unionBld.Append(su.TypeCode(i)) + if su.TypeCode(i) == 5 { + unionBld.Child(0).(*array.Int32Builder).Append(su.Field(0).(*array.Int32).Value(i)) + unionBld.Child(1).(*array.Uint8Builder).AppendNull() + } else { + unionBld.Child(0).(*array.Int32Builder).AppendNull() + unionBld.Child(1).(*array.Uint8Builder).Append(su.Field(1).(*array.Uint8).Value(i)) + } + } + + structArr := structBld.NewArray() + defer structArr.Release() + + carr := createCArr(structArr, mem) + defer freeTestMallocatorArr(carr, mem) + + unionSc := testUnion([]string{"+us:5,10", "i", "u"}, []string{"", "u0", "u1"}, []int64{0, flagIsNullable, flagIsNullable}) + + structSc := testStruct([]string{"+s", "+us:5,10"}, []string{"", "union_field"}, []int64{0, 0}) + defer freeMallocedSchemas(structSc) + + structTop := (*[1]*CArrowSchema)(unsafe.Pointer(structSc))[0] + unionTop := (*[1]*CArrowSchema)(unsafe.Pointer(unionSc))[0] + + children := unsafe.Slice(structTop.children, 1) + oldChild := children[0] + children[0] = unionTop + + _, err := ImportCRecordBatch(carr, structTop) + + children[0] = oldChild + + assert.Error(t, err) +} + func TestRecordReaderStream(t *testing.T) { stream := arrayStreamTest() defer releaseStreamTest(stream) @@ -1006,17 +1124,21 @@ func (r *failingReader) Schema() *arrow.Schema { } return arrdata.Records["primitives"][0].Schema() } + func (r *failingReader) Next() bool { r.opCount -= 1 return r.opCount > 0 } + func (r *failingReader) RecordBatch() arrow.RecordBatch { arrdata.Records["primitives"][0].Retain() return arrdata.Records["primitives"][0] } + func (r *failingReader) Record() arrow.Record { return r.RecordBatch() } + func (r *failingReader) Err() error { if r.opCount == 0 { return fmt.Errorf("Expected error message")