diff --git a/CHANGES.md b/CHANGES.md index 1b72f92d278c..d342e9c5f572 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -74,6 +74,7 @@ * (Python) Added exception chaining to preserve error context in CloudSQLEnrichmentHandler, processes utilities, and core transforms ([#37422](https://github.com/apache/beam/issues/37422)). * (Python) Added a pipeline option `--experiments=pip_no_build_isolation` to disable build isolation when installing dependencies in the runtime environment ([#37331](https://github.com/apache/beam/issues/37331)). +* (Go) Added OrderedListState support to the Go SDK stateful DoFn API ([#37629](https://github.com/apache/beam/pull/37629)). * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/sdks/go/examples/ordered_list_state/ordered_list_state.go b/sdks/go/examples/ordered_list_state/ordered_list_state.go new file mode 100644 index 000000000000..5ff206859927 --- /dev/null +++ b/sdks/go/examples/ordered_list_state/ordered_list_state.go @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ordered_list_state is a toy pipeline demonstrating the use of OrderedListState. +// It creates keyed elements with timestamps, stores them in ordered list state, +// and reads back sub-ranges to emit summaries per key. +package main + +import ( + "context" + "flag" + "fmt" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" + "github.com/apache/beam/sdks/v2/go/pkg/beam/x/beamx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/x/debug" +) + +// eventLogFn accumulates timestamped events per key using OrderedListState +// and emits a summary of events seen so far. +type eventLogFn struct { + Events state.OrderedList[string] +} + +func (fn *eventLogFn) ProcessElement(p state.Provider, key string, ts int64, emit func(string)) error { + // Store an event using the input value as the sort key. + event := fmt.Sprintf("event@%d", ts) + fn.Events.Add(p, ts, event) + + // Read all events accumulated so far for this key. + entries, ok, err := fn.Events.Read(p) + if err != nil { + return err + } + if ok { + latest := entries[len(entries)-1] + emit(fmt.Sprintf("key=%s count=%d latest=%s (sort_key=%d)", key, len(entries), latest.Value, latest.SortKey)) + } + + return nil +} + +func init() { + register.DoFn4x1[state.Provider, string, int64, func(string), error](&eventLogFn{}) + register.Emitter1[string]() + register.Function1x2(toKeyed) +} + +// toKeyed maps an integer to a KV pair of (key, timestamp). +func toKeyed(i int) (string, int64) { + return fmt.Sprintf("user-%d", i%3), int64(i * 1000) +} + +func main() { + flag.Parse() + beam.Init() + + ctx := context.Background() + + p, s := beam.NewPipelineWithRoot() + + // Create a small set of input elements. + impulse := beam.CreateList(s, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + // Key and timestamp each element. + keyed := beam.ParDo(s, toKeyed, impulse) + + // Apply the stateful DoFn with OrderedListState. + summaries := beam.ParDo(s, &eventLogFn{ + Events: state.MakeOrderedListState[string]("events"), + }, keyed) + + debug.Print(s, summaries) + + if err := beamx.Run(ctx, p); err != nil { + log.Exitf(ctx, "Failed to execute job: %v", err) + } +} diff --git a/sdks/go/examples/snippets/04transforms.go b/sdks/go/examples/snippets/04transforms.go index a9c28369198d..210d1d1855c6 100644 --- a/sdks/go/examples/snippets/04transforms.go +++ b/sdks/go/examples/snippets/04transforms.go @@ -743,6 +743,38 @@ func combineState(s beam.Scope, input beam.PCollection) beam.PCollection { return combined } +// [START ordered_list_state] + +// orderedListStateFn tracks timestamped events per key and reads a sub-range. +type orderedListStateFn struct { + Events state.OrderedList[string] +} + +func (s *orderedListStateFn) ProcessElement(p state.Provider, key string, event string, emit func(string)) error { + // Add the event with the current timestamp as the sort key. + now := time.Now().UnixMilli() + s.Events.Add(p, now, event) + + // Read a sub-range of events (e.g. the last hour). + oneHourAgo := now - 3600000 + entries, ok, err := s.Events.ReadRange(p, oneHourAgo, now+1) + if err != nil { + return err + } + if ok { + for _, e := range entries { + emit(fmt.Sprintf("%s@%d", e.Value, e.SortKey)) + } + } + + // Clear events older than one hour. + s.Events.ClearRange(p, 0, oneHourAgo) + + return nil +} + +// [END ordered_list_state] + // [START event_time_timer] type eventTimerDoFn struct { diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go index 80f647abf5e6..64a2268dd1b0 100644 --- a/sdks/go/pkg/beam/core/graph/fn.go +++ b/sdks/go/pkg/beam/core/graph/fn.go @@ -1368,10 +1368,10 @@ func validateState(fn *DoFn, numIn mainInputs) error { "unique per DoFn", k, orig, s) } t := s.StateType() - if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeSet && t != state.TypeMap { + if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeSet && t != state.TypeMap && t != state.TypeOrderedList { err := errors.Errorf("Unrecognized state type %v for state %v", t, s) return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+ - "types are state.Value, state.Combining, state.Bag, state.Set, and state.Map", t, s) + "types are state.Value, state.Combining, state.Bag, state.Set, state.Map, and state.OrderedList", t, s) } stateKeys[k] = s } diff --git a/sdks/go/pkg/beam/core/runtime/exec/data.go b/sdks/go/pkg/beam/core/runtime/exec/data.go index 71954819a748..88d4668e6653 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/data.go +++ b/sdks/go/pkg/beam/core/runtime/exec/data.go @@ -89,6 +89,12 @@ type StateReader interface { OpenMultimapKeysUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.ReadCloser, error) // OpenMultimapKeysUserStateClearer opens a byte stream for clearing all keys of user multimap state. OpenMultimapKeysUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) + // OpenOrderedListUserStateReader opens a byte stream for reading user ordered list state in the range [start, end). + OpenOrderedListUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, start, end int64) (io.ReadCloser, error) + // OpenOrderedListUserStateAppender opens a byte stream for appending user ordered list state. + OpenOrderedListUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) + // OpenOrderedListUserStateClearer opens a byte stream for clearing user ordered list state in the range [start, end). + OpenOrderedListUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, start, end int64) (io.Writer, error) // GetSideInputCache returns the SideInputCache being used at the harness level. GetSideInputCache() SideCache } diff --git a/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go b/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go index ad329006ccdf..ab64a7b8bfa2 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go @@ -173,6 +173,21 @@ func (t *testStateReader) OpenMultimapKeysUserStateClearer(ctx context.Context, return nil, nil } +// OpenOrderedListUserStateReader for the testStateReader is a no-op. +func (t *testStateReader) OpenOrderedListUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, start, end int64) (io.ReadCloser, error) { + return nil, nil +} + +// OpenOrderedListUserStateAppender for the testStateReader is a no-op. +func (t *testStateReader) OpenOrderedListUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) { + return nil, nil +} + +// OpenOrderedListUserStateClearer for the testStateReader is a no-op. +func (t *testStateReader) OpenOrderedListUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, start, end int64) (io.Writer, error) { + return nil, nil +} + func (t *testStateReader) GetSideInputCache() SideCache { return &testSideCache{} } diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index b74ede228fd9..13b40ea0d1c6 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -563,6 +563,8 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { kcID = ms.KeyCoderId } else if ss := spec.GetSetSpec(); ss != nil { kcID = ss.ElementCoderId + } else if ols := spec.GetOrderedListSpec(); ols != nil { + cID = ols.ElementCoderId } else { return nil, errors.Errorf("Unrecognized state type %v", spec) } diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate.go b/sdks/go/pkg/beam/core/runtime/exec/userstate.go index ea723b18e3a7..75c92538f2df 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/userstate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/userstate.go @@ -20,12 +20,14 @@ import ( "context" "fmt" "io" + "math" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "google.golang.org/protobuf/encoding/protowire" ) type stateProvider struct { @@ -41,6 +43,7 @@ type stateProvider struct { blindBagWriteCountsByKey map[string]int // Tracks blind writes to bags before a read. initialMapValuesByKey map[string]map[string]any initialMapKeysByKey map[string][]any + initialOrderedListByKey map[string][]any readersByKey map[string]io.ReadCloser appendersByKey map[string]io.Writer clearersByKey map[string]io.Writer @@ -466,6 +469,152 @@ func (s *stateProvider) getMultiMapKeyReader(userStateID string) (io.ReadCloser, return s.readersByKey[userStateID], nil } +// ReadOrderedListState reads an ordered list state from the State API. +// It fetches the full range on first access and caches the result. +func (s *stateProvider) ReadOrderedListState(userStateID string) ([]any, []state.Transaction, error) { + initialValue, ok := s.initialOrderedListByKey[userStateID] + if !ok { + initialValue = []any{} + rw, err := s.getOrderedListReader(userStateID, math.MinInt64, math.MaxInt64) + if err != nil { + return nil, nil, err + } + for { + entry, err := decodeOrderedListEntry(rw, s.codersByKey[userStateID]) + if err == io.EOF { + break + } + if err != nil { + return nil, nil, err + } + initialValue = append(initialValue, entry) + } + s.initialOrderedListByKey[userStateID] = initialValue + } + + transactions, ok := s.transactionsByKey[userStateID] + if !ok { + transactions = []state.Transaction{} + } + + return initialValue, transactions, nil +} + +// WriteOrderedListState writes a single entry to the ordered list state. +// The wire format is: varint(sortKey) || coder_encoded(value). +func (s *stateProvider) WriteOrderedListState(val state.Transaction) error { + ap, err := s.getOrderedListAppender(val.Key) + if err != nil { + return err + } + + sortKey := val.MapKey.(int64) + if err := encodeOrderedListEntry(sortKey, val.Val, ap, s.codersByKey[val.Key]); err != nil { + return err + } + + if transactions, ok := s.transactionsByKey[val.Key]; ok { + s.transactionsByKey[val.Key] = append(transactions, val) + } else { + s.transactionsByKey[val.Key] = []state.Transaction{val} + } + + return nil +} + +// ClearOrderedListState clears entries in a range from the ordered list state. +func (s *stateProvider) ClearOrderedListState(val state.Transaction) error { + r := val.MapKey.([2]int64) + cl, err := s.getOrderedListClearer(val.Key, r[0], r[1]) + if err != nil { + return err + } + _, err = cl.Write([]byte{}) + if err != nil { + return err + } + + if transactions, ok := s.transactionsByKey[val.Key]; ok { + s.transactionsByKey[val.Key] = append(transactions, val) + } else { + s.transactionsByKey[val.Key] = []state.Transaction{val} + } + + return nil +} + +func (s *stateProvider) getOrderedListReader(userStateID string, start, end int64) (io.ReadCloser, error) { + r, err := s.sr.OpenOrderedListUserStateReader(s.ctx, s.SID, userStateID, s.elementKey, s.window, start, end) + if err != nil { + return nil, err + } + return r, nil +} + +func (s *stateProvider) getOrderedListAppender(userStateID string) (io.Writer, error) { + w, err := s.sr.OpenOrderedListUserStateAppender(s.ctx, s.SID, userStateID, s.elementKey, s.window) + if err != nil { + return nil, err + } + return w, nil +} + +func (s *stateProvider) getOrderedListClearer(userStateID string, start, end int64) (io.Writer, error) { + w, err := s.sr.OpenOrderedListUserStateClearer(s.ctx, s.SID, userStateID, s.elementKey, s.window, start, end) + if err != nil { + return nil, err + } + return w, nil +} + +// encodeOrderedListEntry writes varint(uint64(sortKey)) || coder_encoded(value) to w. +// The entire entry is buffered before writing so that each w.Write call +// delivers a complete entry (important when w is a stateKeyWriter that +// sends each Write as a separate gRPC Append request). +func encodeOrderedListEntry(sortKey int64, val any, w io.Writer, c *coder.Coder) error { + var buf bytes.Buffer + b := protowire.AppendVarint(nil, uint64(sortKey)) + buf.Write(b) + fv := FullValue{Elm: val} + enc := MakeElementEncoder(coder.SkipW(c)) + if err := enc.Encode(&fv, &buf); err != nil { + return err + } + _, err := w.Write(buf.Bytes()) + return err +} + +// decodeOrderedListEntry reads varint(sortKey) || coder_encoded(value) from r. +func decodeOrderedListEntry(r io.Reader, c *coder.Coder) (state.OrderedListEntry, error) { + // Read varint byte-by-byte. + var buf [10]byte // max varint size + var n int + for n = 0; n < len(buf); n++ { + _, err := r.Read(buf[n : n+1]) + if err != nil { + if n == 0 { + return state.OrderedListEntry{}, err + } + return state.OrderedListEntry{}, fmt.Errorf("unexpected error reading varint: %w", err) + } + if buf[n]&0x80 == 0 { + n++ + break + } + } + sortKey, consumed := protowire.ConsumeVarint(buf[:n]) + if consumed < 0 { + return state.OrderedListEntry{}, fmt.Errorf("invalid varint in ordered list entry") + } + + dec := MakeElementDecoder(coder.SkipW(c)) + fv, err := dec.Decode(r) + if err != nil { + return state.OrderedListEntry{}, err + } + return state.OrderedListEntry{SortKey: int64(sortKey), Value: fv.Elm}, nil +} + func (s *stateProvider) encodeKey(userStateID string, key any) ([]byte, error) { fv := FullValue{Elm: key} enc := MakeElementEncoder(coder.SkipW(s.keyCodersByID[userStateID])) @@ -533,6 +682,7 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea blindBagWriteCountsByKey: make(map[string]int), initialMapValuesByKey: make(map[string]map[string]any), initialMapKeysByKey: make(map[string][]any), + initialOrderedListByKey: make(map[string][]any), readersByKey: make(map[string]io.ReadCloser), appendersByKey: make(map[string]io.Writer), clearersByKey: make(map[string]io.Writer), diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go b/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go index e25e4019562c..d463b089a909 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go @@ -77,18 +77,19 @@ func TestReadValueState(t *testing.T) { func buildStateProvider() stateProvider { return stateProvider{ - ctx: context.Background(), - sr: &testStateReader{}, - elementKey: []byte{1}, - window: []byte{1}, - transactionsByKey: make(map[string][]state.Transaction), - initialValueByKey: make(map[string]any), - initialBagByKey: make(map[string][]any), - readersByKey: make(map[string]io.ReadCloser), - appendersByKey: make(map[string]io.Writer), - clearersByKey: make(map[string]io.Writer), - combineFnsByKey: make(map[string]*graph.CombineFn), // Each test can specify coders as needed - codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed + ctx: context.Background(), + sr: &testStateReader{}, + elementKey: []byte{1}, + window: []byte{1}, + transactionsByKey: make(map[string][]state.Transaction), + initialValueByKey: make(map[string]any), + initialBagByKey: make(map[string][]any), + initialOrderedListByKey: make(map[string][]any), + readersByKey: make(map[string]io.ReadCloser), + appendersByKey: make(map[string]io.Writer), + clearersByKey: make(map[string]io.Writer), + combineFnsByKey: make(map[string]*graph.CombineFn), // Each test can specify coders as needed + codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed } } diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go index 3bbb6c70dcf5..3994397e7ba5 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go @@ -95,8 +95,9 @@ const ( URNEnvDocker = "beam:env:docker:v1" // Userstate URNs. - URNBagUserState = "beam:user_state:bag:v1" - URNMultiMapUserState = "beam:user_state:multimap:v1" + URNBagUserState = "beam:user_state:bag:v1" + URNMultiMapUserState = "beam:user_state:multimap:v1" + URNOrderedListUserState = "beam:user_state:ordered_list:v1" // Base version URNs are to allow runners to make distinctions between different releases // in a way that won't change based on actual releases, in particular for FnAPI behaviors. @@ -601,6 +602,17 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) { Urn: URNMultiMapUserState, }, } + case state.TypeOrderedList: + stateSpecs[ps.StateKey()] = &pipepb.StateSpec{ + Spec: &pipepb.StateSpec_OrderedListSpec{ + OrderedListSpec: &pipepb.OrderedListStateSpec{ + ElementCoderId: coderID, + }, + }, + Protocol: &pipepb.FunctionSpec{ + Urn: URNOrderedListUserState, + }, + } default: return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps) } diff --git a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go index 4c1bc0b55fe3..269ded372998 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go @@ -147,6 +147,30 @@ func (s *ScopedStateReader) OpenMultimapKeysUserStateClearer(ctx context.Context return wr, err } +// OpenOrderedListUserStateReader opens a byte stream for reading user ordered list state in the range [start, end). +func (s *ScopedStateReader) OpenOrderedListUserStateReader(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, start, end int64) (io.ReadCloser, error) { + rw, err := s.openReader(ctx, id, func(ch *StateChannel) *stateKeyReader { + return newOrderedListUserStateReader(ch, id, s.instID, userStateID, key, w, start, end) + }) + return rw, err +} + +// OpenOrderedListUserStateAppender opens a byte stream for appending user ordered list state. +func (s *ScopedStateReader) OpenOrderedListUserStateAppender(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) { + wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter { + return newOrderedListUserStateWriter(ch, id, s.instID, userStateID, key, w, writeTypeAppend) + }) + return wr, err +} + +// OpenOrderedListUserStateClearer opens a byte stream for clearing user ordered list state in the range [start, end). +func (s *ScopedStateReader) OpenOrderedListUserStateClearer(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, start, end int64) (io.Writer, error) { + wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter { + return newOrderedListUserStateClearer(ch, id, s.instID, userStateID, key, w, start, end) + }) + return wr, err +} + // GetSideInputCache returns a pointer to the SideInputCache being used by the SDK harness. func (s *ScopedStateReader) GetSideInputCache() exec.SideCache { return s.cache @@ -391,6 +415,64 @@ func newMultimapKeysUserStateWriter(ch *StateChannel, id exec.StreamID, instID i } } +func newOrderedListUserStateReader(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, start, end int64) *stateKeyReader { + key := &fnpb.StateKey{ + Type: &fnpb.StateKey_OrderedListUserState_{ + OrderedListUserState: &fnpb.StateKey_OrderedListUserState{ + TransformId: id.PtransformID, + UserStateId: userStateID, + Window: w, + Key: k, + Range: &fnpb.OrderedListRange{Start: start, End: end}, + }, + }, + } + return &stateKeyReader{ + instID: instID, + key: key, + ch: ch, + } +} + +func newOrderedListUserStateWriter(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, wt writeTypeEnum) *stateKeyWriter { + key := &fnpb.StateKey{ + Type: &fnpb.StateKey_OrderedListUserState_{ + OrderedListUserState: &fnpb.StateKey_OrderedListUserState{ + TransformId: id.PtransformID, + UserStateId: userStateID, + Window: w, + Key: k, + }, + }, + } + return &stateKeyWriter{ + instID: instID, + key: key, + ch: ch, + writeType: wt, + } +} + +func newOrderedListUserStateClearer(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, start, end int64) *stateKeyWriter { + key := &fnpb.StateKey{ + Type: &fnpb.StateKey_OrderedListUserState_{ + OrderedListUserState: &fnpb.StateKey_OrderedListUserState{ + TransformId: id.PtransformID, + UserStateId: userStateID, + Window: w, + Key: k, + Range: &fnpb.OrderedListRange{Start: start, End: end}, + }, + }, + } + return &stateKeyWriter{ + instID: instID, + key: key, + ch: ch, + writeType: writeTypeClear, + } +} + func (r *stateKeyReader) Read(buf []byte) (int, error) { if r.buf == nil { if r.eof { diff --git a/sdks/go/pkg/beam/core/state/state.go b/sdks/go/pkg/beam/core/state/state.go index bdec84f6f656..143840d9e006 100644 --- a/sdks/go/pkg/beam/core/state/state.go +++ b/sdks/go/pkg/beam/core/state/state.go @@ -17,8 +17,11 @@ package state import ( + "cmp" "fmt" + "math" "reflect" + "slices" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" ) @@ -46,6 +49,8 @@ const ( TypeMap TypeEnum = 3 // TypeSet represents a set state TypeSet TypeEnum = 4 + // TypeOrderedList represents an ordered list state + TypeOrderedList TypeEnum = 5 ) var ( @@ -84,6 +89,9 @@ type Provider interface { WriteMapState(val Transaction) error ClearMapStateKey(val Transaction) error ClearMapState(val Transaction) error + ReadOrderedListState(userStateID string) ([]any, []Transaction, error) + WriteOrderedListState(val Transaction) error + ClearOrderedListState(val Transaction) error } // PipelineState is an interface representing different kinds of PipelineState (currently just state.Value). @@ -684,3 +692,128 @@ func MakeSetState[K comparable](k string) Set[K] { Key: k, } } + +// OrderedListEntry is an untyped sort-key/value pair used at the Provider boundary. +type OrderedListEntry struct { + SortKey int64 + Value any +} + +// OrderedListValue is a typed sort-key/value pair returned to the user. +type OrderedListValue[T any] struct { + SortKey int64 + Value T +} + +// OrderedList is used to read and write global pipeline state representing an ordered list of elements. +// Elements are ordered by a sort key (int64, typically representing a timestamp in milliseconds). +// Key represents the key used to lookup this state. +type OrderedList[T any] struct { + Key string +} + +// Add appends a value with the given sort key to this ordered list state. +func (s *OrderedList[T]) Add(p Provider, sortKey int64, val T) error { + return p.WriteOrderedListState(Transaction{ + Key: s.Key, + Type: TransactionTypeAppend, + MapKey: sortKey, + Val: val, + }) +} + +// Read returns all elements in this ordered list state, sorted by sort key. +func (s *OrderedList[T]) Read(p Provider) ([]OrderedListValue[T], bool, error) { + return s.ReadRange(p, math.MinInt64, math.MaxInt64) +} + +// ReadRange returns elements in the half-open interval [start, end), sorted by sort key. +func (s *OrderedList[T]) ReadRange(p Provider, start, end int64) ([]OrderedListValue[T], bool, error) { + initialValue, bufferedTransactions, err := p.ReadOrderedListState(s.Key) + if err != nil { + return nil, false, err + } + + // Collect initial entries that fall in [start, end). + var entries []OrderedListEntry + for _, v := range initialValue { + e := v.(OrderedListEntry) + if e.SortKey >= start && e.SortKey < end { + entries = append(entries, e) + } + } + + // Replay transactions. + for _, t := range bufferedTransactions { + switch t.Type { + case TransactionTypeAppend: + sk := t.MapKey.(int64) + if sk >= start && sk < end { + entries = append(entries, OrderedListEntry{SortKey: sk, Value: t.Val}) + } + case TransactionTypeClear: + r := t.MapKey.([2]int64) + cStart, cEnd := r[0], r[1] + entries = slices.DeleteFunc(entries, func(e OrderedListEntry) bool { + return e.SortKey >= cStart && e.SortKey < cEnd + }) + } + } + + if len(entries) == 0 { + return nil, false, nil + } + + // Stable sort by sort key. + slices.SortStableFunc(entries, func(a, b OrderedListEntry) int { + return cmp.Compare(a.SortKey, b.SortKey) + }) + + result := make([]OrderedListValue[T], len(entries)) + for i, e := range entries { + result[i] = OrderedListValue[T]{SortKey: e.SortKey, Value: e.Value.(T)} + } + return result, true, nil +} + +// Clear removes all elements from this ordered list state. +func (s *OrderedList[T]) Clear(p Provider) error { + return s.ClearRange(p, math.MinInt64, math.MaxInt64) +} + +// ClearRange removes elements in the half-open interval [start, end). +func (s *OrderedList[T]) ClearRange(p Provider, start, end int64) error { + return p.ClearOrderedListState(Transaction{ + Key: s.Key, + Type: TransactionTypeClear, + MapKey: [2]int64{start, end}, + }) +} + +// StateKey returns the key for this pipeline state entry. +func (s OrderedList[T]) StateKey() string { + return s.Key +} + +// KeyCoderType returns nil since OrderedList types aren't keyed. +func (s OrderedList[T]) KeyCoderType() reflect.Type { + return nil +} + +// CoderType returns the element type which should be used for a coder. +func (s OrderedList[T]) CoderType() reflect.Type { + var t T + return reflect.TypeOf(t) +} + +// StateType returns the type of the state (in this case always OrderedList). +func (s OrderedList[T]) StateType() TypeEnum { + return TypeOrderedList +} + +// MakeOrderedListState is a factory function to create an instance of OrderedListState with the given key. +func MakeOrderedListState[T any](k string) OrderedList[T] { + return OrderedList[T]{ + Key: k, + } +} diff --git a/sdks/go/pkg/beam/core/state/state_test.go b/sdks/go/pkg/beam/core/state/state_test.go index 61057c05b639..af48a9ab1339 100644 --- a/sdks/go/pkg/beam/core/state/state_test.go +++ b/sdks/go/pkg/beam/core/state/state_test.go @@ -17,6 +17,7 @@ package state import ( "errors" + "math" "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" @@ -27,15 +28,16 @@ var ( ) type fakeProvider struct { - initialState map[string]any - initialBagState map[string][]any - initialMapState map[string]map[string]any - transactions map[string][]Transaction - err map[string]error - createAccumForKey map[string]bool - addInputForKey map[string]bool - mergeAccumForKey map[string]bool - extractOutForKey map[string]bool + initialState map[string]any + initialBagState map[string][]any + initialMapState map[string]map[string]any + initialOrderedListState map[string][]any + transactions map[string][]Transaction + err map[string]error + createAccumForKey map[string]bool + addInputForKey map[string]bool + mergeAccumForKey map[string]bool + extractOutForKey map[string]bool } func (s *fakeProvider) ReadValueState(userStateID string) (any, []Transaction, error) { @@ -177,6 +179,36 @@ func (s *fakeProvider) ClearMapState(val Transaction) error { return nil } +func (s *fakeProvider) ReadOrderedListState(userStateID string) ([]any, []Transaction, error) { + if err, ok := s.err[userStateID]; ok { + return nil, nil, err + } + base := s.initialOrderedListState[userStateID] + trans, ok := s.transactions[userStateID] + if !ok { + trans = []Transaction{} + } + return base, trans, nil +} + +func (s *fakeProvider) WriteOrderedListState(val Transaction) error { + if transactions, ok := s.transactions[val.Key]; ok { + s.transactions[val.Key] = append(transactions, val) + } else { + s.transactions[val.Key] = []Transaction{val} + } + return nil +} + +func (s *fakeProvider) ClearOrderedListState(val Transaction) error { + if transactions, ok := s.transactions[val.Key]; ok { + s.transactions[val.Key] = append(transactions, val) + } else { + s.transactions[val.Key] = []Transaction{val} + } + return nil +} + func TestValueRead(t *testing.T) { is := make(map[string]any) ts := make(map[string][]Transaction) @@ -1200,3 +1232,196 @@ func TestSetClear(t *testing.T) { } } } + +func TestOrderedListRead(t *testing.T) { + il := make(map[string][]any) + ts := make(map[string][]Transaction) + es := make(map[string]error) + il["no_transactions"] = []any{ + OrderedListEntry{SortKey: 100, Value: 1}, + OrderedListEntry{SortKey: 200, Value: 2}, + } + ts["no_transactions"] = nil + il["basic_append"] = []any{} + ts["basic_append"] = []Transaction{ + {Key: "basic_append", Type: TransactionTypeAppend, MapKey: int64(50), Val: 5}, + } + il["multi_append"] = []any{ + OrderedListEntry{SortKey: 100, Value: 1}, + } + ts["multi_append"] = []Transaction{ + {Key: "multi_append", Type: TransactionTypeAppend, MapKey: int64(50), Val: 5}, + {Key: "multi_append", Type: TransactionTypeAppend, MapKey: int64(150), Val: 15}, + } + il["basic_clear"] = []any{ + OrderedListEntry{SortKey: 100, Value: 1}, + OrderedListEntry{SortKey: 200, Value: 2}, + } + ts["basic_clear"] = []Transaction{ + {Key: "basic_clear", Type: TransactionTypeClear, MapKey: [2]int64{math.MinInt64, math.MaxInt64}}, + } + il["clear_range"] = []any{ + OrderedListEntry{SortKey: 100, Value: 1}, + OrderedListEntry{SortKey: 200, Value: 2}, + OrderedListEntry{SortKey: 300, Value: 3}, + } + ts["clear_range"] = []Transaction{ + {Key: "clear_range", Type: TransactionTypeClear, MapKey: [2]int64{150, 250}}, + } + il["err"] = []any{} + es["err"] = errFake + + f := fakeProvider{ + initialOrderedListState: il, + transactions: ts, + err: es, + } + + tests := []struct { + name string + vs OrderedList[int] + start int64 + end int64 + val []OrderedListValue[int] + ok bool + err error + }{ + {"no_transactions", MakeOrderedListState[int]("no_transactions"), math.MinInt64, math.MaxInt64, []OrderedListValue[int]{{100, 1}, {200, 2}}, true, nil}, + {"basic_append", MakeOrderedListState[int]("basic_append"), math.MinInt64, math.MaxInt64, []OrderedListValue[int]{{50, 5}}, true, nil}, + {"multi_append_sorted", MakeOrderedListState[int]("multi_append"), math.MinInt64, math.MaxInt64, []OrderedListValue[int]{{50, 5}, {100, 1}, {150, 15}}, true, nil}, + {"basic_clear", MakeOrderedListState[int]("basic_clear"), math.MinInt64, math.MaxInt64, nil, false, nil}, + {"clear_range", MakeOrderedListState[int]("clear_range"), math.MinInt64, math.MaxInt64, []OrderedListValue[int]{{100, 1}, {300, 3}}, true, nil}, + {"read_range", MakeOrderedListState[int]("no_transactions"), 150, 250, []OrderedListValue[int]{{200, 2}}, true, nil}, + {"err", MakeOrderedListState[int]("err"), math.MinInt64, math.MaxInt64, nil, false, errFake}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, ok, err := tt.vs.ReadRange(&f, tt.start, tt.end) + if err != nil && tt.err == nil { + t.Errorf("OrderedList.ReadRange() returned error %v when it shouldn't have", err) + } else if err == nil && tt.err != nil { + t.Errorf("OrderedList.ReadRange() returned no error when it should have returned %v", tt.err) + } else if ok != tt.ok { + t.Errorf("OrderedList.ReadRange() ok=%v, want %v", ok, tt.ok) + } else if len(val) != len(tt.val) { + t.Errorf("OrderedList.ReadRange()=%v, want %v", val, tt.val) + } else { + for i, v := range val { + if v != tt.val[i] { + t.Errorf("OrderedList.ReadRange()[%d]=%v, want %v", i, v, tt.val[i]) + } + } + } + }) + } +} + +func TestOrderedListAdd(t *testing.T) { + tests := []struct { + name string + adds []OrderedListValue[int] + val []OrderedListValue[int] + ok bool + }{ + {"empty", nil, nil, false}, + {"single", []OrderedListValue[int]{{100, 1}}, []OrderedListValue[int]{{100, 1}}, true}, + {"sorted", []OrderedListValue[int]{{200, 2}, {100, 1}}, []OrderedListValue[int]{{100, 1}, {200, 2}}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := fakeProvider{ + initialOrderedListState: make(map[string][]any), + transactions: make(map[string][]Transaction), + err: make(map[string]error), + } + vs := MakeOrderedListState[int]("vs") + for _, a := range tt.adds { + if err := vs.Add(&f, a.SortKey, a.Value); err != nil { + t.Fatalf("OrderedList.Add() returned error %v", err) + } + } + val, ok, err := vs.Read(&f) + if err != nil { + t.Fatalf("OrderedList.Read() returned error %v", err) + } + if ok != tt.ok { + t.Errorf("OrderedList.Read() ok=%v, want %v", ok, tt.ok) + } + if len(val) != len(tt.val) { + t.Fatalf("OrderedList.Read()=%v, want %v", val, tt.val) + } + for i, v := range val { + if v != tt.val[i] { + t.Errorf("OrderedList.Read()[%d]=%v, want %v", i, v, tt.val[i]) + } + } + }) + } +} + +func TestOrderedListClear(t *testing.T) { + tests := []struct { + name string + adds []OrderedListValue[int] + }{ + {"empty", nil}, + {"with_data", []OrderedListValue[int]{{100, 1}, {200, 2}}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := fakeProvider{ + initialOrderedListState: make(map[string][]any), + transactions: make(map[string][]Transaction), + err: make(map[string]error), + } + vs := MakeOrderedListState[int]("vs") + for _, a := range tt.adds { + vs.Add(&f, a.SortKey, a.Value) + } + if err := vs.Clear(&f); err != nil { + t.Fatalf("OrderedList.Clear() returned error %v", err) + } + _, ok, err := vs.Read(&f) + if err != nil { + t.Fatalf("OrderedList.Read() returned error %v", err) + } + if ok { + t.Error("OrderedList.Read() returned ok=true after Clear()") + } + }) + } +} + +func TestOrderedListClearRange(t *testing.T) { + f := fakeProvider{ + initialOrderedListState: make(map[string][]any), + transactions: make(map[string][]Transaction), + err: make(map[string]error), + } + vs := MakeOrderedListState[int]("vs") + vs.Add(&f, 100, 1) + vs.Add(&f, 200, 2) + vs.Add(&f, 300, 3) + if err := vs.ClearRange(&f, 150, 250); err != nil { + t.Fatalf("OrderedList.ClearRange() returned error %v", err) + } + val, ok, err := vs.Read(&f) + if err != nil { + t.Fatalf("OrderedList.Read() returned error %v", err) + } + if !ok { + t.Fatal("OrderedList.Read() returned ok=false, want true") + } + want := []OrderedListValue[int]{{100, 1}, {300, 3}} + if len(val) != len(want) { + t.Fatalf("OrderedList.Read()=%v, want %v", val, want) + } + for i, v := range val { + if v != want[i] { + t.Errorf("OrderedList.Read()[%d]=%v, want %v", i, v, want[i]) + } + } +} diff --git a/sdks/go/test/integration/primitives/state.go b/sdks/go/test/integration/primitives/state.go index 6b672acc27bd..911ebae9d460 100644 --- a/sdks/go/test/integration/primitives/state.go +++ b/sdks/go/test/integration/primitives/state.go @@ -40,6 +40,7 @@ func init() { register.DoFn3x1[state.Provider, string, int, string](&mapStateClearFn{}) register.DoFn3x1[state.Provider, string, int, string](&setStateFn{}) register.DoFn3x1[state.Provider, string, int, string](&setStateClearFn{}) + register.DoFn3x1[state.Provider, string, int, string](&orderedListStateFn{}) register.Function2x0(pairWithOne) register.Emitter2[string, int]() register.Combiner1[int](&combine1{}) @@ -560,3 +561,40 @@ func SetStateParDoClear(s beam.Scope) { counts := beam.ParDo(s, &setStateClearFn{State1: state.MakeSetState[string]("key1")}, keyed) passert.Equals(s, counts, "apple: [apple]", "pear: [pear]", "peach: [peach]", "apple: [apple1 apple2 apple3]", "apple: []", "pear: [pear1 pear2 pear3]") } + +type orderedListStateFn struct { + State1 state.OrderedList[int] +} + +func (f *orderedListStateFn) ProcessElement(s state.Provider, w string, c int) string { + // Read current list. + cur, ok, err := f.State1.Read(s) + if err != nil { + panic(err) + } + if !ok { + cur = []state.OrderedListValue[int]{} + } + + // Add element with sort key = count * 100. + sortKey := int64(len(cur)+1) * 100 + err = f.State1.Add(s, sortKey, c) + if err != nil { + panic(err) + } + + // Build output summarizing what we read. + vals := make([]int, len(cur)) + for i, tv := range cur { + vals[i] = tv.Value + } + return fmt.Sprintf("%s: %v", w, vals) +} + +// OrderedListStateParDo tests a DoFn that uses ordered list state. +func OrderedListStateParDo(s beam.Scope) { + in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") + keyed := beam.ParDo(s, pairWithOne, in) + counts := beam.ParDo(s, &orderedListStateFn{State1: state.MakeOrderedListState[int]("key1")}, keyed) + passert.Equals(s, counts, "apple: []", "pear: []", "peach: []", "apple: [1]", "apple: [1 1]", "pear: [1]") +} diff --git a/sdks/go/test/integration/primitives/state_test.go b/sdks/go/test/integration/primitives/state_test.go index 1d1d4860e8f9..e0076fb53963 100644 --- a/sdks/go/test/integration/primitives/state_test.go +++ b/sdks/go/test/integration/primitives/state_test.go @@ -76,3 +76,8 @@ func TestSetStateClear(t *testing.T) { integration.CheckFilters(t) ptest.BuildAndRun(t, SetStateParDoClear) } + +func TestOrderedListState(t *testing.T) { + integration.CheckFilters(t) + ptest.BuildAndRun(t, OrderedListStateParDo) +} diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index 13900f3a7ceb..343fb128b3ef 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -6717,6 +6717,10 @@ _ = (p | 'Read per user' >> ReadPerUser() | 'Set state pardo' >> beam.ParDo(OrderedListStateDoFn())) {{< /highlight >}} +{{< highlight go >}} +{{< code_sample "sdks/go/examples/snippets/04transforms.go" ordered_list_state >}} +{{< /highlight >}} + #### MultimapState {#multimap-state} `MultimapState` allow one key mapped to different values but the key value could be unordered.