Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions go/cmd/api-devserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@ import (
"errors"
"flag"
"fmt"
"log/slog"
"os"
"os/exec"
"os/signal"
"path/filepath"
"syscall"

"cloud.google.com/go/datastore"
"cloud.google.com/go/storage"
"github.com/google/osv.dev/go/internal/api"
db "github.com/google/osv.dev/go/internal/database/datastore"
"github.com/google/osv.dev/go/logger"
"github.com/google/osv.dev/go/osv/clients"
)

const (
Expand Down Expand Up @@ -65,11 +70,7 @@ func run() error {

if !*noBackend {
logger.InfoContext(ctx, "Starting Go API backend natively", "port", *backendPort)
go func() {
if err := api.RunServer(ctx, *backendPort); err != nil {
logger.ErrorContext(ctx, "Go API server exited", "error", err)
}
}()
go runBackend(ctx, *backendPort)
}

logger.InfoContext(ctx, "Starting ESPv2 container", "port", *espPort, "backendPort", *backendPort)
Expand Down Expand Up @@ -149,3 +150,41 @@ func runCmdAsync(cmd *exec.Cmd) <-chan error {

return out
}

func runBackend(ctx context.Context, port int) {
project := os.Getenv("GOOGLE_CLOUD_PROJECT")
if project == "" {
logger.ErrorContext(ctx, "GOOGLE_CLOUD_PROJECT environment variable is not set")
return
}
datastoreID := os.Getenv("DATASTORE_DATABASE_ID") // empty string is the (default) database
dbClient, err := datastore.NewClientWithDatabase(ctx, project, datastoreID)
if err != nil {
logger.ErrorContext(ctx, "failed to create datastore client", "error", err)
return
}
defer dbClient.Close()
gcsClient, err := storage.NewClient(ctx)
if err != nil {
logger.ErrorContext(ctx, "Failed to create storage client", slog.Any("error", err))
return
}
defer gcsClient.Close()
vulnBucket := os.Getenv("OSV_VULNERABILITIES_BUCKET")
if vulnBucket == "" {
logger.ErrorContext(ctx, "OSV_VULNERABILITIES_BUCKET environment variable is not set")
return
}
vulnStore := db.NewVulnerabilityStore(db.VulnStoreConfig{
Client: dbClient,
GCS: clients.NewGCSClient(gcsClient, vulnBucket),
})
relationsStore := db.NewRelationsStore(dbClient)
if err := api.RunServer(ctx, api.ServerOptions{
Port: port,
VulnStore: vulnStore,
RelationsStore: relationsStore,
}); err != nil {
logger.ErrorContext(ctx, "Go API server exited", "error", err)
}
}
41 changes: 40 additions & 1 deletion go/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@ package main

import (
"context"
"errors"
"flag"
"log/slog"
"os"
"os/signal"
"syscall"

"cloud.google.com/go/datastore"
"cloud.google.com/go/storage"
"github.com/google/osv.dev/go/internal/api"
db "github.com/google/osv.dev/go/internal/database/datastore"
"github.com/google/osv.dev/go/logger"
"github.com/google/osv.dev/go/osv/clients"
)

func main() {
Expand All @@ -28,5 +34,38 @@ func run() error {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()

return api.RunServer(ctx, *port)
project := os.Getenv("GOOGLE_CLOUD_PROJECT")
if project == "" {
logger.ErrorContext(ctx, "GOOGLE_CLOUD_PROJECT environment variable is not set")
return errors.New("GOOGLE_CLOUD_PROJECT environment variable is not set")
}
datastoreID := os.Getenv("DATASTORE_DATABASE_ID") // empty string is the (default) database
dbClient, err := datastore.NewClientWithDatabase(ctx, project, datastoreID)
if err != nil {
logger.ErrorContext(ctx, "Failed to create datastore client", slog.Any("error", err))
return err
}
defer dbClient.Close()
gcsClient, err := storage.NewClient(ctx)
if err != nil {
logger.ErrorContext(ctx, "Failed to create storage client", slog.Any("error", err))
return err
}
defer gcsClient.Close()
vulnBucket := os.Getenv("OSV_VULNERABILITIES_BUCKET")
if vulnBucket == "" {
logger.ErrorContext(ctx, "OSV_VULNERABILITIES_BUCKET environment variable is not set")
return errors.New("OSV_VULNERABILITIES_BUCKET environment variable is not set")
}
vulnStore := db.NewVulnerabilityStore(db.VulnStoreConfig{
Client: dbClient,
GCS: clients.NewGCSClient(gcsClient, vulnBucket),
})
relationsStore := db.NewRelationsStore(dbClient)

return api.RunServer(ctx, api.ServerOptions{
Port: *port,
VulnStore: vulnStore,
RelationsStore: relationsStore,
})
}
59 changes: 59 additions & 0 deletions go/internal/api/get_vuln_by_id.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package api

import (
"context"
"errors"
"log/slog"
"strings"

"github.com/google/osv.dev/go/internal/models"
"github.com/google/osv.dev/go/logger"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "osv.dev/bindings/go/api"
)

//nolint:revive // complains about 'Id' instead of 'ID', but that matches the API (the proto).
func (s *server) GetVulnById(ctx context.Context, params *pb.GetVulnByIdParameters) (*osvschema.Vulnerability, error) {
id := params.GetId()
if len(id) == 0 {
return nil, status.Error(codes.InvalidArgument, "ID is required")
}
// Datastore has a limit of how large indexed properties can be (1500 bytes).
// Vulnerability IDs are not going to be over 100 characters.
if len(id) > 100 {
return nil, status.Error(codes.InvalidArgument, "ID is too long")
}
vulnerability, err := s.vulnStore.Get(ctx, id)
if err == nil {
return vulnerability, nil
}
if !errors.Is(err, models.ErrNotFound) {
logger.ErrorContext(ctx, "failed to get vulnerability from store",
slog.String("id", id),
slog.Any("error", err),
)

return nil, status.Errorf(codes.Internal, "error getting vulnerability: %v", err)
}

// Check for aliases
aliases, err := s.relationsStore.GetAliases(ctx, id)
if err != nil {
if errors.Is(err, models.ErrNotFound) {
return nil, status.Error(codes.NotFound, "Vulnerability not found")
}

logger.ErrorContext(ctx, "failed to check aliases for vulnerability",
slog.String("id", id),
slog.Any("error", err),
)

return nil, status.Errorf(codes.Internal, "error getting vulnerability: %v", err)
}

aliasStrs := strings.Join(aliases.Aliases, " ")

return nil, status.Errorf(codes.NotFound, "Vulnerability not found, but the following aliases were: %s", aliasStrs)
}
183 changes: 183 additions & 0 deletions go/internal/api/get_vuln_by_id_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package api

import (
"context"
"errors"
"iter"
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/osv.dev/go/internal/models"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/testing/protocmp"

pb "osv.dev/bindings/go/api"
)

type mockVulnerabilityStore struct {
vuln *osvschema.Vulnerability
err error
}

func (m *mockVulnerabilityStore) Get(_ context.Context, _ string) (*osvschema.Vulnerability, error) {
if m.err != nil {
return nil, m.err
}
if m.vuln == nil {
return nil, models.ErrNotFound
}

return m.vuln, nil
}

func (m *mockVulnerabilityStore) ListBySource(_ context.Context, _ string, _ bool) iter.Seq2[*models.VulnSourceRef, error] {
panic("unimplemented")
}

func (m *mockVulnerabilityStore) GetSourceModified(_ context.Context, _ string) (time.Time, error) {
panic("unimplemented")
}

func (m *mockVulnerabilityStore) GetWithMetadata(_ context.Context, _ string) (*osvschema.Vulnerability, *models.VulnSourceRef, error) {
panic("unimplemented")
}

func (m *mockVulnerabilityStore) Write(_ context.Context, _ models.WriteRequest) error {
panic("unimplemented")
}

type mockRelationsStore struct {
aliases *models.GetAliasResult
err error
}

func (m *mockRelationsStore) GetAliases(_ context.Context, _ string) (*models.GetAliasResult, error) {
if m.err != nil {
return nil, m.err
}
if m.aliases == nil {
return nil, models.ErrNotFound
}

return m.aliases, nil
}

func (m *mockRelationsStore) GetRelated(_ context.Context, _ string) (*models.GetRelatedResult, error) {
panic("unimplemented")
}

func (m *mockRelationsStore) GetUpstream(_ context.Context, _ string) (*models.GetUpstreamResult, error) {
panic("unimplemented")
}

func TestGetVulnById(t *testing.T) {
ctx := context.Background()

testVuln := &osvschema.Vulnerability{
Id: "TEST-1",
}

tests := []struct {
name string
id string
mockVuln *osvschema.Vulnerability
mockVulnErr error
mockAliases *models.GetAliasResult
mockAliasesErr error
want *osvschema.Vulnerability
wantErrCode codes.Code
wantErrMsg string
}{
{
name: "Success",
id: "TEST-1",
mockVuln: testVuln,
want: testVuln,
},
{
name: "Empty ID",
id: "",
wantErrCode: codes.InvalidArgument,
wantErrMsg: "ID is required",
},
{
name: "Too Long ID",
id: string(make([]byte, 101)),
wantErrCode: codes.InvalidArgument,
wantErrMsg: "ID is too long",
},
{
name: "Not Found - No Aliases",
id: "TEST-1",
wantErrCode: codes.NotFound,
wantErrMsg: "Vulnerability not found",
},
{
name: "Not Found - With Aliases",
id: "TEST-1",
mockAliases: &models.GetAliasResult{
Aliases: []string{"ALIAS-1", "ALIAS-2"},
},
wantErrCode: codes.NotFound,
wantErrMsg: "Vulnerability not found, but the following aliases were: ALIAS-1 ALIAS-2",
},
{
name: "VulnStore Error",
id: "TEST-1",
mockVulnErr: errors.New("internal GCS error"),
wantErrCode: codes.Internal,
wantErrMsg: "error getting vulnerability",
},
{
name: "RelationsStore Error",
id: "TEST-1",
mockAliasesErr: errors.New("internal Datastore error"),
wantErrCode: codes.Internal,
wantErrMsg: "error getting vulnerability",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &server{
vulnStore: &mockVulnerabilityStore{
vuln: tt.mockVuln,
err: tt.mockVulnErr,
},
relationsStore: &mockRelationsStore{
aliases: tt.mockAliases,
err: tt.mockAliasesErr,
},
}

got, err := s.GetVulnById(ctx, &pb.GetVulnByIdParameters{Id: tt.id})

if tt.wantErrCode != codes.OK {
if err == nil {
t.Fatalf("GetVulnById() expected error, got nil")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("GetVulnById() expected gRPC status error, got %v", err)
}
if st.Code() != tt.wantErrCode {
t.Errorf("GetVulnById() error code = %v, want %v", st.Code(), tt.wantErrCode)
}
if tt.wantErrMsg != "" && !strings.Contains(st.Message(), tt.wantErrMsg) {
t.Errorf("GetVulnById() error message = %q, want to contain %q", st.Message(), tt.wantErrMsg)
}
} else {
if err != nil {
t.Fatalf("GetVulnById() unexpected error: %v", err)
}
if diff := cmp.Diff(tt.want, got, protocmp.Transform()); diff != "" {
t.Errorf("GetVulnById() mismatch (-want +got):\n%s", diff)
}
}
})
}
}
Loading
Loading