Skip to content

Commit 2625015

Browse files
gh-34: Add LAMBDA.
1 parent cbda87d commit 2625015

File tree

7 files changed

+178
-66
lines changed

7 files changed

+178
-66
lines changed

src/ast.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,17 @@ Expr* expr_wildcard(int line, int column) {
126126
return expr;
127127
}
128128

129+
Expr* expr_lambda(ParamList params, DeclType return_type, Stmt* body, int line, int column) {
130+
Expr* expr = ast_alloc(sizeof(Expr));
131+
expr->type = EXPR_LAMBDA;
132+
expr->line = line;
133+
expr->column = column;
134+
expr->as.lambda.params = params;
135+
expr->as.lambda.return_type = return_type;
136+
expr->as.lambda.body = body;
137+
return expr;
138+
}
139+
129140
Expr* expr_async(Stmt* block, int line, int column) {
130141
Expr* expr = ast_alloc(sizeof(Expr));
131142
expr->type = EXPR_ASYNC;
@@ -403,6 +414,14 @@ void free_expr(Expr* expr) {
403414
free_expr(expr->as.range.start);
404415
free_expr(expr->as.range.end);
405416
break;
417+
case EXPR_LAMBDA:
418+
for (size_t i = 0; i < expr->as.lambda.params.count; i++) {
419+
free(expr->as.lambda.params.items[i].name);
420+
free_expr(expr->as.lambda.params.items[i].default_value);
421+
}
422+
free(expr->as.lambda.params.items);
423+
free_stmt(expr->as.lambda.body);
424+
break;
406425
case EXPR_IDENT:
407426
free(expr->as.ident);
408427
break;

src/ast.h

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ typedef enum {
1717
typedef struct Expr Expr;
1818
typedef struct Stmt Stmt;
1919

20+
typedef struct Param {
21+
DeclType type;
22+
char* name;
23+
Expr* default_value; // optional
24+
} Param;
25+
26+
typedef struct ParamList {
27+
Param* items;
28+
size_t count;
29+
size_t capacity;
30+
} ParamList;
31+
2032
typedef enum {
2133
EXPR_INT,
2234
EXPR_FLT,
@@ -29,7 +41,8 @@ typedef enum {
2941
EXPR_MAP,
3042
EXPR_INDEX,
3143
EXPR_RANGE,
32-
EXPR_WILDCARD
44+
EXPR_WILDCARD,
45+
EXPR_LAMBDA
3346
} ExprType;
3447

3548
typedef struct {
@@ -69,6 +82,11 @@ struct Expr {
6982
ExprList keys;
7083
ExprList values;
7184
} map_items;
85+
struct {
86+
ParamList params;
87+
DeclType return_type;
88+
Stmt* body;
89+
} lambda;
7290
ExprList tns_items;
7391
} as;
7492
};
@@ -100,18 +118,6 @@ typedef struct {
100118
size_t capacity;
101119
} StmtList;
102120

103-
typedef struct {
104-
DeclType type;
105-
char* name;
106-
Expr* default_value; // optional
107-
} Param;
108-
109-
typedef struct {
110-
Param* items;
111-
size_t count;
112-
size_t capacity;
113-
} ParamList;
114-
115121
struct Stmt {
116122
StmtType type;
117123
int line;
@@ -157,6 +163,7 @@ Expr* expr_map(int line, int column);
157163
Expr* expr_index(Expr* target, int line, int column);
158164
Expr* expr_range(Expr* start, Expr* end, int line, int column);
159165
Expr* expr_wildcard(int line, int column);
166+
Expr* expr_lambda(ParamList params, DeclType return_type, Stmt* body, int line, int column);
160167
void expr_list_add(ExprList* list, Expr* expr);
161168

162169
Stmt* stmt_block(int line, int column);

src/env.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,18 @@ static void* env_alloc(size_t size) {
104104
Env* env_create(Env* parent) {
105105
Env* env = env_alloc(sizeof(Env));
106106
env->parent = parent;
107+
env->refcount = 1;
107108
return env;
108109
}
109110

111+
void env_retain(Env* env) {
112+
if (!env) return;
113+
env->refcount++;
114+
}
115+
110116
void env_free(Env* env) {
111117
if (!env) return;
118+
if (--env->refcount > 0) return;
112119
for (size_t i = 0; i < env->count; i++) {
113120
free(env->entries[i].name);
114121
if (env->entries[i].initialized) {

src/env.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ typedef struct Env {
1919
EnvEntry* entries;
2020
size_t count;
2121
size_t capacity;
22+
int refcount;
2223
} Env;
2324

2425
Env* env_create(Env* parent);
26+
void env_retain(Env* env);
2527
void env_free(Env* env);
2628

2729
bool env_define(Env* env, const char* name, DeclType type);

src/interpreter.c

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,42 @@ static void* safe_malloc(size_t size) {
358358
return ptr;
359359
}
360360

361+
static Func* create_runtime_function(const char* name,
362+
DeclType return_type,
363+
ParamList* src_params,
364+
Stmt* body,
365+
Env* closure) {
366+
Func* f = safe_malloc(sizeof(Func));
367+
f->name = name ? strdup(name) : NULL;
368+
f->return_type = return_type;
369+
f->body = body;
370+
f->params.count = src_params ? src_params->count : 0;
371+
f->params.items = NULL;
372+
f->params.capacity = src_params ? src_params->capacity : 0;
373+
if (src_params && src_params->count > 0) {
374+
f->params.items = safe_malloc(sizeof(Param) * src_params->count);
375+
for (size_t i = 0; i < src_params->count; i++) {
376+
f->params.items[i].type = src_params->items[i].type;
377+
f->params.items[i].name = strdup(src_params->items[i].name);
378+
f->params.items[i].default_value = src_params->items[i].default_value;
379+
}
380+
}
381+
f->closure = closure;
382+
env_retain(f->closure);
383+
return f;
384+
}
385+
386+
static void free_runtime_function(Func* f) {
387+
if (!f) return;
388+
if (f->name) free(f->name);
389+
for (size_t i = 0; i < f->params.count; i++) {
390+
free(f->params.items[i].name);
391+
}
392+
free(f->params.items);
393+
env_free(f->closure);
394+
free(f);
395+
}
396+
361397
static int builtin_param_index(BuiltinFunction* builtin, const char* kw) {
362398
if (!builtin || !kw) return -1;
363399
if (!builtin->param_names || builtin->param_count <= 0) return -1;
@@ -666,6 +702,15 @@ Value eval_expr(Interpreter* interp, Expr* expr, Env* env) {
666702
}
667703
return v;
668704
}
705+
706+
case EXPR_LAMBDA: {
707+
Func* f = create_runtime_function(NULL,
708+
expr->as.lambda.return_type,
709+
&expr->as.lambda.params,
710+
expr->as.lambda.body,
711+
env);
712+
return value_func(f);
713+
}
669714

670715
case EXPR_CALL: {
671716
// Get the callee
@@ -1927,40 +1972,20 @@ static ExecResult exec_stmt(Interpreter* interp, Stmt* stmt, Env* env, LabelMap*
19271972

19281973
case STMT_FUNC: {
19291974
// Register user-defined function in the interpreter
1930-
Func* f = safe_malloc(sizeof(Func));
1931-
f->name = strdup(stmt->as.func_stmt.name);
1932-
f->return_type = stmt->as.func_stmt.return_type;
1933-
f->body = stmt->as.func_stmt.body;
1934-
// Copy parameters
1935-
ParamList* src = &stmt->as.func_stmt.params;
1936-
f->params.count = src->count;
1937-
f->params.items = NULL;
1938-
f->params.capacity = src->capacity;
1939-
if (src->count > 0) {
1940-
f->params.items = safe_malloc(sizeof(Param) * src->count);
1941-
for (size_t i = 0; i < src->count; i++) {
1942-
f->params.items[i].type = src->items[i].type;
1943-
f->params.items[i].name = strdup(src->items[i].name);
1944-
f->params.items[i].default_value = src->items[i].default_value; // share AST node
1945-
}
1946-
}
1947-
// Closure is current environment
1948-
f->closure = env;
1975+
Func* f = create_runtime_function(stmt->as.func_stmt.name,
1976+
stmt->as.func_stmt.return_type,
1977+
&stmt->as.func_stmt.params,
1978+
stmt->as.func_stmt.body,
1979+
env);
19491980

19501981
if (builtin_lookup(f->name)) {
1951-
free(f->name);
1952-
for (size_t i = 0; i < f->params.count; i++) free(f->params.items[i].name);
1953-
free(f->params.items);
1954-
free(f);
1982+
free_runtime_function(f);
19551983
return make_error("Function name conflicts with built-in", stmt->line, stmt->column);
19561984
}
19571985

19581986
EnvEntry* prior = env_get_entry(env, f->name);
19591987
if (prior && prior->decl_type != TYPE_FUNC) {
1960-
free(f->name);
1961-
for (size_t i = 0; i < f->params.count; i++) free(f->params.items[i].name);
1962-
free(f->params.items);
1963-
free(f);
1988+
free_runtime_function(f);
19641989
return make_error("Function name conflicts with existing symbol", stmt->line, stmt->column);
19651990
}
19661991

src/parser.c

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ static Expr* parse_expression(Parser* parser);
7575
static Stmt* parse_statement(Parser* parser);
7676
static Stmt* parse_block(Parser* parser);
7777

78+
static bool is_type_token(PTokenType type) {
79+
return type == TOKEN_IDENT || type == TOKEN_FUNC || type == TOKEN_THR;
80+
}
81+
82+
static bool parse_param_list(Parser* parser, ParamList* params) {
83+
if (parser->current_token.type == TOKEN_RPAREN) return true;
84+
do {
85+
if (!is_type_token(parser->current_token.type)) {
86+
report_error(parser, "Expected parameter type");
87+
return false;
88+
}
89+
DeclType ptype = parse_type_name(parser->current_token.literal);
90+
advance(parser);
91+
consume(parser, TOKEN_COLON, "Expected ':' after parameter type");
92+
if (parser->current_token.type != TOKEN_IDENT) {
93+
report_error(parser, "Expected parameter name");
94+
return false;
95+
}
96+
Param param;
97+
param.type = ptype;
98+
param.name = parser->current_token.literal;
99+
param.default_value = NULL;
100+
advance(parser);
101+
if (match(parser, TOKEN_EQUALS)) {
102+
param.default_value = parse_expression(parser);
103+
if (!param.default_value) return false;
104+
}
105+
param_list_add(params, param);
106+
} while (match(parser, TOKEN_COMMA));
107+
return true;
108+
}
109+
78110
static Expr* parse_primary(Parser* parser) {
79111
Token token = parser->current_token;
80112
if (parser->current_token.type == TOKEN_ASYNC) {
@@ -125,6 +157,24 @@ static Expr* parse_primary(Parser* parser) {
125157
advance(parser);
126158
return expr_ptr(id.literal, id.line, id.column);
127159
}
160+
if (match(parser, TOKEN_LAMBDA)) {
161+
Token lambda_tok = token;
162+
consume(parser, TOKEN_LPAREN, "Expected '(' after LAMBDA");
163+
164+
ParamList params = {0};
165+
if (!parse_param_list(parser, &params)) return NULL;
166+
167+
consume(parser, TOKEN_RPAREN, "Expected ')' after parameters");
168+
consume(parser, TOKEN_COLON, "Expected ':' before return type");
169+
if (!is_type_token(parser->current_token.type)) {
170+
report_error(parser, "Expected return type");
171+
return NULL;
172+
}
173+
DeclType ret = parse_type_name(parser->current_token.literal);
174+
advance(parser);
175+
Stmt* body = parse_block(parser);
176+
return expr_lambda(params, ret, body, lambda_tok.line, lambda_tok.column);
177+
}
128178
if (match(parser, TOKEN_IDENT)) {
129179
return expr_ident(token.literal, token.line, token.column);
130180
}
@@ -419,34 +469,10 @@ static Stmt* parse_func(Parser* parser) {
419469
consume(parser, TOKEN_LPAREN, "Expected '(' after function name");
420470

421471
ParamList params = {0};
422-
if (parser->current_token.type != TOKEN_RPAREN) {
423-
do {
424-
if (parser->current_token.type != TOKEN_IDENT) {
425-
report_error(parser, "Expected parameter type");
426-
break;
427-
}
428-
DeclType ptype = parse_type_name(parser->current_token.literal);
429-
advance(parser);
430-
consume(parser, TOKEN_COLON, "Expected ':' after parameter type");
431-
if (parser->current_token.type != TOKEN_IDENT) {
432-
report_error(parser, "Expected parameter name");
433-
break;
434-
}
435-
Param param;
436-
param.type = ptype;
437-
param.name = parser->current_token.literal;
438-
param.default_value = NULL;
439-
advance(parser);
440-
if (match(parser, TOKEN_EQUALS)) {
441-
param.default_value = parse_expression(parser);
442-
if (!param.default_value) return NULL;
443-
}
444-
param_list_add(&params, param);
445-
} while (match(parser, TOKEN_COMMA));
446-
}
472+
if (!parse_param_list(parser, &params)) return NULL;
447473
consume(parser, TOKEN_RPAREN, "Expected ')' after parameters");
448474
consume(parser, TOKEN_COLON, "Expected ':' before return type");
449-
if (!(parser->current_token.type == TOKEN_IDENT || parser->current_token.type == TOKEN_FUNC || parser->current_token.type == TOKEN_THR)) {
475+
if (!is_type_token(parser->current_token.type)) {
450476
report_error(parser, "Expected return type");
451477
return NULL;
452478
}

tests/test2.pre

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,32 @@ FUNC GREET(STR: name, STR: prefix = "Hello"):STR{
604604
}
605605
ASSERT(EQ(GREET("World"), "Hello World"))
606606
ASSERT(EQ(GREET("World", prefix = "Hi"), "Hi World"))
607+
608+
INT: base = 1
609+
FUNC: add_base = LAMBDA(INT: x):INT{ RETURN(ADD(x, base)) }
610+
ASSERT(EQ(add_base(1), 10))
611+
base = 10
612+
ASSERT(EQ(add_base(1), 11))
613+
614+
FUNC: greet_lambda = LAMBDA(STR: name, STR: prefix = "Hi"):STR{ RETURN(JOIN(" ", prefix, name)) }
615+
ASSERT(EQ(greet_lambda("World"), "Hi World"))
616+
ASSERT(EQ(greet_lambda("World", prefix = "Hello"), "Hello World"))
617+
618+
FUNC MAKE_ADDER(INT: seed):FUNC{
619+
INT: local = seed
620+
RETURN(LAMBDA(INT: delta):INT{ RETURN(ADD(local, delta)) })
621+
}
622+
FUNC: escaped_lambda = MAKE_ADDER(10)
623+
ASSERT(EQ(escaped_lambda(1), 11))
624+
625+
TNS: func_slots = [add_base]
626+
ASSERT(EQ(func_slots[1](1), 11))
627+
DEL(func_slots)
628+
DEL(greet_lambda)
629+
DEL(escaped_lambda)
630+
DEL(MAKE_ADDER)
631+
DEL(add_base)
632+
DEL(base)
607633
PRINT("Functions: PASS\n")
608634

609635
PRINT("Testing SIGNATURE...")

0 commit comments

Comments
 (0)