Skip to content

Commit c3b27dd

Browse files
committed
adds tests and bug fixes to the ps1 deobfuscator
1 parent 77cb74c commit c3b27dd

7 files changed

Lines changed: 573 additions & 68 deletions

File tree

refinery/lib/scripts/ps1/deobfuscation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
PowerShell AST deobfuscation transforms.
33
"""
4+
from refinery.lib.scripts.ps1.deobfuscation.constants import Ps1ConstantInlining
45
from refinery.lib.scripts.ps1.deobfuscation.securestring import Ps1SecureStringDecryptor
56
from refinery.lib.scripts.ps1.deobfuscation.simplify import Ps1Simplifications
67
from refinery.lib.scripts.ps1.deobfuscation.strings import Ps1StringOperations
@@ -14,6 +15,7 @@ def deobfuscate(ast: Ps1Script) -> bool:
1415
"""
1516
transformers = [
1617
Ps1Simplifications(),
18+
Ps1ConstantInlining(),
1719
Ps1StringOperations(),
1820
Ps1TypeCasts(),
1921
Ps1SecureStringDecryptor(),
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
"""
2+
Inline constant variable references in PowerShell scripts.
3+
"""
4+
from __future__ import annotations
5+
6+
import copy
7+
8+
from refinery.lib.scripts import Node, Transformer
9+
from refinery.lib.scripts.ps1.model import (
10+
Ps1ArrayExpression,
11+
Ps1ArrayLiteral,
12+
Ps1AssignmentExpression,
13+
Ps1ExpressionStatement,
14+
Ps1ForEachLoop,
15+
Ps1IndexExpression,
16+
Ps1IntegerLiteral,
17+
Ps1ParameterDeclaration,
18+
Ps1Pipeline,
19+
Ps1PipelineElement,
20+
Ps1RealLiteral,
21+
Ps1ScopeModifier,
22+
Ps1StringLiteral,
23+
Ps1TryCatchFinally,
24+
Ps1UnaryExpression,
25+
Ps1Variable,
26+
)
27+
28+
_CONSTANT_TYPES = (Ps1StringLiteral, Ps1IntegerLiteral, Ps1RealLiteral)
29+
30+
31+
def _is_constant(node: Node) -> bool:
32+
if isinstance(node, _CONSTANT_TYPES):
33+
return True
34+
if isinstance(node, Ps1ArrayLiteral):
35+
return all(_is_constant(e) for e in node.elements)
36+
if isinstance(node, Ps1ArrayExpression):
37+
inner = _unwrap_array_expression(node)
38+
if inner is not None:
39+
return _is_constant(inner)
40+
return False
41+
42+
43+
def _unwrap_array_expression(node: Ps1ArrayExpression) -> Ps1ArrayLiteral | None:
44+
"""Unwrap ``@(e1, e2, ...)`` to its inner ``Ps1ArrayLiteral`` if possible."""
45+
if len(node.body) == 1:
46+
stmt = node.body[0]
47+
if isinstance(stmt, Ps1ExpressionStatement) and isinstance(stmt.expression, Ps1ArrayLiteral):
48+
return stmt.expression
49+
return None
50+
51+
52+
def _get_array_literal(node: Node) -> Ps1ArrayLiteral | None:
53+
"""Return the indexable Ps1ArrayLiteral from either a bare literal or @(...)."""
54+
if isinstance(node, Ps1ArrayLiteral):
55+
return node
56+
if isinstance(node, Ps1ArrayExpression):
57+
return _unwrap_array_expression(node)
58+
return None
59+
60+
61+
def _inside_try_body(node: Node) -> bool:
62+
cursor = node.parent
63+
while cursor is not None:
64+
parent = cursor.parent
65+
if isinstance(parent, Ps1TryCatchFinally) and cursor is parent.try_block:
66+
return True
67+
cursor = parent
68+
return False
69+
70+
71+
class Ps1ConstantInlining(Transformer):
72+
73+
def __init__(self, max_inline_length: int = 64):
74+
super().__init__()
75+
self.max_inline_length = max_inline_length
76+
77+
def visit(self, node: Node):
78+
# Phase 1: collect candidates, then phase 2: substitute.
79+
# Only the top-level call triggers the two-phase approach.
80+
candidates = self._collect_candidates(node)
81+
if not candidates:
82+
return None
83+
remaining = self._substitute(node, candidates)
84+
self._remove_dead_assignments(candidates, remaining)
85+
return None
86+
87+
def _collect_candidates(self, root: Node) -> dict[str, tuple[Ps1AssignmentExpression, Node]]:
88+
"""
89+
Returns:
90+
91+
{lower_name: (assignment_node, constant_value)}
92+
93+
for variables assigned exactly once via assignment to a constant expression.
94+
"""
95+
assign_counts: dict[str, int] = {}
96+
assignments: dict[str, tuple[Ps1AssignmentExpression, Node]] = {}
97+
98+
for node in root.walk():
99+
# Explicit assignment: $x = VALUE
100+
if isinstance(node, Ps1AssignmentExpression):
101+
target = node.target
102+
if isinstance(target, Ps1Variable) and target.scope == Ps1ScopeModifier.NONE:
103+
key = target.name.lower()
104+
if node.operator == '=' and node.value is not None and _is_constant(node.value):
105+
assign_counts[key] = assign_counts.get(key, 0) + 1
106+
assignments[key] = (node, node.value)
107+
else:
108+
# Compound assignment or non-constant value
109+
assign_counts[key] = assign_counts.get(key, 0) + 1
110+
111+
# Implicit assignments
112+
elif isinstance(node, Ps1ForEachLoop):
113+
if isinstance(node.variable, Ps1Variable) and node.variable.scope == Ps1ScopeModifier.NONE:
114+
key = node.variable.name.lower()
115+
assign_counts[key] = assign_counts.get(key, 0) + 1
116+
117+
elif isinstance(node, Ps1UnaryExpression):
118+
if node.operator in ('++', '--'):
119+
operand = node.operand
120+
if isinstance(operand, Ps1Variable) and operand.scope == Ps1ScopeModifier.NONE:
121+
key = operand.name.lower()
122+
assign_counts[key] = assign_counts.get(key, 0) + 1
123+
124+
elif isinstance(node, Ps1ParameterDeclaration):
125+
if isinstance(node.variable, Ps1Variable) and node.variable.scope == Ps1ScopeModifier.NONE:
126+
key = node.variable.name.lower()
127+
assign_counts[key] = assign_counts.get(key, 0) + 1
128+
129+
return {
130+
key: val for key, val in assignments.items()
131+
if assign_counts.get(key, 0) == 1
132+
}
133+
134+
def _substitute(
135+
self,
136+
root: Node,
137+
candidates: dict[str, tuple[Ps1AssignmentExpression, Node]],
138+
) -> dict[str, int]:
139+
"""
140+
Inline constant values. Returns:
141+
142+
{lower_name: remaining_ref_count}
143+
144+
for references that could not be substituted.
145+
"""
146+
remaining: dict[str, int] = {}
147+
148+
# Pre-count references to decide whether inlining would bloat the code.
149+
# Variables referenced more than once with long values are kept as-is.
150+
ref_counts: dict[str, int] = {}
151+
for node in root.walk():
152+
if isinstance(node, Ps1IndexExpression):
153+
var = node.object
154+
if isinstance(var, Ps1Variable) and var.scope == Ps1ScopeModifier.NONE:
155+
key = var.name.lower()
156+
if key in candidates:
157+
ref_counts[key] = ref_counts.get(key, 0) + 1
158+
elif isinstance(node, Ps1Variable) and node.scope == Ps1ScopeModifier.NONE:
159+
key = node.name.lower()
160+
if key in candidates:
161+
ref_counts[key] = ref_counts.get(key, 0) + 1
162+
for key, (_, const_value) in candidates.items():
163+
use_count = ref_counts.get(key, 0) - 1 # subtract assignment target
164+
if use_count > 1 and isinstance(const_value, Ps1StringLiteral):
165+
if len(const_value.raw) > self.max_inline_length:
166+
remaining[key] = use_count
167+
168+
handled_vars: set[int] = set()
169+
170+
for node in list(root.walk()):
171+
# Indexed access: $x[2]
172+
if isinstance(node, Ps1IndexExpression):
173+
var = node.object
174+
if isinstance(var, Ps1Variable) and var.scope == Ps1ScopeModifier.NONE:
175+
key = var.name.lower()
176+
info = candidates.get(key)
177+
if info is None:
178+
continue
179+
assign_node, const_value = info
180+
if node is assign_node.target:
181+
handled_vars.add(id(var))
182+
continue
183+
if _inside_try_body(node):
184+
remaining[key] = remaining.get(key, 0) + 1
185+
handled_vars.add(id(var))
186+
continue
187+
array = _get_array_literal(const_value)
188+
if array is None:
189+
remaining[key] = remaining.get(key, 0) + 1
190+
handled_vars.add(id(var))
191+
continue
192+
if not isinstance(node.index, Ps1IntegerLiteral):
193+
remaining[key] = remaining.get(key, 0) + 1
194+
handled_vars.add(id(var))
195+
continue
196+
idx = node.index.value
197+
elements = array.elements
198+
if idx < 0 or idx >= len(elements):
199+
remaining[key] = remaining.get(key, 0) + 1
200+
continue
201+
replacement = copy.deepcopy(elements[idx])
202+
self._replace_in_parent(node, replacement)
203+
self.mark_changed()
204+
handled_vars.add(id(var))
205+
continue
206+
207+
#Simple variable reference: $x
208+
if isinstance(node, Ps1Variable) and node.scope == Ps1ScopeModifier.NONE:
209+
if id(node) in handled_vars:
210+
continue
211+
key = node.name.lower()
212+
info = candidates.get(key)
213+
if info is None:
214+
continue
215+
if key in remaining:
216+
continue
217+
assign_node, const_value = info
218+
# Don't replace the assignment target itself
219+
if node.parent is assign_node and node is assign_node.target:
220+
continue
221+
if _inside_try_body(node):
222+
remaining[key] = remaining.get(key, 0) + 1
223+
continue
224+
replacement = copy.deepcopy(const_value)
225+
self._replace_in_parent(node, replacement)
226+
self.mark_changed()
227+
228+
return remaining
229+
230+
@staticmethod
231+
def _replace_in_parent(old: Node, new: Node):
232+
parent = old.parent
233+
if parent is None:
234+
return
235+
new.parent = parent
236+
for attr_name in vars(parent):
237+
if attr_name in ('parent', 'offset'):
238+
continue
239+
value = getattr(parent, attr_name)
240+
if value is old:
241+
setattr(parent, attr_name, new)
242+
return
243+
if isinstance(value, list):
244+
for i, item in enumerate(value):
245+
if item is old:
246+
value[i] = new
247+
return
248+
if isinstance(item, tuple):
249+
lst = list(item)
250+
for j, elem in enumerate(lst):
251+
if elem is old:
252+
lst[j] = new
253+
value[i] = tuple(lst)
254+
return
255+
256+
def _remove_dead_assignments(
257+
self,
258+
candidates: dict[str, tuple[Ps1AssignmentExpression, Node]],
259+
remaining: dict[str, int],
260+
):
261+
for key, (assign_node, _) in candidates.items():
262+
if remaining.get(key, 0) > 0:
263+
continue
264+
# Find the enclosing statement to remove
265+
stmt = self._find_removable_statement(assign_node)
266+
if stmt is None:
267+
continue
268+
parent = stmt.parent
269+
if parent is None:
270+
continue
271+
for attr_name in vars(parent):
272+
if attr_name in ('parent', 'offset'):
273+
continue
274+
value = getattr(parent, attr_name)
275+
if isinstance(value, list) and stmt in value:
276+
value.remove(stmt)
277+
self.mark_changed()
278+
break
279+
280+
@staticmethod
281+
def _find_removable_statement(assign_node: Ps1AssignmentExpression) -> Node | None:
282+
cursor = assign_node
283+
while cursor.parent is not None:
284+
parent = cursor.parent
285+
if isinstance(parent, Ps1ExpressionStatement):
286+
cursor = parent
287+
continue
288+
if isinstance(parent, Ps1PipelineElement):
289+
cursor = parent
290+
continue
291+
if isinstance(parent, Ps1Pipeline):
292+
if len(parent.elements) == 1:
293+
cursor = parent
294+
continue
295+
# cursor is the statement to remove from parent's body list
296+
return cursor
297+
return None

refinery/lib/scripts/ps1/deobfuscation/strings.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
Ps1IntegerLiteral,
2828
Ps1InvokeMember,
2929
Ps1MemberAccess,
30+
Ps1ParenExpression,
3031
Ps1StringLiteral,
3132
Ps1TypeExpression,
33+
Ps1UnaryExpression,
3234
)
3335

