From f4b357c5d684600313ba1b9d78dc32fa271dd364 Mon Sep 17 00:00:00 2001 From: rabbitstack Date: Thu, 2 Jan 2025 17:54:47 +0100 Subject: [PATCH] feat(instrumentation,telemetry): Incorporate thread pool telemetry The thread pool telemetry is captured from the dedicated thread pool provider. Presently, we collect three different event types including the call stacks. More specifically, SubmitThreadpoolWork, SubmitThreadpoolCallback, and SetThreadpoolTimer events. Additionally, the callback addresses are symbolized to derive to function name, and when the callback is the ZwContinue or RtlCaputreContext function call, we also try to decode the CONTEXT structure and symbolize the instruction pointer address. --- configs/fibratus.yml | 3 + internal/etw/source.go | 5 ++ internal/etw/stackext.go | 9 +++ internal/etw/trace.go | 13 ++- pkg/config/config_windows.go | 1 + pkg/config/filters.go | 29 ++++--- pkg/config/kstream.go | 34 ++++---- pkg/config/schema_windows.go | 31 ++++---- pkg/filter/rules.go | 2 + pkg/kevent/kevent_windows.go | 10 +++ pkg/kevent/kparam_windows.go | 24 ++++++ pkg/kevent/kparams/fields_windows.go | 33 ++++++++ pkg/kevent/ktypes/category.go | 3 + pkg/kevent/ktypes/eventset.go | 2 + pkg/kevent/ktypes/ktypes_windows.go | 37 ++++++++- pkg/kevent/ktypes/metainfo_windows.go | 9 +++ pkg/ps/snapshotter.go | 2 +- pkg/ps/snapshotter_windows.go | 1 + pkg/symbolize/symbolizer.go | 66 ++++++++++++++-- pkg/symbolize/symbolizer_test.go | 105 +++++++++++++++++++++++++ pkg/sys/etw/types.go | 13 ++- pkg/util/threadcontext/context.go | 104 ++++++++++++++++++++++++ pkg/util/threadcontext/context_test.go | 73 +++++++++++++++++ 23 files changed, 546 insertions(+), 63 deletions(-) create mode 100644 pkg/util/threadcontext/context.go create mode 100644 pkg/util/threadcontext/context_test.go diff --git a/configs/fibratus.yml b/configs/fibratus.yml index 766d5fde1..cb8223c9f 100644 --- a/configs/fibratus.yml +++ b/configs/fibratus.yml @@ -229,6 +229,9 @@ kstream: # Determines whether DNS client events are collected #enable-dns: true + # Determines whether thread pool events are collected + #enable-threadpool: true + # Indicates if stack enrichment is enabled for eligible events #stack-enrichment: true diff --git a/internal/etw/source.go b/internal/etw/source.go index c0abb7949..0a9e513bd 100644 --- a/internal/etw/source.go +++ b/internal/etw/source.go @@ -147,6 +147,7 @@ func (e *EventSource) Open(config *config.Config) error { config.Kstream.EnableMemKevents = config.Kstream.EnableMemKevents && (e.r.HasMemEvents || (config.Yara.Enabled && !config.Yara.SkipAllocs)) config.Kstream.EnableDNSEvents = config.Kstream.EnableDNSEvents && e.r.HasDNSEvents config.Kstream.EnableAuditAPIEvents = config.Kstream.EnableAuditAPIEvents && e.r.HasAuditAPIEvents + config.Kstream.EnableThreadpoolEvents = config.Kstream.EnableThreadpoolEvents && e.r.HasThreadpoolEvents for _, ktype := range ktypes.All() { if ktype == ktypes.CreateProcess || ktype == ktypes.TerminateProcess || ktype == ktypes.LoadImage || ktype == ktypes.UnloadImage { @@ -189,6 +190,9 @@ func (e *EventSource) Open(config *config.Config) error { if config.Kstream.EnableAuditAPIEvents { e.addTrace(etw.KernelAuditAPICallsSession, etw.KernelAuditAPICallsGUID) } + if config.Kstream.EnableThreadpoolEvents { + e.addTrace(etw.ThreadpoolSession, etw.ThreadpoolGUID) + } for _, trace := range e.traces { err := trace.Start() @@ -226,6 +230,7 @@ func (e *EventSource) Open(config *config.Config) error { // Init consumer and open the trace for processing consumer := NewConsumer(e.psnap, e.hsnap, config, e.sequencer, e.evts) consumer.SetFilter(e.filter) + // Attach event listeners for _, lis := range e.listeners { consumer.q.RegisterListener(lis) diff --git a/internal/etw/stackext.go b/internal/etw/stackext.go index 396df8617..03c348c1a 100644 --- a/internal/etw/stackext.go +++ b/internal/etw/stackext.go @@ -100,3 +100,12 @@ func (s *StackExtensions) EnableMemoryCallstack() { s.AddStackTracing(ktypes.VirtualAlloc) } } + +// EnableThreadpoolCallstack enables stack tracing for thread pool events. +func (s *StackExtensions) EnableThreadpoolCallstack() { + if s.config.EnableThreadpoolEvents { + s.AddStackTracing(ktypes.SubmitThreadpoolWork) + s.AddStackTracing(ktypes.SubmitThreadpoolCallback) + s.AddStackTracing(ktypes.SetThreadpoolTimer) + } +} diff --git a/internal/etw/trace.go b/internal/etw/trace.go index 3f268e8d4..f60610807 100644 --- a/internal/etw/trace.go +++ b/internal/etw/trace.go @@ -150,6 +150,10 @@ func (t *Trace) enableCallstacks() { if t.IsSystemRegistryTrace() { t.stackExtensions.EnableRegistryCallstack() } + + if t.IsThreadpoolTrace() { + t.stackExtensions.EnableThreadpoolCallstack() + } } // Start registers and starts an event tracing session. @@ -202,7 +206,9 @@ func (t *Trace) Start() error { log.Warnf("unable to set empty system flags: %v", err) return nil } + sysTraceFlags[0] = flags + // enable object manager tracking if cfg.EnableHandleKevents { sysTraceFlags[4] = etw.Handle @@ -225,13 +231,14 @@ func (t *Trace) Start() error { // enrichment is enabled, it is necessary to instruct the provider // to emit stack addresses in the extended data item section when // writing events to the session buffers - if cfg.StackEnrichment && !t.IsSystemProvider() { + if cfg.StackEnrichment && !t.IsSystemProvider() && !t.IsThreadpoolTrace() { return etw.EnableTraceWithOpts(t.GUID, t.startHandle, t.Keywords, etw.EnableTraceOpts{WithStacktrace: true}) } else if cfg.StackEnrichment && len(t.stackExtensions.EventIds()) > 0 { if err := etw.EnableStackTracing(t.startHandle, t.stackExtensions.EventIds()); err != nil { return fmt.Errorf("fail to enable system events callstack tracing: %v", err) } } + if t.IsSystemRegistryTrace() { if err := etw.EnableTrace(t.GUID, t.startHandle, t.Keywords); err != nil { return err @@ -249,6 +256,7 @@ func (t *Trace) Start() error { sysTraceFlags[0] = etw.Registry return etw.SetTraceSystemFlags(handle, sysTraceFlags) } + return etw.EnableTrace(t.GUID, t.startHandle, t.Keywords) } @@ -343,6 +351,9 @@ func (t *Trace) IsKernelTrace() bool { return t.GUID == etw.KernelTraceControlGU // IsSystemRegistryTrace determines if this is the system registry logger trace. func (t *Trace) IsSystemRegistryTrace() bool { return t.GUID == etw.SystemRegistryProviderID } +// IsThreadpoolTrace determines if this is the thread pool logger trace. +func (t *Trace) IsThreadpoolTrace() bool { return t.GUID == etw.ThreadpoolGUID } + // IsSystemProvider determines if this is one of the granular system provider traces. func (t *Trace) IsSystemProvider() bool { return t.GUID == etw.SystemIOProviderID || t.GUID == etw.SystemRegistryProviderID || t.GUID == etw.SystemProcessProviderID || t.GUID == etw.SystemMemoryProviderID diff --git a/pkg/config/config_windows.go b/pkg/config/config_windows.go index 339a6d693..b960f9a76 100644 --- a/pkg/config/config_windows.go +++ b/pkg/config/config_windows.go @@ -411,6 +411,7 @@ func (c *Config) addFlags() { c.flags.Bool(enableMemKevents, true, "Determines whether memory manager kernel events are collected by Kernel Logger provider") c.flags.Bool(enableAuditAPIEvents, true, "Determines whether kernel audit API calls events are published") c.flags.Bool(enableDNSEvents, true, "Determines whether DNS client events are enabled") + c.flags.Bool(enableThreadpoolEvents, true, "Determines whether thread pool events are published") c.flags.Bool(stackEnrichment, true, "Indicates if stack enrichment is enabled for eligible events") c.flags.Int(bufferSize, int(maxBufferSize), "Represents the amount of memory allocated for each event tracing session buffer, in kilobytes. The buffer size affects the rate at which buffers fill and must be flushed (small buffer size requires less memory but it increases the rate at which buffers must be flushed)") c.flags.Int(minBuffers, int(defaultMinBuffers), "Determines the minimum number of buffers allocated for the event tracing session's buffer pool") diff --git a/pkg/config/filters.go b/pkg/config/filters.go index 02a86a145..0a4434dcc 100644 --- a/pkg/config/filters.go +++ b/pkg/config/filters.go @@ -170,19 +170,20 @@ func (ctx *ActionContext) UniquePids() []uint32 { // enabling/disabling event providers/types // dynamically. type RulesCompileResult struct { - HasProcEvents bool - HasThreadEvents bool - HasImageEvents bool - HasFileEvents bool - HasNetworkEvents bool - HasRegistryEvents bool - HasHandleEvents bool - HasMemEvents bool - HasVAMapEvents bool - HasDNSEvents bool - HasAuditAPIEvents bool - UsedEvents []ktypes.Ktype - NumberRules int + HasProcEvents bool + HasThreadEvents bool + HasImageEvents bool + HasFileEvents bool + HasNetworkEvents bool + HasRegistryEvents bool + HasHandleEvents bool + HasMemEvents bool + HasVAMapEvents bool + HasDNSEvents bool + HasAuditAPIEvents bool + HasThreadpoolEvents bool + UsedEvents []ktypes.Ktype + NumberRules int } func (r RulesCompileResult) ContainsEvent(ktype ktypes.Ktype) bool { @@ -217,6 +218,7 @@ func (r RulesCompileResult) String() string { HasVAMapEvents: %t HasAuditAPIEvents: %t HasDNSEvents: %t + HasThreadpoolEvents: %t Events: %s`, r.HasProcEvents, r.HasThreadEvents, @@ -229,6 +231,7 @@ func (r RulesCompileResult) String() string { r.HasVAMapEvents, r.HasAuditAPIEvents, r.HasDNSEvents, + r.HasThreadpoolEvents, strings.Join(events, ", "), ) } diff --git a/pkg/config/kstream.go b/pkg/config/kstream.go index b1e002393..f9fd33019 100644 --- a/pkg/config/kstream.go +++ b/pkg/config/kstream.go @@ -32,21 +32,22 @@ import ( ) const ( - enableThreadKevents = "kstream.enable-thread" - enableRegistryKevents = "kstream.enable-registry" - enableNetKevents = "kstream.enable-net" - enableFileIOKevents = "kstream.enable-fileio" - enableVAMapKevents = "kstream.enable-vamap" - enableImageKevents = "kstream.enable-image" - enableHandleKevents = "kstream.enable-handle" - enableMemKevents = "kstream.enable-mem" - enableAuditAPIEvents = "kstream.enable-audit-api" - enableDNSEvents = "kstream.enable-dns" - stackEnrichment = "kstream.stack-enrichment" - bufferSize = "kstream.buffer-size" - minBuffers = "kstream.min-buffers" - maxBuffers = "kstream.max-buffers" - flushInterval = "kstream.flush-interval" + enableThreadKevents = "kstream.enable-thread" + enableRegistryKevents = "kstream.enable-registry" + enableNetKevents = "kstream.enable-net" + enableFileIOKevents = "kstream.enable-fileio" + enableVAMapKevents = "kstream.enable-vamap" + enableImageKevents = "kstream.enable-image" + enableHandleKevents = "kstream.enable-handle" + enableMemKevents = "kstream.enable-mem" + enableAuditAPIEvents = "kstream.enable-audit-api" + enableDNSEvents = "kstream.enable-dns" + enableThreadpoolEvents = "kstream.enable-threadpool" + stackEnrichment = "kstream.stack-enrichment" + bufferSize = "kstream.buffer-size" + minBuffers = "kstream.min-buffers" + maxBuffers = "kstream.max-buffers" + flushInterval = "kstream.flush-interval" excludedEvents = "kstream.blacklist.events" excludedImages = "kstream.blacklist.images" @@ -82,6 +83,8 @@ type KstreamConfig struct { EnableAuditAPIEvents bool `json:"enable-audit-api" yaml:"enable-audit-api"` // EnableDNSEvents indicates if DNS client events are enabled EnableDNSEvents bool `json:"enable-dns" yaml:"enable-dns"` + // EnableThreadpoolEvents indicates if thread pool events are enabled + EnableThreadpoolEvents bool `json:"enable-threadpool" yaml:"enable-threadpool"` // StackEnrichment indicates if stack enrichment is enabled for eligible events. StackEnrichment bool `json:"stack-enrichment" yaml:"stack-enrichment"` // BufferSize represents the amount of memory allocated for each event tracing session buffer, in kilobytes. @@ -115,6 +118,7 @@ func (c *KstreamConfig) initFromViper(v *viper.Viper) { c.EnableMemKevents = v.GetBool(enableMemKevents) c.EnableAuditAPIEvents = v.GetBool(enableAuditAPIEvents) c.EnableDNSEvents = v.GetBool(enableDNSEvents) + c.EnableThreadpoolEvents = v.GetBool(enableThreadpoolEvents) c.StackEnrichment = v.GetBool(stackEnrichment) c.BufferSize = uint32(v.GetInt(bufferSize)) c.MinBuffers = uint32(v.GetInt(minBuffers)) diff --git a/pkg/config/schema_windows.go b/pkg/config/schema_windows.go index 3ded274df..347e393d2 100644 --- a/pkg/config/schema_windows.go +++ b/pkg/config/schema_windows.go @@ -179,25 +179,26 @@ var schema = ` "kstream": { "type": "object", "properties": { - "enable-thread": {"type": "boolean"}, - "enable-image": {"type": "boolean"}, - "enable-registry": {"type": "boolean"}, - "enable-fileio": {"type": "boolean"}, - "enable-vamap": {"type": "boolean"}, - "enable-handle": {"type": "boolean"}, - "enable-net": {"type": "boolean"}, - "enable-mem": {"type": "boolean"}, - "enable-audit-api": {"type": "boolean"}, - "enable-dns": {"type": "boolean"}, - "stack-enrichment": {"type": "boolean"}, - "min-buffers": {"type": "integer", "minimum": 1, "maximum": {{ .MinBuffers }}}, - "max-buffers": {"type": "integer", "minimum": 2, "maximum": {{ .MaxBuffers }}}, - "buffer-size": {"type": "integer", "maximum": {{ .MaxBufferSize }}}, + "enable-thread": {"type": "boolean"}, + "enable-image": {"type": "boolean"}, + "enable-registry": {"type": "boolean"}, + "enable-fileio": {"type": "boolean"}, + "enable-vamap": {"type": "boolean"}, + "enable-handle": {"type": "boolean"}, + "enable-net": {"type": "boolean"}, + "enable-mem": {"type": "boolean"}, + "enable-audit-api": {"type": "boolean"}, + "enable-dns": {"type": "boolean"}, + "enable-threadpool": {"type": "boolean"}, + "stack-enrichment": {"type": "boolean"}, + "min-buffers": {"type": "integer", "minimum": 1, "maximum": {{ .MinBuffers }}}, + "max-buffers": {"type": "integer", "minimum": 2, "maximum": {{ .MaxBuffers }}}, + "buffer-size": {"type": "integer", "maximum": {{ .MaxBufferSize }}}, "flush-interval": {"type": "string", "minLength": 2, "pattern": "[0-9]+s"}, "blacklist": { "type": "object", "properties": { - "events": {"type": "array", "items": {"type": "string", "enum": ["CreateThread", "TerminateThread", "OpenProcess", "OpenThread", "SetThreadContext", "LoadImage", "UnloadImage", "CreateFile", "CloseFile", "ReadFile", "WriteFile", "DeleteFile", "RenameFile", "SetFileInformation", "EnumDirectory", "MapViewFile", "UnmapViewFile", "RegCreateKey", "RegOpenKey", "RegSetValue", "RegQueryValue", "RegQueryKey", "RegDeleteKey", "RegDeleteValue", "RegCloseKey", "Accept", "Send", "Recv", "Connect", "Disconnect", "Reconnect", "Retransmit", "CreateHandle", "CloseHandle", "DuplicateHandle", "QueryDns", "ReplyDns", "VirtualAlloc", "VirtualFree", "CreateSymbolicLinkObject"]}}, + "events": {"type": "array", "items": {"type": "string", "enum": ["CreateThread", "TerminateThread", "OpenProcess", "OpenThread", "SetThreadContext", "LoadImage", "UnloadImage", "CreateFile", "CloseFile", "ReadFile", "WriteFile", "DeleteFile", "RenameFile", "SetFileInformation", "EnumDirectory", "MapViewFile", "UnmapViewFile", "RegCreateKey", "RegOpenKey", "RegSetValue", "RegQueryValue", "RegQueryKey", "RegDeleteKey", "RegDeleteValue", "RegCloseKey", "Accept", "Send", "Recv", "Connect", "Disconnect", "Reconnect", "Retransmit", "CreateHandle", "CloseHandle", "DuplicateHandle", "QueryDns", "ReplyDns", "VirtualAlloc", "VirtualFree", "CreateSymbolicLinkObject", "SubmitThreadpoolWork", "SubmitThreadpoolCallback", "SetThreadpoolTimer"]}}, "images": {"type": "array", "items": {"type": "string", "minLength": 1}} }, "additionalProperties": false diff --git a/pkg/filter/rules.go b/pkg/filter/rules.go index 65231bfbb..d5eef89e9 100644 --- a/pkg/filter/rules.go +++ b/pkg/filter/rules.go @@ -647,6 +647,8 @@ func (r *Rules) buildCompileResult() *config.RulesCompileResult { rs.HasMemEvents = true case ktypes.Handle: rs.HasHandleEvents = true + case ktypes.Threadpool: + rs.HasThreadpoolEvents = true } if typ == ktypes.MapViewFile || typ == ktypes.UnmapViewFile { rs.HasVAMapEvents = true diff --git a/pkg/kevent/kevent_windows.go b/pkg/kevent/kevent_windows.go index 0d0fdde27..4c800520a 100644 --- a/pkg/kevent/kevent_windows.go +++ b/pkg/kevent/kevent_windows.go @@ -556,6 +556,16 @@ func (e *Kevent) Summary() string { case ktypes.ReplyDNS: dnsName := e.GetParamAsString(kparams.DNSName) return printSummary(e, fmt.Sprintf("received DNS response for %s query", dnsName)) + case ktypes.CreateSymbolicLinkObject: + src := e.GetParamAsString(kparams.LinkSource) + target := e.GetParamAsString(kparams.LinkTarget) + return printSummary(e, fmt.Sprintf("created symbolic link from %s to %s", src, target)) + case ktypes.SubmitThreadpoolWork: + return printSummary(e, "enqueued the work item to the thread pool") + case ktypes.SubmitThreadpoolCallback: + return printSummary(e, "Submitted the thread pool callback for execution within the work item") + case ktypes.SetThreadpoolTimer: + return printSummary(e, "set thread pool timer object") } return "" } diff --git a/pkg/kevent/kparam_windows.go b/pkg/kevent/kparam_windows.go index be298aa55..5907d9224 100644 --- a/pkg/kevent/kparam_windows.go +++ b/pkg/kevent/kparam_windows.go @@ -763,6 +763,30 @@ func (e *Kevent) produceParams(evt *etw.EventRecord) { if evt.HasStackTrace() { e.AppendParam(kparams.Callstack, kparams.Slice, evt.Callstack()) } + case ktypes.SubmitThreadpoolWork, ktypes.SubmitThreadpoolCallback: + poolID := evt.ReadUint64(0) + taskID := evt.ReadUint64(8) + callback := evt.ReadUint64(16) + ctx := evt.ReadUint64(24) + tag := evt.ReadUint64(32) + e.AppendParam(kparams.ThreadpoolPoolID, kparams.Address, poolID) + e.AppendParam(kparams.ThreadpoolTaskID, kparams.Address, taskID) + e.AppendParam(kparams.ThreadpoolCallback, kparams.Address, callback) + e.AppendParam(kparams.ThreadpoolContext, kparams.Address, ctx) + e.AppendParam(kparams.ThreadpoolSubprocessTag, kparams.Address, tag) + case ktypes.SetThreadpoolTimer: + duetime := evt.ReadUint64(0) + subqueue := evt.ReadUint64(8) + timer := evt.ReadUint64(16) + period := evt.ReadUint32(24) + window := evt.ReadUint32(28) + absolute := evt.ReadUint32(32) + e.AppendParam(kparams.ThreadpoolTimerDuetime, kparams.Uint64, duetime) + e.AppendParam(kparams.ThreadpoolTimerSubqueue, kparams.Address, subqueue) + e.AppendParam(kparams.ThreadpoolTimer, kparams.Address, timer) + e.AppendParam(kparams.ThreadpoolTimerPeriod, kparams.Uint32, period) + e.AppendParam(kparams.ThreadpoolTimerWindow, kparams.Uint32, window) + e.AppendParam(kparams.ThreadpoolTimerAbsolute, kparams.Bool, absolute > 0) } } diff --git a/pkg/kevent/kparams/fields_windows.go b/pkg/kevent/kparams/fields_windows.go index c08afb36a..b807181b4 100644 --- a/pkg/kevent/kparams/fields_windows.go +++ b/pkg/kevent/kparams/fields_windows.go @@ -250,4 +250,37 @@ const ( LinkSource = "source" // LinkTarget identifies the parameter that represents the target symbolic link object or other kernel object LinkTarget = "target" + + // ThreadpoolPoolID represents the thread pool identifier. + ThreadpoolPoolID = "pool_id" + // ThreadpoolTaskID represents the thread pool task identifier. + ThreadpoolTaskID = "task_id" + // ThreadpoolCallback represents the address of the callback function. + ThreadpoolCallback = "callback" + // ThreadpoolCallbackSymbol represents the callback symbol. + ThreadpoolCallbackSymbol = "callback_symbol" + // ThreadpoolCallbackModule represents the module containing the callback symbol. + ThreadpoolCallbackModule = "callback_module" + // ThreadpoolContext represents the address of the callback context. + ThreadpoolContext = "context" + // ThreadpoolContextRip represents the value of instruction pointer contained in the callback context. + ThreadpoolContextRip = "context_rip" + // ThreadpoolContextRipSymbol represents the symbol name associated with the instruction pointer in callback context. + ThreadpoolContextRipSymbol = "context_rip_symbol" + // ThreadpoolContextRipModule represents the module name associated with the instruction pointer in callback context. + ThreadpoolContextRipModule = "context_rip_module" + // ThreadpoolSubprocessTag represents the service identifier associated with the thread pool. + ThreadpoolSubprocessTag = "subprocess_tag" + // ThreadpoolTimerDuetime represents the timer due time. + ThreadpoolTimerDuetime = "duetime" + // ThreadpoolTimerSubqueue represents the memory address of the timer subqueue. + ThreadpoolTimerSubqueue = "subqueue" + // ThreadpoolTimer represents the memory address of the timer object. + ThreadpoolTimer = "timer" + // ThreadpoolTimerPeriod represents the period of the timer + ThreadpoolTimerPeriod = "period" + // ThreadpoolTimerWindow represents the timer tolerate period. + ThreadpoolTimerWindow = "window" + // ThreadpoolTimerAbsolute indicates if the timer is absolute or relative. + ThreadpoolTimerAbsolute = "absolute" ) diff --git a/pkg/kevent/ktypes/category.go b/pkg/kevent/ktypes/category.go index 6ce413aeb..08f0c5d52 100644 --- a/pkg/kevent/ktypes/category.go +++ b/pkg/kevent/ktypes/category.go @@ -49,6 +49,8 @@ const ( Mem Category = "mem" // Object the category for object manager events Object Category = "object" + // Threadpool is the category for thread pool events + Threadpool Category = "threadpool" // Other is the category for uncategorized events Other Category = "other" // Unknown is the category for events that couldn't match any of the previous categories @@ -82,5 +84,6 @@ func Categories() []string { string(Other), string(Unknown), string(Object), + string(Threadpool), } } diff --git a/pkg/kevent/ktypes/eventset.go b/pkg/kevent/ktypes/eventset.go index 3cab529ee..a8e62333f 100644 --- a/pkg/kevent/ktypes/eventset.go +++ b/pkg/kevent/ktypes/eventset.go @@ -86,6 +86,8 @@ func (e *EventsetMasks) bitsetIndex(guid windows.GUID) int { return 9 case DNSEventGUID: return 10 + case ThreadpoolGUID: + return 11 default: return -1 } diff --git a/pkg/kevent/ktypes/ktypes_windows.go b/pkg/kevent/ktypes/ktypes_windows.go index 981e9dcb5..20ac8ab01 100644 --- a/pkg/kevent/ktypes/ktypes_windows.go +++ b/pkg/kevent/ktypes/ktypes_windows.go @@ -27,7 +27,7 @@ import ( // ProvidersCount designates the number of interesting providers. // Remember to increment if a new event source is introduced. -const ProvidersCount = 11 +const ProvidersCount = 12 // EventSource is the type that designates the provenance of the event type EventSource uint8 @@ -39,6 +39,8 @@ const ( AuditAPICallsLogger // DNSLogger event is emitted by DNS provider DNSLogger + // ThreadpoolLogger event is emitted by thread pool provider + ThreadpoolLogger ) // Ktype identifies an event type. It comprises the event GUID + hook ID to uniquely identify the event @@ -67,6 +69,8 @@ var ( AuditAPIEventGUID = windows.GUID{Data1: 0xe02a841c, Data2: 0x75a3, Data3: 0x4fa7, Data4: [8]byte{0xaf, 0xc8, 0xae, 0x09, 0xcf, 0x9b, 0x7f, 0x23}} // DNSEventGUID represents DNS provider event GUID DNSEventGUID = windows.GUID{Data1: 0x1c95126e, Data2: 0x7eea, Data3: 0x49a9, Data4: [8]byte{0xa3, 0xfe, 0xa3, 0x78, 0xb0, 0x3d, 0xdb, 0x4d}} + // ThreadpoolGUID represents the thread pool event GUID + ThreadpoolGUID = windows.GUID{Data1: 0xc861d0e2, Data2: 0xa2c1, Data3: 0x4d36, Data4: [8]byte{0x9f, 0x9c, 0x97, 0x0b, 0xab, 0x94, 0x3a, 0x12}} ) var ( @@ -210,6 +214,13 @@ var ( // CreateSymbolicLinkObject represents the event emitted by the object manager when the new symbolic link is created within the object manager directory CreateSymbolicLinkObject = pack(AuditAPIEventGUID, 3) + // SubmitThreadpoolWork represents the event that enqueues the work item to the thread pool + SubmitThreadpoolWork = pack(ThreadpoolGUID, 32) + //SubmitThreadpoolCallback represents the event that submits the thread pool callback for execution within the work item + SubmitThreadpoolCallback = pack(ThreadpoolGUID, 34) + // SetThreadpoolTimer represents the event that sets the thread pool timer object + SetThreadpoolTimer = pack(ThreadpoolGUID, 44) + // UnknownKtype designates unknown kernel event type UnknownKtype = pack(windows.GUID{}, 0) ) @@ -327,6 +338,12 @@ func (k Ktype) String() string { return "StackWalk" case CreateSymbolicLinkObject: return "CreateSymbolicLinkObject" + case SubmitThreadpoolWork: + return "SubmitThreadpoolWork" + case SubmitThreadpoolCallback: + return "SubmitThreadpoolCallback" + case SetThreadpoolTimer: + return "SetThreadpoolTimer" default: return "" } @@ -362,6 +379,8 @@ func (k Ktype) Category() Category { return Mem case CreateSymbolicLinkObject: return Object + case SubmitThreadpoolWork, SubmitThreadpoolCallback, SetThreadpoolTimer: + return Threadpool default: return Unknown } @@ -464,6 +483,12 @@ func (k Ktype) Description() string { return "Receives the response from the DNS server" case CreateSymbolicLinkObject: return "Creates the symbolic link within the object manager directory" + case SubmitThreadpoolWork: + return "Enqueues the work item to the thread pool" + case SubmitThreadpoolCallback: + return "Submits the thread pool callback for execution within the work item" + case SetThreadpoolTimer: + return "Sets the thread pool timer object" default: return "" } @@ -514,7 +539,10 @@ func (k Ktype) CanEnrichStack() bool { RegDeleteValue, DeleteFile, RenameFile, - VirtualAlloc: + VirtualAlloc, + SubmitThreadpoolWork, + SubmitThreadpoolCallback, + SetThreadpoolTimer: return true default: return false @@ -554,6 +582,8 @@ func (k Ktype) Source() EventSource { return AuditAPICallsLogger case QueryDNS, ReplyDNS: return DNSLogger + case SubmitThreadpoolWork, SubmitThreadpoolCallback, SetThreadpoolTimer: + return ThreadpoolLogger default: return SystemLogger } @@ -565,7 +595,8 @@ func (k Ktype) Source() EventSource { // events, but it appears first on the consumer callback // before other events published before it. func (k Ktype) CanArriveOutOfOrder() bool { - return k.Category() == Registry || k.Subcategory() == DNS || k == OpenProcess || k == OpenThread || k == SetThreadContext || k == CreateSymbolicLinkObject + return k.Category() == Registry || k.Category() == Threadpool || k.Subcategory() == DNS || + k == OpenProcess || k == OpenThread || k == SetThreadContext || k == CreateSymbolicLinkObject } // FromParts builds ktype from provider GUID and hook ID. diff --git a/pkg/kevent/ktypes/metainfo_windows.go b/pkg/kevent/ktypes/metainfo_windows.go index cbfcd94dc..db2c7cc22 100644 --- a/pkg/kevent/ktypes/metainfo_windows.go +++ b/pkg/kevent/ktypes/metainfo_windows.go @@ -88,6 +88,9 @@ var kevents = map[Ktype]KeventInfo{ QueryDNS: {"QueryDns", Net, "Sends a DNS query to the name server"}, ReplyDNS: {"ReplyDNS", Net, "Receives the response from the DNS server"}, CreateSymbolicLinkObject: {"CreateSymbolicLinkObject", Object, "Creates the symbolic link within the object manager directory"}, + SubmitThreadpoolWork: {"SubmitThreadpoolWork", Threadpool, "Enqueues the work item to the thread pool"}, + SubmitThreadpoolCallback: {"SubmitThreadpoolCallback", Threadpool, "Submits the thread pool callback for execution within the work item"}, + SetThreadpoolTimer: {"SetThreadpoolTimer", Threadpool, "Sets the thread pool timer object"}, } var ktypes = map[string]Ktype{ @@ -144,6 +147,9 @@ var ktypes = map[string]Ktype{ "QueryDns": QueryDNS, "ReplyDns": ReplyDNS, "CreateSymbolicLinkObject": CreateSymbolicLinkObject, + "SubmitThreadpoolWork": SubmitThreadpoolWork, + "SubmitThreadpoolCallback": SubmitThreadpoolCallback, + "SetThreadpoolTimer": SetThreadpoolTimer, } // indexedKevents keeps the slice of event infos. When the @@ -203,6 +209,9 @@ var indexedKevents = []KeventInfo{ kevents[QueryDNS], kevents[ReplyDNS], kevents[CreateSymbolicLinkObject], + kevents[SubmitThreadpoolWork], + kevents[SubmitThreadpoolCallback], + kevents[SetThreadpoolTimer], } // All returns all event types. diff --git a/pkg/ps/snapshotter.go b/pkg/ps/snapshotter.go index 0dcb60f60..17a93279c 100644 --- a/pkg/ps/snapshotter.go +++ b/pkg/ps/snapshotter.go @@ -43,7 +43,7 @@ type Snapshotter interface { // AddMmap adds a new memory mapping (data memory-mapped file, image, or pagefile) to this process state. AddMmap(*kevent.Kevent) error // RemoveMmap removes memory mapping at the given base address. - RemoveMmap(pid uint32, address va.Address) error + RemoveMmap(pid uint32, addr va.Address) error // WriteFromKcap appends a new process state to the snapshotter from the captured kernel event. WriteFromKcap(kevt *kevent.Kevent) error // Remove deletes process's state from the snapshotter. diff --git a/pkg/ps/snapshotter_windows.go b/pkg/ps/snapshotter_windows.go index 3ee582e65..3db36c9be 100644 --- a/pkg/ps/snapshotter_windows.go +++ b/pkg/ps/snapshotter_windows.go @@ -314,6 +314,7 @@ func (s *snapshotter) newProcState(pid, ppid uint32, e *kevent.Kevent) (*pstypes e.Kparams.MustGetSID(), e.Kparams.MustGetUint32(kparams.SessionID), ) + proc.Parent = s.procs[ppid] proc.StartTime, _ = e.Kparams.GetTime(kparams.StartTime) proc.IsWOW64 = (e.Kparams.MustGetUint32(kparams.ProcessFlags) & kevent.PsWOW64) != 0 diff --git a/pkg/symbolize/symbolizer.go b/pkg/symbolize/symbolizer.go index c0f830ecf..8abdfaf3a 100644 --- a/pkg/symbolize/symbolizer.go +++ b/pkg/symbolize/symbolizer.go @@ -24,11 +24,13 @@ import ( "github.com/rabbitstack/fibratus/pkg/config" "github.com/rabbitstack/fibratus/pkg/kevent" "github.com/rabbitstack/fibratus/pkg/kevent/kparams" + "github.com/rabbitstack/fibratus/pkg/kevent/ktypes" "github.com/rabbitstack/fibratus/pkg/pe" "github.com/rabbitstack/fibratus/pkg/ps" pstypes "github.com/rabbitstack/fibratus/pkg/ps/types" "github.com/rabbitstack/fibratus/pkg/sys" "github.com/rabbitstack/fibratus/pkg/util/convert" + "github.com/rabbitstack/fibratus/pkg/util/threadcontext" "github.com/rabbitstack/fibratus/pkg/util/va" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows" @@ -311,18 +313,66 @@ func (s *Symbolizer) processCallstack(e *kevent.Kevent) error { defer s.mu.Unlock() if e.PS != nil { - // symbolize thread start address - if e.IsCreateThread() && !e.IsSystemPid() { - addr := e.Kparams.TryGetAddress(kparams.StartAddress) + var ( + addr va.Address + pid uint32 + ) + + // get the address that we want to symbolize + switch e.Type { + case ktypes.CreateThread: + pid = e.Kparams.MustGetPid() + addr = e.Kparams.TryGetAddress(kparams.StartAddress) + case ktypes.SubmitThreadpoolWork, ktypes.SubmitThreadpoolCallback: + pid = e.PID + addr = e.Kparams.TryGetAddress(kparams.ThreadpoolCallback) + } + // symbolize thread start or thread pool callback address + // and resolve the module name that contains the function + if addr != 0 { mod := e.PS.FindModuleByVa(addr) - symbol := s.symbolizeAddress(e.Kparams.MustGetPid(), addr, mod) - - if symbol != "" { - e.Kparams.Append(kparams.StartAddressSymbol, kparams.UnicodeString, symbol) + symbol := s.symbolizeAddress(pid, addr, mod) + + if symbol != "" && symbol != "?" { + switch e.Type { + case ktypes.CreateThread: + e.Kparams.Append(kparams.StartAddressSymbol, kparams.UnicodeString, symbol) + case ktypes.SubmitThreadpoolWork, ktypes.SubmitThreadpoolCallback: + e.Kparams.Append(kparams.ThreadpoolCallbackSymbol, kparams.UnicodeString, symbol) + + ctx := e.Kparams.TryGetAddress(kparams.ThreadpoolContext) + + // if the callback resolves to one of the functions + // that receive the CONTEXT structure as a parameter + // try to read the thread context and resolve the + // function address stored in the instruction pointer + if ctx != 0 && threadcontext.IsParamOfFunc(symbol) { + rip := threadcontext.Rip(pid, ctx) + if rip != 0 { + e.Kparams.Append(kparams.ThreadpoolContextRip, kparams.Address, rip.Uint64()) + + m := e.PS.FindModuleByVa(rip) + if m != nil { + e.Kparams.Append(kparams.ThreadpoolContextRipModule, kparams.UnicodeString, m.Name) + } + + sym := s.symbolizeAddress(pid, rip, m) + if sym != "" && sym != "?" { + e.Kparams.Append(kparams.ThreadpoolContextRipSymbol, kparams.UnicodeString, sym) + } + } + } + } } + if mod != nil { - e.Kparams.Append(kparams.StartAddressModule, kparams.UnicodeString, mod.Name) + switch e.Type { + case ktypes.CreateThread: + e.Kparams.Append(kparams.StartAddressModule, kparams.UnicodeString, mod.Name) + case ktypes.SubmitThreadpoolWork, ktypes.SubmitThreadpoolCallback: + e.Kparams.Append(kparams.ThreadpoolCallbackModule, kparams.UnicodeString, mod.Name) + } } } diff --git a/pkg/symbolize/symbolizer_test.go b/pkg/symbolize/symbolizer_test.go index f10066600..e52f6fc5d 100644 --- a/pkg/symbolize/symbolizer_test.go +++ b/pkg/symbolize/symbolizer_test.go @@ -385,6 +385,111 @@ func TestProcessCallstackFullMode(t *testing.T) { assert.Equal(t, 0, s.procsSize()) } +func TestSymbolizeEventParamAddress(t *testing.T) { + r := new(MockResolver) + c := &config.Config{} + + psnap := new(ps.SnapshotterMock) + + opts := uint32(sys.SymUndname | sys.SymCaseInsensitive | sys.SymAutoPublics | sys.SymOmapFindNearest | sys.SymDeferredLoads) + r.On("Initialize", mock.Anything, opts).Return(nil) + r.On("LoadModule", windows.CurrentProcess(), mock.Anything).Return(nil) + + r.On("GetModuleName", mock.Anything, mock.Anything).Return("C:\\WINDOWS\\System32\\KERNEL32.DLL").Once() + r.On("GetModuleName", mock.Anything, mock.Anything).Return("C:\\WINDOWS\\System32\\KERNELBASE.dll").Once() + r.On("GetModuleName", mock.Anything, mock.Anything).Return("C:\\WINDOWS\\System32\\ntdll.dll").Times(3) + + r.On("GetSymbolNameAndOffset", mock.Anything, mock.Anything).Return("CreateProcessW", 0x54).Times(2) + r.On("GetSymbolNameAndOffset", mock.Anything, mock.Anything).Return("CreateProcessW", 0x66).Once() + r.On("GetSymbolNameAndOffset", mock.Anything, mock.Anything).Return("NtCreateProcess", 0x3a2).Once() + r.On("GetSymbolNameAndOffset", mock.Anything, mock.Anything).Return("NtCreateProcessEx", 0x3a2).Times(2) + + r.On("Cleanup", mock.Anything) + + s := NewSymbolizer(r, psnap, c, false) + require.NotNil(t, s) + + parsePeFile = func(name string, option ...pe.Option) (*pe.PE, error) { + exports := map[uint32]string{ + 8192: "RtlSetSearchPathMode", + 9344: "CreateProcessW", + 20352: "LoadKeyboardLayoutW", + } + px := &pe.PE{ + Exports: exports, + } + return px, nil + } + + proc := &pstypes.PS{ + Name: "notepad.exe", + PID: 23234, + Ppid: 2434, + Exe: `C:\Windows\notepad.exe`, + Cmdline: `C:\Windows\notepad.exe`, + SID: "S-1-1-18", + Cwd: `C:\Windows\`, + SessionID: 1, + Threads: map[uint32]pstypes.Thread{ + 3453: {Tid: 3453, StartAddress: va.Address(140729524944768), IOPrio: 2, PagePrio: 5, KstackBase: va.Address(18446677035730165760), KstackLimit: va.Address(18446677035730137088), UstackLimit: va.Address(86376448), UstackBase: va.Address(86372352)}, + 3455: {Tid: 3455, StartAddress: va.Address(140729524944768), IOPrio: 3, PagePrio: 5, KstackBase: va.Address(18446677035730165760), KstackLimit: va.Address(18446677035730137088), UstackLimit: va.Address(86376448), UstackBase: va.Address(86372352)}, + }, + Envs: map[string]string{"ProgramData": "C:\\ProgramData", "COMPUTRENAME": "archrabbit"}, + Modules: []pstypes.Module{ + {Name: "C:\\Windows\\System32\\ntdll.dll", Size: 32358, Checksum: 23123343, BaseAddress: va.Address(0x7ffb313833a3), DefaultBaseAddress: va.Address(0x7ffb313833a3)}, + {Name: "C:\\Windows\\System32\\kernel32.dll", Size: 12354, Checksum: 23123343, BaseAddress: va.Address(0x7ffb5c1d0126), DefaultBaseAddress: va.Address(0x7ffb5c1d0126)}, + {Name: "C:\\Windows\\System32\\user32.dll", Size: 212354, Checksum: 33123343, BaseAddress: va.Address(0x7ffb5d8e11c4), DefaultBaseAddress: va.Address(0x7ffb5d8e11c4)}, + }, + } + e := &kevent.Kevent{ + Type: ktypes.CreateThread, + Tid: 2484, + PID: uint32(os.Getpid()), + CPU: 1, + Seq: 2, + Name: "CreateThread", + Timestamp: time.Now(), + Category: ktypes.Thread, + Host: "archrabbit", + Kparams: kevent.Kparams{ + kparams.Callstack: {Name: kparams.Callstack, Type: kparams.Slice, Value: []va.Address{0x7ffb5c1d0396, 0x7ffb5d8e61f4, 0x7ffb3138592e, 0x7ffb313853b2, 0x2638e59e0a5}}, + kparams.StartAddress: {Name: kparams.StartAddress, Type: kparams.Address, Value: uint64(0x7ffb3138592e)}, + kparams.ProcessID: {Name: kparams.ProcessID, Type: kparams.PID, Value: uint32(os.Getpid())}, + }, + PS: proc, + } + + _, err := s.ProcessEvent(e) + require.NoError(t, err) + + assert.Equal(t, "CreateProcessW", e.GetParamAsString(kparams.StartAddressSymbol)) + assert.Equal(t, "C:\\Windows\\System32\\ntdll.dll", e.GetParamAsString(kparams.StartAddressModule)) + + e1 := &kevent.Kevent{ + Type: ktypes.SubmitThreadpoolCallback, + Tid: 2484, + PID: uint32(os.Getpid()), + CPU: 1, + Seq: 2, + Name: "SubmitThreadpoolCallback", + Timestamp: time.Now(), + Category: ktypes.Threadpool, + Host: "archrabbit", + Kparams: kevent.Kparams{ + kparams.Callstack: {Name: kparams.Callstack, Type: kparams.Slice, Value: []va.Address{0x7ffb5c1d0396}}, + kparams.ThreadpoolCallback: {Name: kparams.ThreadpoolCallback, Type: kparams.Address, Value: uint64(0x7ffb3138592e)}, + kparams.ThreadpoolContext: {Name: kparams.ThreadpoolContext, Type: kparams.Address, Value: uint64(0)}, + }, + PS: proc, + } + + _, err = s.ProcessEvent(e1) + require.NoError(t, err) + + assert.Equal(t, "CreateProcessW", e1.GetParamAsString(kparams.ThreadpoolCallbackSymbol)) + assert.Equal(t, "C:\\Windows\\System32\\ntdll.dll", e1.GetParamAsString(kparams.ThreadpoolCallbackModule)) +} + func init() { procTTL = time.Second } diff --git a/pkg/sys/etw/types.go b/pkg/sys/etw/types.go index b87c50499..04d7863eb 100644 --- a/pkg/sys/etw/types.go +++ b/pkg/sys/etw/types.go @@ -38,9 +38,12 @@ var KernelTraceControlGUID = windows.GUID{Data1: 0x9e814aad, Data2: 0x3204, Data // KernelAuditAPICallsGUID represents the GUID for the kernel audit API provider var KernelAuditAPICallsGUID = windows.GUID{Data1: 0xe02a841c, Data2: 0x75a3, Data3: 0x4fa7, Data4: [8]byte{0xaf, 0xc8, 0xae, 0x09, 0xcf, 0x9b, 0x7f, 0x23}} -// DNSClientGUID presents the GUID for the Windows DNS Client provider +// DNSClientGUID represents the GUID for the Windows DNS Client provider var DNSClientGUID = windows.GUID{Data1: 0x1c95126e, Data2: 0x7eea, Data3: 0x49a9, Data4: [8]byte{0xa3, 0xfe, 0xa3, 0x78, 0xb0, 0x3d, 0xdb, 0x4d}} +// ThreadpoolGUID represents the GUID for the thread pool provider +var ThreadpoolGUID = windows.GUID{Data1: 0xc861d0e2, Data2: 0xa2c1, Data3: 0x4d36, Data4: [8]byte{0x9f, 0x9c, 0x97, 0x0b, 0xab, 0x94, 0x3a, 0x12}} + const ( // TraceStackTracingInfo controls call stack tracing for kernel events TraceStackTracingInfo = uint8(3) @@ -60,15 +63,11 @@ const ( KernelAuditAPICallsSession = "Kernel Audit API Calls Logger" // DNSClientSession represents the session name for the DNS client logger DNSClientSession = "DNS Client Logger" + // ThreadpoolSession represents the session name for the thread pool logger + ThreadpoolSession = "Threadpool Logger" - // SystemProcessSession represents system process provider logger - SystemProcessSession = "System Process Logger" - // SystemIOSession represents system I/O provider logger - SystemIOSession = "System I/O Logger" // SystemRegistrySession represents system registry logger SystemRegistrySession = "System Registry Logger" - // SystemMemorySession represents system memory logger - SystemMemorySession = "System Memory Logger" // WnodeTraceFlagGUID indicates that the structure contains event tracing information WnodeTraceFlagGUID = 0x00020000 diff --git a/pkg/util/threadcontext/context.go b/pkg/util/threadcontext/context.go new file mode 100644 index 000000000..2769b92cf --- /dev/null +++ b/pkg/util/threadcontext/context.go @@ -0,0 +1,104 @@ +/* + * Copyright 2021-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package threadcontext + +import ( + "github.com/rabbitstack/fibratus/pkg/util/va" + "golang.org/x/sys/windows" + "unsafe" +) + +// Context contains processor-specific register data. +type Context struct { + P1 uint64 + P2 uint64 + P3 uint64 + P4 uint64 + P5 uint64 + P6 uint64 + ContextFlags uint32 + MxCsr uint32 + SegCs uint16 + SegDs uint16 + SegEs uint16 + SegFs uint16 + SegGs uint16 + SegSs uint16 + EFlags uint32 + Dr0 uint64 + Dr1 uint64 + Dr2 uint64 + Dr3 uint64 + Dr6 uint64 + Dr7 uint64 + Rax uint64 + Rcx uint64 + Rdx uint64 + Rbx uint64 + Rsp uint64 + Rbp uint64 + Rsi uint64 + Rdi uint64 + R8 uint64 + R9 uint64 + R10 uint64 + R11 uint64 + R12 uint64 + R13 uint64 + R14 uint64 + R15 uint64 + Rip uint64 +} + +// Decode reads the thread context structure from +// the given process memory and at the specified +// base address. Returns the decoded Context struct +// or nil if the data cannot be read from the remote +// process address space. +func Decode(pid uint32, addr va.Address) *Context { + proc, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid) + if err != nil { + return nil + } + defer windows.Close(proc) + + size := uint(unsafe.Sizeof(Context{})) + ctx := va.ReadArea(proc, addr.Uintptr(), size, size, false) + if !va.Zeroed(ctx) { + return (*Context)(unsafe.Pointer(&ctx[0])) + } + + return nil +} + +// Rip returns the address stored in the instruction pointer register. +func Rip(pid uint32, addr va.Address) va.Address { + ctx := Decode(pid, addr) + if ctx != nil { + return va.Address(ctx.Rip) + } + return 0 +} + +// IsParamOfFunc returns true if the CONTEXT +// structure is supplied as a single parameter +// to the well-known API functions. +func IsParamOfFunc(f string) bool { + return f == "NtContinue" || f == "ZwContinue" || f == "RtlCaptureContext" +} diff --git a/pkg/util/threadcontext/context_test.go b/pkg/util/threadcontext/context_test.go new file mode 100644 index 000000000..a0e3f7833 --- /dev/null +++ b/pkg/util/threadcontext/context_test.go @@ -0,0 +1,73 @@ +/* + * Copyright 2021-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package threadcontext + +import ( + "github.com/rabbitstack/fibratus/pkg/util/va" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" + "os" + "testing" + "unsafe" +) + +func TestDecode(t *testing.T) { + ntdll, err := windows.LoadLibrary("kernel32.dll") + require.NoError(t, err) + + fn, err := windows.GetProcAddress(ntdll, "VirtualProtect") + require.NoError(t, err) + + ctx := Context{ + Rip: uint64(fn), + } + + const sz = int(unsafe.Sizeof(Context{})) + b := (*(*[sz]byte)(unsafe.Pointer(&ctx)))[:] + + addr, err := windows.VirtualAlloc(0, uintptr(sz), windows.MEM_COMMIT, windows.PAGE_EXECUTE_READWRITE) + require.NoError(t, err) + + var n uintptr + require.NoError(t, windows.WriteProcessMemory(windows.CurrentProcess(), addr, &b[0], uintptr(sz), &n)) + + c := Decode(uint32(os.Getpid()), va.Address(addr)) + + require.NotNil(t, c) + require.Equal(t, fn, uintptr(c.Rip)) +} + +func TestIsParamOfFunc(t *testing.T) { + var tests = []struct { + f string + ok bool + }{ + {"ZwContinue", true}, + {"RtlCaptureContext", true}, + {"CreateFile", false}, + {"CreateThread", false}, + } + + for _, tt := range tests { + t.Run(tt.f, func(t *testing.T) { + assert.Equal(t, tt.ok, IsParamOfFunc(tt.f)) + }) + } +}