3436
_ENCODING_MAP = {
@@ -76,8 +78,31 @@ def _is_static_encoding_chain(node: Ps1InvokeMember) -> tuple[str, bool] | None:
7678
return encoding_name, True
7779

7880

81+
def _unwrap_to_array_literal(node: Expression) -> Ps1ArrayLiteral | None:
82+
"""
83+
Unwrap parentheses to find an inner Ps1ArrayLiteral.
84+
"""
85+
while isinstance(node, Ps1ParenExpression) and node.expression is not None:
86+
node = node.expression
87+
if isinstance(node, Ps1ArrayLiteral):
88+
return node
89+
return None
90+
91+
7992
class Ps1StringOperations(Transformer):
8093

94+
def visit_Ps1UnaryExpression(self, node: Ps1UnaryExpression):
95+
self.generic_visit(node)
96+
if node.operator.lower() != '-join' or node.operand is None:
97+
return None
98+
array = _unwrap_to_array_literal(node.operand)
99+
if array is None:
100+
return None
101+
args = _collect_string_arguments(array)
102+
if args is None:
103+
return None
104+
return _make_string_literal(''.join(args))
105+
81106
def visit_Ps1InvokeMember(self, node: Ps1InvokeMember):
82107
self.generic_visit(node)
83108
if isinstance(node.member, Ps1StringLiteral):
@@ -160,6 +185,8 @@ def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
160185
return self._handle_format(node)
161186
if op == '+':
162187
return self._handle_concat(node)
188+
if op == '-join':
189+
return self._handle_binary_join(node)
163190
if op in ('-replace', '-creplace', '-ireplace'):
164191
return self._handle_binary_replace(node, op)
165192
return None
@@ -193,6 +220,18 @@ def _handle_concat(self, node: Ps1BinaryExpression) -> Expression | None:
193220
return node.left
194221
return None
195222

223+
def _handle_binary_join(self, node: Ps1BinaryExpression) -> Expression | None:
224+
separator = _string_value(node.right) if node.right else None
225+
if separator is None or node.left is None:
226+
return None
227+
array = _unwrap_to_array_literal(node.left)
228+
if array is None:
229+
return None
230+
args = _collect_string_arguments(array)
231+
if args is None:
232+
return None
233+
return _make_string_literal(separator.join(args))
234+
196235
def _handle_binary_replace(
197236
self, node: Ps1BinaryExpression, op: str,
198237
) -> Expression | None:

0 commit comments

Comments
 (0)